In [1]:
#Needed installs
#%pip install scikit-learn

In [2]:
import polars as pl
from mibi_dataset import MibiDataset
from sklearn.model_selection import train_test_split,GroupShuffleSplit
from torchvision import transforms
import matplotlib.pyplot as plt
import numpy as np
import torch


In [3]:
full_cell_sheet=pl.read_csv(r"D:\MIBI-TOFF\Data_For_Amos\cleaned_expression_with_both_classification_prob_spatial_30_08_24.csv")
print(full_cell_sheet.columns)
cell_frame = full_cell_sheet[['label', 'fov','pred','Group','patient number']]
unique_fovs = cell_frame.unique(subset='fov')
print(unique_fovs.head(5))


['Unnamed: 0', 'cell_size', '128Te', '129Xe', '12C', '130Xe', '131Xe', '132Xe', '137Ba', '138Ba', '139La', '140Ce', '181Ta', '182Empty', '197Au', '23Na', '24Mg', '25Mg', '27Al', '28Si', '31P', '32Si', '39K', '40Ca', '41K', '56Fe', '64Zn', '66Zn', 'Alexa Fluor 488', 'Bax', 'CCR7', 'CD11c', 'CD14', 'CD163', 'CD20', 'CD206', 'CD21', 'CD3', 'CD31', 'CD4', 'CD45', 'CD45RA', 'CD45RO', 'CD56', 'CD68', 'CD69', 'CD8', 'COL1A1', 'DC-SIGN', 'Foxp3', 'Granzyme B', 'HLA-DR-DP-DQ', 'HLA-class-1-A-B-C', 'IDO-1', 'Ki67', 'LAG-3', 'MECA-79', 'MelanA', 'PD-1', 'S100A9-Calprotectin', 'SMA', 'SOX10', 'TCF1TCF7', 'TIM-3', 'Tox-Tox2', 'anti-Biotin', 'dsDNA', 'label', 'area', 'eccentricity', 'major_axis_length', 'minor_axis_length', 'perimeter', 'convex_area', 'equivalent_diameter', 'centroid-0', 'centroid-1', 'major_minor_axis_ratio', 'perim_square_over_area', 'major_axis_equiv_diam_ratio', 'convex_hull_resid', 'centroid_dif', 'num_concavities', 'fov', 'pred', 'pred_prob', 'class', 'score', 'spatial', 'Grou

In [4]:
splitter=GroupShuffleSplit(n_splits=1, test_size=0.25, random_state=1995)
for train_idx, val_idx in splitter.split(unique_fovs, groups=unique_fovs['patient number']):
    train_data = unique_fovs[train_idx]
    val_data = unique_fovs[val_idx]
    print("Training Set:")
    print(train_data.head(5))
    print("\nValidation Set:")
    print(val_data.head(5))

Training Set:
shape: (5, 5)
┌───────┬────────┬───────────────┬───────┬────────────────┐
│ label ┆ fov    ┆ pred          ┆ Group ┆ patient number │
│ ---   ┆ ---    ┆ ---           ┆ ---   ┆ ---            │
│ f64   ┆ str    ┆ str           ┆ str   ┆ i64            │
╞═══════╪════════╪═══════════════╪═══════╪════════════════╡
│ 1.0   ┆ FOV356 ┆ CD4 APC       ┆ G4    ┆ 118            │
│ 1.0   ┆ FOV400 ┆ CD4 APC       ┆ G4    ┆ 94             │
│ 1.0   ┆ FOV142 ┆ CD4 APC       ┆ G2    ┆ 45             │
│ 1.0   ┆ FOV178 ┆ Collagen      ┆ G4    ┆ 90             │
│ 1.0   ┆ FOV64  ┆ blood vessels ┆ G3    ┆ 82             │
└───────┴────────┴───────────────┴───────┴────────────────┘

Validation Set:
shape: (5, 5)
┌───────┬────────┬────────────────────┬───────┬────────────────┐
│ label ┆ fov    ┆ pred               ┆ Group ┆ patient number │
│ ---   ┆ ---    ┆ ---                ┆ ---   ┆ ---            │
│ f64   ┆ str    ┆ str                ┆ str   ┆ i64            │
╞═══════╪════════╪═══

In [31]:
import os
root_dir=data_path=r'D:\MIBI-TOFF\Data_For_Amos'
df=unique_fovs
prefix='FOV'
fov_col='fov'
label_col='Group'
image_paths=[]

# The loop moves through the predefined file structure of the MIBI dataset being used to identify relevent image samples while filtering out files that we do not want
#Theoretically provides 39 expressions but some patients have varying amounts
for fov_dir in os.listdir(root_dir):
            if fov_dir.startswith(prefix):#replace with Pathlib matching later
                fov_path = os.path.join(root_dir, fov_dir)
                
                if os.path.isdir(fov_path):
                    #print(fov_dir)
                    # check if the folder name exists in the "FOV" column of the DataFrame
                    group = df.filter(pl.col(fov_col) == fov_dir)  #why oh why did I decide to decide to use polars instead of pandas for this test?
                    if group.height>0:
                        # pull the matching data from the "group" column
                        
                        group_data = group[label_col].to_list()[0]  
                        #print(group_data,binarized_data)
                        tif_path = os.path.join(fov_path, 'TIFs')
                        if os.path.exists(tif_path):
                            # Load all .tiff files ignoring those with "segmentation in the name"
                            sublist=[]
                            for image_file in os.listdir(tif_path):
                                if (image_file.endswith('.tif') or image_file.endswith('.tiff')) and ('segmentation' not in image_file) and not (image_file[0].isdigit()):
                                    sublist.append(os.path.join(tif_path, image_file))
                            

                            sublist.sort()#Should normalize the data assuming the data format is maintained true for this dataset 
                            #should create a perminant mapping
                            image_paths.append(sublist)
                        num_sublists = len(sublist)



In [32]:
#Find the expressions that are not in every file should be equal to the number of fovs 177 in this sample
expressions={}
for fov in image_paths:
    for path in fov:
        filename = os.path.basename(path)
        if filename not in expressions:
            expressions[filename]=1
        else:
            expressions[filename]+=1
print(expressions)


{'Alexa Fluor 488.tif': 177, 'Bax.tif': 177, 'CCR7.tif': 177, 'CD11c.tif': 177, 'CD14.tif': 177, 'CD163.tif': 177, 'CD20.tif': 177, 'CD206.tif': 177, 'CD21.tif': 177, 'CD3.tif': 177, 'CD31.tif': 177, 'CD4.tif': 177, 'CD45.tif': 177, 'CD45RA.tif': 177, 'CD45RO.tif': 177, 'CD56.tif': 177, 'CD68.tif': 177, 'CD69.tif': 177, 'CD8.tif': 177, 'COL1A1.tif': 177, 'DC-SIGN.tif': 177, 'Foxp3.tif': 177, 'Granzyme B.tif': 177, 'HLA-DR-DP-DQ.tif': 177, 'HLA-class-1-A-B-C.tif': 177, 'IDO-1.tif': 177, 'Ki67.tif': 177, 'LAG-3.tif': 177, 'MECA-79.tif': 177, 'MelanA.tif': 177, 'PD-1.tif': 177, 'S100A9-Calprotectin.tif': 177, 'SMA.tif': 177, 'SOX10.tif': 177, 'TCF1TCF7.tif': 177, 'TIM-3.tif': 177, 'Tox-Tox2.tif': 177, 'anti-Biotin.tif': 177, 'dsDNA.tif': 177, 'dsDNA-enhanced.tif': 9, 'dsDNA_cont.tif': 6, 'mask_fov_164.tif': 1, 'CD4_smooth.tif': 3, 'Foxp3_smooth.tif': 1, 'CD14_smooth.tif': 1, 'CD31_smooth.tif': 1, 'CD8_smooth.tif': 2, 'MelanA_enhanced.tif': 1, 'MelanA_smooth.tif': 1, 'S100A9-Calprotectin_s

In [33]:
threshold=117
consistent_expressions=[key for key, value in expressions.items() if value > threshold]
#I have no idea why > is needed to make this work
print(len(consistent_expressions))
print(consistent_expressions)

39
['Alexa Fluor 488.tif', 'Bax.tif', 'CCR7.tif', 'CD11c.tif', 'CD14.tif', 'CD163.tif', 'CD20.tif', 'CD206.tif', 'CD21.tif', 'CD3.tif', 'CD31.tif', 'CD4.tif', 'CD45.tif', 'CD45RA.tif', 'CD45RO.tif', 'CD56.tif', 'CD68.tif', 'CD69.tif', 'CD8.tif', 'COL1A1.tif', 'DC-SIGN.tif', 'Foxp3.tif', 'Granzyme B.tif', 'HLA-DR-DP-DQ.tif', 'HLA-class-1-A-B-C.tif', 'IDO-1.tif', 'Ki67.tif', 'LAG-3.tif', 'MECA-79.tif', 'MelanA.tif', 'PD-1.tif', 'S100A9-Calprotectin.tif', 'SMA.tif', 'SOX10.tif', 'TCF1TCF7.tif', 'TIM-3.tif', 'Tox-Tox2.tif', 'anti-Biotin.tif', 'dsDNA.tif']


In [34]:
image_paths=[]
for fov_dir in os.listdir(root_dir):
            if fov_dir.startswith(prefix):#replace with Pathlib matching later
                fov_path = os.path.join(root_dir, fov_dir)
                
                if os.path.isdir(fov_path):
                    #print(fov_dir)
                    # check if the folder name exists in the "FOV" column of the DataFrame
                    group = df.filter(pl.col(fov_col) == fov_dir)  #why oh why did I decide to decide to use polars instead of pandas for this test?
                    if group.height>0:
                        # pull the matching data from the "group" column
                        
                        group_data = group[label_col].to_list()[0]  
                        #print(group_data,binarized_data)
                        tif_path = os.path.join(fov_path, 'TIFs')
                        if os.path.exists(tif_path):
                            # Load all .tiff files ignoring those with "segmentation in the name"
                            sublist=[]
                            for image_file in os.listdir(tif_path):
                                if consistent_expressions:# this is a list of expression names we want to make sure are in the dataset. 
                                    if image_file in consistent_expressions:
                                        sublist.append(os.path.join(tif_path, image_file))

                            sublist.sort()#Should normalize the data assuming the data format is maintained true for this dataset 
                            #should create a perminant mapping
                            image_paths.append(sublist)
                        num_sublists = len(sublist)



In [35]:
expressions_test={}
for fov in image_paths:
    for path in fov:
        filename = os.path.basename(path)
        if filename not in expressions_test:
            expressions_test[filename]=1
        else:
            expressions_test[filename]+=1
print(expressions_test)


{'Alexa Fluor 488.tif': 177, 'Bax.tif': 177, 'CCR7.tif': 177, 'CD11c.tif': 177, 'CD14.tif': 177, 'CD163.tif': 177, 'CD20.tif': 177, 'CD206.tif': 177, 'CD21.tif': 177, 'CD3.tif': 177, 'CD31.tif': 177, 'CD4.tif': 177, 'CD45.tif': 177, 'CD45RA.tif': 177, 'CD45RO.tif': 177, 'CD56.tif': 177, 'CD68.tif': 177, 'CD69.tif': 177, 'CD8.tif': 177, 'COL1A1.tif': 177, 'DC-SIGN.tif': 177, 'Foxp3.tif': 177, 'Granzyme B.tif': 177, 'HLA-DR-DP-DQ.tif': 177, 'HLA-class-1-A-B-C.tif': 177, 'IDO-1.tif': 177, 'Ki67.tif': 177, 'LAG-3.tif': 177, 'MECA-79.tif': 177, 'MelanA.tif': 177, 'PD-1.tif': 177, 'S100A9-Calprotectin.tif': 177, 'SMA.tif': 177, 'SOX10.tif': 177, 'TCF1TCF7.tif': 177, 'TIM-3.tif': 177, 'Tox-Tox2.tif': 177, 'anti-Biotin.tif': 177, 'dsDNA.tif': 177}


In [16]:
import importlib
import model_utils
from mibi_dataset import MibiDataset
from torch.utils.data import DataLoader


data_path=r'D:\MIBI-TOFF\Data_For_Amos'
train_dataset=MibiDataset(root_dir=data_path,df=train_data,prefix='FOV',fov_col='fov',label_col='Group', transform=None)
train_loader=DataLoader(dataset=train_dataset,batch_size=1,shuffle=True)

val_dataset=MibiDataset(root_dir=data_path,df=val_data,prefix='FOV',fov_col='fov',label_col='Group', transform=None)
val_loader=DataLoader(dataset=train_dataset,batch_size=1,shuffle=True)

In [43]:
print(600/24)


25.0


In [17]:
import model_utils
from mibi_vit import ViTBinaryClassifier
model = ViTBinaryClassifier(img_size_x=128, img_size_y=128, in_channels=39, num_classes=2, patch_size_x=16, patch_size_y=16, embed_dim=768, num_heads=4, depth=6, mlp_dim=3072)
criterion = torch.nn.CrossEntropyLoss()  # Define the classification criterion
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)  # Define the optimizer
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")  # Set device to GPU (0)

model_utils.train_model(model, train_loader, val_loader, criterion, optimizer, device, epochs=5)




tensor([0])
('D:\\MIBI-TOFF\\Data_For_Amos\\FOV274\\TIFs\\Alexa Fluor 488.tif',)
tensor([1])
('D:\\MIBI-TOFF\\Data_For_Amos\\FOV354\\TIFs\\Alexa Fluor 488.tif',)
tensor([0])
('D:\\MIBI-TOFF\\Data_For_Amos\\FOV304\\TIFs\\Alexa Fluor 488.tif',)
tensor([0])
('D:\\MIBI-TOFF\\Data_For_Amos\\FOV392\\TIFs\\Alexa Fluor 488.tif',)
tensor([1])
('D:\\MIBI-TOFF\\Data_For_Amos\\FOV412\\TIFs\\Alexa Fluor 488.tif',)


RuntimeError: Given groups=1, weight of size [768, 39, 16, 16], expected input[1, 40, 128, 128] to have 39 channels, but got 40 channels instead