In [1]:
import numpy as np
from cellpose import models, core, io, plot
from pathlib import Path
from tqdm import trange
import matplotlib.pyplot as plt
from natsort import natsorted
import tifffile
import pickle
from helpers.visualization_helper import extract_contours_from_mask, plot_image_with_clustered_contours_RGB
from helpers.features_helper import cluster_features, extract_features_and_contours
import json

imgs_20xRenamed = []
with open('save_data/3D_images_Renamed/imgs_20xRenamed.pkl', 'rb') as f:
    imgs_20xRenamed = pickle.load(f)
masks_20xRenamed = []
with open('save_data/masks/20xRenamed/masks3D_CELLPOSE_RUN_1.pkl', 'rb') as f:
    masks_20xRenamed = pickle.load(f)
final_classes = {}
with open("classes_0.json", "r") as f:
    final_classes = json.load(f)



Welcome to CellposeSAM, cellpose v
cellpose version: 	4.0.4 
platform:       	linux 
python version: 	3.11.0rc1 
torch version:  	2.6.0+cu124! The neural network component of
CPSAM is much larger than in previous versions and CPU excution is slow. 
We encourage users to use GPU/MPS if available. 




In [4]:
image_index = 13
slice = 15
mask = masks_20xRenamed[image_index][slice,:,:]
img = imgs_20xRenamed[image_index][slice, 0:2, :, :]
print(img.shape)
img_reshaped = img.transpose(1, 2, 0) 
print(img_reshaped.shape)
img_rgb = np.concatenate([img_reshaped.astype(np.uint8), np.zeros((1024, 1024, 1), dtype=np.uint8)], axis=2)
print(img_rgb.shape)


(2, 1024, 1024)
(1024, 1024, 2)
(1024, 1024, 3)


In [None]:
image_index = 0
mask = masks_20xRenamed[image_index]
classes = {}
slices = mask.shape[0]
for s in range(slices): 
    print(f"Processing slice {s} of {slices} of image {image_index}")
    contours_with_features = extract_features_and_contours(mask[s, :, :])
    contours_with_classes = cluster_features(contours_with_features, k=5)
    print("Number of contours with classes:", len(contours_with_classes))
    for label in contours_with_classes:
        if label == 0:
            continue
        obj_class = contours_with_classes[label]['class']
        
        if label not in classes:
            classes[label] = []
        classes[label].append(obj_class)

# Optionally, convert sets to lists if you need to serialize or inspect
#classes = {k: list(v) for k, v in classes.items()}
#print("Classes per object ID:")
#for obj_id, obj_classes in classes.items():
#    print(f"Object ID {obj_id}: Classes {obj_classes}")
#


In [78]:
import json
import random
from collections import Counter

final_classes = {}

for obj_id, obj_classes in classes.items():
    trimmed = obj_classes[:]
    
    # Keep trimming until â‰¤ 4 or no change
    while len(trimmed) > 4:
        trimmed = trimmed[1:-1]
        if len(trimmed) <= 4:
            break

    counter = Counter(trimmed)
    most_common = counter.most_common()

    if len(most_common) == 1:
        final_class = most_common[0][0]
    else:
        top_freq = most_common[0][1]
        tied_classes = [cls for cls, freq in most_common if freq == top_freq]
        
        if len(tied_classes) == 1:
            final_class = tied_classes[0]
        else:
            final_class = random.choice(tied_classes)
    
    final_classes[int(obj_id)] = final_class


# Save to JSON
with open("classes_0.json", "w") as f:
    json.dump(final_classes, f, indent=2)


In [12]:
    
with open("classes_0.json", "r") as f:
    final_classes = json.load(f)


In [4]:
print("Number of contours with features:", len(contours_with_features))
print("Number of contours with classes:", len(contours_with_classes))
#print("First contour with classes:", contours_with_classes[1]['class'])
#print("First contour with classes:", contours_with_classes[1].keys())
print(len(np.unique(mask)))
#for data in contours_with_classes:
#    print(contours_with_classes[data]['class'])
#    #print(data['class'])

NameError: name 'contours_with_features' is not defined

# Experiment with different models

In [None]:
from monai.networks.nets import DenseNet121, resnet18
import torch.nn as nn
from Classification.classifier_helper import ClassificatorBlobHelper

# Example with DenseNet
VERSION = 1
densenet_model = DenseNet121(spatial_dims=3, in_channels=2+VERSION, out_channels=5) 
blb = ClassificatorBlobHelper()
blb.version = VERSION

# Example with ResNet
resnet_model = resnet18(spatial_dims=3, n_input_channels=3, num_classes=5)



In [2]:
gt_classes = {}
with open("classes_0.json", "r") as f:
    gt_classes = json.load(f)

print(gt_classes["12"])

3


