# Multimodal

## Setup

### Imports

In [1]:
# Imports
import os
import numpy as np
import pandas as pd
import random
from tqdm.notebook import tqdm
import torch
from transformers import AutoTokenizer, AutoModel
from torch.utils.data import Dataset, DataLoader
from src.utils import *
from src.models import MLPSTILClassifier, Attention1DSTILClassifier, Attention2DSTILClassifier, MLPSTILRegressor, Attention1DSTILRegressor, Attention2DSTILRegressor
from src.train import train_mm_stil, kfold_cv

%load_ext autoreload
%autoreload 2

### Set seed & device

In [2]:
set_seed(42)      
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

### Extract text feats from reports

In [3]:
%%script false --no-raise-error
# load biobert model & tokenizer
tokenizer = AutoTokenizer.from_pretrained(''dmis-lab/biobert-large-cased-v1.1-mnli)
lm = AutoModel.from_pretrained('dmis-lab/biobert-large-cased-v1.1-mnli')
data_dir = '/mnt/disks/ext/data/gdc/tcga/brca'
output_dir = 'data/report_feats'
extract_text_features(lm, tokenizer, data_dir)

## Task: Predict sTILs from WSIs & reports

### Set hparams

In [3]:
bsz = 16 # batch size for dataloaders
img_channels_in = 2048  # emb dim of wsi feats
text_channels_in = 1024 # emb dim of report feats
hidden_dim_mlp = 16 # hidden layer dim of mlp
num_classes = 10 # number of stil levels

### Load data

In [4]:
# Create the dataset
root_dir = './'
data_file = 'data/stils/data_stils.csv'
use_rand_splits = False # use random splits or predefined splits
train_data = MMDataset(root_dir, data_file, 'train', use_rand_splits)
val_data = MMDataset(root_dir, data_file, 'val', use_rand_splits)
test_data = MMDataset(root_dir, data_file, 'test', use_rand_splits)

print(f'size of train set: {len(train_data)}, val set: {len(val_data)}, test set: {len(test_data)}')

# Create the dataloaders
train_loader = DataLoader(train_data, batch_size=bsz, shuffle=True, num_workers=12, collate_fn=MMDataset.mm_collate_fn)
val_loader = DataLoader(val_data, batch_size=bsz, shuffle=False, num_workers=12, collate_fn=MMDataset.mm_collate_fn)
test_loader = DataLoader(test_data, batch_size=bsz, shuffle=False, num_workers=12, collate_fn=MMDataset.mm_collate_fn)

size of train set: 468, val set: 58, test set: 170


### Train & eval

In [None]:
# %%script false --no-raise-error
# init model (one of: MLPSTILClassifier, MLPSTILRegressor, Attention1DSTILClassifier, Attention1DSTILRegressor, Attention2DSTILClassifier, Attention2DSTILRegressor) 
model = Attention1DSTILRegressor(mode='image')

# set training args
num_epochs = 50
resume_ckpt = None
args = {'num_epochs': num_epochs, 'ckpt_name': 'ckpt_best_img'}

# train model
model, trainer = train_mm_stil(model, train_loader, val_loader, args)

# evaluate the trained model on the test set
trainer.test(model, test_loader)

### K-fold CV

In [None]:
# %%script false --no-raise-error
# run k-fold CV
dataset = MMDataset()

model_class = Attention1DSTILRegressor
model_args = {'mode': 'image'}
train_args = {'k': 5, 'num_epochs': 50, 'ckpt_name': 'ckpt_best_img_kfold_cv'}

res_kfold_cv = kfold_cv(model_class, dataset, model_args, train_args)
metrics = ['test_loss', 'test_corr', 'test_r2']
avg_res = {k: np.mean([res[k] for res in res_kfold_cv]).round(3) for k in metrics}
print(f"avg res over {train_args['k']} folds: {avg_res}")

## Task: Predict cancer subtypes from multimodal data

### Generate cancer subtype annotations from path reports

In [11]:
# load pretrainer BioBERT model
from transformers import AutoTokenizer, AutoModel, AutoModelForSequenceClassification, pipeline

# load biobert model & tokenizer
tokenizer = AutoTokenizer.from_pretrained('dmis-lab/biobert-large-cased-v1.1-mnli')
lm = AutoModelForSequenceClassification.from_pretrained('dmis-lab/biobert-large-cased-v1.1-mnli')

# Define the possible cancer subtypes
# Define the labels for each category
region_labels = ['ductal', 'lobular', 'metastatic']
localization_labels = ['in-situ', 'invasive', 'metastatic']

# Create a zero-shot classification pipeline
classifier = pipeline("zero-shot-classification", model=lm, tokenizer=tokenizer)

# For demonstration, let's classify a single report
sample_report_path = 'data/reports_processed/TCGA-3C-AALJ.265E5A9A-64FD-4B86-89BC-5E89F253C118.txt'
with open(sample_report_path, 'r') as f:
    sample_report = f.read()
    
# Use the classifier
# results = classifier(sample_report, labels)
# Tokenize the report and split it into overlapping chunks
max_length = 512 - 2  # account for [CLS] and [SEP] tokens
overlap = 100
tokens = tokenizer.tokenize(sample_report)
# Create overlapping chunks
chunk_size = max_length - overlap - 2  # account for [CLS] and [SEP] tokens and overlap
chunks = [tokens[i:i+chunk_size] for i in range(0, len(tokens), chunk_size - overlap)]

# Initialize scores for each label
region_scores = {label: 0 for label in region_labels}
localization_scores = {label: 0 for label in localization_labels}

# Classify each chunk and aggregate the results
for chunk in chunks:
    chunk_text = tokenizer.decode(tokenizer.convert_tokens_to_ids(chunk))
    
    # Classify for region
    region_result = classifier(chunk_text, region_labels)
    for label, score in zip(region_result["labels"], region_result["scores"]):
        region_scores[label] += score
        
    # Classify for localization
    localization_result = classifier(chunk_text, localization_labels)
    for label, score in zip(localization_result["labels"], localization_result["scores"]):
        localization_scores[label] += score

# The predicted subtype will be the label with the highest score
# predicted_subtype = results["labels"][0]  # The labels are ordered from highest to lowest score
# The predicted subtype will be the label with the highest aggregated score
# The predicted subtype for each category will be the label with the highest aggregated score
predicted_region = max(region_scores, key=region_scores.get)
predicted_localization = max(localization_scores, key=localization_scores.get)

print("Predicted Region:", predicted_region)
print("Predicted Degree of Localization:", predicted_localization)

Failed to determine 'entailment' label id from the label2id mapping in the model config. Setting to -1. Define a descriptive label2id mapping in the model config to ensure correct outputs.
Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


Predicted Region: lobular
Predicted Degree of Localization: in-situ
