In [1]:
import numpy as np
import pandas as pd
import os
import random

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.optim import lr_scheduler

import wandb
from tqdm import tqdm

from sklearn.metrics import roc_auc_score, accuracy_score
from scipy.special import softmax

from imblearn.over_sampling import RandomOverSampler


##helper functions
import import_ipynb
import ipynbname
nb_fname = ipynbname.name()


from models import ABMIL, DSMIL, MeanPoolingMIL, MeanPoolingMILREG, IClassifier, BClassifier
from datasets import WSIFeatDataset, WSIFeatDataset_Reg
from helperfunctions import strat_k_fold, sampler_strat_kfold
from datasets import WSIFeatDataset, WSIFeatDataset_Reg


In [None]:
import os
import torch
import wandb
import numpy as np
import pandas as pd
from tqdm import tqdm
import torch.nn.functional as F
from torch.optim import lr_scheduler
from torch.utils.data import DataLoader
from sklearn.metrics import balanced_accuracy_score, roc_auc_score, accuracy_score

from scipy.special import softmax

from imblearn.over_sampling import RandomOverSampler


num_classes = 3

epochs = 10

in_dim = 1536 # dim of UNI2-h embedding -> if a case has 500 patches extracted, respective case embedding dims are 500 x 1536 (with batch size = 1 -> dims: 1, 500, 1536).

mil_types = ['multipoint_regressor', 'dsmil', 'abmil', 'meanpoolmil']

accumulate = True
test = True
feature_dir = "/path/to/dir"

# feature dir contains one .pt file per case which includes all patch embeddings per case in one file. 

target_column = ["combined"]  # Regression labels: One Column with all three regression values per class for one molecular simplified consensus subtype each.
                              # example for one ID: [0.6176298319609687, 0.5754460199978302, 0.5890559926087964]


class_column = "consensusClass"  # Categorical molecular subtype label for evaluation

df_274_reg = pd.read_excel("/path/to/df.xlsx")

df_train = df_274_reg


device = torch.device('cuda:0')

# Store results
cv_results = []
runs = 10
n_splits = 5


