# Multimodal

## Setup

### Imports

In [1]:
# Imports
import os
import numpy as np
import pandas as pd
import random
from dotenv import load_dotenv
from tqdm.notebook import tqdm
import torch
from transformers import AutoTokenizer, AutoModel, AutoModelForSequenceClassification, pipeline
from torch.utils.data import Dataset, DataLoader
from src.utils import *
from src.models import *
from src.process_reports import *
from src.train import train_mm, kfold_cv

%load_ext autoreload
%autoreload 2

### Set seed & device

In [12]:
set_seed(42)      
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# set torch matmul precision
torch.set_float32_matmul_precision('medium')
# set tokenizers parallelism to false
os.environ['TOKENIZERS_PARALLELISM'] = 'false'
# set numpy float precision
np.set_printoptions(precision=3, suppress=True)

## Process reports

### Summarize reports

In [7]:
%%script false --no-raise-error
# remove invalid reports
reports_dir = 'data/tcga_crc/reports/'
for report in os.listdir(reports_dir):
    with open(reports_dir + report, 'r') as f:
        report_text = f.read()
        if len(report_text) < 20:
            print(f'invalid report: {report}\ncontent: {report_text}\ndeleting...\n')
            os.remove(reports_dir + report)

In [None]:
%%script false --no-raise-error
# summarize report using gpt-3
summarize_reports(reports_dir='data/tcga_crc/reports', reports_sum_dir='data/tcga_crc/reports_sum')

### Extract text feats from reports

In [3]:
# %%script false --no-raise-error
# extract text feats from (summarized) reports
extract_text_feats(reports_dir='data/tcga_crc/reports_sum', report_feats_dir='data/tcga_crc/report_feats')

100%|██████████| 607/607 [04:33<00:00,  2.22it/s]


### Annotate subtype & grade from path reports

In [None]:
%%script false --no-raise-error
lm_name = 'gpt-3.5-turbo'

# sample_report_path = 'data/reports_distilled/TCGA-WT-AB41.txt'
reports_dir = 'data/reports_distilled'
# create prompt
# prompt = create_zs_prompt(sample_report_path)

# args for generation
gen_args = {'max_tokens': 200}

# out = gen_subtype_grade_zs(lm_name, prompt, api='openai', args=gen_args)
df_res = classify_reports_zs(lm_name, reports_dir, api='openai', args=gen_args)
df_res.head(10)

## Task: Predict target from WSIs & reports

### Set params

In [10]:
wsi_feats_dir = 'data/tcga_crc/wsi_feats'
report_feats_dir = 'data/tcga_crc/report_feats'
target = 'msi'
data_file = 'data/tcga_crc/tcga_crc_gdc_pca.csv'
split = 'def'   # dataset split: 'def' or 'rand'
bsz = 64 # batch size for dataloaders
resample = True # resample data to balance classes
mode = 'img'  # input modalities: 'text', 'img', or 'mm'
num_classes = 2 # # classes for classification
class_weights = [1, 1] # class weights for loss function
metrics = ['acc', 'bal_acc', 'f1', 'auroc', 'acc_per_class', 'f1_per_class'] # metrics to track

### Single-fold

#### Load data

In [4]:
# create dataloaders
train_loader, val_loader, test_loader = create_dataloaders(target, data_file, wsi_feats_dir, report_feats_dir, use_rand_splits=split=='rand', bsz=bsz, resample=resample)

size of train set: 376, val set: 53, test set: 109
# samples for each class in train set: (array([0, 1]), array([326,  50]))
# samples for each class in val set: (array([0, 1]), array([46,  7]))
# samples for each class in test set: (array([0, 1]), array([94, 15]))


#### Train & eval

In [7]:
# %%script false --no-raise-error
# init model
model = Attention1DClassifier(target=target, mode=mode, num_classes=num_classes, class_weights=class_weights, metrics=metrics) # model architecture: 'Attention1DRegressor' or 'Attention1DClassifier'

# set training args
args = {'num_epochs': 100, 'ckpt_name': f'ckpt_best_{mode}_{split}_split', 'resume_ckpt': None, 'tblog_name': f'best_{mode}_{split}_split'}

