 This notebook is designed to facilitate the training and evaluation of a Vision Transformer (ViT) model for binary classification tasks using MIBI (Multiplexed Imaging) datasets.
 
The main functions of this notebook include:
 
1. **CUDA Availability Check**: The notebook checks if a CUDA-enabled GPU is available for training, which can significantly speed up the training process.

2. **Data Loading**: It utilizes the `MibiDataset` class to load training, validation, and testing datasets from specified HDF5 files. Data loaders are created for each dataset to facilitate batch processing during training.
 
3. **Model Training**: The notebook is set up to train a ViT model using the `train_model` function from the `model_utils` module. This function handles the training loop, loss calculation, and optimization.

4. **Model Evaluation**: After training, the model can be evaluated on the validation and test datasets to assess its performance using various metrics.
 

In [1]:
import torch

# Check if CUDA is available
print("Is CUDA available:", torch.cuda.is_available())

# If you have multiple GPUs, check how many are available
print("Number of GPUs available:", torch.cuda.device_count())

# Get the name of the current GPU device
if torch.cuda.is_available():
    print("GPU Name:", torch.cuda.get_device_name(0))


# Prints the CUDA version PyTorch was built with
print("PyTorch built with CUDA Version:", torch.version.cuda)

Is CUDA available: True
Number of GPUs available: 1
GPU Name: NVIDIA GeForce RTX 4090 Laptop GPU
PyTorch built with CUDA Version: 12.4


In [2]:
import os,sys
from torch.utils.data import DataLoader
import torch

notebook_path=os.getcwd()
sys.path.append(os.path.abspath(os.path.join(notebook_path,'NN_Framework')))
from NN_Framework import model_utils
from NN_Framework.mibi_dataset import MibiDataset
from NN_Framework.mibi_models import ViTClassifier, DenseNet
from NN_Framework.multichannel_transforms import *

In [3]:
expression_types = ['MelanA.tif', 'Ki67.tif', 'SOX10.tif', 'COL1A1.tif', 'SMA.tif', 
                            'CD206.tif', 'CD8.tif', 'CD4.tif', 'CD45.tif', 'CD3.tif', 'CD20.tif', 'CD11c.tif']

In [4]:
train_transforms = Compose3D([
    RandomHorizontalFlip3D(p=0.5),
    RandomVerticalFlip3D(p=0.5),
    RandomRotation3D(p=0.5),
])
print(train_transforms)

<NN_Framework.multichannel_transforms.Compose3D object at 0x0000027121762190>


In [5]:
data_path=r'D:\MIBI-TOFF\Data_For_Amos'
train_dataset=MibiDataset(hdf5_path=r'D:\MIBI-TOFF\Scratch\training_512.h5',transform=train_transforms,expressions=expression_types)
train_loader=DataLoader(dataset=train_dataset,batch_size=15,shuffle=True, num_workers=4)

val_dataset=MibiDataset(hdf5_path=r'D:\MIBI-TOFF\Scratch\validation_512.h5',expressions=expression_types)
val_loader=DataLoader(dataset=val_dataset,batch_size=15,shuffle=True, num_workers=4)

test_dataset=MibiDataset(hdf5_path=r'D:\MIBI-TOFF\Scratch\testing_512.h5',expressions=expression_types)
test_loader=DataLoader(dataset=test_dataset,batch_size=15,shuffle=True, num_workers=4)

In [6]:
print(train_dataset.class_counts)
print(val_dataset.class_counts)
print(test_dataset.class_counts)

{0: 912, 1: 980}
{0: 176, 1: 96}
{0: 80, 1: 168}


In [8]:
'''
dims_scaling=2
model = ViTClassifier(
    img_size_x=512, img_size_y=512, in_channels=12, num_classes=2, 
    patch_size_x=32, patch_size_y=32, embed_dim=768*dims_scaling, num_heads=12, 
    depth=12, mlp_dim=768*dims_scaling*4, dropout_rate=0.1, weight_decay=1e-5
)
'''

model=DenseNet(num_init_features=96, growth_rate=32, block_config=(6, 12, 24, 16),
                 num_classes=2, bn_size=4, drop_rate=0.1, input_channels=12)
criterion = torch.nn.CrossEntropyLoss()  # Define the classification criterion
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4,weight_decay=1e-5)  # Define the optimizer
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")  # Set device to GPU (0)
model.to(device)
print(torch.cuda.is_available(),device)
print(next(model.parameters()).device) 

