# Import things

In [None]:
!git clone https://github.com/Omid-Nejati/MedViT

In [None]:
import os
import sys
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

import torch
import torch.nn as nn
import torch.optim as optim

import torchvision.utils
from torchvision import models
import torchvision.datasets as dsets
import torchvision.transforms as transforms

In [None]:
!pip install timm
!pip install einops

In [None]:
package_path = "/kaggle/input/medvit-for-brain-tumor/MedViT"
import sys 
sys.path.append(package_path)

In [None]:
from MedViT import MedViT_base

In [None]:
!pip install config

In [None]:
MedViT_base

In [None]:
package_path = "../input/efficientnet-pytorch/EfficientNet-PyTorch/EfficientNet-PyTorch-master/"
import sys 
sys.path.append(package_path)

import os
import glob
import time
import random

import numpy as np
import pandas as pd

import pydicom
from pydicom.pixel_data_handlers.util import apply_voi_lut
import cv2
import matplotlib.pyplot as plt

import torch
from torch import nn
from torch.utils import data as torch_data
from torch.nn import functional as F
from torch.utils.data import Dataset
from torch.utils.data import Dataset, DataLoader
import pydicom
from pydicom.pixel_data_handlers.util import apply_voi_lut

import efficientnet_pytorch

from sklearn.model_selection import StratifiedKFold

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
seed = 123

def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True

seed_everything(seed)

class CFG:
    cnn_features = 256
    lstm_hidden = 32
    n_heads = 4
    proj_dim = 128  
    n_fold = 4
    n_epochs = 20
    img_size = 256
    n_frames = 40  
    cnn_features = 512
    n_heads = 16
    proj_dim = 128
    batch_size = 8

# Model

In [None]:
def load_medvit_weights(model, weight_path):
    state_dict = torch.load(weight_path)
    model.load_state_dict(state_dict, strict=False)
    return model


In [None]:
class MedViT3D(nn.Module):
    def __init__(self, num_classes, patch_size):
        super(MedViT3D, self).__init__()
        self.num_classes = num_classes
        self.patch_size = patch_size
        self.hidden_dim = 768
        self.num_heads = CFG.n_heads
        self.dropout_rate = 0.001
        
        self.patch_embeddings = nn.Conv3d(in_channels=4, out_channels=self.hidden_dim, 
                                          kernel_size=self.patch_size, stride=self.patch_size)
        

        encoder_layer = nn.TransformerEncoderLayer(d_model=self.hidden_dim, nhead=self.num_heads, dropout=self.dropout_rate)
        self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=2)
        
        self.fc = nn.Linear(self.hidden_dim, self.num_classes)
        self.dropout = nn.Dropout(self.dropout_rate)

    def forward(self, x):
        x = self.patch_embeddings(x)  
        x = x.flatten(2) 
        x = x.transpose(1, 2)  
        x = self.transformer_encoder(x)  
        x = x.mean(dim=1)  
        x = self.fc(self.dropout(x))  
        return x

class MedViTModel(nn.Module):
    def __init__(self):
        super(MedViTModel, self).__init__()
        self.map = nn.Conv2d(in_channels=4, out_channels=3, kernel_size=1)  # Convert 4 channels to 3 channels
        self.net = nn.ModuleList([
            MedViT3D(num_classes=CFG.cnn_features, patch_size=16),
            MedViT3D(num_classes=CFG.cnn_features, patch_size=32),
            MedViT3D(num_classes=CFG.cnn_features, patch_size=16),
            MedViT3D(num_classes=CFG.cnn_features, patch_size=32)
        ])

    def forward(self, x):
        out = []
        for model in self.net:
            out.append(model(x))
        out = torch.stack(out, dim=1) 
        return out