for mil_type in mil_types:
    fold_accuracies = []
    fold_auc_scores = []

    for z in range(runs):
        group_name = f"{os.path.basename(os.path.normpath(feature_dir))}_{mil_type}_ep:{epochs}"
        
        torch.cuda.empty_cache()
        
        dfs = sampler_strat_kfold(strat_k_fold(df_train[["ID", "consensusClass"]], n_splits=n_splits, random_state=None), rs="ros", random_state=None, random_state_valid=None)

        for i in range(n_splits):

            wandb.init(
            settings=wandb.Settings(start_method="thread"),
            project="tcga_blca_molecular_uni2",
            group=group_name,
            job_type=f"run{z}",
            config={"mil_type": mil_type, "epochs": epochs, "num_classes": num_classes, "runs": runs}
            )
            
            print(f"Training {mil_type}, Fold {i+1}/5")

            if mil_type == 'multipoint_regressor':

                df_oversampled = dfs[i][["ID", f"split{i}"]].merge(df_train, on="ID", how= "left")
            
                df_oversampled["combined"] = df_oversampled[["LumAll", "Ba/Sq", "Stroma-rich"]].values.tolist()
                train_data = df_oversampled[df_oversampled[f"split{i}"] == False].reset_index(drop=True)[["ID", "combined"]]
                val_data = df_oversampled[df_oversampled[f"split{i}"] == True].reset_index(drop=True)[["ID", "consensusClass"]]

                
                # Load dataset
                train_dataset_reg = WSIFeatDataset_Reg(train_data, feature_dir, id_column="ID", target_column="combined", phase="train")
                
                train_loader = DataLoader(train_dataset_reg, batch_size=1, shuffle=True, num_workers=4)
                

                val_dataset_reg = WSIFeatDataset_Reg(val_data, feature_dir, id_column="ID", class_column="consensusClass",
                                                    phase="val")

                val_loader = DataLoader(val_dataset_reg, batch_size=1, shuffle=False, num_workers=4)

            elif mil_type in ['dsmil', 'abmil', 'meanpoolmil']:
                
                train_data = dfs[i][dfs[i][f"split{i}"] == False].reset_index(drop=True)[["ID", "consensusClass"]]
                val_data = dfs[i][dfs[i][f"split{i}"] == True].reset_index(drop=True)[["ID", "consensusClass"]]

                train_dataset = WSIFeatDataset(train_data, feature_dir=feature_dir, id_column= "ID",
                 label_column = "consensusClass",)
                train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True, num_workers=4) 
            
                val_dataset = WSIFeatDataset(val_data, feature_dir=feature_dir, id_column= "ID",
                 label_column = "consensusClass",)
                val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False, num_workers=4)

            else:
                print("Something went wrong with selected MIL type.")
                break
                


            # Initialize model
            if mil_type == 'abmil':
                model = ABMIL(in_dim, 512, 128, num_classes)
            elif mil_type == 'dsmil':
                model = DSMIL(IClassifier(in_dim, num_classes), BClassifier(in_dim, num_classes))
            elif mil_type == 'meanpoolmil':
                model = MeanPoolingMIL(in_dim, 768, num_classes)
            elif mil_type == 'multipoint_regressor':
                model = MeanPoolingMILREG(in_dim, 768, num_classes)

            optimizer = torch.optim.Adam(model.parameters(), 5e-4, weight_decay=5e-4)
            scheduler = lr_scheduler.CosineAnnealingLR(optimizer, epochs, 5e-5)
            model = model.to(device)

            max_acc = 0.0
            balanced_acc = 0.0

            for epoch in range(1, epochs + 1):
                loss_sum = 0.0
                n = 0

                loop = tqdm(train_loader, total=len(train_loader), desc=f'Train [{epoch}/{epochs}]')
                model.train()

                for file_name, features, label in loop:
                    label = label.to(device)
                    features = features.squeeze(0).to(device)

                    if mil_type == 'abmil':
                        scores = model(features)
                        loss = F.cross_entropy(scores, label)
                    elif mil_type == 'dsmil':
                        classes, bag_prediction, _, _ = model(features)
                        max_prediction, _ = torch.max(classes, 0, True)
                        loss_bag = F.cross_entropy(bag_prediction, label)
                        loss_max = F.cross_entropy(max_prediction.view(1, -1), label)
                        loss = 0.5 * loss_bag + 0.5 * loss_max
                    elif mil_type == 'meanpoolmil':
                        scores = model(features.unsqueeze(0))
                        loss = F.cross_entropy(scores, label)
                    elif mil_type == 'multipoint_regressor': 
                        scores = model(features.unsqueeze(0))  # Predict 3 continuous values
                        loss = nn.L1Loss()(scores, label)  # Regression loss

                    if accumulate:
                        loss = loss / 4
                        loss.backward(retain_graph=True)
                        if (n + 1) % 4 == 0 or (n + 1) == len(train_loader):
                            optimizer.step()
                            optimizer.zero_grad()
                    else:
                        optimizer.zero_grad()
                        loss.backward()
                        optimizer.step()

                    n += 1
                    loss_sum += loss.item()

                    loop.set_postfix(loss=loss.item(), loss_mean=loss_sum / n)

                wandb.log({"epoch": epoch, "train_loss": loss_sum / n})

                if test:
                    with torch.no_grad():
                        correct = 0
                        total = 0
                        y_true, y_pred, y_probs = [], [], []
                
                        loop = tqdm(val_loader, total=len(val_loader), desc=f'Val [{epoch}/{epochs}]')
                        model.eval()
                
                        for file_name, features, label in loop:
                            label = label.to(device)
                            features = features.squeeze(0).to(device)
                
                            if mil_type == 'abmil':
                                scores = model(features)
                                scores = torch.softmax(scores, dim=1).cpu()
                                class_label = label.cpu().numpy().item()

                                
                            elif mil_type == 'dsmil':
                                classes, bag_prediction, _, _ = model(features)
                                max_prediction, _ = torch.max(classes, 0, True)
                                scores = (0.5 * torch.softmax(max_prediction, dim=1) + 0.5 * torch.softmax(bag_prediction, dim=1)).cpu()

                                class_label = label.cpu().numpy().item()
                                
                            elif mil_type == 'meanpoolmil':
                                scores = model(features.unsqueeze(0))
                                scores = torch.softmax(scores, dim=1).cpu()
                                class_label = label.cpu().numpy().item()
                                
                            elif mil_type == 'multipoint_regressor':
                                scores = model(features.unsqueeze(0))  # Predict 3 continuous values
                                scores = scores.cpu().numpy().flatten()  # Convert to NumPy (shape: [3])
                                class_label = label.cpu().numpy().item()

                            y_pred.append(scores)# Store raw regression predictions
                            y_true.append(class_label)  # Store true categorical label
                                                    
                            
                        y_pred_np = np.vstack(y_pred)  # Shape: [num_samples, 3] (continuous predictions)
                        y_true_np = np.array(y_true)  # Shape: [num_samples] (categorical labels)

                        y_class_pred_np = np.argmax(y_pred_np, axis=1)  # Convert to categorical predictions
                
                        # Compute Classification Metrics
                        test_accuracy = accuracy_score(y_true_np, y_class_pred_np)
                        bal_acc = balanced_accuracy_score(y_true_np, y_class_pred_np)
                
                        # Compute AUC (macro-averaged, multi-class)
                        if mil_type == 'multipoint_regressor':

                            # Compute Regression Metrics
                            mae = np.mean(np.abs(y_pred_np - y_true_np.reshape(-1, 1)))  # Mean Absolute Error
                            mse = np.mean((y_pred_np - y_true_np.reshape(-1, 1)) ** 2)  # Mean Squared Error
                    
                            print(f"Validation MAE: {mae:.4f}")
                            print(f"Validation MSE: {mse:.4f}")
                            
                            test_auc = roc_auc_score(y_true_np, softmax(y_pred_np, axis=1), multi_class="ovr", average="macro")
                        
                        else:
                            test_auc = roc_auc_score(y_true_np, y_pred_np, multi_class="ovr", average="macro")
                
                        print(f"Validation Accuracy: {test_accuracy:.4f}")
                        print(f"Validation Balanced Accuracy: {bal_acc:.4f}")
                        print(f"Validation AUC (Macro-Averaged): {test_auc:.4f}")                            

                    wandb.log({"epoch": epoch, "acc_agg": test_accuracy, "auc_macro_agg": test_auc, "bal_acc": bal_acc})

                scheduler.step()
            
            if z >= runs - 1 and i >= n_splits -1:
                print("logging code:")
                code_artifact = wandb.Artifact(type="code", name=f"{mil_type}_ep_{epochs}")
                code_artifact.add_file(f"./{nb_fname}.ipynb")
                wandb.log_artifact(code_artifact)
            
            wandb.finish()
