 This notebook is designed to facilitate the training and evaluation of a Vision Transformer (ViT) model for binary classification tasks using MIBI (Multiplexed Imaging) datasets.
 
The main functions of this notebook include:
 
1. **CUDA Availability Check**: The notebook checks if a CUDA-enabled GPU is available for training, which can significantly speed up the training process.

2. **Data Loading**: It utilizes the `MibiDataset` class to load training, validation, and testing datasets from specified HDF5 files. Data loaders are created for each dataset to facilitate batch processing during training.
 
3. **Model Training**: The notebook is set up to train a ViT model using the `train_model` function from the `model_utils` module. This function handles the training loop, loss calculation, and optimization.

4. **Model Evaluation**: After training, the model can be evaluated on the validation and test datasets to assess its performance using various metrics.
 

In [1]:
import torch
import os,sys
import polars as pl
import mlflow
import pickle
from torch_geometric.loader import DataLoader

# Check if CUDA is available
print("Is CUDA available:", torch.cuda.is_available())

print("Number of GPUs available:", torch.cuda.device_count())


if torch.cuda.is_available():
    print("GPU Name:", torch.cuda.get_device_name(0))
print("PyTorch built with CUDA Version:", torch.version.cuda)

Is CUDA available: True
Number of GPUs available: 1
GPU Name: NVIDIA GeForce RTX 4090 Laptop GPU
PyTorch built with CUDA Version: 11.8


In [2]:
notebook_path=os.getcwd()
sys.path.append(os.path.abspath(os.path.join(notebook_path,'NN_Framework')))
from NN_Framework import graph_model_train
from NN_Framework.mibi_dataset import MibiDataset
from NN_Framework.models import GraphConvClassifier
from NN_Framework.mibi_data_prep_graph import remapping, create_graph
from NN_Framework.multichannel_transforms import *

In [3]:
expression_types = ['MelanA', 'Ki67', 'SOX10', 'COL1A1', 'SMA', 
                   'CD206', 'CD8', 'CD4', 'CD45', 'CD3', 'CD20', 'CD11c']



In [4]:
def filter_and_process(full_df, sub_df, expression_types, binarize=False,cell_type_col='pred', fov_col='fov',radius=25):
        fovs = full_df.filter(pl.col(fov_col).is_in(sub_df[fov_col]))
        fovs = fovs.filter(~pl.col(cell_type_col).is_in(['Unidentified', 'Immune']))  # Remove confounding cells
        fovs = remapping(df=fovs, column_name=cell_type_col)  # Remap larger cell name list to smaller one
        graphs = create_graph(fovs, expression_types,
                              binarize=binarize,
                               cell_type_col=cell_type_col, 
                               radius=radius)
        return graphs

In [5]:
'''
full_df = pl.read_csv(r"D:\MIBI-TOFF\Data_For_Amos\cleaned_expression_with_both_classification_prob_spatial_30_08_24.csv")

with open(r'D:\MIBI-TOFF\Mibi-Analysis-py\data_split_20241102_173851.pkl', 'rb') as pickle_file:
    data_loaded = pickle.load(pickle_file)

train_data = pl.DataFrame(data_loaded['train_data'])
val_data = pl.DataFrame(data_loaded['val_data'])
test_data = pl.DataFrame(data_loaded['test_data'])

train_graphs=filter_and_process(full_df=full_df, sub_df=train_data, 
                   expression_types=expression_types, binarize=True,
                   cell_type_col='pred',fov_col='fov' ,radius=25)
torch.save(train_graphs, r"D:\MIBI-TOFF\Scratch\train_full_r25_graphs.pt")

val_graphs=filter_and_process(full_df=full_df, sub_df=val_data, 
                   expression_types=expression_types,binarize=True, 
                   cell_type_col='pred',fov_col='fov' ,radius=25)
torch.save(val_graphs, r"D:\MIBI-TOFF\Scratch\val_full_r25_graphs.pt")

test_graphs=filter_and_process(full_df=full_df, sub_df=test_data, 
                   expression_types=expression_types,binarize=True,
                   cell_type_col='pred',fov_col='fov' ,radius=25)
torch.save(test_graphs, r"D:\MIBI-TOFF\Scratch\test_full_r25_graphs.pt")
'''


