 This notebook is designed to facilitate the training and evaluation of a Vision Transformer (ViT) model for binary classification tasks using MIBI (Multiplexed Imaging) datasets.
 
The main functions of this notebook include:
 
1. **CUDA Availability Check**: The notebook checks if a CUDA-enabled GPU is available for training, which can significantly speed up the training process.

2. **Data Loading**: It utilizes the `MibiDataset` class to load training, validation, and testing datasets from specified HDF5 files. Data loaders are created for each dataset to facilitate batch processing during training.
 
3. **Model Training**: The notebook is set up to train a ViT model using the `train_model` function from the `model_utils` module. This function handles the training loop, loss calculation, and optimization.

4. **Model Evaluation**: After training, the model can be evaluated on the validation and test datasets to assess its performance using various metrics.
 

In [1]:
import torch

# Check if CUDA is available
print("Is CUDA available:", torch.cuda.is_available())

print("Number of GPUs available:", torch.cuda.device_count())

if torch.cuda.is_available():
    print("GPU Name:", torch.cuda.get_device_name(0))
print("PyTorch built with CUDA Version:", torch.version.cuda)

Is CUDA available: True
Number of GPUs available: 1
GPU Name: NVIDIA GeForce RTX 4090 Laptop GPU
PyTorch built with CUDA Version: 12.4


In [2]:
import os,sys
from torch.utils.data import DataLoader
import torch

notebook_path=os.getcwd()
sys.path.append(os.path.abspath(os.path.join(notebook_path,'NN_Framework')))
from NN_Framework import model_utils
from NN_Framework.mibi_dataset import MibiDataset
from NN_Framework.models import ViTClassifier, DenseNet, SwinTransformer
from NN_Framework.multichannel_transforms import *

In [3]:
expression_types = ['MelanA.tif', 'Ki67.tif', 'SOX10.tif', 'COL1A1.tif', 'SMA.tif', 
                            'CD206.tif', 'CD8.tif', 'CD4.tif', 'CD45.tif', 'CD3.tif', 'CD20.tif', 'CD11c.tif']

In [4]:
train_transforms = Compose3D([
    RandomHorizontalFlip3D(p=0.5),
    RandomVerticalFlip3D(p=0.5),
    RandomRotation3D(p=0.5),
])
print(train_transforms)

<NN_Framework.multichannel_transforms.Compose3D object at 0x000001F1F8E5A580>


In [5]:
data_path=r'D:\MIBI-TOFF\Data_For_Amos'
train_dataset=MibiDataset(hdf5_path=r'D:\MIBI-TOFF\Scratch\training_1024.h5',transform=train_transforms,expressions=expression_types)
train_loader=DataLoader(dataset=train_dataset,batch_size=5,shuffle=True, num_workers=4,pin_memory=True)

val_dataset=MibiDataset(hdf5_path=r'D:\MIBI-TOFF\Scratch\validation_1024.h5',expressions=expression_types)
val_loader=DataLoader(dataset=val_dataset,batch_size=5,shuffle=True, num_workers=4,pin_memory=True)

test_dataset=MibiDataset(hdf5_path=r'D:\MIBI-TOFF\Scratch\testing_1024.h5',expressions=expression_types)
test_loader=DataLoader(dataset=test_dataset,batch_size=5,shuffle=True, num_workers=4,pin_memory=True)

In [6]:
'''
import matplotlib.pyplot as plt
import numpy as np
from collections import defaultdict

def plot_expression_histogram(loader, title, expressions, num_bins=50):
    class_0_counts = [defaultdict(int) for _ in range(len(expressions))]
    class_1_counts = [defaultdict(int) for _ in range(len(expressions))]

    for patches, labels in loader:
        for i in range(len(expressions)):
            mask_0 = labels == 0
            mask_1 = labels == 1
            
            if mask_0.any():
                values_class_0 = patches[mask_0, i, :, :].flatten().numpy()
                for val in values_class_0:
                    if val != 0: 
                        class_0_counts[i][val] += 1
            if mask_1.any():
                values_class_1 = patches[mask_1, i, :, :].flatten().numpy()
                for val in values_class_1:
                    if val != 0:  
                        class_1_counts[i][val] += 11

    n_rows = (len(expressions) + 2) // 3
    fig, axes = plt.subplots(n_rows, 3, figsize=(15, 5 * n_rows))
    fig.suptitle(f'Expression Distributions - {title}', fontsize=16)
    fig.subplots_adjust(top=0.9)
    
    axes_flat = axes.flatten() if n_rows > 1 else axes
    
    for i, (exp_name, ax) in enumerate(zip(expressions, axes_flat)):
        exp_name = exp_name.replace('.tif', '')

        if class_0_counts[i]:
            values_0, counts_0 = zip(*sorted(class_0_counts[i].items()))
            hist_0, bin_edges = np.histogram(values_0, bins=num_bins, weights=counts_0)
            ax.bar(bin_edges[:-1], hist_0, width=np.diff(bin_edges), alpha=0.5, label='Class 0')

        if class_1_counts[i]:
            values_1, counts_1 = zip(*sorted(class_1_counts[i].items()))
            hist_1, bin_edges = np.histogram(values_1, bins=bin_edges, weights=counts_1)
            ax.bar(bin_edges[:-1], hist_1, width=np.diff(bin_edges), alpha=0.5, label='Class 1')
        
        ax.set_title(exp_name)
        ax.legend()
        
    for i in range(len(expressions), len(axes_flat)):
        fig.delaxes(axes_flat[i])
        
    plt.tight_layout(rect=[0, 0, 1, 0.95])
    plt.show()

plot_expression_histogram(train_loader, "Training Set", expression_types)
plot_expression_histogram(val_loader, "Validation Set", expression_types)
plot_expression_histogram(test_loader, "Test Set", expression_types)
'''