In [None]:
import torch
from torch.utils.data import DataLoader
from Classification.classifier_helper import ObjectPatchDataset, calculate_test_loss
from sklearn.model_selection import train_test_split
import numpy as np

from monai.transforms import (
    Compose, RandRotate90, RandFlip, RandGaussianNoise,
    EnsureType, EnsureChannelFirst
)

train_transforms = Compose([
    RandRotate90(prob=0.5, spatial_axes=(0, 1)),  # rotate in XY
    RandFlip(prob=0.5, spatial_axis=0),           # flip on Z
    RandGaussianNoise(prob=0.2, mean=0.0, std=0.01),
    EnsureType()  # Makes sure it's a torch.Tensor
])


image_index = 0 # potentially add more images as a list
label_indices = [int(i) for i in np.unique(masks_20xRenamed[image_index]) if i != 0]

#for image_index in range(len(masks_20xRenamed)):
#    all_indices = [(image_index, i) for i in label_indices]
all_indices = [(0, i) for i in label_indices]   
all_labels = [gt_classes[str(i)] for i in label_indices]

train_idx, val_idx, train_lbls, val_lbls = train_test_split(all_indices, all_labels, test_size=0.2, random_state=42)

train_dataset = ObjectPatchDataset(blb, train_idx, train_lbls, transform=train_transforms)
val_dataset = ObjectPatchDataset(blb, val_idx, val_lbls)  # no augmentation

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=64)
print(f"Train dataset size: {len(train_dataset)}")
print(f"Train dataset 0, 0: {train_dataset[0][0].shape}")


Train dataset size: 1899
Train dataset 0, 0: torch.Size([3, 32, 64, 64])


In [None]:
import os
model_version = 1
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

history = {"loss": [], "val_loss": [], "val_acc": []}
model = resnet_model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
criterion = torch.nn.CrossEntropyLoss()

checkpoint_dir = "save_data/checkpoints"

num_epochs = 20

for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0

    for batch_idx, (X, y) in enumerate(train_loader):
        X, y = X.to(device), y.to(device)

        # Debug shapes
        print(f"Epoch {epoch}, Batch {batch_idx}: X = {X.shape}, y = {y.shape}")

        # Forward
        pred = model(X)
        loss = criterion(pred, y)

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    val_loss, val_acc = calculate_test_loss(model, val_loader, criterion, device)
    print(f"[Epoch {epoch+1}/{num_epochs}]: Avg_loss = {avg_loss}, Val Loss = {val_loss:.4f}, Val Acc = {val_acc:.4f}")
    avg_loss = running_loss / len(train_loader)
    history["loss"].append(avg_loss)
    history["val_loss"].append(val_loss)    
    history["val_acc"].append(val_acc)    
    #print(f"[Epoch {epoch+1}/{num_epochs}] Loss: {avg_loss:.4f}")

    # ðŸ”½ Save model checkpoint (overwrite to save space)
    checkpoint_path = os.path.join(checkpoint_dir, f"model_v{model_version}.pt")
    torch.save({
        "epoch": epoch + 1,
        "model_state_dict": model.state_dict(),
        "optimizer_state_dict": optimizer.state_dict(),
        "loss": avg_loss,
    }, checkpoint_path)
    print(f"Saved checkpoint: {checkpoint_path}")


In [10]:
import gc
import torch

# Move model to CPU and delete
model = model.to("cpu")
del model

# Empty cache and collect garbage
torch.cuda.empty_cache()
gc.collect()


3096

In [81]:
history = {"loss": [], "val_loss": [], "val_acc": []}
history["loss"].append(1)
history["val_loss"].append(1)    
history["val_acc"].append(1)    
history["loss"].append(2)
history["val_loss"].append(2)    
history["val_acc"].append(2)    
history["loss"].append(3)
history["val_loss"].append(3)    
history["val_acc"].append(3)

import os
checkpoint_dir = "save_data/checkpoints"
history_path = os.path.join(checkpoint_dir, f"history_v{VERSION}.json")
import json
with open(history_path, "w") as f:
    json.dump(history, f)    

In [None]:
import matplotlib.pyplot as plt

plt.plot(history["loss"], label="Train Loss")
plt.plot(history["val_loss"], label="Val Loss")
plt.legend()
plt.title("Training Curve")


(1.5255147602683619, 0.36)


In [None]:
image_index = 0
mask = masks_20xRenamed[image_index]
img = imgs_20xRenamed[image_index][:, 0:2, :, :]
print(img.shape)
img_reshaped = img.transpose(1, 2, 3, 0) 
mask_reshaped = mask.transpose(1, 2, 0) 
print(img_reshaped.shape)
print(mask_reshaped.shape)
zeros = np.zeros((1, 1024, 1024, img_reshaped.shape[3]), dtype=np.uint8)
print(zeros.shape)
img_rgb = np.concatenate([img_reshaped.astype(np.uint8), zeros], axis=0)
print(img_rgb.shape)
img_rgb[2,:,:,:] = (mask_reshaped > 0) * 255