class SeparableEmbedding(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(SeparableEmbedding, self).__init__()
        self.fc = nn.Linear(input_dim, output_dim)

    def forward(self, x):
        return F.relu(self.fc(x))

class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.medvit = MedViTModel()
        self.embedding = SeparableEmbedding(CFG.cnn_features, CFG.proj_dim)
        self.fc = nn.Linear(CFG.proj_dim * 4, 1, bias=True)  
    def forward(self, x):
        batch_size, timesteps, C, H, W = x.size()
        medvit_out = self.medvit(x)
        embedding_output = self.embedding(medvit_out)
        embedding_output = embedding_output.view(batch_size, -1)
        out = self.fc(embedding_output)
        return out


# Data Processing

In [None]:
def load_image(path):
    ext = os.path.splitext(path)[-1].lower()
    if ext == '.png':
        image = cv2.imread(path, cv2.IMREAD_GRAYSCALE)
        if image is None:
            return np.zeros((CFG.img_size, CFG.img_size))
        image = cv2.resize(image, (CFG.img_size, CFG.img_size))
        return image.astype('float32') / 255
    else:
        try:
            dicom = pydicom.dcmread(path, force=True)
            image = apply_voi_lut(dicom.pixel_array, dicom)
            image = cv2.resize(image, (CFG.img_size, CFG.img_size))
            image = image - np.min(image)
            if np.min(image) < np.max(image):
                image = image / np.max(image)
            return image.astype('float32')
        except (pydicom.errors.InvalidDicomError, AttributeError):
            print(f"Error reading DICOM file: {path}")
            return np.zeros((CFG.img_size, CFG.img_size))

def load_3d_image(dicom_paths):
    slices = []
    for path in dicom_paths: 
        x = load_image(path)
        try:
            slices.append(x)
        except:
            print(type(x))
            return
            
    if len(slices) == 0:
        return np.zeros((CFG.img_size, CFG.img_size, CFG.n_frames))
    
    volume = np.stack(slices, axis=-1)
    
    if volume.shape[-1] < CFG.n_frames:
        pad_width = CFG.n_frames - volume.shape[-1]
        volume = np.pad(volume, ((0, 0), (0, 0), (0, pad_width)), mode='constant')
    elif volume.shape[-1] > CFG.n_frames:
        indices = np.linspace(0, volume.shape[-1] - 1, CFG.n_frames).astype(int)
        volume = volume[:, :, indices]
    return volume


def uniform_temporal_subsample(x, num_samples):
    t = len(x)
    indices = torch.linspace(0, t - 1, num_samples)
    indices = torch.clamp(indices, 0, t - 1).long()
    return [x[i] for i in indices]

In [None]:
class DataRetriever(Dataset):
    def __init__(self, paths, targets, transform=None):
        self.paths = paths
        self.targets = targets
        self.transform = transform

    def __len__(self):
        return len(self.paths)

    def read_video(self, vid_paths):
        video = [load_3d_image(vid_paths)]
        if self.transform:
            seed = random.randint(0, 99999)
            for i in range(len(video)):
                random.seed(seed)
                video[i] = self.transform(image=video[i])["image"]

        video = [torch.tensor(frame, dtype=torch.float32) for frame in video]
        if len(video) == 0:
            video = torch.zeros((CFG.img_size, CFG.img_size, CFG.n_frames))
        else:
            video = torch.stack(video)  # H * W * D
        return video

    def __getitem__(self, index):
        _id = self.paths[index]
        patient_path = f"/kaggle/input/rsna-miccai-brain-tumor-radiogenomic-classification/train/{str(_id).zfill(5)}/"
        channels = []
        for t in ["FLAIR", "T1w", "T1wCE", "T2w"]:
            t_paths = sorted(
                glob.glob(os.path.join(patient_path, t, "*")), 
                key=lambda x: int(os.path.basename(x).split("-")[-1].split(".")[0]),
            )
            channel = load_3d_image(t_paths)
            if channel.shape[-1] == 0:
                print(f"Empty channel detected for patient {_id}, type {t}")
                channel = np.zeros((CFG.img_size, CFG.img_size, CFG.n_frames))
            channels.append(torch.tensor(channel, dtype=torch.float32))

        channels = torch.stack(channels)  # (channels, H, W, D)
        y = torch.tensor(self.targets[index], dtype=torch.float32)
        return {"X": channels, "y": y}

In [None]:
import albumentations as A
from albumentations.pytorch import ToTensorV2

#Data augmentation
train_transform = A.Compose([
                                A.HorizontalFlip(p=0.5),
                                A.ShiftScaleRotate(
                                    shift_limit=0.0625, 
                                    scale_limit=0.1, 
                                    rotate_limit=10, 
                                    p=0.5
                                ),
                                A.RandomBrightnessContrast(p=0.5),
                            ])
valid_transform = A.Compose([
                            ])

In [None]:
df = pd.read_csv("../input/rsna-miccai-brain-tumor-radiogenomic-classification/train_labels.csv")
df.head(10)

# Training

In [None]:
class LossMeter:
    def __init__(self):
        self.avg = 0
        self.n = 0
    
    def reset(self):
        self.avg = 0
        self.n = 0

    def update(self, val):
        self.n += 1
        # incremental update
        self.avg = val / self.n + (self.n - 1) / self.n * self.avg

class AccMeter:
    def __init__(self):
        self.avg = 0
        self.n = 0
        
    def reset(self):
        self.avg = 0
        self.n = 0
        
    def update(self, y_true, y_pred):
        y_true = y_true.cpu().numpy().astype(int)
        y_pred = y_pred.detach().cpu().numpy() >= 0  
        last_n = self.n
        self.n += len(y_true)
        true_count = np.sum(y_true == y_pred)
        # incremental update
        self.avg = true_count / self.n + last_n / self.n * self.avg



In [None]:
class Trainer:
    def __init__(self, model, device, optimizer, criterion, loss_meter, score_meter, accumulation_steps=1):
        self.model = model
        self.device = device
        self.optimizer = optimizer
        self.criterion = criterion
        self.loss_meter = loss_meter
        self.score_meter = score_meter
        self.hist = {
            'val_loss': [],
            'val_score': [],
            'train_loss': [],
            'train_score': []
        }
        
        self.best_valid_score = -np.inf
        self.best_valid_loss = np.inf
        self.best_train_score = -np.inf
        self.n_patience = 0
        
        self.messages = {
            "epoch": "[Epoch {}: {}] loss: {:.9f}, score: {:.9f}, time: {} s",
            "checkpoint": "The score improved from {:.9f} to {:.9f}. Save model to '{}'",
            "patience": "\nValid score didn't improve last {} epochs."
        }
        self.accumulation_steps = accumulation_steps
        self.train_targets = []
        self.train_preds = []
        self.valid_targets = []
        self.valid_preds = []

    def fit(self, epochs, train_loader, valid_loader, save_path, patience):
        for n_epoch in range(1, epochs + 1):
            self.info_message("EPOCH: {}", n_epoch)
            
            train_loss, train_score, train_time = self.train_epoch(train_loader)
            valid_loss, valid_score, valid_time = self.valid_epoch(valid_loader)
            
            
            if self.best_train_score < train_score:
                self.best_train_score = train_score

            self.hist['val_loss'].append(valid_loss)
            self.hist['train_loss'].append(train_loss)
            self.hist['val_score'].append(valid_score)
            self.hist['train_score'].append(train_score)
            
            self.info_message(
                self.messages["epoch"], "Train", n_epoch, train_loss, train_score, train_time
            )
            
            self.info_message(
                self.messages["epoch"], "Valid", n_epoch, valid_loss, valid_score, valid_time
            )
                

            if self.best_valid_score < valid_score:
                self.info_message(
                    self.messages["checkpoint"], self.best_valid_score, valid_score, save_path
                )
                self.best_valid_score = valid_score
                self.best_valid_loss = valid_loss
                self.save_model(n_epoch, save_path)
                self.n_patience = 0
            else:
                self.n_patience += 1
            
            if self.n_patience >= patience:
                self.info_message(self.messages["patience"], patience)
                break
                
        return self.best_valid_loss, self.best_valid_score

    def train_epoch(self, train_loader):
        self.model.train()
        t = time.time()
        self.loss_meter.reset()
        self.score_meter.reset()
        
        self.optimizer.zero_grad()
        
        for step, batch in enumerate(train_loader, 1):
            X = batch["X"].to(self.device)
            targets = batch["y"].to(self.device)
            
            with torch.cuda.amp.autocast():  # Mixed precision
                outputs = self.model(X).squeeze(1)
                loss = self.criterion(outputs, targets) / self.accumulation_steps
            
            self.scaler.scale(loss).backward()

            if step % self.accumulation_steps == 0:
                self.scaler.step(self.optimizer)
                self.scaler.update()
                self.optimizer.zero_grad()
                
            self.loss_meter.update(loss.detach().item() * self.accumulation_steps)
            self.score_meter.update(targets, outputs)
            
            self.train_targets.extend(targets.cpu().numpy())
            self.train_preds.extend(outputs.detach().cpu().numpy())

            _loss, _score = self.loss_meter.avg, self.score_meter.avg
            message = 'Train Step {}/{}, train_loss: {:.9f}, train_score: {:.9f}'
            self.info_message(message, step, len(train_loader), _loss, _score, end="\r")
        
        torch.cuda.empty_cache()
        return _loss, _score, int(time.time() - t)
    
    def valid_epoch(self, valid_loader):
        self.model.eval()
        t = time.time()
        self.loss_meter.reset()
        self.score_meter.reset()

        for step, batch in enumerate(valid_loader, 1):
            with torch.no_grad():
                X = batch["X"].to(self.device)
                targets = batch["y"].to(self.device)
                
                with torch.cuda.amp.autocast():  # Mixed precision
                    outputs = self.model(X).squeeze(1)
                    loss = self.criterion(outputs, targets)
                
                self.loss_meter.update(loss.detach().item())
                self.score_meter.update(targets, outputs)

                self.valid_targets.extend(targets.cpu().numpy())
                self.valid_preds.extend(outputs.detach().cpu().numpy())

            _loss, _score = self.loss_meter.avg, self.score_meter.avg
            message = 'Valid Step {}/{}, valid_loss: {:.9f}, valid_score: {:.9f}'
            self.info_message(message, step, len(valid_loader), _loss, _score, end="\r")
        
        torch.cuda.empty_cache()
        return _loss, _score, int(time.time() - t)

    def plot_loss(self):
        plt.title("Loss")
        plt.xlabel("Training Epochs")
        plt.ylabel("Loss")

        plt.plot(self.hist['train_loss'], label="Train")
        plt.plot(self.hist['val_loss'], label="Validation")
        plt.legend()
        plt.show()
    
    def plot_score(self):
        plt.title("Score")
        plt.xlabel("Training Epochs")
        plt.ylabel("Acc")

        plt.plot(self.hist['train_score'], label="Train")
        plt.plot(self.hist['val_score'], label="Validation")
        plt.legend()
        plt.show()
    
    def save_model(self, n_epoch, save_path):
        torch.save(
            {
                "model_state_dict": self.model.state_dict(),
                "optimizer_state_dict": self.optimizer.state_dict(),
                "best_valid_score": self.best_valid_score,
                "n_epoch": n_epoch,
            },
            save_path,
        )
    
    @staticmethod
    def info_message(message, *args, end="\n"):
        print(message.format(*args), end=end)


In [None]:
print(len(df))

In [None]:
from sklearn.model_selection import train_test_split

train_df, valid_df = train_test_split(df, test_size=0.2, random_state=42, stratify=df['MGMT_value'])

print(f"Train size: {len(train_df)}")
print(f"Validation size: {len(valid_df)}")

In [None]:
df_train_valid = pd.concat([train_df, valid_df]).reset_index(drop=True)


In [None]:
import time
import numpy as np
from sklearn.model_selection import StratifiedKFold
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader


skf = StratifiedKFold(n_splits=CFG.n_fold)

start_time = time.time()

losses = []
scores = []

for fold, (train_index, val_index) in enumerate(skf.split(train_df, train_df['MGMT_value']), 1):
   print('-' * 30)
    print(f"Fold {fold}")
    
    train_fold_df = train_df.iloc[train_index]
    val_fold_df = train_df.iloc[val_index]
    
    train_retriever = DataRetriever(
        train_fold_df["BraTS21ID"].values, 
        train_fold_df["MGMT_value"].values,
        train_transform
    )
    
    val_retriever = DataRetriever(
        val_fold_df["BraTS21ID"].values, 
        val_fold_df["MGMT_value"].values
    )
    
    train_loader = DataLoader(
        train_retriever,
        batch_size=CFG.batch_size,
        shuffle=True,
        num_workers=4,
    )
    valid_loader = DataLoader(
        val_retriever, 
        batch_size=CFG.batch_size,
        shuffle=False,
        num_workers=4,
    )
    
    model = Model()
    model.to(device)
    
    optimizer = optim.Adam(model.parameters(), lr=0.0001)
    criterion = nn.BCEWithLogitsLoss()
    
    loss_meter = LossMeter()
    score_meter = AccMeter()
    
    trainer = Trainer(
        model, 
        device, 
        optimizer, 
        criterion, 
        loss_meter, 
        score_meter
    )
    
    loss, score = trainer.fit(
        CFG.n_epochs, 
        train_loader, 
        valid_loader, 
        f"best-model-{fold}.pth", 
        100,
    )
    
    losses.append(loss)
    scores.append(score)
    
    trainer.plot_loss()
    trainer.plot_score()

elapsed_time = time.time() - start_time
print('\nTraining complete in {:.0f}m {:.0f}s'.format(elapsed_time // 60, elapsed_time % 60))
print('Avg loss {}'.format(np.mean(losses)))
print('Avg score {}'.format(np.mean(scores)))