# train model
model, trainer = train_mm(model, train_loader, val_loader, args)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name            | Type               | Params
-------------------------------------------------------
0 | attention       | Sequential         | 262 K 
1 | classifier      | Sequential         | 4.1 K 
2 | loss            | CrossEntropyLoss   | 0     
3 | acc             | MulticlassAccuracy | 0     
4 | bal_acc         | MulticlassAccuracy | 0     
5 | f1              | MulticlassF1Score  | 0     
6 | auroc           | MulticlassAUROC    | 0     
7 | acc_per_class   | MulticlassAccuracy | 0     
8 | f1_per_class    | MulticlassF1Score  | 0     
9 | auroc_per_class | MulticlassAUROC    | 0     
-------------------------------------------------------
266 K     Trainable params
0         Non-trainable params
266 K     Total params
1.066     Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Metric val_loss improved. New best score: 0.705


Validation: 0it [00:00, ?it/s]

Metric val_loss improved by 0.005 >= min_delta = 0.0. New best score: 0.700


Validation: 0it [00:00, ?it/s]

Metric val_loss improved by 0.006 >= min_delta = 0.0. New best score: 0.694


Validation: 0it [00:00, ?it/s]

Metric val_loss improved by 0.011 >= min_delta = 0.0. New best score: 0.683


Validation: 0it [00:00, ?it/s]

Metric val_loss improved by 0.020 >= min_delta = 0.0. New best score: 0.663


Validation: 0it [00:00, ?it/s]

Metric val_loss improved by 0.029 >= min_delta = 0.0. New best score: 0.634


Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Metric val_loss improved by 0.000 >= min_delta = 0.0. New best score: 0.634


Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Metric val_loss improved by 0.014 >= min_delta = 0.0. New best score: 0.620


Validation: 0it [00:00, ?it/s]

Metric val_loss improved by 0.020 >= min_delta = 0.0. New best score: 0.600


Validation: 0it [00:00, ?it/s]

Metric val_loss improved by 0.008 >= min_delta = 0.0. New best score: 0.593


Validation: 0it [00:00, ?it/s]

Metric val_loss improved by 0.011 >= min_delta = 0.0. New best score: 0.582


Validation: 0it [00:00, ?it/s]

Metric val_loss improved by 0.031 >= min_delta = 0.0. New best score: 0.551


Validation: 0it [00:00, ?it/s]

Metric val_loss improved by 0.030 >= min_delta = 0.0. New best score: 0.520


Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Monitored metric val_loss did not improve in the last 5 records. Best score: 0.520. Signaling Trainer to stop.


training on device: cpu


In [8]:
# evaluate the trained model on the test set
trainer.test(model, test_loader)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Testing: 0it [00:00, ?it/s]

[{'test_acc_epoch': 0.8073394298553467,
  'test_f1_epoch': 0.6869539022445679,
  'test_bal_acc_epoch': 0.7499372959136963,
  'test_auroc_epoch': 0.7396705150604248,
  'test_acc_class_0_epoch': 0.8296400904655457,
  'test_acc_class_1_epoch': 0.6702344417572021,
  'test_f1_class_0_epoch': 0.8810179829597473,
  'test_f1_class_1_epoch': 0.4928898811340332,
  'test_auroc_class_0_epoch': 0.7396705150604248,
  'test_auroc_class_1_epoch': 0.7396705150604248,
  'test_loss_epoch': 0.5155966877937317}]

### K-fold CV

In [15]:
# %%script false --no-raise-error
# create dataset
dataset = MMDataset(target, data_file, wsi_feats_dir, report_feats_dir)
# model class
model_class = Attention1DClassifier
# model args
model_args = {'mode': mode, 'target': target, 'num_classes': num_classes, 'class_weights': class_weights, 'metrics': metrics}

# train args
train_args = {'bsz': bsz, 'k': 5, 'resample': resample, 'num_epochs': 100, 'patience': 5, 'save_top_k': 0, 'tblog_name': f'best_{mode}_kfold', 'enable_progress_bar': False}

# run k-fold CV
res_kfold_cv = kfold_cv(model_class, dataset, model_args, train_args)

# print avg results over k folds
print(f"avg res over {train_args['k']} folds:")
for metric, value in res_kfold_cv['avg'].items():
    print(f"{metric.replace('_epoch', '')}: {value:.3f}")
    
# compute variance over k folds
res_var = {metric: np.var([res_kfold_cv[i][metric] for i in range(len(res_kfold_cv) - 1)]) for metric in res_kfold_cv[0].keys()}
print(f"variance over {train_args['k']} folds:")
for metric, value in res_var.items():
    print(f"{metric.replace('_epoch', '')}: {value:.3f}")

avg res over 5 folds:
test_acc: 0.795
test_f1: 0.672
test_bal_acc: 0.739
test_auroc: 0.783
test_acc_class_0: 0.815
test_acc_class_1: 0.662
test_f1_class_0: 0.871
test_f1_class_1: 0.473
test_loss: 0.490
