 This notebook is designed to facilitate the training and evaluation of a Vision Transformer (ViT) model for binary classification tasks using MIBI (Multiplexed Imaging) datasets.
 
The main functions of this notebook include:
 
1. **CUDA Availability Check**: The notebook checks if a CUDA-enabled GPU is available for training, which can significantly speed up the training process.

2. **Data Loading**: It utilizes the `MibiDataset` class to load training, validation, and testing datasets from specified HDF5 files. Data loaders are created for each dataset to facilitate batch processing during training.
 
3. **Model Training**: The notebook is set up to train a ViT model using the `train_model` function from the `model_utils` module. This function handles the training loop, loss calculation, and optimization.

4. **Model Evaluation**: After training, the model can be evaluated on the validation and test datasets to assess its performance using various metrics.
 

In [25]:
import torch
import os,sys
import polars as pl
import mlflow
import pickle
from torch_geometric.loader import DataLoader

# Check if CUDA is available
print("Is CUDA available:", torch.cuda.is_available())

print("Number of GPUs available:", torch.cuda.device_count())


if torch.cuda.is_available():
    print("GPU Name:", torch.cuda.get_device_name(0))
print("PyTorch built with CUDA Version:", torch.version.cuda)

Is CUDA available: True
Number of GPUs available: 1
GPU Name: NVIDIA GeForce RTX 4090 Laptop GPU
PyTorch built with CUDA Version: 11.8


In [26]:
notebook_path=os.getcwd()
sys.path.append(os.path.abspath(os.path.join(notebook_path,'NN_Framework')))
from NN_Framework import graph_model_train
from NN_Framework.mibi_dataset import MibiDataset
from NN_Framework.models import GraphConvClassifier
from NN_Framework.mibi_data_prep_graph import remapping, create_graph_patches
from NN_Framework.models.multichannel_transforms import *

In [27]:
expression_types = ['MelanA', 'Ki67', 'SOX10', 'COL1A1', 'SMA', 
                   'CD206', 'CD8', 'CD4', 'CD45', 'CD3', 'CD20', 'CD11c']



In [28]:
#Duplicate and make a version with patches and with full?
#Normalize data.
def filter_and_process(full_df, sub_df, expression_types, 
                       binarize=False,cell_type_col='pred',
                         fov_col='fov',radius=25,mean=None,std=None):
        fovs = full_df.filter(pl.col(fov_col).is_in(sub_df[fov_col]))
        fovs = fovs.filter(~pl.col(cell_type_col).is_in(['Unidentified', 'Immune']))  # Remove confounding cells
        fovs = remapping(df=fovs, column_name=cell_type_col)  # Remap larger cell name list to smaller one
        graphs,fov_list = create_graph_patches(fovs, expression_types,
                                stride=124,binarize=binarize,
                                cell_type_col=cell_type_col, 
                                radius=radius)
        
        assert graphs is not None, "Error: Graphs is None."
        assert fov_list is not None, "Error: fov_list is None."
        print("Graphs and fov_list have been successfully created.")
        return graphs,fov_list

In [29]:


full_df = pl.read_csv(r"D:\MIBI-TOFF\Data_For_Amos\cleaned_expression_with_both_classification_prob_spatial_30_08_24.csv")

with open(r'D:\MIBI-TOFF\Mibi-Analysis-py\data_split_train100_val58_test19.pkl', 'rb') as pickle_file:
    data_loaded = pickle.load(pickle_file)

train_data = pl.DataFrame(data_loaded['train_data'])
val_data = pl.DataFrame(data_loaded['val_data'])
test_data = pl.DataFrame(data_loaded['test_data'])

train_graphs,train_fovs=filter_and_process(full_df=full_df, sub_df=train_data, 
                   expression_types=expression_types, binarize=True,
                   cell_type_col='pred',fov_col='fov' ,radius=25)
torch.save((train_graphs, train_fovs), r"D:\MIBI-TOFF\Scratch\train_full_r25_graphs.pt")