'\nimport matplotlib.pyplot as plt\nimport numpy as np\nfrom collections import defaultdict\n\ndef plot_expression_histogram(loader, title, expressions, num_bins=50):\n    class_0_counts = [defaultdict(int) for _ in range(len(expressions))]\n    class_1_counts = [defaultdict(int) for _ in range(len(expressions))]\n\n    for patches, labels in loader:\n        for i in range(len(expressions)):\n            mask_0 = labels == 0\n            mask_1 = labels == 1\n            \n            if mask_0.any():\n                values_class_0 = patches[mask_0, i, :, :].flatten().numpy()\n                for val in values_class_0:\n                    if val != 0: \n                        class_0_counts[i][val] += 1\n            if mask_1.any():\n                values_class_1 = patches[mask_1, i, :, :].flatten().numpy()\n                for val in values_class_1:\n                    if val != 0:  \n                        class_1_counts[i][val] += 11\n\n    n_rows = (len(expressions) + 2) // 

In [6]:
print(train_dataset.class_counts)
print(val_dataset.class_counts)
print(test_dataset.class_counts)

{0: 228, 1: 245}
{0: 44, 1: 24}
{0: 20, 1: 42}


In [7]:
model_name='1024_swin_12_channel_'
img_size=1024

# Model selection based on name
if any(x in model_name.lower() for x in ['swint', 'swin']):
    print('Swin')
    model = SwinTransformer(
        img_size=1024, in_channels=12, patch_size=16, num_classes=2,
        embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24],
        window_size=7, mlp_ratio=4.0, dropout_rate=0.1, weight_decay=0.05)
    
elif any(x in model_name.lower() for x in ['vit']):
    dims_scaling = 2
    model = ViTClassifier(
        img_size_x=img_size, img_size_y=img_size, in_channels=12, num_classes=2,
        patch_size_x=32, patch_size_y=32, embed_dim=768*dims_scaling, num_heads=12,
        depth=12, mlp_dim=768*dims_scaling*4, dropout_rate=0.1, weight_decay=1e-5)
    
elif any(x in model_name.lower() for x in ['dn', 'densenet']):
    model = DenseNet(
        num_init_features=96, growth_rate=32, block_config=(6, 12, 24, 16),
        num_classes=2, bn_size=4, dropout_rate=0.25, input_channels=12)
    
else:
    raise ValueError(f"Model type not recognized in model name: {model_name}")




criterion = torch.nn.CrossEntropyLoss()  # Define the classification criterion
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4,weight_decay=1e-5)  # Define the optimizer
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")  # Set device to GPU (0)
model.to(device)
criterion = criterion.to(device)
print(torch.cuda.is_available(),device)
print(next(model.parameters()).device) 

Swin
True cuda:0
cuda:0


Ending MLFlow if an issue causes it to not close correctly. 

In [8]:
import mlflow
mlflow.end_run()


In [9]:
#Parameter Block
params_block={'location':r'D:\MIBI-TOFF\Scratch\DL_Results',
'epochs':200,
'patience':200,
'delta':0.00000001,
'check_val_freq':5,
'num_classes':2,
'model_name':model_name,
'log_with_mlflow':True,
'mlflow_uri':"http://127.0.0.1:5000"}

In [10]:

model_utils.train_model(model, train_loader, val_loader, criterion, optimizer, device, location=params_block['location'], 
    epochs=params_block['epochs'], patience=params_block['patience'], delta=params_block['delta'], check_val_freq=params_block['check_val_freq'],
    num_classes=params_block['num_classes'], model_name=params_block['model_name'], log_with_mlflow=params_block['log_with_mlflow'], mlflow_uri=params_block['mlflow_uri'])

Started MLflow run with ID: 95be5f23e5c648359a7c436de314ea39


In [12]:

model.load_state_dict(torch.load(os.path.join(params_block['location'], f"{params_block['model_name']}best_model.pth")))

avg_test_loss, test_metrics = model_utils.eval_model(model, test_loader, criterion, device, params_block['num_classes'], epoch=0)


print(f"Test Loss: {avg_test_loss:.4f}")
for metric_name, metric_value in test_metrics.items():
    print(metric_name,metric_value)
    #print(f"{metric_name}: {metric_value:.4f}")


  model.load_state_dict(torch.load(os.path.join(params_block['location'], f"{params_block['model_name']}best_model.pth")))


FileNotFoundError: [Errno 2] No such file or directory: 'D:\\MIBI-TOFF\\Scratch\\DL_Results\\1024_SwinT_12_channel_best_model.pth'

In [None]:
# Evaluate the model using the eval_model function
avg_val_loss, val_metrics = model_utils.eval_model(model, val_loader, criterion, device, params_block['num_classes'], epoch=0)

# Print all the metrics
print(f"Test Loss: {avg_val_loss:.4f}")
for metric_val_name, metric_val_value in val_metrics.items():
    print(metric_name,metric_value)
    #print(f"{metric_name}: {metric_value:.4f}")

In [None]:
print(test_loader)
print(val_loader)
# Compare the contents of the two loaders
test_data = [data for data, _ in test_loader]
val_data = [data for data, _ in val_loader]

# Check if the lengths of the datasets are the same
if len(test_data) == len(val_data):
    print("The test and validation loaders have the same number of batches.")
else:
    print(f"The test loader has {len(test_data)} batches, while the validation loader has {len(val_data)} batches.")

# Compare the contents of the first batch in both loaders
if test_data and val_data:
    print("Comparing the first batch of test and validation loaders:")
    print("Test batch:", test_data[0])
    print("Validation batch:", val_data[0])
else:
    print("One of the loaders is empty.")
