In [None]:
import os
import copy
import time
import json
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import h5py
import random

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
from torch.utils.data import Dataset, DataLoader, random_split
import torch.optim as optim
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts
from torch.cuda.amp import autocast, GradScaler
from torchvision import transforms
from torchvision.ops import sigmoid_focal_loss

from tqdm import tqdm
from PIL import Image
from itertools import cycle

seed = 42
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

In [None]:
used_ids = ["28668", "28946", "29074", "29080"]

In [None]:
total_ids = list(range(1, 129))
labelled_indices = [i for i in range(5,121,5)]
total_ids = [item for item in total_ids if item not in labelled_indices]

random.seed(seed)
random.shuffle(total_ids)

train_size = int(0.7 * len(total_ids)) - len(labelled_indices)
val_size = int(0.15 * len(total_ids))


labelled_ids = []
unlabelled_ids = []
val_ids = []
test_ids = []
for ids in used_ids:
    labelled_ids += [f"emd_{ids}_{i}.png" for i in range(5, 121, 5)]
    unlabelled_ids += [f"emd_{ids}_{i}.png" for i in total_ids[:train_size]]
    val_ids += [f"emd_{ids}_{i}.png" for i in total_ids[train_size:train_size+val_size]]
    test_ids += [f"emd_{ids}_{i}.png" for i in total_ids[train_size+val_size:]]

In [None]:
import torchvision.transforms.v2 as transforms

# Define augmentation transforms ensuring spatial transforms are applied jointly
aug_transform = transforms.Compose([
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomVerticalFlip(p=0.3),
    transforms.RandomRotation(degrees=20),
    # transforms.RandomAffine(
    #     degrees=0,
    #     translate=(0.1, 0.1),
    #     interpolation=transforms.InterpolationMode.NEAREST
    # ),
])

In [None]:
class SegmentationDataset(Dataset):
    def __init__(self, image_dir, mask_dir, image_filenames, mask_filenames, transform=None):
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.image_filenames = image_filenames
        self.mask_filenames = mask_filenames
        self.transform = transform

    def __len__(self):
        return len(self.image_filenames)

    def __getitem__(self, idx):
        img_path = os.path.join(self.image_dir, self.image_filenames[idx])
        image = Image.open(img_path).convert("L")

        image_tensor = torch.tensor(np.array(image, dtype=np.float32) / 255.0).unsqueeze(0)  # (1, H, W)
        mask_tensor = None
        if self.mask_dir is not None:
            mask_path = os.path.join(self.mask_dir, "seg_" + self.mask_filenames[idx])
            mask = Image.open(mask_path).convert("L") 
            mask = np.array(mask, dtype=np.float32) / 255.0
            mask = np.where(mask > 0.5, 1.0, 0.0)  # Binarize the mask
            mask_tensor = torch.tensor(mask).unsqueeze(0)
            if self.transform is not None:
                image_tensor, mask_tensor = self.transform(image_tensor, mask_tensor)
        else:
            if self.transform is not None:
                image_tensor = self.transform(image_tensor)

        bw_image = torch.cat([image_tensor, image_tensor, image_tensor], dim=0)
        
        mean = torch.tensor([0.485, 0.456, 0.406]).view(3,1,1)
        std = torch.tensor([0.229, 0.224, 0.225]).view(3,1,1)
        
        rgb_image = (bw_image - mean) / std

        rgb_image = rgb_image.unsqueeze(0)
        rgb_image_interp = F.interpolate(rgb_image, size=(448, 448), mode='bicubic', align_corners=False)
        rgb_image_interp = rgb_image_interp.squeeze(0)

        if self.mask_dir is not None:
            return rgb_image_interp, mask_tensor
        else:
            return rgb_image_interp

