# Multimodal

## Setup

### Imports

In [1]:
# Imports
import os
import numpy as np
import pandas as pd
import random
from tqdm.notebook import tqdm
import torch
from transformers import AutoTokenizer, AutoModel, AutoModelForSequenceClassification, pipeline
from torch.utils.data import Dataset, DataLoader
from src.utils import *
from src.models import MLPSTILClassifier, Attention1DSTILClassifier, Attention2DSTILClassifier, MLPSTILRegressor, Attention1DSTILRegressor, Attention2DSTILRegressor
from src.train import train_mm_stil, kfold_cv

%load_ext autoreload
%autoreload 2

### Set seed & device

In [2]:
set_seed(42)      
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

### Extract text feats from reports

In [3]:
%%script false --no-raise-error
# load biobert model & tokenizer
tokenizer = AutoTokenizer.from_pretrained(''dmis-lab/biobert-large-cased-v1.1-mnli)
lm = AutoModel.from_pretrained('dmis-lab/biobert-large-cased-v1.1-mnli')
data_dir = '/mnt/disks/ext/data/gdc/tcga/brca'
output_dir = 'data/report_feats'
extract_text_features(lm, tokenizer, data_dir)

## Task: Predict sTILs from WSIs & reports

### Set hparams

In [3]:
bsz = 16 # batch size for dataloaders
img_channels_in = 2048  # emb dim of wsi feats
text_channels_in = 1024 # emb dim of report feats
hidden_dim_mlp = 16 # hidden layer dim of mlp
num_classes = 10 # number of stil levels

### Load data

In [4]:
# Create the dataset
root_dir = './'
data_file = 'data/stils/data_stils.csv'
use_rand_splits = False # use random splits or predefined splits
train_data = MMDataset(root_dir, data_file, 'train', use_rand_splits)
val_data = MMDataset(root_dir, data_file, 'val', use_rand_splits)
test_data = MMDataset(root_dir, data_file, 'test', use_rand_splits)

print(f'size of train set: {len(train_data)}, val set: {len(val_data)}, test set: {len(test_data)}')

# Create the dataloaders
train_loader = DataLoader(train_data, batch_size=bsz, shuffle=True, num_workers=12, collate_fn=MMDataset.mm_collate_fn)
val_loader = DataLoader(val_data, batch_size=bsz, shuffle=False, num_workers=12, collate_fn=MMDataset.mm_collate_fn)
test_loader = DataLoader(test_data, batch_size=bsz, shuffle=False, num_workers=12, collate_fn=MMDataset.mm_collate_fn)

size of train set: 468, val set: 58, test set: 170


### Train & eval

In [None]:
# %%script false --no-raise-error
# init model (one of: MLPSTILClassifier, MLPSTILRegressor, Attention1DSTILClassifier, Attention1DSTILRegressor, Attention2DSTILClassifier, Attention2DSTILRegressor) 
model = Attention1DSTILRegressor(mode='image')

# set training args
num_epochs = 50
resume_ckpt = None
args = {'num_epochs': num_epochs, 'ckpt_name': 'ckpt_best_img'}

# train model
model, trainer = train_mm_stil(model, train_loader, val_loader, args)

# evaluate the trained model on the test set
trainer.test(model, test_loader)

### K-fold CV

In [None]:
# %%script false --no-raise-error
# run k-fold CV
dataset = MMDataset()

model_class = Attention1DSTILRegressor
model_args = {'mode': 'image'}
train_args = {'k': 5, 'num_epochs': 50, 'ckpt_name': 'ckpt_best_img_kfold_cv'}

res_kfold_cv = kfold_cv(model_class, dataset, model_args, train_args)
metrics = ['test_loss', 'test_corr', 'test_r2']
avg_res = {k: np.mean([res[k] for res in res_kfold_cv]).round(3) for k in metrics}
print(f"avg res over {train_args['k']} folds: {avg_res}")

## Task: Predict cancer subtypes from multimodal data

### Generate cancer subtype annotations from path reports

In [6]:
lm_name = 'facebook/bart-large-mnli'
reports_dir = 'data/reports_distilled_sample'
df_subtypes_grades = classify_subtype_grade_zs(lm_name, reports_dir)
df_subtypes_grades

Unnamed: 0_level_0,region,localization,grade
case_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
TCGA-A2-A0CZ,metastatic,metastatic,grade 2
TCGA-A2-A0D0,other,other,
TCGA-A2-A0D1,other,in-situ,grade 3
TCGA-A2-A0D2,other,invasive/infiltrating,grade 2
TCGA-A2-A0D3,ductal/intraductal,invasive/infiltrating,grade 3
TCGA-A2-A0D4,metastatic,invasive/infiltrating,grade 2
TCGA-A2-A0EM,other,other,
TCGA-A2-A0EN,lobular,invasive/infiltrating,grade 2
TCGA-A2-A0EO,ductal/intraductal,in-situ,grade 1
TCGA-A2-A0EP,lobular,invasive/infiltrating,grade 3
