### This notebook is designed to run on collab (for GPU). 

It contains training of Cycle GANs and Multi-Cycle GAN, DINOV2 features extraction and potential fine-tunnig, MedImageInsight features extraction and training + inference on the test set with the classifier.

In [None]:
!git clone https://github.com/MANY09F4/kaggle-DL-MI.git

In [None]:
# We get the data from google drive
from google.colab import drive
drive.mount('/content/drive')

In [None]:
!pwd

In [None]:
!ls

In [None]:
!pip install -r kaggle-DL-MI/requirements.txt

In [None]:
%cd /content/kaggle-DL-MI

## Data cleaning

In [None]:
import h5py
import torch

# List to store IDs of aberrant (corrupted) images in the training set
list_aberrant_ids_train = []

# Open the training HDF5 file
with h5py.File('/content/drive/MyDrive/kaggle-DL-MI/data/train.h5', 'r') as f:
    img_ids = list(f.keys())  # Get all image IDs

    # Iterate through all image IDs
    for img_id in img_ids:
        # Load the image as a float tensor
        img = torch.tensor(f[img_id]['img'][()]).float()

        # Count how many pixels are exactly zero
        if (img == 0).sum().item() > 200:
            # If too many black pixels, mark image as aberrant
            print(f"Aberrant image found at ID: {img_id} with 0s in the image {(img == 0).sum().item()}")
            list_aberrant_ids_train.append(img_id)

# Summary: number of aberrant images detected
print(f"Number of aberrant images in train set: {len(list_aberrant_ids_train)}")


In [None]:
# List to store IDs of aberrant (corrupted) images in the validation set
list_aberrant_ids_val = []

# Open the validation HDF5 file
with h5py.File('/content/drive/MyDrive/kaggle-DL-MI/data/val.h5', 'r') as f:
    img_ids = list(f.keys())  # Get all image IDs

    # Iterate through all images
    for img_id in img_ids:
        # Load the image as a float tensor
        img = torch.tensor(f[img_id]['img'][()]).float()

        # Check if the image contains more than 200 black pixels (value = 0)
        if (img == 0).sum().item() > 200:
            print(f"Aberrant image found at ID: {img_id} with 0s in the image")
            list_aberrant_ids_val.append(img_id)

# Print the total number of aberrant images detected in validation set
print(f"Number of aberrant images in val set: {len(list_aberrant_ids_val)}")


In [None]:
# List to store IDs of aberrant (corrupted) images in the test set
list_aberrant_ids_test = []

# Open the test HDF5 file
with h5py.File('/content/drive/MyDrive/kaggle-DL-MI/data/test.h5', 'r') as f:
    img_ids = list(f.keys())  # Get all image IDs

    # Iterate through all images
    for img_id in img_ids:
        # Load the image as a float tensor
        img = torch.tensor(f[img_id]['img'][()]).float()

        # Check if the image contains more than 200 black pixels (value = 0)
        if (img == 0).sum().item() > 200:
            print(f"Aberrant image found at ID: {img_id} with 0s in the image")
            list_aberrant_ids_test.append(img_id)

# Print the total number of aberrant images detected in test set
print(f"Number of aberrant images in test set: {len(list_aberrant_ids_test)}")


In [None]:
# Combine the aberrant image IDs from both training and validation sets
list_aberrant_ids = list_aberrant_ids_train + list_aberrant_ids_val

# Print the total number of unique aberrant images from both sets
len(list_aberrant_ids)


In [None]:
# Convert the lists of aberrant image IDs (training and val sets) into a comma-separated strings
str_aberrant_ids_train = ",".join(map(str, list_aberrant_ids_train))
str_aberrant_ids_val = ",".join(map(str, list_aberrant_ids_val))

## Mutli-Cycle GAN

In [None]:
# Launch training of the MultiStain-CycleGAN model with all source centers
# Source domains = train + val / Target = test
# No color augmentation is applied (values = 0)
# Discriminator thresholding is enabled (D_thresh = 0.1)
# Aberrant images from train/val are excluded

!python -m CycleGAN.train_CycleGAN \
  --train_path "/content/drive/MyDrive/kaggle-DL-MI/data/train.h5" \
  --val_path "/content/drive/MyDrive/kaggle-DL-MI/data/val.h5" \
  --test_path "/content/drive/MyDrive/kaggle-DL-MI/data/test.h5" \
  --name multistain_run_all_domains \
  --batch_size 64 \
  --gpu_ids 0 \
  --n_epochs 7 \
  --n_epochs_decay 3 \
  --lr_G 0.0002 \
  --lr_D 0.0002 \
  --save_epoch_freq 2 \
  --display_id 0 \
  --lambda_A 10.0 \
  --lambda_B 10.0 \
  --D_thresh \
  --D_thresh_value 0.1 \
  --lambda_identity 0.5 \
  --gan_mode lsgan \
  --color_augment \
  --brightness 0 \
  --contrast 0 \
  --saturation 0 \
  --hue 0 \
  --aberrant_ids_train "$str_aberrant_ids_train" \
  --aberrant_ids_val "$str_aberrant_ids_val"


In [None]:
# Zip and download model checkpoints from Colab
# Compress the 'checkpoints' directory (containing the saved GAN weights)
!zip -r /content/checkpoints_multi_domains.zip /content/kaggle-DL-MI/checkpoints

# Trigger download to your local machine
from google.colab import files
files.download('/content/checkpoints_multi_domains.zip')


## Unique-Cycle GAN

In [None]:
# Launch training of the Unique Cycle GAN for center 0
# Source domain = 0 / Target = test
# No color augmentation is applied (values = 0)
# Discriminator thresholding is enabled (D_thresh = 0.1)
# Aberrant images from train/val are excluded

!python -m CycleGAN.train_CycleGAN \
  --train_path "/content/drive/MyDrive/kaggle-DL-MI/data/train.h5" \
  --val_path "/content/drive/MyDrive/kaggle-DL-MI/data/val.h5" \
  --test_path "/content/drive/MyDrive/kaggle-DL-MI/data/test.h5" \
  --name test_run_domain_0 \
  --batch_size 64 \
  --gpu_ids 0 \
  --n_epochs 7 \
  --n_epochs_decay 3 \
  --lr_G 0.0002 \
  --lr_D 0.0002 \
  --save_epoch_freq 2 \
  --display_id 0 \
  --lambda_A 10.0 \
  --lambda_B 10.0 \
  --D_thresh \
  --D_thresh_value 0.1 \
  --lambda_identity 0.5 \
  --gan_mode lsgan \
  --domain 0 \
  --color_augment \
  --brightness 0.0 \
  --contrast 0.0 \
  --saturation 0.0 \
  --hue 0.0 \
  --aberrant_ids_train "$str_aberrant_ids_train" \
  --aberrant_ids_val "$str_aberrant_ids_val"


In [None]:
# Center 3