In [None]:
def get_loaders(image_dir, mask_dir, labelled_files, unlabelled_files, val_files, test_files, batch_size, aug_transform, transform=None):
    # Create datasets using the file lists
    train_labelled_dataset = SegmentationDataset(image_dir, mask_dir, labelled_files, labelled_files, transform=aug_transform)
    train_unlabelled_dataset = SegmentationDataset(image_dir, None, unlabelled_files, unlabelled_files, transform=aug_transform)
    val_dataset = SegmentationDataset(image_dir, mask_dir, val_files, val_files, transform=transform)
    test_dataset = SegmentationDataset(image_dir, mask_dir, test_files, test_files, transform=transform)

    # Create data loaders
    train_labelled_loader = DataLoader(train_labelled_dataset, batch_size=batch_size, shuffle=True, num_workers=1)
    train_unlabelled_loader = DataLoader(train_unlabelled_dataset, batch_size=batch_size, shuffle=True, num_workers=1)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=1)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=1)

    return train_labelled_loader, train_unlabelled_loader, val_loader, test_loader

In [None]:
train_labelled_loader, train_unlabelled_loader, val_loader, test_loader = get_loaders(
    "/kaggle/input/cryovit-data/tomogram_images", 
    "/kaggle/input/cryovit-data/segmentation_mask_images",
    labelled_ids,
    unlabelled_ids,
    val_ids,
    test_ids,
    batch_size=8,
    aug_transform=aug_transform
)

In [None]:
idx = 64
mask_path = f"/kaggle/input/cryovit-data/segmentation_mask_images/seg_emd_28668_{idx}.png"
mask = Image.open(mask_path)

img_path = f"/kaggle/input/cryovit-data/tomogram_images/emd_28668_{idx}.png"
img = Image.open(img_path)

plt.imshow(img, cmap="gray")
plt.imshow(mask, cmap="jet", alpha=0.5)
plt.axis("off")
plt.show()

In [None]:
image = torch.tensor(np.array(img, dtype=np.float32) / 255.0).unsqueeze(0)  # (1, H, W)
bw_image = torch.cat([image, image, image], dim=0)

mean = torch.tensor([0.485, 0.456, 0.406]).view(3,1,1)
std = torch.tensor([0.229, 0.224, 0.225]).view(3,1,1)

rgb_image = (bw_image - mean) / std
rgb_image = rgb_image.unsqueeze(0)
rgb_image_interp = F.interpolate(rgb_image, size=(448, 448), mode='bicubic', align_corners=False)
rgb_image_interp = rgb_image_interp.squeeze(0)

In [None]:
fig, axes = plt.subplots(1, 4, figsize=(20, 5))

axes[0].imshow(img, cmap="gray")
axes[0].set_title('Gray Channel')
axes[0].axis('off')  

axes[1].imshow(rgb_image_interp[0], cmap="Reds")
axes[1].set_title('Red Channel')
axes[1].axis('off')  

axes[2].imshow(rgb_image_interp[1], cmap="Greens")
axes[2].set_title('Green Channel')
axes[2].axis('off')  

axes[3].imshow(rgb_image_interp[2], cmap="Blues")
axes[3].set_title('Blue Channel')
axes[3].axis('off')  

plt.tight_layout()
plt.show()

In [None]:
class FrozenDinoV2Backbone(nn.Module):
    def __init__(self, original_model):
        super(FrozenDinoV2Backbone, self).__init__()
        self.backbone = original_model
        for param in self.backbone.parameters():
            param.requires_grad = False 
        self.patch_embed = original_model.patch_embed
        self.blocks = original_model.blocks

    def forward(self, x):
        x = self.patch_embed(x)
        for block in self.blocks:
            x = block(x)
        return x 

