In [None]:
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import h5py

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
from torch.utils.data import Dataset, DataLoader, random_split
import torch.optim as optim
from torch.cuda.amp import autocast
from torchvision import transforms
from torchvision.ops import sigmoid_focal_loss

from tqdm import tqdm
from PIL import Image

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

In [None]:
used_ids = ["28668", "28946", "29074", "29080"]

In [None]:
total_ids = list(range(1, 129))
random.seed(seed)
random.shuffle(total_ids)

train_size = int(0.7 * len(total_ids))
val_size = int(0.15 * len(total_ids))

train_ids = []
val_ids = []
test_ids = []
for ids in used_ids:
    train_ids += [f"emd_{ids}_{i}.png" for i in total_ids[:train_size]]
    val_ids += [f"emd_{ids}_{i}.png" for i in total_ids[train_size:train_size+val_size]]
    test_ids += [f"emd_{ids}_{i}.png" for i in total_ids[train_size+val_size:]]

In [None]:
class SegmentationDataset(Dataset):
    def __init__(self, image_dir, mask_dir, image_filenames, mask_filenames, transform=None):
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.image_filenames = image_filenames
        self.mask_filenames = mask_filenames
        self.transform = transform

    def __len__(self):
        return len(self.image_filenames)

    def __getitem__(self, idx):
        img_path = os.path.join(self.image_dir, self.image_filenames[idx])
        mask_path = os.path.join(self.mask_dir, "seg_" + self.mask_filenames[idx])

        image = Image.open(img_path).convert("L")
        mask = Image.open(mask_path).convert("L")

        mask = np.array(mask, dtype=np.float32) / 255.0
        mask = np.where(mask > 0.5, 1.0, 0.0)  # Binarize the mask

        image = torch.tensor(np.array(image, dtype=np.float32) / 255.0).unsqueeze(0)  # (1, H, W)
        mask = torch.tensor(mask).unsqueeze(0)

        bw_image = torch.cat([image, image, image], dim=0)
        
        mean = torch.tensor([0.485, 0.456, 0.406]).view(3,1,1)
        std = torch.tensor([0.229, 0.224, 0.225]).view(3,1,1)
        
        rgb_image = (bw_image - mean) / std

        rgb_image = rgb_image.unsqueeze(0)
        rgb_image_interp = F.interpolate(rgb_image, size=(448, 448), mode='bicubic', align_corners=False)
        rgb_image_interp = rgb_image_interp.squeeze(0)

        if self.transform:
            augmented = self.transform(image=image, mask=mask)
            image, mask = augmented["image"], augmented["mask"]

        return rgb_image_interp, mask

In [None]:
def get_loaders(image_dir, mask_dir, train_files, val_files, test_files, batch_size, transform=None):
    # Create datasets using the file lists
    train_dataset = SegmentationDataset(image_dir, mask_dir, train_files, train_files, transform=transform)
    val_dataset = SegmentationDataset(image_dir, mask_dir, val_files, val_files, transform=transform)
    test_dataset = SegmentationDataset(image_dir, mask_dir, test_files, test_files, transform=transform)

    # Create data loaders
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=1)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=1)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=1)

    return train_loader, val_loader, test_loader

In [None]:
train_loader, val_loader, test_loader = get_loaders(
    "/kaggle/input/cryovit-data/tomogram_images", 
    "/kaggle/input/cryovit-data/segmentation_mask_images",
    train_ids,
    val_ids,
    test_ids,
    batch_size=8
)

In [None]:
idx = 64
mask_path = f"/kaggle/input/cryovit-data/segmentation_mask_images/seg_emd_28668_{idx}.png"
mask = Image.open(mask_path)

img_path = f"/kaggle/input/cryovit-data/tomogram_images/emd_28668_{idx}.png"
img = Image.open(img_path)

plt.imshow(img, cmap="gray")
plt.imshow(mask, cmap="jet", alpha=0.5)
plt.axis("off")
plt.show()

In [None]:
image = torch.tensor(np.array(img, dtype=np.float32) / 255.0).unsqueeze(0)  # (1, H, W)
bw_image = torch.cat([image, image, image], dim=0)

mean = torch.tensor([0.485, 0.456, 0.406]).view(3,1,1)
std = torch.tensor([0.229, 0.224, 0.225]).view(3,1,1)

rgb_image = (bw_image - mean) / std
rgb_image = rgb_image.unsqueeze(0)
rgb_image_interp = F.interpolate(rgb_image, size=(448, 448), mode='bicubic', align_corners=False)
rgb_image_interp = rgb_image_interp.squeeze(0)

In [None]:
fig, axes = plt.subplots(1, 4, figsize=(20, 5))

axes[0].imshow(img, cmap="gray")
axes[0].set_title('Gray Channel')
axes[0].axis('off')  

