 This notebook is designed to facilitate the training and evaluation of a Vision Transformer (ViT) model for binary classification tasks using MIBI (Multiplexed Imaging) datasets.
 
The main functions of this notebook include:
 
1. **CUDA Availability Check**: The notebook checks if a CUDA-enabled GPU is available for training, which can significantly speed up the training process.

2. **Data Loading**: It utilizes the `MibiDataset` class to load training, validation, and testing datasets from specified HDF5 files. Data loaders are created for each dataset to facilitate batch processing during training.
 
3. **Model Training**: The notebook is set up to train a ViT model using the `train_model` function from the `model_utils` module. This function handles the training loop, loss calculation, and optimization.

4. **Model Evaluation**: After training, the model can be evaluated on the validation and test datasets to assess its performance using various metrics.
 

In [None]:
import torch
import os,sys
import polars as pl
import mlflow
from torch_geometric.loader import DataLoader

# Check if CUDA is available
print("Is CUDA available:", torch.cuda.is_available())

print("Number of GPUs available:", torch.cuda.device_count())

if torch.cuda.is_available():
    print("GPU Name:", torch.cuda.get_device_name(0))
print("PyTorch built with CUDA Version:", torch.version.cuda)

In [6]:
notebook_path=os.getcwd()
sys.path.append(os.path.abspath(os.path.join(notebook_path,'NN_Framework')))
from NN_Framework import utils, graph_model_train
from NN_Framework.mibi_dataset import MibiDataset
from NN_Framework.models import GraphConvClassifier
from NN_Framework.mibi_data_prep_graph import remapping, create_graph

from NN_Framework.multichannel_transforms import *

In [7]:
expression_types = ['MelanA', 'Ki67', 'SOX10', 'COL1A1', 'SMA', 
                   'CD206', 'CD8', 'CD4', 'CD45', 'CD3', 'CD20', 'CD11c']



In [None]:
full_df = pl.read_csv(r"D:\MIBI-TOFF\Data_For_Amos\cleaned_expression_with_both_classification_prob_spatial_30_08_24.csv")

import pickle

# Load the data from the pickle file
with open(r'D:\MIBI-TOFF\Mibi-Analysis-py\data_split_20241102_173851.pkl', 'rb') as pickle_file:
    data_loaded = pickle.load(pickle_file)

train_data = pl.DataFrame(data_loaded['train_data'])
val_data = pl.DataFrame(data_loaded['val_data'])
test_data = pl.DataFrame(data_loaded['test_data'])

# Filter full_df to get overlapping FOVs for each of train, val, and test datasets
train_fovs = full_df.filter(pl.col('fov').is_in(train_data['fov']))
val_fovs = full_df.filter(pl.col('fov').is_in(val_data['fov']))
test_fovs = full_df.filter(pl.col('fov').is_in(test_data['fov']))

del train_data, val_data, test_data

train_fovs = train_fovs.filter(~pl.col('pred').is_in(['Unidentified', 'Immune']))#Remove confounding cells
val_fovs = val_fovs.filter(~pl.col('pred').is_in(['Unidentified', 'Immune']))
test_fovs = test_fovs.filter(~pl.col('pred').is_in(['Unidentified', 'Immune']))

train_fovs=remapping(df=train_fovs, column_name='pred')#remap larger cell name list to smaller one 
train_graphs = create_graph(train_fovs, expression_types, cell_type_col='remapped', radius=25)
del train_fovs


val_fovs=remapping(df=val_fovs, column_name='pred')#remap larger cell name list to smaller one
val_graphs = create_graph(val_fovs, expression_types, cell_type_col='remapped', radius=25)
del val_fovs

test_fovs=remapping(df=test_fovs, column_name='pred')#remap larger cell name list to smaller one 
test_graphs = create_graph(test_fovs, expression_types, cell_type_col='remapped', radius=25)
del test_fovs

torch.save(train_graphs, r"D:\MIBI-TOFF\Scratch\train_full_r25_graphs.pt")
torch.save(val_graphs, r"D:\MIBI-TOFF\Scratch\val_full_r25_graphs.pt")
torch.save(test_graphs, r"D:\MIBI-TOFF\Scratch\test_full_r25_graphs.pt")

print(f"Saved {len(train_graphs)} train graphs.")
print(f"Saved {len(val_graphs)} val graphs.")
print(f"Saved {len(val_graphs)} test graphs.")

In [None]:
train_graphs =torch.load( r"D:\MIBI-TOFF\Scratch\train_full_r25_graphs.pt")
val_graphs = torch.load( r"D:\MIBI-TOFF\Scratch\val_full_r25_graphs.pt")
test_graphs = torch.load( r"D:\MIBI-TOFF\Scratch\test_full_r25_graphs.pt")

