In [66]:
# %%
import pandas as pd
import numpy as np
import torch
import os
import random
import lightning.pytorch as pl
import sys
from pathlib import Path
sys.path.append(str(Path('../../').resolve()))
from utils import convnext, tools
from fastai.vision.all import *
from experiment_specific_utils import data_module, transforms

# %%

seed_value = 42

os.environ['PYTHONHASHSEED']=str(seed_value)
random.seed(seed_value)

np.random.seed(seed_value)
torch.manual_seed(seed_value)
torch.cuda.manual_seed(seed_value)

# %%
metadata = pd.read_csv("/home/jedrzej/projects/image_flow_cytometry_fine_tune/data/jedrzej/metadata_subset.csv.gz")
metadata

# %%
metadata.set.unique()

# %%
indx = metadata.condition.isin(["-SEA","+SEA"])
metadata = metadata.loc[indx, :].reset_index(drop = True )

# %%
set_of_interesting_classes = ['B_cell',  'T_cell', 
                        'T_cell_with_signaling',
                        'T_cell_with_B_cell_fragments',
                        'B_T_cell_in_one_layer',
                        'Synapses_without_signaling', 
                        'Synapses_with_signaling',
                        'No_cell_cell_interaction', 
                        'Multiplets'] 

indx = metadata.set.isin([ "train", "validation","test" ])
indx = indx & metadata.label.isin(set_of_interesting_classes)

train_index = metadata["set"] == "train"
train_index = train_index & metadata.label.isin(set_of_interesting_classes)
train_index = train_index[train_index].index

validation_index = metadata["set"] == "validation"
validation_index = validation_index & metadata.label.isin(set_of_interesting_classes)
validation_index = validation_index[validation_index].index

test_index = metadata["set"] == "test"
test_index = test_index & metadata.label.isin(set_of_interesting_classes)
test_index = test_index[test_index].index

# %%
metadata["set"].unique()

# %%
label_map = dict()
for i, cl in enumerate(set_of_interesting_classes):
    label_map[cl] = i

label_map['-1'] = -1
label_map[-1] = -1

# %%
channels = {
     "Ch1": ("Greys", "BF"),  
     "Ch2": ("Greens", "Antibody"),
     "Ch3": ("Reds", "CD18"),
     "Ch4": ("Oranges", "F-Actin"),
     "Ch6": ("RdPu", "MHCII"),
     "Ch7": ("Purples", "CD3/CD4"),
     "Ch11": ("Blues", "P-CD3zeta"),
     "Ch12": ("Greens", "Live-Dead")
 }

# %%
selected_channels = [0,3,4,5,6]
model_dir = "models"
log_dir = "logs"
scaling_factor = 4095.
reshape_size = 256
train_transform = transforms.train_transform_fit_image(reshape_size, include_normalization = True)
test_val_transform = transforms.test_val_transform_fit_image(reshape_size)

# %%
lr=0.0004
batch_size=32
max_epochs=50

# %%
print("Available cuda memory before model initialization: ")
tools.print_cuda_memory()

synapse_formation_module = data_module.SynapseFormationDataModule(metadata, train_index, validation_index, test_index, label_map, selected_channels, train_transform,
                                                test_val_transform, test_val_transform, batch_size, reshape_size)

synapse_formation_module.setup(stage='fit')
train_loader = synapse_formation_module.train_dataloader()
val_loader = synapse_formation_module.val_dataloader()
model = convnext.ConvnextModel(num_classes=len(set_of_interesting_classes), in_chans=len(selected_channels), steps_per_epoch=len(train_loader), learning_rate=lr, max_epochs=max_epochs)


  metadata = pd.read_csv("/home/jedrzej/projects/image_flow_cytometry_fine_tune/data/jedrzej/metadata_subset.csv.gz")


Available cuda memory before model initialization: 
Device 0:
  Allocated Memory: 772.33 MB
  Reserved Memory: 1658.00 MB
  Free Memory: 13259.69 MB
  Total Memory: 14917.69 MB
Initializing datasets...
Datasets initialized successfully!


In [67]:
checkpoint_path = "/home/jedrzej/projects/image_flow_cytometry_fine_tune/6 - Appsilon/1 - Machine learning/Fine-tuning ConvNext, new Augments/.neptune/IM-58/IM-58/checkpoints/epoch=49-step=4600.ckpt"
checkpoint = torch.load(checkpoint_path, map_location=torch.device('cpu'))
model.load_state_dict(checkpoint['state_dict'])

  checkpoint = torch.load(checkpoint_path, map_location=torch.device('cpu'))


<All keys matched successfully>

In [68]:
model.eval()

