 This notebook is designed to facilitate the training and evaluation of a Vision Transformer (ViT) model for binary classification tasks using MIBI (Multiplexed Imaging) datasets.
 
The main functions of this notebook include:
 
1. **CUDA Availability Check**: The notebook checks if a CUDA-enabled GPU is available for training, which can significantly speed up the training process.

2. **Data Loading**: It utilizes the `MibiDataset` class to load training, validation, and testing datasets from specified HDF5 files. Data loaders are created for each dataset to facilitate batch processing during training.
 
3. **Model Training**: The notebook is set up to train a ViT model using the `train_model` function from the `model_utils` module. This function handles the training loop, loss calculation, and optimization.

4. **Model Evaluation**: After training, the model can be evaluated on the validation and test datasets to assess its performance using various metrics.
 

In [1]:
import torch

# Check if CUDA is available
print("Is CUDA available:", torch.cuda.is_available())

# If you have multiple GPUs, check how many are available
print("Number of GPUs available:", torch.cuda.device_count())

# Get the name of the current GPU device
if torch.cuda.is_available():
    print("GPU Name:", torch.cuda.get_device_name(0))


# Prints the CUDA version PyTorch was built with
print("PyTorch built with CUDA Version:", torch.version.cuda)

Is CUDA available: True
Number of GPUs available: 1
GPU Name: NVIDIA GeForce RTX 4090 Laptop GPU
PyTorch built with CUDA Version: 12.4


In [2]:
import os,sys
from torch.utils.data import DataLoader
import torch

notebook_path=os.getcwd()
sys.path.append(os.path.abspath(os.path.join(notebook_path,'NN_Framework')))
from NN_Framework import model_utils
from NN_Framework.mibi_dataset import MibiDataset
from NN_Framework.mibi_models import ViTBinaryClassifier

In [4]:
expression_types = ['MelanA.tif', 'Ki67.tif', 'SOX10.tif', 'COL1A1.tif', 'SMA.tif', 
                            'CD206.tif', 'CD8.tif', 'CD4.tif', 'CD45.tif', 'CD3.tif', 'CD20.tif', 'CD11c.tif']

In [5]:
data_path=r'D:\MIBI-TOFF\Data_For_Amos'
train_dataset=MibiDataset(hdf5_path=r'D:\MIBI-TOFF\Scratch\training_512.h5',expressions=expression_types)
train_loader=DataLoader(dataset=train_dataset,batch_size=15,shuffle=True, num_workers=4)

val_dataset=MibiDataset(hdf5_path=r'D:\MIBI-TOFF\Scratch\validation_512.h5',expressions=expression_types)
val_loader=DataLoader(dataset=train_dataset,batch_size=15,shuffle=True, num_workers=4)

test_dataset=MibiDataset(hdf5_path=r'D:\MIBI-TOFF\Scratch\testing_512.h5',expressions=expression_types)
test_loader=DataLoader(dataset=train_dataset,batch_size=15,shuffle=True, num_workers=4)

In [6]:
dims_scaling=2
model = ViTBinaryClassifier(
    img_size_x=512, img_size_y=512, in_channels=12, num_classes=2, 
    patch_size_x=32, patch_size_y=32, embed_dim=768*dims_scaling, num_heads=12, 
    depth=12, mlp_dim=768*dims_scaling*4, dropout_rate=0.1, weight_decay=1e-5
)
criterion = torch.nn.CrossEntropyLoss()  # Define the classification criterion
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4,weight_decay=1e-5)  # Define the optimizer
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")  # Set device to GPU (0)
model.to(device)
print(torch.cuda.is_available(),device)
print(next(model.parameters()).device) 

True cuda:0
cuda:0


In [5]:
print(train_dataset.class_counts)
print(val_dataset.class_counts)
print(test_dataset.class_counts)

{0: 912, 1: 980}
{0: 176, 1: 96}
{0: 80, 1: 168}


Ending MLFlow if an issue causes it to not close correctly. 

In [7]:
import mlflow
mlflow.end_run()


In [10]:
#Parameter Block
params_block={'location':r'D:\MIBI-TOFF\Scratch\DL_Results',
'epochs':50,
'patience':25,
'delta':0.00000001,
'check_val_freq':5,
'num_classes':2,
'model_name':'512_vit_12_channel',
'log_with_mlflow':True,
'mlflow_uri':"http://127.0.0.1:5000"}

In [11]:

model_utils.train_model(model, train_loader, val_loader, criterion, optimizer, device, location=params_block['location'], 
    epochs=params_block['epochs'], patience=params_block['patience'], delta=params_block['delta'], check_val_freq=params_block['check_val_freq'],
    num_classes=params_block['num_classes'], model_name=params_block['model_name'], log_with_mlflow=params_block['log_with_mlflow'], mlflow_uri=params_block['mlflow_uri'])

Started MLflow run with ID: bd34d0f4a51e4bb09be2dde6d900b36b
Epoch 1, Train Loss: 0.8156, Train Acc: 50.63%
Epoch 2, Train Loss: 0.7196, Train Acc: 50.21%
Epoch 3, Train Loss: 0.7076, Train Acc: 51.06%
Epoch 4, Train Loss: 0.7100, Train Acc: 48.68%
Epoch 5, Train Loss: 0.7140, Train Acc: 50.85%
Validation Loss: 0.6927, Validation Accuracy: 51.80%
Epoch 5, Val Loss: 0.6927, Val Acc: 51.80%
Class-wise Metrics: 
Model saved with improved validation loss: 0.6927
Epoch 6, Train Loss: 0.7025, Train Acc: 50.32%