'\nfull_df = pl.read_csv(r"D:\\MIBI-TOFF\\Data_For_Amos\\cleaned_expression_with_both_classification_prob_spatial_30_08_24.csv")\n\nwith open(r\'D:\\MIBI-TOFF\\Mibi-Analysis-py\\data_split_20241102_173851.pkl\', \'rb\') as pickle_file:\n    data_loaded = pickle.load(pickle_file)\n\ntrain_data = pl.DataFrame(data_loaded[\'train_data\'])\nval_data = pl.DataFrame(data_loaded[\'val_data\'])\ntest_data = pl.DataFrame(data_loaded[\'test_data\'])\n\ntrain_graphs=filter_and_process(full_df=full_df, sub_df=train_data, \n                   expression_types=expression_types, binarize=True,\n                   cell_type_col=\'pred\',fov_col=\'fov\' ,radius=25)\ntorch.save(train_graphs, r"D:\\MIBI-TOFF\\Scratch\train_full_r25_graphs.pt")\n\nval_graphs=filter_and_process(full_df=full_df, sub_df=val_data, \n                   expression_types=expression_types,binarize=True, \n                   cell_type_col=\'pred\',fov_col=\'fov\' ,radius=25)\ntorch.save(val_graphs, r"D:\\MIBI-TOFF\\Scratch\x0bal_f

In [6]:
train_graphs =torch.load( r"D:\MIBI-TOFF\Scratch\train_full_r25_graphs.pt")
val_graphs = torch.load( r"D:\MIBI-TOFF\Scratch\val_full_r25_graphs.pt")
test_graphs = torch.load( r"D:\MIBI-TOFF\Scratch\test_full_r25_graphs.pt")

train_loader = DataLoader(train_graphs, batch_size=1, shuffle=True)
test_loader = DataLoader(test_graphs, batch_size=1, shuffle=True)
val_loader = DataLoader(val_graphs, batch_size=1, shuffle=True)

  train_graphs =torch.load( r"D:\MIBI-TOFF\Scratch\train_full_r25_graphs.pt")
  val_graphs = torch.load( r"D:\MIBI-TOFF\Scratch\val_full_r25_graphs.pt")
  test_graphs = torch.load( r"D:\MIBI-TOFF\Scratch\test_full_r25_graphs.pt")


In [7]:
#Check Feature Dims Training
for i, data in enumerate(train_loader.dataset):
    print(f"Graph {i}:")
    print(f"  Node feature shape: {data.x.shape}")
    print(f"  Edge index shape: {data.edge_index.shape}")
    if data.edge_attr is not None:
        print(f"  Edge attribute shape: {data.edge_attr.shape}")

Graph 0:
  Node feature shape: torch.Size([11077, 16])
  Edge index shape: torch.Size([2, 52352])
Graph 1:
  Node feature shape: torch.Size([9285, 16])
  Edge index shape: torch.Size([2, 36448])
Graph 2:
  Node feature shape: torch.Size([9590, 16])
  Edge index shape: torch.Size([2, 40188])
Graph 3:
  Node feature shape: torch.Size([11349, 16])
  Edge index shape: torch.Size([2, 53352])
Graph 4:
  Node feature shape: torch.Size([11895, 16])
  Edge index shape: torch.Size([2, 59294])
Graph 5:
  Node feature shape: torch.Size([2604, 16])
  Edge index shape: torch.Size([2, 11074])
Graph 6:
  Node feature shape: torch.Size([10676, 16])
  Edge index shape: torch.Size([2, 49788])
Graph 7:
  Node feature shape: torch.Size([3072, 16])
  Edge index shape: torch.Size([2, 15362])
Graph 8:
  Node feature shape: torch.Size([13154, 16])
  Edge index shape: torch.Size([2, 72234])
Graph 9:
  Node feature shape: torch.Size([2233, 16])
  Edge index shape: torch.Size([2, 8532])
Graph 10:
  Node feature s

In [8]:
model_name='full_GCN_12_channel'

# Model selection based on name
if any(x in model_name.lower() for x in ['gcn']):
    print('GCN')
    model = GraphConvClassifier(input_dim=len(expression_types)+4, hidden_dim=128, num_classes=2)
    
else:
    raise ValueError(f"Model type not recognized in model name: {model_name}")




criterion = torch.nn.CrossEntropyLoss()  # Define the classification criterion
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4,weight_decay=1e-5)  # Define the optimizer
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")  # Set device to GPU (0)
model.to(device)
criterion = criterion.to(device)
print(torch.cuda.is_available(),device)
print(next(model.parameters()).device) 

GCN
True cuda:0
cuda:0


Ending MLFlow if an issue causes it to not close correctly. 

In [9]:
mlflow.end_run()


