# Multimodal

## Setup

### Imports

In [2]:
# Imports
import os
import numpy as np
import pandas as pd
import random
from tqdm.notebook import tqdm
from collections import OrderedDict
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModel
from torch.utils.data import Dataset, DataLoader
import pytorch_lightning as pl
import torchmetrics
from src.utils import *
from src.models import MLPClassifier, Attention1DSTILClassifier, Attention2DSTILClassifier
from src.train import train_mm, evaluate

%load_ext autoreload
%autoreload 2

### Set seed & device

In [3]:
set_seed(42)      
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

### Extract text feats from reports

In [3]:
# load biobert model & tokenizer
# tokenizer = AutoTokenizer.from_pretrained('dmis-lab/biobert-large-cased-v1.1-mnli')
# lm = AutoModel.from_pretrained('dmis-lab/biobert-large-cased-v1.1-mnli')
# data_dir = '/mnt/disks/ext/data/gdc/tcga/brca'
# output_dir = 'data/report_feats'
# extract_text_features(lm, tokenizer, data_dir)

## Task: Predict sTILs from WSIs & reports

### Load data

In [26]:
# Create the dataset
root_dir = './'
data_file = 'data/stils/data_stils.csv'
train_data = MMDataset(root_dir, data_file, 'train')
val_data = MMDataset(root_dir, data_file, 'val')
test_data = MMDataset(root_dir, data_file, 'test')

# Create the dataloaders
bsz = 16
train_loader = DataLoader(train_data, batch_size=bsz, shuffle=True, num_workers=4, collate_fn=MMDataset.mm_collate_fn)
val_loader = DataLoader(val_data, batch_size=bsz, shuffle=False, num_workers=4, collate_fn=MMDataset.mm_collate_fn)
test_loader = DataLoader(test_data, batch_size=bsz, shuffle=False, num_workers=4, collate_fn=MMDataset.mm_collate_fn)

### Train & eval

#### MLP classifier

In [12]:
%%script false --no-raise-error
# train model
input_dim = 3072
hidden_dim = 32
num_classes = 10
num_epochs = 20
mlp_model = MLPClassifier(input_dim, hidden_dim, num_classes).to(device)
mlp_model = train_mm(mlp_model, train_loader, val_loader, num_epochs, device)
# evaluate the trained model on the test set
test_loss, test_acc = evaluate(mlp_model, test_loader, device)

#### MIL classifier w 1D attention

In [None]:
# Initialize the model
img_channels_in = 2048  
text_channels_in = 1024 
num_classes = 10  

model = Attention1DSTILClassifier(img_channels_in, text_channels_in, num_classes)

# Initialize the trainer
trainer = pl.Trainer(max_epochs=10)  

# Train the model
trainer.fit(model, train_loader, val_loader)

# Evaluate the model
trainer.test(model, test_loader)

#### MIL classifier w 2D attention

In [37]:
# Initialize the model
model = Attention2DSTILClassifier(img_channels_in, text_channels_in, num_classes)

# Initialize the trainer
trainer = pl.Trainer(max_epochs=10)  

# Train the model
trainer.fit(model, train_loader, val_loader)

# Evaluate the model
trainer.test(model, test_loader)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name       | Type                    | Params
-------------------------------------------------------
0 | attention  | EncoderDecoderAttention | 1.3 M 
1 | maps       | Sequential              | 18.4 K
2 | classifier | Sequential              | 30.7 K
3 | loss       | CrossEntropyLoss        | 0     
4 | accuracy   | MulticlassAccuracy      | 0     
-------------------------------------------------------
1.3 M     Trainable params
0         Non-trainable params
1.3 M     Total params
5.287     Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

  rank_zero_warn(


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

`Trainer.fit` stopped: `max_epochs=10` reached.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Testing: 0it [00:00, ?it/s]

[{'test_loss': 1.887076497077942,
  'test_acc': 0.5740740895271301,
  'test_acc_epoch': 0.5147058963775635}]