axes[1].imshow(rgb_image_interp[0], cmap="Reds")
axes[1].set_title('Red Channel')
axes[1].axis('off')  

axes[2].imshow(rgb_image_interp[1], cmap="Greens")
axes[2].set_title('Green Channel')
axes[2].axis('off')  

axes[3].imshow(rgb_image_interp[2], cmap="Blues")
axes[3].set_title('Blue Channel')
axes[3].axis('off')  

plt.tight_layout()
plt.show()

In [None]:
class FrozenDinoV2Backbone(nn.Module):
    def __init__(self, original_model):
        super(FrozenDinoV2Backbone, self).__init__()
        self.backbone = original_model
        for param in self.backbone.parameters():
            param.requires_grad = False 
        self.patch_embed = original_model.patch_embed
        self.blocks = original_model.blocks

    def forward(self, x):
        x = self.patch_embed(x)
        for block in self.blocks:
            x = block(x)
        return x 

In [None]:
class SynthBlock(nn.Module):
    def __init__(self, c1, c2, c3, d1, d2):
        super().__init__()
        self.conv1 = nn.Conv3d(c1, c2, 3, padding="same", dilation=(d1, 1, 1))
        self.conv2 = nn.Conv3d(c2, c2, 3, padding="same", dilation=(d2, 1, 1))
        self.trans1 = nn.ConvTranspose3d(c2, c3, (1, 2, 2), stride=(1, 2, 2))
        self.gelu = nn.GELU()

    def forward(self, x):
        x = self.gelu(self.conv1(x))
        x = self.gelu(self.conv2(x))
        x = self.gelu(self.trans1(x))
        return x

class CryoVIT(nn.Module):
    def __init__(self) -> None:
        super(CryoVIT, self).__init__()
        self.conv1 = nn.Conv3d(1536, 1024, 1, padding="same")
        self.synth1 = SynthBlock(c1=1024, c2=192, c3=128, d1=32, d2=24)
        self.synth2 = SynthBlock(c1=128, c2=64, c3=32, d1=16, d2=12)
        self.synth3 = SynthBlock(c1=32, c2=32, c3=32, d1=8, d2=4)
        self.synth4 = SynthBlock(c1=32, c2=16, c3=8, d1=2, d2=1)
        self.conv2 = nn.Conv3d(8, 8, 3, padding="same")
        self.conv3 = nn.Conv3d(8, 1, 3, padding="same")
        self.gelu = nn.GELU()

    def forward(self, x):
        x = x.unsqueeze(0)  # Ensure input shape matches expectations
        x = self.gelu(self.conv1(x))
        x = self.synth1(x)
        x = self.synth2(x)
        x = self.synth3(x)
        x = self.synth4(x)
        x = self.gelu(self.conv2(x))
        x = self.conv3(x)  
        return x.squeeze()  

In [None]:
class DiceLoss(nn.Module):
    def __init__(self, smooth=1):
        super(DiceLoss, self).__init__()
        self.smooth = smooth

    def forward(self, inputs, targets):
        inputs = torch.sigmoid(inputs)       
        inputs = inputs.view(-1)
        targets = targets.view(-1)
        
        intersection = (inputs * targets).sum()                            
        dice = (2. * intersection + self.smooth) / (inputs.sum() + targets.sum() + self.smooth)  
        
        return 1 - dice

class FocalLoss(nn.Module):
    def __init__(self, gamma=2):
        super().__init__()
        self.gamma = gamma

    def forward(self, y_pred, y_true):
        weight = (y_true.numel() - y_true.sum()) / y_true.numel()
        return sigmoid_focal_loss(
            y_pred,
            y_true,
            alpha=weight,
            gamma=self.gamma,
            reduction="mean",
        )

class CombinedLoss(nn.Module):
    def __init__(self, gamma=2, weight_dice=1.0, weight_focal=1.0, smooth=1.0):
        super(CombinedLoss, self).__init__()
        self.dice_loss = DiceLoss(smooth=smooth)
        self.focal_loss = FocalLoss(gamma=gamma)
        self.weight_dice = weight_dice
        self.weight_focal = weight_focal

    def forward(self, inputs, targets):
        dice_loss_value = self.dice_loss(inputs, targets)
        focal_loss_value = self.focal_loss(inputs, targets)
        combined_loss = self.weight_dice * dice_loss_value + self.weight_focal * focal_loss_value
        return combined_loss

In [None]:
dino_model = torch.hub.load("facebookresearch/dinov2", "dinov2_vitg14", verbose=True).cuda()
dino_backbone = FrozenDinoV2Backbone(dino_model) 
del dino_model

In [None]:
cryo_vit_model = CryoVIT().to(device)

def initialize_weights(model):
    for name, param in model.named_parameters():
        if 'weight' in name:
            init.kaiming_normal_(param, mode='fan_in', nonlinearity='relu')
        elif 'bias' in name:
            init.zeros_(param)