val_graphs,val_fovs=filter_and_process(full_df=full_df, sub_df=val_data, 
                   expression_types=expression_types,binarize=True, 
                   cell_type_col='pred',fov_col='fov' ,radius=25)
torch.save((val_graphs,val_fovs), r"D:\MIBI-TOFF\Scratch\val_full_r25_graphs.pt")

test_graphs,test_fovs=filter_and_process(full_df=full_df, sub_df=test_data, 
                   expression_types=expression_types,binarize=True,
                   cell_type_col='pred',fov_col='fov' ,radius=25)
torch.save((test_graphs,test_fovs), r"D:\MIBI-TOFF\Scratch\test_full_r25_graphs.pt")



Skipped Region: [1240, 1364, 1984, 2108]
Skipped Region: [1612, 1736, 1984, 2108]
Skipped Region: [1984, 2108, 1860, 1984]
Skipped Region: [1984, 2108, 1984, 2108]
Skipped Region: [1364, 1488, 0, 124]
Skipped Region: [1488, 1612, 124, 248]
Skipped Region: [1612, 1736, 0, 124]
Skipped Region: [1736, 1860, 0, 124]
Skipped Region: [1736, 1860, 1984, 2108]
Skipped Region: [1860, 1984, 0, 124]
Skipped Region: [1984, 2108, 124, 248]
Skipped Region: [1984, 2108, 1984, 2108]
Skipped Region: [248, 372, 1984, 2108]
Skipped Region: [1488, 1612, 1984, 2108]
Skipped Region: [1612, 1736, 1984, 2108]
Skipped Region: [1736, 1860, 1984, 2108]
Skipped Region: [1860, 1984, 1736, 1860]
Skipped Region: [1860, 1984, 1984, 2108]
Skipped Region: [1984, 2108, 1488, 1612]
Skipped Region: [1984, 2108, 1612, 1736]
Skipped Region: [1984, 2108, 1736, 1860]
Skipped Region: [744, 868, 992, 1116]
Skipped Region: [992, 1116, 248, 372]
Skipped Region: [992, 1116, 372, 496]
Skipped Region: [992, 1116, 992, 1116]
Skipped 

In [30]:
train_graphs,fov_test =torch.load( r"D:\MIBI-TOFF\Scratch\train_full_r25_graphs.pt")
val_graphs,_ = torch.load( r"D:\MIBI-TOFF\Scratch\val_full_r25_graphs.pt")
test_graphs,_ = torch.load( r"D:\MIBI-TOFF\Scratch\test_full_r25_graphs.pt")

print(fov_test)

train_loader = DataLoader(train_graphs, batch_size=100, shuffle=True)
test_loader = DataLoader(test_graphs, batch_size=100, shuffle=True)
val_loader = DataLoader(val_graphs, batch_size=100, shuffle=True)
del train_graphs, val_graphs, test_graphs

  train_graphs,fov_test =torch.load( r"D:\MIBI-TOFF\Scratch\train_full_r25_graphs.pt")
  val_graphs,_ = torch.load( r"D:\MIBI-TOFF\Scratch\val_full_r25_graphs.pt")
  test_graphs,_ = torch.load( r"D:\MIBI-TOFF\Scratch\test_full_r25_graphs.pt")