ConvnextModel(
  (model): ConvNeXt(
    (stem): Sequential(
      (0): Conv2d(5, 128, kernel_size=(4, 4), stride=(4, 4))
      (1): LayerNorm2d((128,), eps=1e-06, elementwise_affine=True)
    )
    (stages): Sequential(
      (0): ConvNeXtStage(
        (downsample): Identity()
        (blocks): Sequential(
          (0): ConvNeXtBlock(
            (conv_dw): Conv2d(128, 128, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3), groups=128)
            (norm): LayerNorm((128,), eps=1e-06, elementwise_affine=True)
            (mlp): Mlp(
              (fc1): Linear(in_features=128, out_features=512, bias=True)
              (act): GELU()
              (drop1): Dropout(p=0.0, inplace=False)
              (norm): Identity()
              (fc2): Linear(in_features=512, out_features=128, bias=True)
              (drop2): Dropout(p=0.0, inplace=False)
            )
            (shortcut): Identity()
            (drop_path): Identity()
          )
          (1): ConvNeXtBlock(
            (conv_

In [69]:
# Extract embeddings
def extract_embeddings(nn_model, data_loader):
    nn_model.eval()
    embeddings = []
    labels = []
    with torch.no_grad():
        for batch in data_loader:
            inputs, batch_labels = batch
            inputs = inputs.to(torch.device('cuda' if torch.cuda.is_available() else 'cpu'))
            batch_embeddings = nn_model.model.forward_features(inputs)
            if batch_embeddings.dim() == 4:
                batch_embeddings = nn_model.model.head.global_pool(batch_embeddings)
                batch_embeddings = nn_model.model.head.norm(batch_embeddings)
                batch_embeddings = nn_model.model.head.flatten(batch_embeddings)
            embeddings.append(batch_embeddings.cpu().numpy())
            labels.append(batch_labels.cpu().numpy())
    return np.vstack(embeddings), np.hstack(labels)
test_loader = synapse_formation_module.test_dataloader()

In [70]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to("cuda")
a = extract_embeddings(model, test_loader)

In [71]:
embeddings, labels = a
embeddings_df = pd.DataFrame(embeddings)
embeddings_df['label'] = labels

In [72]:
model.model.head

NormMlpClassifierHead(
  (global_pool): SelectAdaptivePool2d(pool_type=avg, flatten=Identity())
  (norm): LayerNorm2d((1024,), eps=1e-06, elementwise_affine=True)
  (flatten): Flatten(start_dim=1, end_dim=-1)
  (pre_logits): Identity()
  (drop): Dropout(p=0, inplace=False)
  (fc): Linear(in_features=1024, out_features=9, bias=True)
)

In [73]:
embeddings_df

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,1015,1016,1017,1018,1019,1020,1021,1022,1023,label
0,-0.805928,0.157662,0.576173,-1.388507,-0.532080,-0.541615,0.370675,-0.245025,-0.098967,-0.168570,...,0.526800,0.135303,-0.742776,-0.300688,0.490303,-0.122583,-0.406470,1.017134,0.422110,7
1,-1.231208,0.220913,0.069007,-1.778880,-0.581510,-0.988908,0.896636,-0.943685,0.320128,0.364289,...,0.653211,-0.412481,-0.439893,-0.611772,0.721221,-0.370233,-0.130313,0.814041,0.189012,7
2,-0.789709,0.021753,-0.269214,0.156545,0.484728,0.968309,-1.113086,-1.027873,-0.020747,0.834888,...,0.338997,0.428002,0.894622,-0.018873,0.605495,-0.928183,-0.155133,-0.315386,-0.502222,1
3,0.428416,0.991594,1.584693,1.542397,0.868679,0.373881,-0.386599,0.750486,-0.251176,-0.739195,...,0.376323,-0.744421,0.597434,-0.288840,-1.622254,0.041113,-0.150973,-1.568519,-0.433807,3
4,-1.588597,-0.093517,-0.193951,-1.536536,-0.823488,-0.815557,0.410563,-0.760440,0.172343,0.493134,...,1.359203,-0.409685,-0.396260,-0.475106,0.795944,-0.686296,0.292118,0.829582,0.322555,7
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1562,0.602211,-0.651645,0.424338,-0.409297,0.236796,0.232826,-0.759870,-0.214824,0.058799,-0.242292,...,-0.482324,0.198505,1.190233,0.132923,-0.917505,-0.232920,-0.175433,-0.690463,1.159616,0
1563,0.618455,0.636236,-0.199383,-0.415231,0.032799,0.348827,-0.299616,1.382883,-0.527806,0.349540,...,0.434015,0.590330,-0.245880,0.366202,0.125456,0.212768,0.664973,1.908127,1.251918,6
1564,-0.395108,-1.858959,0.874642,0.097085,-0.252257,1.345946,-0.161412,0.073800,0.019879,-1.262412,...,-0.261453,0.560484,-2.219505,-0.059975,0.185377,-0.446499,1.517568,0.045162,-0.524976,4
1565,1.171908,-0.349277,-0.423955,-0.416062,-1.117623,-0.213249,-0.573584,-0.241543,1.125976,-0.368923,...,-0.755614,-0.526437,0.245659,-1.298746,-0.052835,-0.404227,0.097928,-0.441783,0.182905,8


In [None]:
embeddings_df.columns = [str(col) for col in embeddings_df.columns]
embeddings_df.to_parquet("test_embeddings_labels_convnext.parquet", index=False)