True cuda:0
cuda:0


Ending MLFlow if an issue causes it to not close correctly. 

In [9]:
import mlflow
mlflow.end_run()


In [10]:
#Parameter Block
params_block={'location':r'D:\MIBI-TOFF\Scratch\DL_Results',
'epochs':50,
'patience':25,
'delta':0.00000001,
'check_val_freq':5,
'num_classes':2,
'model_name':'512_DN_12_channel',
'log_with_mlflow':True,
'mlflow_uri':"http://127.0.0.1:5000"}

In [11]:

model_utils.train_model(model, train_loader, val_loader, criterion, optimizer, device, location=params_block['location'], 
    epochs=params_block['epochs'], patience=params_block['patience'], delta=params_block['delta'], check_val_freq=params_block['check_val_freq'],
    num_classes=params_block['num_classes'], model_name=params_block['model_name'], log_with_mlflow=params_block['log_with_mlflow'], mlflow_uri=params_block['mlflow_uri'])

Started MLflow run with ID: 72c0ccd9623f4ca68c6029316f796227
Epoch 1, Train Loss: 0.6595, Train Acc: 58.19%
Epoch 2, Train Loss: 0.6262, Train Acc: 64.27%


In [19]:
# Load the model
model.load_state_dict(torch.load(os.path.join(params_block['location'], f"{params_block['model_name']}best_model.pth")))



# Evaluate the model using the eval_model function
avg_test_loss, test_metrics = model_utils.eval_model(model, test_loader, criterion, device, params_block['num_classes'], epoch=0)

# Print all the metrics
print(f"Test Loss: {avg_test_loss:.4f}")
for metric_name, metric_value in test_metrics.items():
    print(metric_name,metric_value)
    #print(f"{metric_name}: {metric_value:.4f}")




  model.load_state_dict(torch.load(os.path.join(params_block['location'], f"{params_block['model_name']}best_model.pth")))


Validation Loss: 3.0725, Validation Accuracy: 31.05%
Epoch 0, Val Loss: 3.0725, Val Acc: 31.05%
Class-wise Metrics: 
Test Loss: 3.0725
accuracy 31.04838752746582
f1_score 38.70967483520508
precision 48.648651123046875
recall 32.14285659790039
sensitivity [28.75, 32.14285659790039]
specificity [32.14285659790039, 28.75]


In [20]:
# Evaluate the model using the eval_model function
avg_val_loss, val_metrics = model_utils.eval_model(model, val_loader, criterion, device, params_block['num_classes'], epoch=0)

# Print all the metrics
print(f"Test Loss: {avg_val_loss:.4f}")
for metric_val_name, metric_val_value in val_metrics.items():
    print(metric_name,metric_value)
    #print(f"{metric_name}: {metric_value:.4f}")

Validation Loss: 2.7325, Validation Accuracy: 39.34%
Epoch 0, Val Loss: 2.7325, Val Acc: 39.34%
Class-wise Metrics: 
Test Loss: 2.7325
specificity [32.14285659790039, 28.75]
specificity [32.14285659790039, 28.75]
specificity [32.14285659790039, 28.75]
specificity [32.14285659790039, 28.75]
specificity [32.14285659790039, 28.75]
specificity [32.14285659790039, 28.75]


In [13]:
print(test_loader)
print(val_loader)
# Compare the contents of the two loaders
test_data = [data for data, _ in test_loader]
val_data = [data for data, _ in val_loader]

# Check if the lengths of the datasets are the same
if len(test_data) == len(val_data):
    print("The test and validation loaders have the same number of batches.")
else:
    print(f"The test loader has {len(test_data)} batches, while the validation loader has {len(val_data)} batches.")

# Compare the contents of the first batch in both loaders
if test_data and val_data:
    print("Comparing the first batch of test and validation loaders:")
    print("Test batch:", test_data[0])
    print("Validation batch:", val_data[0])
else:
    print("One of the loaders is empty.")


<torch.utils.data.dataloader.DataLoader object at 0x00000153841FD880>
<torch.utils.data.dataloader.DataLoader object at 0x0000015383D7FB80>
The test and validation loaders have the same number of batches.
Comparing the first batch of test and validation loaders:
Test batch: tensor([[[[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.]],

         [[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.]],

         [[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [0., 1., 1.,  ..., 0., 0., 0.],
    