!python -m CycleGAN.train_CycleGAN \
  --train_path "/content/drive/MyDrive/kaggle-DL-MI/data/train.h5" \
  --val_path "/content/drive/MyDrive/kaggle-DL-MI/data/val.h5" \
  --test_path "/content/drive/MyDrive/kaggle-DL-MI/data/test.h5" \
  --name test_run_domain_3 \
  --batch_size 64 \
  --gpu_ids 0 \
  --n_epochs 7 \
  --n_epochs_decay 3 \
  --lr_G 0.0002 \
  --lr_D 0.0002 \
  --save_epoch_freq 2 \
  --display_id 0 \
  --lambda_A 10.0 \
  --lambda_B 10.0 \
  --D_thresh \
  --D_thresh_value 0.1 \
  --lambda_identity 0.5 \
  --gan_mode lsgan \
  --domain 3 \
  --color_augment \
  --brightness 0.0 \
  --contrast 0.0 \
  --saturation 0.0 \
  --hue 0.0 \
  --aberrant_ids_train "$str_aberrant_ids_train" \
  --aberrant_ids_val "$str_aberrant_ids_val"

In [None]:
# Center 4

!python -m CycleGAN.train_CycleGAN \
  --train_path "/content/drive/MyDrive/kaggle-DL-MI/data/train.h5" \
  --val_path "/content/drive/MyDrive/kaggle-DL-MI/data/val.h5" \
  --test_path "/content/drive/MyDrive/kaggle-DL-MI/data/test.h5" \
  --name test_run_domain_4 \
  --batch_size 64 \
  --gpu_ids 0 \
  --n_epochs 7 \
  --n_epochs_decay 3 \
  --lr_G 0.0002 \
  --lr_D 0.0002 \
  --save_epoch_freq 2 \
  --display_id 0 \
  --lambda_A 10.0 \
  --lambda_B 10.0 \
  --D_thresh \
  --D_thresh_value 0.1 \
  --lambda_identity 0.5 \
  --gan_mode lsgan \
  --domain 4 \
  --color_augment \
  --brightness 0.0 \
  --contrast 0.0 \
  --saturation 0.0 \
  --hue 0.0 \
  --aberrant_ids_train "$str_aberrant_ids_train" \
  --aberrant_ids_val "$str_aberrant_ids_val"

In [None]:
# Center 1

!python -m CycleGAN.train_CycleGAN \
  --train_path "/content/drive/MyDrive/kaggle-DL-MI/data/train.h5" \
  --val_path "/content/drive/MyDrive/kaggle-DL-MI/data/val.h5" \
  --test_path "/content/drive/MyDrive/kaggle-DL-MI/data/test.h5" \
  --name test_run_domain_1 \
  --batch_size 64 \
  --gpu_ids 0 \
  --n_epochs 7 \
  --n_epochs_decay 3 \
  --lr_G 0.0002 \
  --lr_D 0.0002 \
  --save_epoch_freq 2 \
  --display_id 0 \
  --lambda_A 10.0 \
  --lambda_B 10.0 \
  --D_thresh \
  --D_thresh_value 0.1 \
  --lambda_identity 0.5 \
  --gan_mode lsgan \
  --domain 1 \
  --color_augment \
  --brightness 0.0 \
  --contrast 0.0 \
  --saturation 0.0 \
  --hue 0.0 \
  --aberrant_ids_train "$str_aberrant_ids_train" \
  --aberrant_ids_val "$str_aberrant_ids_val"

In [None]:
# Zip and download model checkpoints from Colab
# Compress the 'checkpoints' directory (containing the 4 saved GAN weights)
!zip -r /content/checkpoints_all_domains.zip /content/kaggle-DL-MI/checkpoints
from google.colab import files
files.download('/content/checkpoints_all_domains.zip')


## GANs visualisation 

In [None]:
# Load and visualize a specific image from train.h5 or val.h5
# Converts the image if necessary to [C, H, W] format
# Displays the image and prints metadata (center index, min/max values)

import h5py
import torch
import matplotlib.pyplot as plt
import numpy as np

# Load an image from val.h5 (domain A)
h5_path = "/content/drive/MyDrive/kaggle-DL-MI/data/val.h5"
index = 600  # Index of the image to inspect

with h5py.File(h5_path, 'r') as f:
    key = list(f.keys())[index]
    img = torch.tensor(f[key]['img'][()])  # [H, W, C] or [3, H, W]
    center_index = np.array(f[key]['metadata'])[0]

# Convert to [C, H, W] if needed
if img.ndim == 3 and img.shape[-1] == 3:
    img = img.permute(2, 0, 1)
    print("3 en dernier")

# Convert to float32 and [H, W, C] for matplotlib display
img_np = img.permute(1, 2, 0).float().numpy()
print(img_np.shape)
print(np.max(img_np), np.min(img_np))
print(center_index)

# Display the image
plt.imshow(img_np)
plt.title(f"Image {key}")
plt.axis("off")
plt.show()


In [None]:
# Load pre-trained generators for image normalization (CycleGAN)
# net_GA: multi-domain generator (MultiStain-CycleGAN trained on all centers)
# gen_centerX: unique CycleGANs trained per center (0, 1, 3, 4)
# Each generator is loaded from its corresponding checkpoint and set to eval mode
# You need to import the weights downloaded before and place them in /content/kaggle-DL-MI/ with the right naming
# Renaming the unique domain Cycle GANs weights is necessary ex : netG_A_0_epoch10.pth for center 0

from CycleGAN import networks

# Multi-domain generator (trained on all source domains → target domain)
net_GA = networks.define_G(3, 3, 64, 'resnet_9blocks', 'instance', True, "normal", 0.02, [0])
state_dict = torch.load("/content/kaggle-DL-MI/netG_A_epoch10.pth", map_location='cpu')
net_GA.load_state_dict(state_dict)
net_GA.eval()

# Unique CycleGAN generator for center 0
gen_center0 = networks.define_G(3, 3, 64, 'resnet_9blocks', 'instance', True, "normal", 0.02, [0])
state_dict0 = torch.load("/content/kaggle-DL-MI/netG_A_0_epoch10.pth", map_location='cpu')
gen_center0.load_state_dict(state_dict0)
gen_center0.eval()

# Unique CycleGAN generator for center 1
gen_center1 = networks.define_G(3, 3, 64, 'resnet_9blocks', 'instance', True, "normal", 0.02, [0])
state_dict1 = torch.load("/content/kaggle-DL-MI/netG_A_1_epoch10.pth", map_location='cpu')
gen_center1.load_state_dict(state_dict1)
gen_center1.eval()

# Unique CycleGAN generator for center 3
gen_center3 = networks.define_G(3, 3, 64, 'resnet_9blocks', 'instance', True, "normal", 0.02, [0])
state_dict3 = torch.load("/content/kaggle-DL-MI/netG_A_3_epoch10.pth", map_location='cpu')
gen_center3.load_state_dict(state_dict3)
gen_center3.eval()