# Initialize your model with He initialization
initialize_weights(cryo_vit_model)

In [None]:
criterion = CombinedLoss(weight_dice=1.0, weight_focal=1.0)
optimizer = optim.Adam(cryo_vit_model.parameters(), lr=0.001)

In [None]:
best_loss = float('inf')
loss_history = []

patience = 3  # Stop training if no improvement for 'patience' epochs
epochs_no_improve = 0  # Counter for early stopping
min_delta = 1e-4  # Minimum change to qualify as an improvement

epochs = 40  
for epoch in range(epochs):
    cryo_vit_model.train()
    running_loss = 0.0

    for images, masks in tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}"):
        images, masks = images.cuda(non_blocking=True), masks.cuda(non_blocking=True)

        optimizer.zero_grad()
        with torch.no_grad():
            features = dino_backbone(images) 
        features = features.reshape(-1, 32, 32, 1536).permute(3, 0, 1, 2)
        features = features.cuda(non_blocking=True)
        outputs = cryo_vit_model(features)  
        outputs = outputs.unsqueeze(1)  
        loss = criterion(outputs, masks)
        running_loss += loss.item()

        loss.backward()
        optimizer.step()

        del features, outputs
        torch.cuda.empty_cache()

    avg_loss = running_loss / len(train_loader)
    loss_history.append(avg_loss)
    print(f"Epoch {epoch+1}, Loss: {avg_loss:.4f}")

    # Validation phase
    cryo_vit_model.eval()
    val_loss = 0.0
    with torch.no_grad():
        for images, masks in val_loader:
            images, masks = images.cuda(non_blocking=True), masks.cuda(non_blocking=True)
            features = dino_backbone(images)
            features = features.reshape(-1, 32, 32, 1536).permute(3, 0, 1, 2)
            features = features.cuda(non_blocking=True)
            outputs = cryo_vit_model(features)
            outputs = outputs.unsqueeze(1)
            val_loss += criterion(outputs, masks).item()

    avg_val_loss = val_loss / len(val_loader)
    print(f"Epoch {epoch+1}, Validation Loss: {avg_val_loss:.4f}")

    # Early stopping check
    if avg_val_loss < best_loss - min_delta:
        best_loss = avg_val_loss
        torch.save(cryo_vit_model.state_dict(), "cryo_vit_model_best.pth")
        print(f"Model saved at epoch {epoch+1} with validation loss: {avg_val_loss:.4f}")
        epochs_no_improve = 0  # Reset counter
    else:
        epochs_no_improve += 1

    if epochs_no_improve >= patience:
        print(f"Early stopping triggered at epoch {epoch+1}.")
        break

In [None]:
plt.figure(figsize=(10, 6))
plt.plot(loss_history, label='Training Loss', color='b', linewidth=2)
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.title('Training Loss Over Epochs')
plt.grid(True)
plt.legend(loc='upper right')

plt.savefig('training_loss_plot.png', format='png')
plt.show()

In [None]:
checkpoint_path = "/kaggle/input/cryovit-trained/cryo_vit_model_best.pth"
cryo_vit_model.load_state_dict(torch.load(checkpoint_path, map_location="cuda"))
cryo_vit_model.eval()

In [None]:
running_loss = 0.0
correct = 0
total = 0

tot_output = []
tot_input = []

with torch.no_grad():
    for images, masks in tqdm(test_loader, desc="Inference"):
        images, masks = images.cuda(non_blocking=True), masks.cuda(non_blocking=True)

        # Extract features using the dino backbone
        with torch.no_grad():
            features = dino_backbone(images) 
        features = features.reshape(-1, 32, 32, 1536).permute(3, 0, 1, 2).contiguous()
        features = features.cuda(non_blocking=True)

        # Get the model's predictions
        outputs = cryo_vit_model(features) 
        outputs = outputs.unsqueeze(1)  
        temp = outputs
        
        tot_input.append(images)
        tot_output.append(temp)
        # Calculate the loss
        loss = criterion(outputs, masks)
        running_loss += loss.item()

        del features, outputs
        torch.cuda.empty_cache()

avg_loss = running_loss / len(test_loader)
print(f"Inference Loss: {avg_loss:.4f}")
tot_output = torch.cat(tot_output, dim=0)

In [None]:
# torch.save(tot_output, "inference_output.pt")

In [None]:
tot_input = torch.cat(tot_input, dim=0)
tot_output_resized = F.interpolate(tot_output, size=(448, 448), mode="bicubic", align_corners=False)

In [None]:
print(tot_input.shape)
print(tot_output_resized.shape)

In [None]:
idx = 4
plt.imshow(tot_input[idx].permute(1,2,0).cpu())
plt.imshow((tot_output_resized[idx][0] > 0.5).float().cpu(), alpha=0.5)
plt.show()