In [10]:
#Parameter Block
params_block={'location':r'D:\MIBI-TOFF\Scratch\DL_Results',
'epochs':200,
'patience':200,
'delta':0.00000001,
'check_val_freq':5,
'num_classes':2,
'model_name':model_name,
'log_with_mlflow':True,
'mlflow_uri':"http://127.0.0.1:5000"}

In [11]:
batch = next(iter(train_loader))
print("Batch node features shape:", batch.x.shape)
print("Batch edge index shape:", batch.edge_index.shape)
if batch.edge_attr is not None:
    print("Batch edge attributes shape:", batch.edge_attr.shape)

Batch node features shape: torch.Size([2634, 16])
Batch edge index shape: torch.Size([2, 12246])


In [12]:

graph_model_train.train_model(model, train_loader, val_loader, criterion, optimizer, device, location=params_block['location'], 
    epochs=params_block['epochs'], patience=params_block['patience'], delta=params_block['delta'], check_val_freq=params_block['check_val_freq'],
    num_classes=params_block['num_classes'], model_name=params_block['model_name'], log_with_mlflow=params_block['log_with_mlflow'], mlflow_uri=params_block['mlflow_uri'])

Started MLflow run with ID: b74bd40bc56547088650aa0195badd9d
Epoch 1, Train Loss: 0.7098, Train Acc: 43.36%
Epoch 2, Train Loss: 0.6960, Train Acc: 52.45%
Epoch 3, Train Loss: 0.6944, Train Acc: 48.95%
Epoch 4, Train Loss: 0.6898, Train Acc: 51.75%
Epoch 5, Train Loss: 0.6842, Train Acc: 52.45%
Epoch 5, Val Loss: 0.7414, Val Acc: 29.41%
Model saved with improved validation loss: 0.7414
Epoch 6, Train Loss: 0.6818, Train Acc: 56.64%
Epoch 7, Train Loss: 0.6791, Train Acc: 58.04%
Epoch 8, Train Loss: 0.6816, Train Acc: 55.94%
Epoch 9, Train Loss: 0.6777, Train Acc: 55.24%
Epoch 10, Train Loss: 0.6734, Train Acc: 56.64%
Epoch 10, Val Loss: 0.7565, Val Acc: 29.41%
EarlyStopping counter: 1/200
Epoch 11, Train Loss: 0.6725, Train Acc: 53.85%
Epoch 12, Train Loss: 0.6684, Train Acc: 58.74%
Epoch 13, Train Loss: 0.6656, Train Acc: 62.24%
Epoch 14, Train Loss: 0.6648, Train Acc: 56.64%
Epoch 15, Train Loss: 0.6633, Train Acc: 56.64%
Epoch 15, Val Loss: 0.7466, Val Acc: 35.29%
EarlyStopping coun

In [13]:

model.load_state_dict(torch.load(os.path.join(params_block['location'], f"{params_block['model_name']}best_model.pth")))

avg_test_loss, test_metrics = graph_model_train.eval_model(model, test_loader, criterion, device, params_block['num_classes'], epoch=0)


print(f"Test Loss: {avg_test_loss:.4f}")
for metric_name, metric_value in test_metrics.items():
    print(metric_name,metric_value)
    #print(f"{metric_name}: {metric_value:.4f}")


  model.load_state_dict(torch.load(os.path.join(params_block['location'], f"{params_block['model_name']}best_model.pth")))


FileNotFoundError: [Errno 2] No such file or directory: 'D:\\MIBI-TOFF\\Scratch\\DL_Results\\full_GCN_12_channelbest_model.pth'

In [None]:
# Evaluate the model using the eval_model function
avg_val_loss, val_metrics = graph_model_train.eval_model(model, val_loader, criterion, device, params_block['num_classes'], epoch=0)

# Print all the metrics
print(f"Test Loss: {avg_val_loss:.4f}")
for metric_val_name, metric_val_value in val_metrics.items():
    print(metric_name,metric_value)
    #print(f"{metric_name}: {metric_value:.4f}")

In [None]:
print(test_loader)
print(val_loader)
# Compare the contents of the two loaders
test_data = [data for data, _ in test_loader]
val_data = [data for data, _ in val_loader]

# Check if the lengths of the datasets are the same
if len(test_data) == len(val_data):
    print("The test and validation loaders have the same number of batches.")
else:
    print(f"The test loader has {len(test_data)} batches, while the validation loader has {len(val_data)} batches.")

# Compare the contents of the first batch in both loaders
if test_data and val_data:
    print("Comparing the first batch of test and validation loaders:")
    print("Test batch:", test_data[0])
    print("Validation batch:", val_data[0])
else:
    print("One of the loaders is empty.")