# Unique CycleGAN generator for center 4
gen_center4 = networks.define_G(3, 3, 64, 'resnet_9blocks', 'instance', True, "normal", 0.02, [0])
state_dict4 = torch.load("/content/kaggle-DL-MI/netG_A_4_epoch10.pth", map_location='cpu')
gen_center4.load_state_dict(state_dict4)
gen_center4.eval()


In [None]:
# Apply MultiStain-CycleGAN generator (net_GA) to a sample image from domain A
# The input image is first normalized to [-1, 1] before being passed through the generator
# The output is then rescaled to [0, 255] and displayed with matplotlib

import numpy as np

# Normalize image to [-1, 1] depending on its original scale
if img.max() <= 1.0:
    img = img.float() * 2.0 - 1.0
    print("normal -1 1")
else:
    img = img.float() / 127.5 - 1.0

# Add batch dimension: [1, C, H, W]
img_input = img.unsqueeze(0)

# Apply the generator to translate the image to the target domain
with torch.no_grad():
    fake_B = net_GA(img_input)

# Print output shape and value range
print("fake_B shape:", fake_B.shape)
print("fake_B range: min =", fake_B.min().item(), ", max =", fake_B.max().item())

# Convert the output to [H, W, C] in uint8 for visualization
fake_B_np = ((fake_B.squeeze().cpu().numpy().transpose(1, 2, 0) + 1) / 2.0 * 255.0).astype(np.uint8)

# Display the generated image
plt.imshow(fake_B_np)
plt.title("Image générée par Multi-Stain Cycle GAN")
plt.axis("off")
plt.show()


In [None]:
# Apply the unique CycleGAN generator corresponding to the image center
# Normalize input image to [-1, 1] and apply the appropriate generator from the dictionary
# Output image is rescaled to [0, 255] and displayed

import numpy as np
import matplotlib.pyplot as plt
import h5py
import torch

# Dictionary containing trained generators for each center
generators_dict = {
    0: gen_center0,
    1: gen_center1,
    3: gen_center3,
    4: gen_center4,
}

# Ensure image is in [C, H, W] format
if img.ndim == 3 and img.shape[-1] == 3:
    img = img.permute(2, 0, 1)

# Normalize pixel values to [-1, 1]
if img.max() <= 1.0:
    img = img * 2.0 - 1.0
    print("normalized [0, 1] → [-1, 1]")
else:
    img = img / 127.5 - 1.0
    print("normalized [0, 255] → [-1, 1]")

# Select the generator based on image center index
generator = generators_dict.get(center_index)
if generator is None:
    raise ValueError(f"No generator found for center {center_index}")

# Add batch dimension
img_input = img.unsqueeze(0)

# Generate normalized image with selected generator
with torch.no_grad():
    fake_B_unique = generator(img_input)

# Output information
print("fake_B shape:", fake_B_unique.shape)
print("fake_B range: min =", fake_B_unique.min().item(), ", max =", fake_B_unique.max().item())

# Convert to [H, W, C] and scale to [0, 255] for display
fake_B_np_unique = ((fake_B_unique.squeeze().cpu().numpy().transpose(1, 2, 0) + 1) / 2.0 * 255.0).astype(np.uint8)

# Display the generated image
plt.imshow(fake_B_np_unique)
plt.title(f"Generated image for center {center_index} using unique CycleGAN")
plt.axis("off")
plt.show()


In [None]:
# Display side-by-side comparison of:
# (1) the original image from domain A (source center),
# (2) the image translated by the Multi-domain CycleGAN,
# (3) the image translated by the unique (per-center) CycleGAN

import matplotlib.pyplot as plt

fig, axes = plt.subplots(1, 3, figsize=(10, 5))

# Original image from source domain
axes[0].imshow(img_np)
axes[0].set_title("Original image (Domain A)")
axes[0].axis("off")

# Translated image using MultiStain-CycleGAN
axes[1].imshow(fake_B_np)
axes[1].set_title("Generated image (Domain B)\nMulti-domain CycleGAN")
axes[1].axis("off")

# Translated image using unique per-center CycleGAN
axes[2].imshow(fake_B_np_unique)
axes[2].set_title("Generated image (Domain B)\nUnique CycleGAN")
axes[2].axis("off")

plt.tight_layout()
plt.show()


## DinoV2 framework

### Features extraction test

In [None]:
# Load the DINOv2 ViT-S/14 feature extractor from the official repository via torch.hub
# The model is moved to GPU ("cuda") for faster inference
feature_extractor = torch.hub.load('facebookresearch/dinov2', 'dinov2_vits14').to("cuda")


In [None]:
# Feature extraction from transformed image test

from torchvision import transforms

# Remove batch dimension → shape becomes [3, 96, 96]
fake_B = fake_B.squeeze(0)

# Rescale image pixel values from [-1, 1] to [0, 1]
fake_B = (fake_B + 1) / 2.0

# Resize image to 98x98 (required by DINOv2 which expects dimensions multiple of 14)
transform = transforms.Compose([
    transforms.Resize((98, 98))
])
fake_B = transform(fake_B)

# Add back batch dimension → shape becomes [1, 3, 98, 98]
fake_B = fake_B.unsqueeze(0)

# Set the DINOv2 model to evaluation mode
feature_extractor.eval()

# Extract features without gradient tracking
with torch.no_grad():
    features = feature_extractor(fake_B)

# Output feature shape (should be [1, 384] for ViT-S/14)
features.shape


### Creation of Dataloaders 

In [None]:
from torchvision import transforms
import h5py
import torch
import numpy as np
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
import random

# Define the IDs of aberrant images that should be excluded from training/validation
aberrant_train_ids = list_aberrant_ids_train
aberrant_val_ids = list_aberrant_ids_val

# Data augmentation for training images (can help reduce overfitting)
transform_train = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomVerticalFlip(),
    transforms.RandomRotation(15),
    # transforms.ColorJitter(...)  # Optional: apply color jittering
])

# Custom dataset class that loads and optionally transforms GAN-normalized images
class H5UnalignedDataset(Dataset):
    def __init__(self, h5_path, transform=None, aberrant_ids_train=None, aberrant_ids_val=None,
                 net_GA=None, generators=None, multi_gens=False, train=True):
        super().__init__()
        self.h5_path = h5_path
        self.transform = transform
        self.aberrant_ids_train = aberrant_ids_train or []
        self.aberrant_ids_val = aberrant_ids_val or []
        self.net_GA = net_GA  # Single GAN generator (MultiStain-CycleGAN)
        self.generators = generators  # Dict of generators (one per center) if multi_gens=True
        self.multi_gens = multi_gens
        self.train = train

        random.seed(42)

        # Load all image keys from the h5 file
        with h5py.File(self.h5_path, 'r') as f:
            self.img_ids = list(f.keys())

        # Exclude aberrant image IDs
        if self.train:
            self.img_ids = [img_id for img_id in self.img_ids if img_id not in self.aberrant_ids_train]
        else:
            self.img_ids = [img_id for img_id in self.img_ids if img_id not in self.aberrant_ids_val]

    def __len__(self):
        return len(self.img_ids)

    def __getitem__(self, idx):
        img_id = self.img_ids[idx]
        with h5py.File(self.h5_path, 'r') as f:
            img = torch.tensor(f[img_id]['img'][()]).float()
            label = np.array(f[img_id].get("label"))  # label can be None for test set

            # Extract center ID from metadata
            metadata = f[img_id]['metadata']
            center_id = int(np.array(metadata)[0])

        # Apply data augmentation (if any)
        if self.transform:
            img = self.transform(img)

        # Normalize to [-1, 1] before GAN input
        img = img * 2.0 - 1.0

        # Apply GAN generator (either single or by-center)
        if self.multi_gens:
            generator = self.generators.get(center_id)
            if generator is not None:
                img = generator(img.unsqueeze(0)).squeeze(0)
            else:
                raise ValueError(f"No generator found for center {center_id}")
        elif self.net_GA is not None:
            img = self.net_GA(img.unsqueeze(0)).squeeze(0)

        # Rescale back to [0, 1]
        img = (img + 1) / 2.0

        return img, label