In [None]:
class SynthBlock(nn.Module):
    def __init__(self, c1, c2, c3, d1, d2):
        super().__init__()
        self.conv1 = nn.Conv3d(c1, c2, 3, padding="same", dilation=(d1, 1, 1))
        self.conv2 = nn.Conv3d(c2, c2, 3, padding="same", dilation=(d2, 1, 1))
        self.trans1 = nn.ConvTranspose3d(c2, c3, (1, 2, 2), stride=(1, 2, 2))
        self.gelu = nn.GELU()

    def forward(self, x):
        x = self.gelu(self.conv1(x))
        x = self.gelu(self.conv2(x))
        x = self.gelu(self.trans1(x))
        return x 

class CryoVIT(nn.Module):
    def __init__(self) -> None:
        super(CryoVIT, self).__init__()
        self.conv1 = nn.Conv3d(1536, 1024, 1, padding="same")
        self.synth1 = SynthBlock(c1=1024, c2=192, c3=128, d1=32, d2=24)
        self.synth2 = SynthBlock(c1=128, c2=64, c3=32, d1=16, d2=12)
        self.synth3 = SynthBlock(c1=32, c2=32, c3=32, d1=8, d2=4)
        self.synth4 = SynthBlock(c1=32, c2=16, c3=8, d1=2, d2=1)
        self.conv2 = nn.Conv3d(8, 8, 3, padding="same")
        self.conv3 = nn.Conv3d(8, 1, 3, padding="same")
        self.gelu = nn.GELU()

    def forward(self, x):
        x = x.unsqueeze(0)  # Ensure input shape matches expectations
        x = self.gelu(self.conv1(x))
        x = self.synth1(x)
        x = self.synth2(x)
        x = self.synth3(x)
        x = self.synth4(x)
        x = self.gelu(self.conv2(x))
        self.final_act = nn.Sigmoid()
        x = self.final_act(self.conv3(x))
        return x.squeeze()  

In [None]:
dino_model = torch.hub.load("facebookresearch/dinov2", "dinov2_vitg14", verbose=True).cuda()
dino_backbone = FrozenDinoV2Backbone(dino_model) 
del dino_model

In [None]:
cryo_vit_model = CryoVIT().to(device)

def initialize_weights(m):
    if isinstance(m, nn.Conv3d) or isinstance(m, nn.ConvTranspose3d):
        nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu', a=0.1)
    if hasattr(m, 'bias') and m.bias is not None:
        nn.init.constant_(m.bias, 0.01)  # Small nonzero bias

cryo_vit_model.apply(initialize_weights)

In [None]:
class DiceLoss(nn.Module):
    def __init__(self, smooth=1, conf=0.0, unlabelled=False):
        super(DiceLoss, self).__init__()
        self.smooth = smooth
        self.conf = conf
        self.unlabelled = unlabelled

    def forward(self, inputs, targets):
        if self.unlabelled:
            targets = torch.where(targets > 0.5, torch.ones_like(targets), torch.zeros_like(targets))
            
        inputs = torch.where(inputs > self.conf, inputs, torch.zeros_like(inputs))
        
        inputs = inputs.view(-1)
        targets = targets.view(-1)
        
        intersection = (inputs * targets).sum()                            
        dice = (2. * intersection + self.smooth) / (inputs.sum() + targets.sum() + self.smooth)  
        return 1 - dice

class FocalLoss(nn.Module):
    def __init__(self, gamma=2):
        super().__init__()
        self.gamma = gamma

    def forward(self, y_pred, y_true):
        y_pred = torch.clamp(y_pred, min=-10, max=10)  # Avoid extreme values
        weight = (1 - y_true.mean()).clamp(0.1, 0.9)   # Keep weight stable
        return sigmoid_focal_loss(
            y_pred,
            y_true,
            alpha=weight,
            gamma=self.gamma,
            reduction="mean",
        )

