# Multimodal

## Setup

### Imports

In [1]:
# Imports
import os
import numpy as np
import json
import torch
from src.utils import *
from src.models import *
from src.process_reports import *
from src.train import train_mm, kfold_cv

%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


### Set seed & device

In [2]:
set_seed(42)      
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# set torch matmul precision
torch.set_float32_matmul_precision('medium')
# set tokenizers parallelism to false
os.environ['TOKENIZERS_PARALLELISM'] = 'false'
# set numpy float precision
np.set_printoptions(precision=3, suppress=True)

## Process reports

### Summarize reports

In [7]:
# remove invalid reports
reports_dir = 'data/tcga_crc/reports/'
for report in os.listdir(reports_dir):
    with open(reports_dir + report, 'r') as f:
        report_text = f.read()
        if len(report_text) < 20:
            print(f'invalid report: {report}\ncontent: {report_text}\ndeleting...\n')
            os.remove(reports_dir + report)

In [None]:
# summarize report using gpt-3
summarize_reports(reports_dir='data/tcga_crc/reports', reports_sum_dir='data/tcga_crc/reports_sum')

### Extract text feats from reports

In [3]:
# extract text feats from (summarized) reports
extract_text_feats(reports_dir='data/tcga_crc/reports_sum', report_feats_dir='data/tcga_crc/report_feats')

100%|██████████| 607/607 [04:33<00:00,  2.22it/s]


## Extract images feats from WSIs

In [None]:
slides_root_dir = '/mnt/disks/ext/data/tcga/brca/' # root dir of raw slides
slides_manifest_path = 'data/tcga_brca/wsi_feats_manifest.txt' # path to slides manifest
output_dir = 'data/tcga_brca/wsi_feats' # output dir for wsi feats

# run script to extract wsi feats
!python src/extract_wsi_feats.py \
    --slides-root-dir $slides_root_dir \
    --slides-manifest-path $slides_manifest_path \
    --output-dir $output_dir

## Task: Predict target from WSIs & reports

### Set params

In [8]:
subtype = 'tcga_crc'
wsi_feats_dir = f'data/{subtype}/wsi_feats'
report_feats_dir = f'data/{subtype}/report_feats'
target = 'msi'  # target, either 'msi' or 'stils'
mode = 'img'  # input modalities: 'text', 'img', or 'mm'
data_file = f'data/{subtype}/{subtype}_{target}.csv'
split = 'def'   # dataset split: 'def' or 'rand'
bsz = 128 # batch size for dataloaders
resample = True # resample data to balance classes
num_classes = 2 # # classes for classification
class_weights = [1, 1] # class weights for loss function
metrics = ['acc', 'bal_acc', 'f1', 'auroc', 'acc_per_class', 'f1_per_class'] # metrics to track

### Single-fold

In [15]:
# create dataloaders
train_loader, val_loader, test_loader = create_dataloaders(target, data_file, wsi_feats_dir, report_feats_dir, use_rand_splits=split=='rand', bsz=bsz, resample=resample)

# init model
model = Attention1DClassifier(target=target, mode=mode, num_classes=num_classes, class_weights=class_weights, metrics=metrics) # model architecture: 'Attention1DRegressor' or 'Attention1DClassifier'

# set training args
args = {'num_epochs': 100, 'ckpt_name': f'ckpt_best_{mode}_{split}_split', 'resume_ckpt': None, 'tblog_name': f'best_{mode}_{split}_split'}

# train model
model, trainer = train_mm(model, train_loader, val_loader, args)

# evaluate the trained model on the test set
trainer.test(model, test_loader)

size of train set: 376, val set: 53, test set: 109
# samples for each class in train set: (array([0, 1]), array([326,  50]))
# samples for each class in val set: (array([0, 1]), array([46,  7]))
# samples for each class in test set: (array([0, 1]), array([94, 15]))


### K-fold CV

In [None]:
# create dataset
dataset = MMDataset(target, data_file, wsi_feats_dir, report_feats_dir)
# model class, either 'Attention1DRegressor' or 'Attention1DClassifier'
model_class = Attention1DClassifier
# model args
model_args = {'mode': mode, 'target': target, 'num_classes': num_classes, 'class_weights': class_weights, 'metrics': metrics}

# train args
train_args = {'bsz': bsz, 'k': 5, 'rand_seed': 42, 'resample': resample, 'num_epochs': 100, 'patience': 10, 'save_top_k': 0, 'tblog_name': f'best_{mode}_kfold', 'enable_progress_bar': False}

# run k-fold CV
res_kfold_cv = kfold_cv(model_class, dataset, model_args, train_args)

In [None]:
# compute mean & std over k folds
# load results
res_kfold_path = 'outputs/Attention1DClassifier/msi/kfold_img.json'
with open(res_kfold_path, 'r') as f:
    res_kfold_cv = json.load(f)['results']

# compute variance over k folds
res_std = {metric: np.std([res_kfold_cv[i][metric] for i in res_kfold_cv.keys()]) for metric in res_kfold_cv['0'].keys()}

# print avg results over k folds
print(f"avg res over {train_args['k']} folds:")
for metric in res_kfold_cv['avg'].keys():
    print(f"{metric.replace('_epoch', '')}: {res_kfold_cv['avg'][metric]:.3f} +- {res_std[metric]:.3f}")