# Create training dataset using MultiStain-CycleGAN here
train_dataset = H5UnalignedDataset(
    h5_path="/content/drive/MyDrive/kaggle-DL-MI/data/train.h5",
    transform=transform_train,
    aberrant_ids_train=aberrant_train_ids,
    aberrant_ids_val=aberrant_val_ids,
    multi_gens=False,
    #generators=generators_dict,
    net_GA=net_GA,
    train=True
)

# Create validation dataset (no augmentation)
val_dataset = H5UnalignedDataset(
    h5_path="/content/drive/MyDrive/kaggle-DL-MI/data/val.h5",
    transform=None,
    aberrant_ids_train=aberrant_train_ids,
    aberrant_ids_val=aberrant_val_ids,
    multi_gens=False,
    #generators=generators_dict,
    net_GA=net_GA,
    train=False
)

# Create DataLoaders for training and validation
train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=32, shuffle=False)

print(len(train_dataset), len(val_dataset))

### Data visualization from datasets

In [None]:
import matplotlib.pyplot as plt
import numpy as np

# Function to visualize a few images from a dataset (e.g., train or val)
# The function retrieves 'num_images' samples and displays them in a row.
def show_images_from_dataset(dataset, dataset_name="Train", num_images=5):
    fig, axes = plt.subplots(1, num_images, figsize=(15, 5))  # Create a row of subplots
    for i in range(num_images):
        img, label = dataset[i]  # Get the image and its label

        # Convert tensor image to numpy format [C,H,W] → [H,W,C]
        img = img.detach().cpu().numpy().transpose((1, 2, 0))

        # Plot image
        axes[i].imshow(img)
        axes[i].set_title(f"{dataset_name} Image {i+1} - Label: {label}")
        axes[i].axis('off')

    plt.show()  # Display the figure

# Display 5 images from the training and validation datasets
show_images_from_dataset(train_dataset, "Train")
show_images_from_dataset(val_dataset, "Val")


### Fine-tunnig of DinoV2 and training + validation of classifier

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
import numpy as np
import h5py
from torchmetrics.classification import BinaryAccuracy
from tqdm import tqdm

device = "cuda"  # Use GPU if available

# Resize images to 98x98 (required by DINOv2)
transform = transforms.Compose([
    transforms.Resize((98, 98)),
])

# Load the pre-trained DINOv2 model
dino_model = torch.hub.load('facebookresearch/dinov2', 'dinov2_vits14').to(device)

# Freeze all layers except the last 2 transformer blocks + normalization layers
for name, param in dino_model.named_parameters():
    if not ("blocks.10" in name or "blocks.11" in name or "norm" in name or "head" in name):
        param.requires_grad = False

# Define the MLP classifier used after DINOv2 features
class Classifier(nn.Module):
    def __init__(self, input_size, output_size):
        super(Classifier, self).__init__()
        self.fc1 = nn.Linear(input_size, 512)
        self.bn1 = nn.BatchNorm1d(512)
        self.fc2 = nn.Linear(512, 256)
        self.bn2 = nn.BatchNorm1d(256)
        self.fc3 = nn.Linear(256, 128)
        self.bn3 = nn.BatchNorm1d(128)
        self.fc4 = nn.Linear(128, output_size)
        self.dropout = nn.Dropout(p=0.5)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        x = torch.relu(self.bn1(self.fc1(x)))
        x = self.dropout(x)
        x = torch.relu(self.bn2(self.fc2(x)))
        x = self.dropout(x)
        x = torch.relu(self.bn3(self.fc3(x)))
        x = self.fc4(x)
        x = self.sigmoid(x)
        return x.view(-1)

feature_dim = 384  # DINOv2_vits14 output feature dimension

# Instantiate the classifier
classifier = Classifier(input_size=feature_dim, output_size=1).to(device)

# Combine trainable DINO layers + classifier weights
params_to_optimize = list(filter(lambda p: p.requires_grad, dino_model.parameters())) + list(classifier.parameters())

# Training setup
optimizer = optim.Adam(params_to_optimize, lr=1e-4, weight_decay=1e-5)
criterion = nn.BCELoss()
accuracy_metric = BinaryAccuracy().to(device)

best_val_loss = float('inf')
patience = 5
counter = 0
NUM_EPOCHS = 50