class CombinedLoss(nn.Module):
    def __init__(self, gamma=3, weight_dice=1.0, weight_focal=1.0, weight_bce=1.0, 
                 smooth=1.0, conf=0.0, unlabelled=False):
        super(CombinedLoss, self).__init__()
        self.dice_loss = DiceLoss(smooth=smooth, conf=conf, unlabelled=unlabelled)
        self.focal_loss = FocalLoss(gamma=gamma)
        self.bce_loss = nn.BCEWithLogitsLoss()  # New BCE loss
        self.weight_dice = weight_dice
        self.weight_focal = weight_focal
        self.weight_bce = weight_bce  # Weight for BCE
        self.unlabelled = unlabelled

    def forward(self, inputs, targets):
        # print(f"Input range: {inputs.min()} to {inputs.max()}")
        # print(f"Target range: {targets.min()} to {targets.max()}")
        inputs = inputs.unsqueeze(1)

        probs = inputs  # Compute once for DiceLoss
        # if self.unlabelled:
            # targets = torch.sigmoid(targets)
        
        dice = self.dice_loss(probs, targets)  # Use probabilities
        focal = self.focal_loss(inputs, targets)  # Uses logits
        bce = self.bce_loss(inputs, targets)  # Uses logits (sigmoid applied internally)

        # Handle NaNs/Infs (existing code)
        dice = torch.nan_to_num(dice, 0.0)
        focal = torch.nan_to_num(focal, 0.0)
        bce = torch.nan_to_num(bce, 0.0)

        return (self.weight_dice * dice + 
                self.weight_focal * focal + 
                self.weight_bce * bce)

In [None]:
class EMA:
    def __init__(self, model, decay=0.999):
        self.model = copy.deepcopy(model)
        self.model.eval()
        for param in self.model.parameters():
            param.requires_grad = False
        self.decay = decay

    def update(self, student_model):
        with torch.no_grad():
            for ema_param, student_param in zip(self.model.parameters(), student_model.parameters()):
                ema_param.data = self.decay * ema_param.data + (1 - self.decay) * student_param.data

def sharpen(probs, temperature=0.5):
    probs = probs ** (1 / temperature)
    return probs / probs.sum(dim=1, keepdim=True)

def train_model(cryo_vit_model, dino_backbone, train_labelled_loader, train_unlabelled_loader, val_loader, device, epochs=70, patience=10):
    criterion = CombinedLoss(weight_dice=1.0, weight_focal=1.0, weight_bce=1.0)
    criterion_unlabelled = CombinedLoss(weight_dice=1.0, weight_focal=1.0, weight_bce=1.0, conf=0.4, unlabelled=True)
    optimizer = optim.AdamW(cryo_vit_model.parameters(), lr=5e-3, weight_decay=3e-4)
    scheduler = CosineAnnealingWarmRestarts(
        optimizer,
        T_0=10,  # Number of iterations for the first restart
        T_mult=2,  # A factor increases T_i after a restart
        eta_min=1e-5  # Minimum learning rate
    )
    
    ema_model = EMA(cryo_vit_model)  # Initialize EMA model

    best_loss = float('inf')
    best_weights = None
    epochs_no_improve = 0
    history = {'train': [], 'val': []}

    for epoch in range(epochs):
        train_loss = train_epoch(cryo_vit_model, ema_model, dino_backbone, train_labelled_loader, train_unlabelled_loader, criterion, criterion_unlabelled, optimizer, device)
        val_loss = validate_epoch(cryo_vit_model, ema_model, dino_backbone, val_loader, criterion, device)
        scheduler.step(val_loss)
        
        history['train'].append(train_loss)
        history['val'].append(val_loss)

        print(f"Epoch {epoch+1}/{epochs}: Train Loss = {train_loss:.4f}, Val Loss = {val_loss:.4f}")

        if val_loss < best_loss:
            best_loss = val_loss
            best_weights = copy.deepcopy(cryo_vit_model.state_dict())
            epochs_no_improve = 0
            torch.save(cryo_vit_model.state_dict(), "best_cryo_vit_model.pth")
        else:
            epochs_no_improve += 1

        if epochs_no_improve >= patience:
            print(f"Early stopping at epoch {epoch+1} with val loss: {best_loss:.4f}")
            break

    if best_weights:
        cryo_vit_model.load_state_dict(best_weights)

    return history