['FOV12', 'FOV36', 'FOV148', 'FOV86', 'FOV134', 'FOV82', 'FOV154', 'FOV234', 'FOV124', 'FOV24', 'FOV96', 'FOV98', 'FOV76', 'FOV176', 'FOV2', 'FOV104', 'FOV390', 'FOV152', 'FOV140', 'FOV106', 'FOV392', 'FOV42', 'FOV8', 'FOV118', 'FOV150', 'FOV146', 'FOV362', 'FOV420', 'FOV70', 'FOV288', 'FOV52', 'FOV64', 'FOV110', 'FOV108', 'FOV178', 'FOV190', 'FOV172', 'FOV364', 'FOV426', 'FOV400', 'FOV250', 'FOV88', 'FOV368', 'FOV22', 'FOV34', 'FOV116', 'FOV294', 'FOV144', 'FOV404', 'FOV298', 'FOV92', 'FOV184', 'FOV72', 'FOV304', 'FOV402', 'FOV232', 'FOV126', 'FOV342', 'FOV136', 'FOV366', 'FOV248', 'FOV114', 'FOV44', 'FOV80', 'FOV94', 'FOV84', 'FOV112', 'FOV222', 'FOV142', 'FOV138', 'FOV68', 'FOV132', 'FOV270', 'FOV428', 'FOV422', 'FOV90', 'FOV102', 'FOV74', 'FOV6', 'FOV284', 'FOV224', 'FOV306', 'FOV418', 'FOV226', 'FOV240', 'FOV412', 'FOV268', 'FOV424', 'FOV66', 'FOV50', 'FOV300', 'FOV286', 'FOV308', 'FOV10', 'FOV78', 'FOV302', 'FOV296', 'FOV252', 'FOV4', 'FOV62']


In [31]:
#Check Feature Dims Training
for i, data in enumerate(train_loader.dataset):
    print(f"Graph {i}:")
    print(f"  Node feature shape: {data.x.shape}")
    print(f"  Edge index shape: {data.edge_index.shape}")
    if data.edge_attr is not None:
        print(f"  Edge attribute shape: {data.edge_attr.shape}")

Graph 0:
  Node feature shape: torch.Size([35, 16])
  Edge index shape: torch.Size([2, 122])
Graph 1:
  Node feature shape: torch.Size([43, 16])
  Edge index shape: torch.Size([2, 154])
Graph 2:
  Node feature shape: torch.Size([52, 16])
  Edge index shape: torch.Size([2, 228])
Graph 3:
  Node feature shape: torch.Size([48, 16])
  Edge index shape: torch.Size([2, 214])
Graph 4:
  Node feature shape: torch.Size([45, 16])
  Edge index shape: torch.Size([2, 208])
Graph 5:
  Node feature shape: torch.Size([34, 16])
  Edge index shape: torch.Size([2, 112])
Graph 6:
  Node feature shape: torch.Size([44, 16])
  Edge index shape: torch.Size([2, 190])
Graph 7:
  Node feature shape: torch.Size([46, 16])
  Edge index shape: torch.Size([2, 194])
Graph 8:
  Node feature shape: torch.Size([50, 16])
  Edge index shape: torch.Size([2, 216])
Graph 9:
  Node feature shape: torch.Size([46, 16])
  Edge index shape: torch.Size([2, 168])
Graph 10:
  Node feature shape: torch.Size([41, 16])
  Edge index shap

In [32]:
model_name='full_GCN_12_channel'

# Model selection based on name
if any(x in model_name.lower() for x in ['gcn']):
    print('GCN')
    model = GraphConvClassifier(input_dim=len(expression_types)+4, hidden_dim=128, num_classes=2)
    
else:
    raise ValueError(f"Model type not recognized in model name: {model_name}")




criterion = torch.nn.CrossEntropyLoss()  # Define the classification criterion
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4,weight_decay=1e-5)  # Define the optimizer
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")  # Set device to GPU (0)
model.to(device)
criterion = criterion.to(device)
print(torch.cuda.is_available(),device)
print(next(model.parameters()).device) 

GCN
True cuda:0
cuda:0


Ending MLFlow if an issue causes it to not close correctly. 

In [33]:
mlflow.end_run()


In [34]:
#Parameter Block
params_block={'location':r'D:\MIBI-TOFF\Scratch\DL_Results',
'epochs':2000,
'patience':200,
'delta':0.00000001,
'check_val_freq':5,
'num_classes':2,
'model_name':model_name,
'log_with_mlflow':True,
'mlflow_uri':"http://127.0.0.1:5000"}

