# Multimodal

## Setup

### Imports

In [2]:
# Imports
import os
import numpy as np
import pandas as pd
import random
from tqdm.notebook import tqdm
from collections import OrderedDict
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModel
from torch.utils.data import Dataset, DataLoader
import pytorch_lightning as pl
import torchmetrics
from src.utils import *
from src.models import MLPClassifier, WeakSTILClassifier
from src.train import train_mm, evaluate

%load_ext autoreload
%autoreload 2

### Set seed & device

In [2]:
set_seed(42)      
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

### Extract text feats from reports

In [3]:
# load biobert model & tokenizer
tokenizer = AutoTokenizer.from_pretrained('dmis-lab/biobert-large-cased-v1.1-mnli')
lm = AutoModel.from_pretrained('dmis-lab/biobert-large-cased-v1.1-mnli')
data_dir = '/mnt/disks/ext/data/gdc/tcga/brca'
output_dir = 'data/report_feats'
# extract_text_features(lm, tokenizer, data_dir)

### Load data

In [20]:
# Create the dataset
train_data = MMDataset(split='train')
val_data = MMDataset(split='val')
test_data = MMDataset(split='test')

# Create the dataloaders
bsz = 1
train_loader = DataLoader(train_data, batch_size=bsz, shuffle=True, num_workers=4)
val_loader = DataLoader(val_data, batch_size=bsz, shuffle=False, num_workers=4)
test_loader = DataLoader(test_data, batch_size=bsz, shuffle=False, num_workers=4)

In [6]:
# check shape of data
sample_batch = next(iter(train_loader))
print(f'shape of mm feats: {sample_batch[0].shape}')
print(f'shape of mm labels: {sample_batch[1].shape}')

shape of mm feats: torch.Size([16, 3072])
shape of mm labels: torch.Size([16])


## Train & eval

### MLP classifier

In [None]:
# train model
input_dim = 3072
hidden_dim = 32
num_classes = 10
num_epochs = 20
mlp_model = MLPClassifier(input_dim, hidden_dim, num_classes).to(device)
mlp_model = train_mm(mlp_model, train_loader, val_loader, num_epochs, device)

In [12]:
# evaluate the trained model on the test set
test_loss, test_acc = evaluate(mlp_model, test_loader, device)

test Loss: 1.305, test Acc: 0.648


### MIL classifier

In [57]:
# Initialize the model
img_channels_in = 2048  # replace with your actual image feature dimension
text_channels_in = 1024  # replace with your actual text feature dimension
num_classes = 10  # replace with your actual number of classes
learning_rate = 5e-4
weight_decay = 5e-4

model = WeakSTILClassifier(img_channels_in, text_channels_in, num_classes, learning_rate, weight_decay)

# Initialize the trainer
trainer = pl.Trainer(max_epochs=10)  # replace with your actual number of epochs and GPUs

# Train the model
trainer.fit(model, train_loader, val_loader)

# Evaluate the model
trainer.test(model, test_loader)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name       | Type               | Params
--------------------------------------------------
0 | attention  | Sequential         | 262 K 
1 | classifier | Sequential         | 30.7 K
2 | loss       | CrossEntropyLoss   | 0     
3 | accuracy   | MulticlassAccuracy | 0     
--------------------------------------------------
293 K     Trainable params
0         Non-trainable params
293 K     Total params
1.173     Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

`Trainer.fit` stopped: `max_epochs=10` reached.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Testing: 0it [00:00, ?it/s]

[{'test_loss': 1.8130868673324585,
  'test_acc': 0.6481481194496155,
  'test_acc_epoch': 0.5108919143676758}]