def train_epoch(model, ema_model, dino_backbone, labelled_loader, unlabelled_loader, 
                criterion, criterion_unlabelled, optimizer, device):
    model.train()
    total_loss = 0.0
    
    labelled_iter = iter(labelled_loader)
    unlabelled_iter = iter(unlabelled_loader)
    
    total_iterations = min(len(labelled_loader), len(unlabelled_loader) // 2)
    for _ in tqdm(range(total_iterations), leave=False):
        try:
            # Process labelled data
            labelled_data = next(labelled_iter)
            labelled_images, labelled_masks = labelled_data
            labelled_images = labelled_images.to(device)
            labelled_masks = labelled_masks.to(device)
            
            with torch.no_grad():
                features_labelled = dino_backbone(labelled_images).reshape(-1, 32, 32, 1536).permute(3, 0, 1, 2)
            
            optimizer.zero_grad()
            outputs_labelled = model(features_labelled)
            loss_labelled = criterion(outputs_labelled, labelled_masks)
            
            # Process unlabelled data
            for _ in range(2):
                unlabelled_data = next(unlabelled_iter)
                unlabelled_images = unlabelled_data.to(device)
                
                with torch.no_grad():
                    features_unlabelled = dino_backbone(unlabelled_images).reshape(-1, 32, 32, 1536).permute(3, 0, 1, 2)
                    ema_outputs = ema_model.model(features_unlabelled)
                    pseudo_labels = sharpen(ema_outputs.softmax(dim=1), temperature=0.5).unsqueeze(1)
                
                outputs_unlabelled = model(features_unlabelled)
                loss_unlabelled = criterion_unlabelled(outputs_unlabelled, pseudo_labels)
                loss_labelled += loss_unlabelled
            
            loss_labelled.backward()
            optimizer.step()
            ema_model.update(model)
            
            total_loss += loss_labelled.item()
            
        except StopIteration:
            break
    
    return total_loss / total_iterations

def validate_epoch(model, ema_model, backbone, val_loader, criterion, device):
    ema_model.model.eval()  # Use EMA model for validation
    total_loss = 0.0
    
    with torch.no_grad():
        for images, masks in tqdm(val_loader, leave=False):
            images, masks = images.to(device), masks.to(device)
            
            features = backbone(images).reshape(-1, 32, 32, 1536).permute(3, 0, 1, 2)
            outputs = ema_model.model(features)  # Use EMA model
            loss = criterion(outputs, masks)
            
            total_loss += loss.item()
    
    return total_loss / len(val_loader)


In [None]:
history = train_model(cryo_vit_model, dino_backbone, train_labelled_loader, train_unlabelled_loader, val_loader, device)

In [None]:
with open("data.json", "w") as f:
    json.dump(history, f, indent=4)

In [None]:
with open("/kaggle/working/data.json", "r") as f:
    data = json.load(f)

fig, ax1 = plt.subplots()

ax1.set_xlabel("Epochs")
ax1.set_ylabel("Train Loss", color="tab:blue")
ax1.plot(data["train"], label="Train Loss", color="tab:blue")
ax1.tick_params(axis="y", labelcolor="tab:blue")

ax2 = ax1.twinx()
ax2.set_ylabel("Validation Loss", color="tab:red")
ax2.plot(data["val"], label="Validation Loss", color="tab:red")
ax2.tick_params(axis="y", labelcolor="tab:orange")

fig.suptitle("Loss Curve")
fig.tight_layout()
plt.show()

In [None]:
criterion = CombinedLoss(weight_dice=1.0, weight_focal=0.0, weight_bce=0.0)
# optimizer = optim.Adam(cryo_vit_model.parameters(), lr=0.001)

In [None]:
checkpoint_path = "/kaggle/working/best_cryo_vit_model.pth"
cryo_vit_model.load_state_dict(torch.load(checkpoint_path, map_location="cuda"))
cryo_vit_model.eval()

In [None]:
def calculate_iou(mask_gt, mask_pred):
    # Ensure masks are binary (0 or 1)
    mask_gt = (mask_gt > 0).to(torch.uint8)
    mask_pred = (mask_pred > 0).to(torch.uint8)

    intersection = (mask_gt & mask_pred).sum().float()
    union = (mask_gt | mask_pred).sum().float()

    return intersection / union if union > 0 else torch.tensor(0.0)

def calculate_mean_iou(masks_gt, masks_pred):
    num_images = len(masks_gt)
    total_iou = 0.0

    for i in range(num_images):
        iou = calculate_iou(masks_gt[i], masks_pred[i])
        total_iou += iou

    return total_iou / num_images

In [None]:
running_loss = 0.0
correct = 0
total = 0

tot_output = []
tot_gt = []

with torch.no_grad():
    for images, masks in tqdm(test_loader, desc="Inference"):
        images, masks = images.cuda(non_blocking=True), masks.cuda(non_blocking=True)

        # Extract features using the dino backbone
        with torch.no_grad():
            features = dino_backbone(images) 
        features = features.reshape(-1, 32, 32, 1536).permute(3, 0, 1, 2).contiguous()
        features = features.cuda(non_blocking=True)

        # Get the model's predictions
        outputs = cryo_vit_model(features) 
        # outputs = outputs.unsqueeze(1)  
        temp = outputs
        
        tot_output.append(temp)
        tot_gt.append(masks)
        
        # Calculate the loss
        loss = criterion(outputs, masks)
        running_loss += loss.item()

        del features, outputs
        torch.cuda.empty_cache()

avg_loss = running_loss / len(test_loader)
iou = calculate_mean_iou(tot_gt, tot_output)

print(f"Inference Loss: {avg_loss:.4f}")
print(f"Dice Score: {1 - avg_loss:.4f}")
print(f"IoU: {iou:.4f}")

In [None]:
transform = transforms.Compose([
    transforms.ToTensor()  
])

def load_images_as_tensors(test_ids, image_dir):
    tot_input = []
    for file_name in test_ids:
        image_path = os.path.join(image_dir, file_name)
        image = Image.open(image_path).convert("RGB")
        image_tensor = transform(image)
        tot_input.append(image_tensor)
    return tot_input

tot_input = load_images_as_tensors(test_ids, "/kaggle/input/cryovit-data/tomogram_images")

In [None]:
# tot_input_cat = torch.cat(tot_input, dim=0)

tot_output_cat = torch.cat(tot_output, dim=0)
# tot_output_resized = F.interpolate(tot_output_cat, size=(448, 448), mode="bicubic", align_corners=False)
# tot_output_resized = tot_output_resized / tot_output_resized.max()  # Normalization

tot_gt_cat = torch.cat(tot_gt, dim=0)
# tot_gt_resized = F.interpolate(tot_gt_cat, size=(448, 448), mode="bicubic", align_corners=False)
# tot_gt_resized = tot_gt_resized / tot_gt_resized.max()  # Normalization

In [None]:
tot_gt_cat.shape

In [None]:
idx = 3
fig, ax = plt.subplots(1, 2, figsize=(10, 5))

# Input with overlayed output
ax[0].imshow(tot_input[idx].permute(1, 2, 0).cpu())
ax[0].imshow((tot_output_cat[idx] > 0.5).cpu(), alpha=0.5, cmap='jet')
ax[0].set_title("Predicted")
ax[0].axis("off")

# Ground truth
ax[1].imshow(tot_input[idx].permute(1, 2, 0).cpu())
ax[1].imshow(tot_gt_cat[idx][0].cpu(), alpha=0.5, cmap="jet")
ax[1].set_title("Ground Truth")
ax[1].axis("off")

plt.show()