# Training loop
for epoch in range(NUM_EPOCHS):
    dino_model.train()
    classifier.train()
    total_loss, total_acc = 0, 0

    for imgs, labels in tqdm(train_dataloader, desc=f"Epoch {epoch+1}/{NUM_EPOCHS}"):
        imgs, labels = imgs.to(device), labels.float().to(device)
        imgs = transform(imgs)

        features = dino_model(imgs)
        preds = classifier(features)

        loss = criterion(preds, labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        acc = accuracy_metric(preds > 0.5, labels.int())
        total_loss += loss.item() * imgs.size(0)
        total_acc += acc.item() * imgs.size(0)

    avg_loss = total_loss / len(train_dataloader.dataset)
    avg_acc = total_acc / len(train_dataloader.dataset)
    print(f"Train Loss: {avg_loss:.4f} | Accuracy: {avg_acc:.4f}")

    # Validation phase
    dino_model.eval()
    classifier.eval()
    val_loss, val_acc = 0, 0

    with torch.no_grad():
        for imgs, labels in tqdm(val_dataloader, desc="Val Phase"):
            imgs, labels = imgs.to(device), labels.float().to(device)
            imgs = transform(imgs)

            features = dino_model(imgs)
            preds = classifier(features)

            loss = criterion(preds, labels)
            acc = accuracy_metric(preds > 0.5, labels.int())

            val_loss += loss.item() * imgs.size(0)
            val_acc += acc.item() * imgs.size(0)

    avg_val_loss = val_loss / len(val_dataloader.dataset)
    avg_val_acc = val_acc / len(val_dataloader.dataset)
    print(f"Val Loss: {avg_val_loss:.4f} | Accuracy: {avg_val_acc:.4f}")

    # Save best model based on validation loss
    if avg_val_loss < best_val_loss:
        print("New best val_loss. Saving model.")
        best_val_loss = avg_val_loss
        counter = 0
        torch.save(classifier.state_dict(), "best_classifier.pth")
        fine_tuned_dino_weights = {
            k: v.cpu()
            for k, v in dino_model.state_dict().items()
            if any(layer in k for layer in ["blocks.10", "blocks.11", "norm", "head"])
        }
        torch.save(fine_tuned_dino_weights, "best_finetuned_dino_layers.pth")
    else:
        counter += 1
        if counter >= patience:
            print(f"Early stopping triggered at epoch {epoch+1}")
            break


In [None]:
import zipfile
from google.colab import files

# Name of the zip archive to create
zip_name = "model_outputs.zip"

# Create the zip archive and add desired files
with zipfile.ZipFile(zip_name, "w") as zipf:
    zipf.write("best_classifier.pth")               # Save the trained classifier
    zipf.write("best_finetuned_dino_layers.pth")     # Save the fine-tuned DINOv2 layers
    # Add more files here if needed

# Trigger download of the archive from Colab
files.download(zip_name)


### Inference on test set and submission

In [None]:
import pandas as pd

# Load the trained classifier weights
classifier.load_state_dict(torch.load("best_classifier.pth", map_location=device))

# Load fine-tuned DINOv2 weights (only partial layers)
finetuned_weights = torch.load("best_finetuned_dino_layers.pth", map_location=device)

# Merge fine-tuned layers into the full DINOv2 state_dict
state_dict = dino_model.state_dict()
state_dict.update(finetuned_weights)  # Update blocks.10, blocks.11, norm, head
dino_model.load_state_dict(state_dict)

# Set models to evaluation mode
classifier.eval()
dino_model.eval()

# Prepare the dictionary to store predictions
solutions_data = {'ID': [], 'Pred': []}

# Load and process each image from the test set
with h5py.File("/content/drive/MyDrive/kaggle-DL-MI/data/test.h5", 'r') as hdf:
    test_ids = list(hdf.keys())

    for test_id in tqdm(test_ids):
        # Load the test image as a tensor
        img = torch.tensor(np.array(hdf.get(test_id).get('img'))).float()

        # Resize the image to match DINOv2 input constraints
        transform = transforms.Compose([
            transforms.Resize((98, 98))
        ])
        img_resized = transform(img)

        # Extract DINOv2 features
        with torch.no_grad():
            features = dino_model(img_resized.unsqueeze(0).to('cuda')).squeeze(0)

        # Run classifier to get prediction
        pred = classifier(features.unsqueeze(0)).detach().cpu()

        # Save binary prediction (threshold at 0.5)
        solutions_data['ID'].append(int(test_id))
        solutions_data['Pred'].append(int(pred.item() > 0.5))

# Save results to CSV file
solutions_data = pd.DataFrame(solutions_data).set_index('ID')
solutions_data.to_csv('cycleGAN_fine_tune_dino_submit.csv')

print("Submission saved to 'cycleGAN_fine_tune_dino_submit.csv'")

In [None]:
# Download submission
from google.colab import files
files.download("cycleGAN_fine_tune_dino_submit.csv")

### (Optionnal) Train of train+val datasets

In [None]:
from torch.utils.data import ConcatDataset, DataLoader

# Create a "validation-as-training" dataset using the validation set
# This version applies data augmentation (transform_train) and domain adaptation (GAN)
val_dataset_train = H5UnalignedDataset(
    h5_path="/content/drive/MyDrive/kaggle-DL-MI/data/val.h5",
    transform=transform_train,                     # Apply the same augmentation as for training
    aberrant_ids_train=aberrant_train_ids,
    aberrant_ids_val=aberrant_val_ids,
    multi_gens=False,
    #generators=generators_dict,
    net_GA = net_GA,
    train=False
)

# Merge train and val datasets to form a combined training set
combined_dataset = ConcatDataset([train_dataset, val_dataset_train])

# Create DataLoader from the merged dataset
combined_dataloader = DataLoader(combined_dataset, batch_size=256, shuffle=True)


In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
import numpy as np
import h5py
from torchmetrics.classification import BinaryAccuracy
from tqdm import tqdm

# Set device
device = "cuda"

# Resize images to 98x98 (required for DINOv2)
transform = transforms.Compose([
    transforms.Resize((98, 98)),
])

# Load pre-trained DINOv2 model
dino_model = torch.hub.load('facebookresearch/dinov2', 'dinov2_vits14').to(device)

# Freeze all layers except the last 2 transformer blocks, normalization, and the head
for name, param in dino_model.named_parameters():
    if not ("blocks.10" in name or "blocks.11" in name or "norm" in name or "head" in name):
        param.requires_grad = False

# Define the classifier head
class Classifier(nn.Module):
    def __init__(self, input_size, output_size):
        super(Classifier, self).__init__()
        self.fc1 = nn.Linear(input_size, 512)
        self.bn1 = nn.BatchNorm1d(512)
        self.fc2 = nn.Linear(512, 256)
        self.bn2 = nn.BatchNorm1d(256)
        self.fc3 = nn.Linear(256, 128)
        self.bn3 = nn.BatchNorm1d(128)
        self.fc4 = nn.Linear(128, output_size)
        self.dropout = nn.Dropout(p=0.5)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        x = torch.relu(self.bn1(self.fc1(x)))
        x = self.dropout(x)
        x = torch.relu(self.bn2(self.fc2(x)))
        x = self.dropout(x)
        x = torch.relu(self.bn3(self.fc3(x)))
        x = self.fc4(x)
        x = self.sigmoid(x)
        return x.view(-1)  # Ensure output shape = (batch_size,)

# Input feature size from DINOv2
feature_dim = 384

# Instantiate the classifier and move to device
classifier = Classifier(input_size=feature_dim, output_size=1).to(device)

# Combine trainable DINO layers and classifier parameters
params_to_optimize = list(filter(lambda p: p.requires_grad, dino_model.parameters())) + list(classifier.parameters())

# Define optimizer with weight decay
optimizer = optim.Adam(params_to_optimize, lr=1e-4, weight_decay=1e-5)

# Binary cross-entropy loss and accuracy metric
criterion = nn.BCELoss()
accuracy_metric = BinaryAccuracy().to(device)

# Number of training epochs
NUM_EPOCHS = 5

# Training loop
for epoch in range(NUM_EPOCHS):
    dino_model.train()
    classifier.train()
    total_loss, total_acc = 0, 0

    for imgs, labels in tqdm(combined_dataloader, desc=f"Epoch {epoch+1}/{NUM_EPOCHS}"):
        imgs, labels = imgs.to(device), labels.float().to(device)
        imgs = transform(imgs)  # Resize to 98x98

        features = dino_model(imgs)
        preds = classifier(features)

        loss = criterion(preds, labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        acc = accuracy_metric(preds > 0.5, labels.int())
        total_loss += loss.item() * imgs.size(0)
        total_acc += acc.item() * imgs.size(0)

    avg_loss = total_loss / len(combined_dataloader.dataset)
    avg_acc = total_acc / len(combined_dataloader.dataset)
    print(f"Train Loss: {avg_loss:.4f} | Accuracy: {avg_acc:.4f}")

# Save the classifier weights
torch.save(classifier.state_dict(), "best_classifier_full.pth")

# Save only the fine-tuned layers of DINOv2
fine_tuned_dino_weights = {
    k: v.cpu()
    for k, v in dino_model.state_dict().items()
    if any(layer in k for layer in ["blocks.10", "blocks.11", "norm", "head"])
}
torch.save(fine_tuned_dino_weights, "best_finetuned_dino_layers_full.pth")


In [None]:
import zipfile
from google.colab import files

# Create a zip archive containing the full model outputs
zip_name = "model_full_outputs.zip"

with zipfile.ZipFile(zip_name, "w") as zipf:
    zipf.write("best_classifier_full.pth")             # Add the trained classifier weights
    zipf.write("best_finetuned_dino_layers_full.pth")  # Add the fine-tuned DINOv2 layers

# Trigger the download in the Colab interface
files.download(zip_name)


In [None]:
import torch
import h5py
import numpy as np
import pandas as pd
from tqdm import tqdm

# Load the classifier trained on train + val
classifier = Classifier(input_size=feature_dim, output_size=1).to('cuda')
classifier.load_state_dict(torch.load("best_classifier_full.pth", map_location=device))

# Load the fine-tuned DINOv2 layers
finetuned_weights = torch.load("best_finetuned_dino_layers_full.pth", map_location=device)

# Update only the relevant layers in the DINOv2 model
state_dict = dino_model.state_dict()
state_dict.update(finetuned_weights)  # Replace blocks 10, 11, norm, and head
dino_model.load_state_dict(state_dict)

# Set both models to evaluation mode
classifier.eval()
dino_model.eval()

# Prepare dictionary to store predictions
solutions_data = {'ID': [], 'Pred': []}

# Load test images and run inference
with h5py.File("/content/drive/MyDrive/kaggle-DL-MI/data/test.h5", 'r') as hdf:
    test_ids = list(hdf.keys())

    for test_id in tqdm(test_ids):
        # Load the image (already normalized to test domain)
        img = torch.tensor(np.array(hdf.get(test_id).get('img'))).float()

        # Resize to match DINOv2 input size
        transform = transforms.Compose([
            transforms.Resize((98, 98)),
        ])
        img_resized = transform(img)

        # Extract features using DINOv2
        with torch.no_grad():
            features = dino_model(img_resized.unsqueeze(0).to('cuda')).squeeze(0)

        # Predict using the classifier
        pred = classifier(features.unsqueeze(0)).detach().cpu()

        # Store prediction as binary label
        solutions_data['ID'].append(int(test_id))
        solutions_data['Pred'].append(int(pred.item() > 0.5))

# Save results to CSV for Kaggle submission
solutions_data = pd.DataFrame(solutions_data).set_index('ID')
solutions_data.to_csv('cycleGAN_fine_tune_dino_submit_full.csv')

print("Submission saved to 'cycleGAN_fine_tune_dino_submit_full.csv'")


In [None]:
# Download submission
from google.colab import files
files.download('cycleGAN_fine_tune_dino_submit_full.csv')

## MedImageInsight

### Recquired downloads

In [None]:
!git lfs install

In [None]:
%cd kaggle-DL-MI/

In [None]:
!git clone https://huggingface.co/lion-ai/MedImageInsights

In [None]:
%cd MedImageInsights

In [None]:
!pwd

In [None]:
!uv sync

In [None]:
!pip install mup

In [None]:
!pip install fvcore

### Model loading

In [None]:
from medimageinsightmodel import MedImageInsight

# Initialize the MedImageInsight model
# The model uses a vision encoder pre-trained specifically for medical imaging
embedding_extractor = MedImageInsight(
    model_dir="2024.09.27",                         # Directory containing the model files
    vision_model_name="medimageinsigt-v1.0.0.pt",   # Vision encoder checkpoint
    language_model_name="language_model.pth"        # Language encoder checkpoint (not used here)
)

# Load the vision encoder weights (and text encoder if needed)
embedding_extractor.load_model()


### Embeddings extraction

In [None]:
train_dataloader = DataLoader(train_dataset, batch_size=16, shuffle=False)
val_dataloader = DataLoader(val_dataset, batch_size=16, shuffle=False)

In [None]:
import base64, io
from PIL import Image
import pandas as pd
from tqdm import tqdm

def extract_embeddings_from_dataloader(dataloader, model, save_path=None, max_batches=None):
    """
    Extracts image embeddings from a dataloader using a base64-based model API (e.g., MedImageInsight).

    Args:
        dataloader: PyTorch DataLoader containing images and labels.
        model: Feature extractor with a .encode(images=[base64 strings]) method.
        save_path: Optional path to save the output DataFrame as a .pkl file.
        max_batches: Optional number of batches to process (useful for debugging or speed constraints).

    Returns:
        A pandas DataFrame with columns ['ID', 'label', 'embedding'].
    """

    image_ids, labels, all_embeddings = [], [], []

    for batch_idx, (imgs, lbls) in enumerate(tqdm(dataloader, desc="Extraction d'embeddings")):
        if max_batches is not None and batch_idx >= max_batches:
            break

        batch_b64 = []
        for img in imgs:
            # Convert image from [C, H, W] to [H, W, C] and move to CPU
            np_img = img.permute(1, 2, 0).detach().cpu().numpy()

            # Rescale to [0, 255] and convert to uint8 if needed
            if np_img.max() <= 1.0:
                np_img = (np_img * 255).astype(np.uint8)
            else:
                np_img = np_img.astype(np.uint8)

            # Encode the image as base64 PNG for model input
            buffer = io.BytesIO()
            Image.fromarray(np_img).save(buffer, format="PNG")
            img_b64 = base64.b64encode(buffer.getvalue()).decode("utf-8")
            batch_b64.append(img_b64)

        try:
            # Send batch of base64 images to the model
            result = model.encode(images=batch_b64)
            embeddings = result["image_embeddings"]
        except Exception as e:
            print(f"Error while encoding batch {batch_idx}: {e}")
            continue

        all_embeddings.extend(embeddings)
        labels.extend(lbls.tolist())
        image_ids.extend([f"batch{batch_idx}_img{i}" for i in range(len(lbls))])  # Temporary ID

    # Create dataframe with extracted embeddings and labels
    df_embed = pd.DataFrame({
        "ID": image_ids,
        "label": labels,
        "embedding": all_embeddings
    })

    if save_path:
        df_embed.to_pickle(save_path)
        print(f"Embeddings saved to {save_path}")

    return df_embed


In [None]:
df_train_embed = extract_embeddings_from_dataloader(train_dataloader, embedding_extractor, save_path="train_GAN_embed.pkl")
df_val_embed = extract_embeddings_from_dataloader(val_dataloader, embedding_extractor, save_path="val_GAN_embed.pkl")


In [None]:
import zipfile
from google.colab import files

# Define the name of the output zip file
zip_name = "embeds_train_val.zip"

# Create a zip archive containing the train and val embedding files
with zipfile.ZipFile(zip_name, "w") as zipf:
    zipf.write("train_GAN_embed.pkl")
    zipf.write("val_GAN_embed.pkl")
    # Add more files here if needed

# Trigger download of the zip archive
files.download(zip_name)


In [None]:
def extract_embeddings_from_h5_batch(h5_path, model, max_images=None, batch_size=64, save_path=None):
    """
    Extract image embeddings from a .h5 dataset using a model that supports base64 input.

    Parameters:
    - h5_path (str): Path to the HDF5 file.
    - model: Model object with an 'encode(images=...)' method.
    - max_images (int): Optional limit on number of images to process.
    - batch_size (int): Number of images per batch sent to the model.
    - save_path (str): Optional path to save resulting DataFrame as .pkl file.

    Returns:
    - pd.DataFrame: DataFrame containing image IDs, labels, and embeddings.
    """

    image_ids, labels, all_embeddings = [], [], []
    batch_b64, batch_ids, batch_labels = [], [], []

    with h5py.File(h5_path, 'r') as hdf:
        ids = list(hdf.keys())
        if max_images is not None:
            ids = ids[:max_images]

        for i, img_id in enumerate(tqdm(ids)):
            # Load image array
            img_array = np.array(hdf[img_id]['img'])

            # Ensure correct shape: [H, W, C]
            if img_array.shape[0] == 3:
                img_array = np.transpose(img_array, (1, 2, 0))

            # Convert to uint8 in [0, 255]
            if img_array.max() <= 1.0:
                img_array = (img_array * 255).astype(np.uint8)
            else:
                img_array = img_array.astype(np.uint8)

            # Encode image to base64
            buffer = io.BytesIO()
            Image.fromarray(img_array).save(buffer, format="PNG")
            img_b64 = base64.b64encode(buffer.getvalue()).decode("utf-8")

            # Accumulate batch info
            batch_b64.append(img_b64)
            batch_ids.append(img_id)
            batch_labels.append(int(np.array(hdf[img_id]['label'])) if 'label' in hdf[img_id] else None)

            # Once batch is full or last image → send to model
            if len(batch_b64) == batch_size or (i == len(ids) - 1):
                try:
                    result = model.encode(images=batch_b64)
                    embeddings = result["image_embeddings"]
                except Exception as e:
                    print(f"Error at batch {i}: {e}")
                    batch_b64, batch_ids, batch_labels = [], [], []
                    continue

                all_embeddings.extend(embeddings)
                image_ids.extend(batch_ids)
                labels.extend(batch_labels)

                # Reset batch
                batch_b64, batch_ids, batch_labels = [], [], []
                print("Batch processed.")

    # Build resulting DataFrame
    df_embed = pd.DataFrame({
        "ID": image_ids,
        "label": labels,
        "embedding": all_embeddings
    })

    if save_path:
        df_embed.to_pickle(save_path)
        print(f"✅ Embeddings saved to {save_path}")

    return df_embed


In [None]:
df_test_embed = extract_embeddings_from_h5_batch(
    h5_path="/content/drive/MyDrive/kaggle-DL-MI/data/test.h5",
    model=embedding_extractor,
    batch_size=64,
    save_path="test_embed.pkl"
)


In [None]:
# Download test embeddings

from google.colab import files
files.download('test_embed.pkl')


### Training classifier with training embeddings and validation on val embeddings 

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchmetrics.classification import BinaryAccuracy
from tqdm import tqdm
import pandas as pd

# Define MLP classifier for binary prediction from embeddings
class Classifier(nn.Module):
    def __init__(self, input_size, output_size):
        super(Classifier, self).__init__()
        self.fc1 = nn.Linear(input_size, 512)
        self.bn1 = nn.BatchNorm1d(512)
        self.fc2 = nn.Linear(512, 256)
        self.bn2 = nn.BatchNorm1d(256)
        self.fc3 = nn.Linear(256, 128)
        self.bn3 = nn.BatchNorm1d(128)
        self.fc4 = nn.Linear(128, output_size)
        self.dropout = nn.Dropout(p=0.5)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        x = torch.relu(self.bn1(self.fc1(x)))
        x = self.dropout(x)
        x = torch.relu(self.bn2(self.fc2(x)))
        x = self.dropout(x)
        x = torch.relu(self.bn3(self.fc3(x)))
        x = self.fc4(x)
        x = self.sigmoid(x)
        return x.view(-1)  # Output shape: (batch_size,)

# Load training and validation embedding datasets
train_embedd_dataset = pd.read_pickle("train_GAN_embed.pkl")
val_embedd_dataset = pd.read_pickle("val_GAN_embed.pkl")

device = "cuda" if torch.cuda.is_available() else "cpu"

# Custom Dataset to handle embeddings and labels
class EmbeddingDataset(Dataset):
    def __init__(self, dataframe):
        self.embeddings = dataframe["embedding"].tolist()
        self.labels = dataframe["label"].tolist()

    def __len__(self):
        return len(self.embeddings)

    def __getitem__(self, idx):
        embedding = torch.tensor(self.embeddings[idx]).float()
        label = torch.tensor(self.labels[idx]).float()
        return embedding, label

# Create DataLoaders
train_dataset = EmbeddingDataset(train_embedd_dataset)
val_dataset = EmbeddingDataset(val_embedd_dataset)
train_loader = DataLoader(train_dataset, batch_size=256, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=256, shuffle=False)

# Initialize model and training components
feature_dim = len(train_embedd_dataset.iloc[0]["embedding"])
classifier = Classifier(input_size=feature_dim, output_size=1).to(device)
criterion = nn.BCELoss()
optimizer = optim.Adam(classifier.parameters(), lr=1e-4, weight_decay=1e-5)
metric = BinaryAccuracy().to(device)

# Add learning rate scheduler to reduce LR on plateau
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode='min', factor=0.5, patience=3, verbose=True
)

# Training loop with early stopping
best_val_loss = float('inf')
patience = 5
counter = 0
NUM_EPOCHS = 50

for epoch in range(NUM_EPOCHS):
    classifier.train()
    total_loss, total_acc = 0, 0

    for emb, lbl in tqdm(train_loader, desc=f"[Epoch {epoch+1}]"):
        emb, lbl = emb.to(device), lbl.to(device)
        preds = classifier(emb)
        loss = criterion(preds, lbl)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        acc = metric(preds > 0.5, lbl.int())
        total_loss += loss.item() * emb.size(0)
        total_acc += acc.item() * emb.size(0)

    avg_loss = total_loss / len(train_loader.dataset)
    avg_acc = total_acc / len(train_loader.dataset)
    print(f"Train Loss: {avg_loss:.4f} | Accuracy: {avg_acc:.4f}")

    # Validation phase
    classifier.eval()
    val_loss, val_acc = 0, 0
    with torch.no_grad():
        for emb, lbl in val_loader:
            emb, lbl = emb.to(device), lbl.to(device)
            preds = classifier(emb)
            loss = criterion(preds, lbl)
            acc = metric(preds > 0.5, lbl.int())

            val_loss += loss.item() * emb.size(0)
            val_acc += acc.item() * emb.size(0)

    avg_val_loss = val_loss / len(val_loader.dataset)
    avg_val_acc = val_acc / len(val_loader.dataset)
    print(f"Val Loss: {avg_val_loss:.4f} | Accuracy: {avg_val_acc:.4f}")

    scheduler.step(avg_val_loss)

    # Save best model and check early stopping
    if avg_val_loss < best_val_loss:
        print("New best validation loss. Saving model.")
        best_val_loss = avg_val_loss
        counter = 0
        torch.save(classifier.state_dict(), "best_embed_classifier.pth")
    else:
        counter += 1
        if counter >= patience:
            print(f"Early stopping triggered at epoch {epoch+1}")
            break


In [None]:
# Download classifier weights
import zipfile
from google.colab import files
files.download("best_embed_classifier.pth")

### Inference on test embeddings and submission

In [None]:
test_embedd_dataset = pd.read_pickle("test_embed.pkl")


In [None]:
import torch
from torch.utils.data import Dataset, DataLoader
import pandas as pd
from tqdm import tqdm

# Custom Dataset for test embeddings
class EmbeddingTestDataset(Dataset):
    def __init__(self, dataframe):
        self.embeddings = dataframe["embedding"].tolist()
        self.ids = dataframe["ID"].tolist()

    def __len__(self):
        return len(self.embeddings)

    def __getitem__(self, idx):
        embedding = torch.tensor(self.embeddings[idx]).float()
        id_ = self.ids[idx]
        return embedding, id_

# Load trained classifier
device = "cuda" if torch.cuda.is_available() else "cpu"
feature_dim = len(test_embedd_dataset.iloc[0]["embedding"])

classifier = Classifier(input_size=feature_dim, output_size=1).to(device)
classifier.load_state_dict(torch.load("best_embed_classifier.pth", map_location=device))
classifier.eval()

# Create DataLoader for test embeddings
test_dataset = EmbeddingTestDataset(test_embedd_dataset)
test_loader = DataLoader(test_dataset, batch_size=256, shuffle=False)

# Perform inference
ids, probs, preds_bin = [], [], []

with torch.no_grad():
    for emb, id_ in tqdm(test_loader, desc="Inference"):
        emb = emb.to(device)
        prob = classifier(emb)  # Predicted probability
        pred = (prob > 0.5).int()  # Binary prediction

        ids.extend(id_)
        probs.extend(prob.cpu().numpy().tolist())
        preds_bin.extend(pred.cpu().numpy().tolist())

# Create final submission DataFrame
df_submission = pd.DataFrame({
    "ID": ids,
    "Pred": preds_bin
})


In [None]:
# Save the submission file as CSV
df_submission.to_csv("submission_MedImg_GAN.csv", index=False)
print("submission_MedImg_GAN.csv generated with probabilities and binary predictions!")

# Download the CSV file from Colab
from google.colab import files
files.download("submission_MedImg_GAN.csv")


### (Optionnal) Train+val training 

In [None]:
# Combine the training and validation datasets into a single dataset
combined_dataset = ConcatDataset([train_dataset, val_dataset])

# Create a DataLoader for the combined dataset (used for final training)
combined_dataloader = DataLoader(combined_dataset, batch_size=256, shuffle=True)


In [None]:
# Initialize the classifier with the appropriate input size (matching the embedding dimension)
feature_dim = len(train_embedd_dataset.iloc[0]["embedding"])
classifier = Classifier(input_size=feature_dim, output_size=1).to(device)

# Define loss function and optimizer
criterion = nn.BCELoss()
optimizer = optim.Adam(classifier.parameters(), lr=1e-4, weight_decay=1e-5)
metric = BinaryAccuracy().to(device)

# Training loop parameters
best_val_loss = float('inf')
patience = 5
counter = 0
NUM_EPOCHS = 15

# Training loop over combined training + validation embeddings
for epoch in range(NUM_EPOCHS):
    classifier.train()
    total_loss, total_acc = 0, 0

    for emb, lbl in tqdm(combined_dataloader, desc=f"[Epoch {epoch+1}]"):
        emb, lbl = emb.to(device), lbl.to(device)

        preds = classifier(emb)
        loss = criterion(preds, lbl)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        acc = metric(preds > 0.5, lbl.int())
        total_loss += loss.item() * emb.size(0)
        total_acc += acc.item() * emb.size(0)

    avg_loss = total_loss / len(combined_dataloader.dataset)
    avg_acc = total_acc / len(combined_dataloader.dataset)
    print(f"Train Loss: {avg_loss:.4f} | Accuracy: {avg_acc:.4f}")

# Save the trained classifier weights
torch.save(classifier.state_dict(), "best_embed_classifier_full.pth")


In [None]:
import torch
from torch.utils.data import Dataset, DataLoader
import pandas as pd
from tqdm import tqdm

# Custom dataset for test embeddings
class EmbeddingTestDataset(Dataset):
    def __init__(self, dataframe):
        self.embeddings = dataframe["embedding"].tolist()
        self.ids = dataframe["ID"].tolist()

    def __len__(self):
        return len(self.embeddings)

    def __getitem__(self, idx):
        embedding = torch.tensor(self.embeddings[idx]).float()
        id_ = self.ids[idx]
        return embedding, id_

# Load trained classifier
device = "cuda" if torch.cuda.is_available() else "cpu"
feature_dim = len(test_embedd_dataset.iloc[0]["embedding"])

classifier = Classifier(input_size=feature_dim, output_size=1).to(device)
classifier.load_state_dict(torch.load("best_embed_classifier_full.pth", map_location=device))
classifier.eval()

# Create test DataLoader
test_dataset = EmbeddingTestDataset(test_embedd_dataset)
test_loader = DataLoader(test_dataset, batch_size=256, shuffle=False)

# Inference loop
ids, probs, preds_bin = [], [], []

with torch.no_grad():
    for emb, id_ in tqdm(test_loader, desc="Inference"):
        emb = emb.to(device)
        prob = classifier(emb)
        pred = (prob > 0.5).int()  # Convert probabilities to binary predictions

        ids.extend(id_)
        probs.extend(prob.cpu().numpy().tolist())
        preds_bin.extend(pred.cpu().numpy().tolist())

# Create the submission DataFrame
df_submission = pd.DataFrame({
    "ID": ids,
    "Pred": preds_bin
})


In [None]:
# Save the submission DataFrame to a CSV file
df_submission.to_csv("submission_MedImg_GAN_full.csv", index=False)
print("submission_MedImg_GAN_full.csv has been saved with binary predictions.")

# Download the CSV file to your local machine
from google.colab import files
files.download("submission_MedImg_GAN_full.csv")