In [35]:
batch = next(iter(train_loader))
print("Batch node features shape:", batch.x.shape)
print("Batch edge index shape:", batch.edge_index.shape)
if batch.edge_attr is not None:
    print("Batch edge attributes shape:", batch.edge_attr.shape)

Batch node features shape: torch.Size([3695, 16])
Batch edge index shape: torch.Size([2, 14648])


Comments from the BOSS- Markov Clustering on the cell types. broad clusters based on the cell types. 

In [36]:

graph_model_train.train_model(model, train_loader, val_loader, criterion, optimizer, device, location=params_block['location'], 
    epochs=params_block['epochs'], patience=params_block['patience'], delta=params_block['delta'], check_val_freq=params_block['check_val_freq'],
    num_classes=params_block['num_classes'], model_name=params_block['model_name'], log_with_mlflow=params_block['log_with_mlflow'], mlflow_uri=params_block['mlflow_uri'])

Started MLflow run with ID: ff32f8c8ae4c4a298ba453530c225ffe
Epoch 1, Train Loss: 0.6610, Train Acc: 59.83%, Train F1: 52.524620056152344
Epoch 2, Train Loss: 0.6208, Train Acc: 66.63%, Train F1: 58.327369689941406
Epoch 3, Train Loss: 0.5985, Train Acc: 68.58%, Train F1: 61.92155456542969
Epoch 4, Train Loss: 0.5828, Train Acc: 69.64%, Train F1: 63.30935287475586
Epoch 5, Train Loss: 0.5715, Train Acc: 70.39%, Train F1: 64.46150207519531
Epoch 5, Val Loss: 0.8602, Val Acc: 44.21%, Val F1: 50.4608268737793
Model saved with improved validation loss: 0.8602
Epoch 6, Train Loss: 0.5637, Train Acc: 71.05%, Train F1: 65.61943817138672
Epoch 7, Train Loss: 0.5560, Train Acc: 71.70%, Train F1: 66.42251586914062
Epoch 8, Train Loss: 0.5485, Train Acc: 72.50%, Train F1: 67.4530258178711
Epoch 9, Train Loss: 0.5436, Train Acc: 72.78%, Train F1: 67.81375122070312
Epoch 10, Train Loss: 0.5393, Train Acc: 73.01%, Train F1: 68.11458587646484
Epoch 10, Val Loss: 0.9299, Val Acc: 43.61%, Val F1: 49.57

In [None]:
model.load_state_dict(torch.load(os.path.join(params_block['location'], f"{params_block['model_name']}best_model.pth")))

avg_test_loss, test_metrics = graph_model_train.eval_model(model, test_loader, criterion, device, params_block['num_classes'], epoch=0)


print(f"Test Loss: {avg_test_loss:.4f}")
for metric_name, metric_value in test_metrics.items():
    print(metric_name,metric_value)
    #print(f"{metric_name}: {metric_value:.4f}")


In [None]:
# Evaluate the model using the eval_model function
avg_val_loss, val_metrics = graph_model_train.eval_model(model, val_loader, criterion, device, params_block['num_classes'], epoch=0)

# Print all the metrics
print(f"Test Loss: {avg_val_loss:.4f}")
for metric_val_name, metric_val_value in val_metrics.items():
    print(metric_name,metric_value)
    #print(f"{metric_name}: {metric_value:.4f}")

In [None]:
print(test_loader)
print(val_loader)
# Compare the contents of the two loaders
test_data = [data for data, _ in test_loader]
val_data = [data for data, _ in val_loader]

# Check if the lengths of the datasets are the same
if len(test_data) == len(val_data):
    print("The test and validation loaders have the same number of batches.")
else:
    print(f"The test loader has {len(test_data)} batches, while the validation loader has {len(val_data)} batches.")

# Compare the contents of the first batch in both loaders
if test_data and val_data:
    print("Comparing the first batch of test and validation loaders:")
    print("Test batch:", test_data[0])
    print("Validation batch:", val_data[0])
else:
    print("One of the loaders is empty.")