train_loader = DataLoader(train_graphs, batch_size=1, shuffle=True)
test_loader = DataLoader(test_graphs, batch_size=1, shuffle=True)
val_loader = DataLoader(val_graphs, batch_size=1, shuffle=True)

In [None]:
model_name='full_GCN_12_channel_'

# Model selection based on name
if any(x in model_name.lower() for x in ['gcn']):
    print('GCN')
    model = GraphConvClassifier(input_dim=len(expression_types)+4, hidden_dim=128, num_classes=2)
    
else:
    raise ValueError(f"Model type not recognized in model name: {model_name}")




criterion = torch.nn.CrossEntropyLoss()  # Define the classification criterion
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4,weight_decay=1e-5)  # Define the optimizer
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")  # Set device to GPU (0)
model.to(device)
criterion = criterion.to(device)
print(torch.cuda.is_available(),device)
print(next(model.parameters()).device) 

Ending MLFlow if an issue causes it to not close correctly. 

In [16]:
mlflow.end_run()


In [27]:
#Parameter Block
params_block={'location':r'D:\MIBI-TOFF\Scratch\DL_Results',
'epochs':200,
'patience':200,
'delta':0.00000001,
'check_val_freq':5,
'num_classes':2,
'model_name':model_name,
'log_with_mlflow':True,
'mlflow_uri':"http://127.0.0.1:5000"}

In [None]:
#Check Feature Dims Training
for i, data in enumerate(train_loader.dataset):
    print(f"Graph {i}:")
    print(f"  Node feature shape: {data.x.shape}")
    print(f"  Edge index shape: {data.edge_index.shape}")
    if data.edge_attr is not None:
        print(f"  Edge attribute shape: {data.edge_attr.shape}")

In [None]:
#Check Feature Dims Training
for i, data in enumerate(val_loader.dataset):
    print(f"Graph {i}:")
    print(f"  Node feature shape: {data.x.shape}")
    print(f"  Edge index shape: {data.edge_index.shape}")
    if data.edge_attr is not None:
        print(f"  Edge attribute shape: {data.edge_attr.shape}")

In [None]:
batch = next(iter(train_loader))
print("Batch node features shape:", batch.x.shape)
print("Batch edge index shape:", batch.edge_index.shape)
if batch.edge_attr is not None:
    print("Batch edge attributes shape:", batch.edge_attr.shape)

In [None]:

graph_model_train.train_model(model, train_loader, val_loader, criterion, optimizer, device, location=params_block['location'], 
    epochs=params_block['epochs'], patience=params_block['patience'], delta=params_block['delta'], check_val_freq=params_block['check_val_freq'],
    num_classes=params_block['num_classes'], model_name=params_block['model_name'], log_with_mlflow=params_block['log_with_mlflow'], mlflow_uri=params_block['mlflow_uri'])

In [None]:

model.load_state_dict(torch.load(os.path.join(params_block['location'], f"{params_block['model_name']}best_model.pth")))

avg_test_loss, test_metrics = model_utils.eval_model(model, test_loader, criterion, device, params_block['num_classes'], epoch=0)


print(f"Test Loss: {avg_test_loss:.4f}")
for metric_name, metric_value in test_metrics.items():
    print(metric_name,metric_value)
    #print(f"{metric_name}: {metric_value:.4f}")


In [None]:
# Evaluate the model using the eval_model function
avg_val_loss, val_metrics = model_utils.eval_model(model, val_loader, criterion, device, params_block['num_classes'], epoch=0)

# Print all the metrics
print(f"Test Loss: {avg_val_loss:.4f}")
for metric_val_name, metric_val_value in val_metrics.items():
    print(metric_name,metric_value)
    #print(f"{metric_name}: {metric_value:.4f}")

In [None]:
print(test_loader)
print(val_loader)
# Compare the contents of the two loaders
test_data = [data for data, _ in test_loader]
val_data = [data for data, _ in val_loader]

# Check if the lengths of the datasets are the same
if len(test_data) == len(val_data):
    print("The test and validation loaders have the same number of batches.")
else:
    print(f"The test loader has {len(test_data)} batches, while the validation loader has {len(val_data)} batches.")

# Compare the contents of the first batch in both loaders
if test_data and val_data:
    print("Comparing the first batch of test and validation loaders:")
    print("Test batch:", test_data[0])
    print("Validation batch:", val_data[0])
else:
    print("One of the loaders is empty.")