(28, 2, 1024, 1024)
(2, 1024, 1024, 28)
(1024, 1024, 28)
(1, 1024, 1024, 28)
(3, 1024, 1024, 28)


In [None]:
import importlib
import Classification.classifier_helper
importlib.reload(Classification.classifier_helper)
from Classification.classifier_helper import ClassificatorBlobHelper, preprocess_blob

blb = ClassificatorBlobHelper()
blob, oob = blb.get_blob(0,3)

self.blob.shape: (28, 1024, 1024)
final_blob.shape: (3, 64, 64, 28)


In [None]:
print(blob.shape)

(3, 64, 64, 28)


In [None]:
import matplotlib.pyplot as plt

blob, oob = blb.get_blob(0,4, in_channel_dist=False, gaus_exp_nuc=20.0)
plt.imshow(blob[(0,1,3),:,:,10].transpose(1, 2, 0))
plt.show()
plt.imshow(blob[0,:,:,10])
plt.show()
plt.imshow(blob[1,:,:,10])
plt.show()
plt.imshow(blob[2,:,:,10])
plt.show()
plt.imshow(blob[3,:,:,10])
plt.show()


In [None]:
blob, oob = blb.get_blob(0,4, in_channel_dist=True, gaus_exp_nuc=10.0, gaus_exp_myo=20.0)
#plt.imshow(blob[:,:,:,10].transpose(1, 2, 0))
#plt.show()
plt.imshow(blob[0,:,:,10])
plt.show()
plt.imshow(blob[1,:,:,10])
plt.show()
plt.imshow(blob[2,:,:,10])
plt.show()


In [1]:
import importlib
import classifier_helper #import SAMClassifier3D_CENTER_AWARE, get_resnet18_encoder, get_swinl_encoder, get_convnextxl_encoder, get_efficientnetv2l_encoder, get_resnet101_encoder

importlib.reload(classifier_helper)

<module 'classifier_helper' from '/Users/davidexler/Documents/Masterarbeit/repo/Masterarbeit/Classification/classifier_helper.py'>

In [1]:
from classifier_helper import SAMClassifier3D_CENTER_AWARE, get_resnet18_encoder, get_swinl_encoder, get_convnextxl_encoder, get_efficientnetv2l_encoder, get_resnet101_encoder
import torch
#enc = get_resnet18_encoder() #WORKING
#enc = get_resnet101_encoder() #WORKING
#enc = get_swinl_encoder() #WORKING
#enc = get_convnextxl_encoder() #WORKING
enc = get_efficientnetv2l_encoder() #WORKING

In [2]:
dim = 256
X_random_3D = torch.randn((1,3,32,dim,dim))
X_random_2D = torch.randn((1,3,dim,dim))
test_model = SAMClassifier3D_CENTER_AWARE(enc, 5, False, False)
cls = test_model(X_random_3D)
print(cls.shape)
#[DEBUG] Input shape: torch.Size([1, 3, 32, 256, 256])
#[DEBUG] Ecoder-Input shape: torch.Size([32, 3, 256, 256])
#[DEBUG] Ecoder-Output shape: torch.Size([32, 3, 256, 256])
#torch.Size([1, 5])

[DEBUG] Input shape: torch.Size([1, 3, 32, 256, 256])
[DEBUG] Ecoder-Input shape: torch.Size([32, 3, 256, 256])
[DEBUG] Ecoder-Output shape: torch.Size([32, 3, 256, 256])
torch.Size([1, 5])


In [4]:
print(test_model.out_feats.shape)

torch.Size([1, 64, 32, 8, 8])


SwinTransformer(
  (features): Sequential(
    (0): Sequential(
      (0): Conv2d(3, 128, kernel_size=(4, 4), stride=(4, 4))
      (1): Permute()
      (2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
    )
    (1): Sequential(
      (0): SwinTransformerBlockV2(
        (norm1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
        (attn): ShiftedWindowAttentionV2(
          (qkv): Linear(in_features=128, out_features=384, bias=True)
          (proj): Linear(in_features=128, out_features=128, bias=True)
          (cpb_mlp): Sequential(
            (0): Linear(in_features=2, out_features=512, bias=True)
            (1): ReLU(inplace=True)
            (2): Linear(in_features=512, out_features=4, bias=False)
          )
        )
        (stochastic_depth): StochasticDepth(p=0.0, mode=row)
        (norm2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
        (mlp): MLP(
          (0): Linear(in_features=128, out_features=512, bias=True)
          (1): GELU(appro