# Center Specific Normalization and Ensemble Prediction

In this notebook, we propose to investigate the advantage to use center specific normalization instead of using individual normalisation. To be able to compare the advantage of this method, we use the same baseline model and only change the train test and valid dataset. We also investigate aggregating several predictions on augmented data to reduce the variance of the predicitions.

This method gives promising results, leading to an accuracy of 93.5% on the test set (where the baseline model achieved 90.6%)

In [None]:
import h5py
import torch
import random
import numpy as np
import pandas as pd
import torchmetrics
import matplotlib.pyplot as plt
import torchvision.transforms as transforms
from tqdm.notebook import tqdm
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms.functional as TF


In [2]:
TRAIN_IMAGES_PATH = 'train.h5'
VAL_IMAGES_PATH = 'val.h5'
TEST_IMAGES_PATH = 'test.h5'
SEED = 0
GLOBAL_MEAN = np.array([0.485, 0.456, 0.406])
GLOBAL_STD = np.array([0.229, 0.224, 0.225])

In [3]:
torch.random.manual_seed(SEED)
random.seed(SEED)

## 1. Building a center normalized dataset
The datasets will leverage statistics computed in each center to be able to normalized the data according to the center it comes from

In [4]:
BATCH_SIZE = 32

Normalizing according to the center can be better than normalizing images individually because doing so would erase outlyer information. Here, we can know how the image behaves relatively to other obtained in the same condition, which could preserve outlier info.

In [5]:
def compute_center_stats(dataset_paths):
        """Compute mean and std for each center in the dataset"""
        center_stats = {}
        for dataset_path in dataset_paths:
            with h5py.File(dataset_path, 'r') as hdf:
                # First we group images by center
                center_images = {}
                for img_id in hdf.keys():
                    try:
                        center = int(np.array(hdf.get(img_id).get('metadata'))[0])
                        if center not in center_images:
                            center_images[center] = []
                        center_images[center].append(np.array(hdf.get(img_id).get('img')))
                    except:
                        continue

        # Compute mean and std for each center
        for center, images in center_images.items():
            all_images = np.vstack([img.reshape(1, *img.shape) for img in images])
            mean = np.mean(all_images, axis=(0, 2, 3))
            std = np.std(all_images, axis=(0, 2, 3))
            center_stats[center] = {'mean': mean, 'std': std}
                    
        return center_stats

class CenterNormalizedDataset(Dataset):
    def __init__(self, dataset_path, preprocessing, mode, center_stats=None):
        super(CenterNormalizedDataset, self).__init__()
        self.dataset_path = dataset_path
        self.preprocessing = preprocessing
        self.mode = mode
        
        # Pre-compute center statistics if not provided
        self.center_stats = center_stats
        if self.center_stats is None:
            self.center_stats = compute_center_stats([dataset_path])
        
        with h5py.File(self.dataset_path, 'r') as hdf:        
            self.image_ids = list(hdf.keys())

    def __len__(self):
        return len(self.image_ids)

    def __getitem__(self, idx):
        img_id = self.image_ids[idx]
        with h5py.File(self.dataset_path, 'r') as hdf:
            img = np.array(hdf.get(img_id).get('img'))
            center = int(np.array(hdf.get(img_id).get('metadata'))[0])
            label = np.array(hdf.get(img_id).get('label')) if self.mode == 'train' else None
        
        # Apply normalization to match global stats
        if center in self.center_stats:
            # First, standardize using center-specific stats
            center_mean = self.center_stats[center]['mean']
            center_std = self.center_stats[center]['std']
            
            # Standardize to zero mean and unit variance
            img_standardized = (img - center_mean.reshape(-1, 1, 1)) / (center_std.reshape(-1, 1, 1) + 1e-8)
            
            # Rescale to global statistics
            img = img_standardized * GLOBAL_STD.reshape(-1, 1, 1) + GLOBAL_MEAN.reshape(-1, 1, 1)
        
        img_tensor = torch.tensor(img)
        if self.preprocessing:
            img_tensor = self.preprocessing(img_tensor)
        
        if self.mode == 'train':
            return img_tensor.float(), torch.tensor(label).float()
        else:
            return img_tensor.float(), None

In [6]:
preprocessing = transforms.Resize((98, 98))

# Compute stats from all data
center_stats = compute_center_stats([TRAIN_IMAGES_PATH, VAL_IMAGES_PATH, TEST_IMAGES_PATH])
print("center stats computed")

# Pass the computed center statistics to the datasets
train_dataset = CenterNormalizedDataset(TRAIN_IMAGES_PATH, preprocessing, 'train', center_stats)
val_dataset = CenterNormalizedDataset(VAL_IMAGES_PATH, preprocessing, 'train', center_stats)
test_dataset = CenterNormalizedDataset(TEST_IMAGES_PATH, preprocessing, 'eval', center_stats)

train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=BATCH_SIZE)
val_dataloader = DataLoader(val_dataset, shuffle=False, batch_size=BATCH_SIZE)
test_dataloader = DataLoader(test_dataset, shuffle=False, batch_size=BATCH_SIZE)

center stats computed


Since we will use the same model as the baseline proposed (Dinov2 + linear classifier), we will use their PrecomputedDataset class

In [7]:
def precompute(dataloader, model, device):
    xs, ys = [], []
    for x, y in tqdm(dataloader, leave=False):
        with torch.no_grad():
            # Make sure x has the correct shape (B, C, H, W)
            if len(x.shape) == 3:  # If missing batch dimension
                x = x.unsqueeze(0)
            # Extract features
            features = model(x.to(device)).detach().cpu()
            xs.append(features.numpy())
        ys.append(y.numpy() if y is not None else np.zeros(x.size(0)))
    xs = np.vstack(xs)
    ys = np.hstack(ys)
    return torch.tensor(xs), torch.tensor(ys)

This dataset is used to store precomputed features on which we can train a classifier only model

In [8]:
class PrecomputedDataset(Dataset):
    def __init__(self, features, labels):
        super(PrecomputedDataset, self).__init__()
        self.features = features
        self.labels = labels.unsqueeze(-1)
    
    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        return self.features[idx], self.labels[idx].float()

## 2. Precomputing the features

In [9]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Working on {device}.')

Working on cuda.


As in the baseline model, we use dinov2 to precompute the features. Our main contribution in this notebook is to propose and test a center specific normalisation

In [10]:
feature_extractor = torch.hub.load('facebookresearch/dinov2', 'dinov2_vits14').to(device)
feature_extractor.eval()
linear_probing = torch.nn.Sequential(torch.nn.Linear(feature_extractor.num_features, 1),
                                     torch.nn.Sigmoid()).to(device)

Using cache found in /raid/home/automatants/tabbara_pau/.cache/torch/hub/facebookresearch_dinov2_main


In [11]:
train_dataset = PrecomputedDataset(*precompute(train_dataloader, feature_extractor, device))
val_dataset = PrecomputedDataset(*precompute(val_dataloader, feature_extractor, device))

  0%|          | 0/3125 [00:00<?, ?it/s]



  0%|          | 0/1091 [00:00<?, ?it/s]

In [12]:
torch.save(train_dataset, 'train_dataset_center_normalized.pth')
torch.save(val_dataset, 'val_dataset_center_normalized.pth')

In [13]:
train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=BATCH_SIZE)
val_dataloader = DataLoader(val_dataset, shuffle=False, batch_size=BATCH_SIZE)

## 3.a) Training the model on center-normalized data

In [31]:
def training_pipeline(feature_extractor, device, train_dataloader, val_dataloader, num_epochs=100, patience=10):
    
    linear_probing = torch.nn.Sequential(
        torch.nn.Linear(feature_extractor.num_features, 1),
        torch.nn.Sigmoid()
    ).to(device)
    
    optimizer = torch.optim.Adam(linear_probing.parameters(), lr=0.001)
    criterion = torch.nn.BCELoss()
    metric = torchmetrics.Accuracy('binary')
    
    min_loss, best_epoch = float('inf'), 0
    
    for epoch in range(num_epochs):
        linear_probing.train()
        train_metrics, train_losses = [], []
        
        for train_x, train_y in tqdm(train_dataloader, leave=False):
            optimizer.zero_grad()
            train_pred = linear_probing(train_x.to(device))
            loss = criterion(train_pred, train_y.to(device))
            loss.backward()
            optimizer.step()
            
            train_losses.extend([loss.item()]*len(train_y))
            train_metric = metric(train_pred.cpu(), train_y.int().cpu())
            train_metrics.extend([train_metric.item()]*len(train_y))
        
        print(f'Epoch train [{epoch+1}/{num_epochs}] | Loss {np.mean(train_losses):.4f} | Metric {np.mean(train_metrics):.4f}')
        
        linear_probing.eval()
        val_metrics, val_losses = [], []
        
        for val_x, val_y in tqdm(val_dataloader, leave=False):
            with torch.no_grad():
                val_pred = linear_probing(val_x.to(device))
            
            loss = criterion(val_pred, val_y.to(device))
            val_losses.extend([loss.item()]*len(val_y))
            val_metric = metric(val_pred.cpu(), val_y.int().cpu())
            val_metrics.extend([val_metric.item()]*len(val_y))
        
        mean_val_loss = np.mean(val_losses)
        print(f'Epoch valid [{epoch+1}/{num_epochs}] | Loss {mean_val_loss:.4f} | Metric {np.mean(val_metrics):.4f}')
        
        # Save best model
        if mean_val_loss < min_loss:
            print(f'New best loss {min_loss:.4f} -> {mean_val_loss:.4f}')
            min_loss = mean_val_loss
            best_epoch = epoch
            torch.save(linear_probing.state_dict(), 'best_model.pth')
        
        # Early stopping
        if epoch - best_epoch == patience:
            print(f'Early stopping after {patience} epochs without improvement')
            break
    
    # Load best model for final evaluation
    linear_probing.load_state_dict(torch.load('best_model.pth'))
    
    return linear_probing

train_dataset = torch.load('train_dataset_center_normalized.pth')
val_dataset = torch.load('val_dataset_center_normalized.pth')
train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=BATCH_SIZE)
val_dataloader = DataLoader(val_dataset, shuffle=False, batch_size=BATCH_SIZE)

best_model = training_pipeline(
    feature_extractor=feature_extractor,
    device=device,
    train_dataloader=train_dataloader,
    val_dataloader=val_dataloader,
    num_epochs=100,
    patience=10
)

  0%|          | 0/3125 [00:00<?, ?it/s]

Epoch train [1/100] | Loss 0.1842 | Metric 0.9295


  0%|          | 0/1091 [00:00<?, ?it/s]

Epoch valid [1/100] | Loss 0.3431 | Metric 0.8563
New best loss inf -> 0.3431


  0%|          | 0/3125 [00:00<?, ?it/s]

Epoch train [2/100] | Loss 0.1555 | Metric 0.9415


  0%|          | 0/1091 [00:00<?, ?it/s]

Epoch valid [2/100] | Loss 0.3224 | Metric 0.8647
New best loss 0.3431 -> 0.3224


  0%|          | 0/3125 [00:00<?, ?it/s]

Epoch train [3/100] | Loss 0.1506 | Metric 0.9426


  0%|          | 0/1091 [00:00<?, ?it/s]

Epoch valid [3/100] | Loss 0.3816 | Metric 0.8499


  0%|          | 0/3125 [00:00<?, ?it/s]

Epoch train [4/100] | Loss 0.1475 | Metric 0.9446


  0%|          | 0/1091 [00:00<?, ?it/s]

Epoch valid [4/100] | Loss 0.3955 | Metric 0.8449


  0%|          | 0/3125 [00:00<?, ?it/s]

Epoch train [5/100] | Loss 0.1461 | Metric 0.9454


  0%|          | 0/1091 [00:00<?, ?it/s]

Epoch valid [5/100] | Loss 0.3109 | Metric 0.8724
New best loss 0.3224 -> 0.3109


  0%|          | 0/3125 [00:00<?, ?it/s]

Epoch train [6/100] | Loss 0.1443 | Metric 0.9459


  0%|          | 0/1091 [00:00<?, ?it/s]

Epoch valid [6/100] | Loss 0.3332 | Metric 0.8663


  0%|          | 0/3125 [00:00<?, ?it/s]

Epoch train [7/100] | Loss 0.1428 | Metric 0.9464


  0%|          | 0/1091 [00:00<?, ?it/s]

Epoch valid [7/100] | Loss 0.3201 | Metric 0.8713


  0%|          | 0/3125 [00:00<?, ?it/s]

Epoch train [8/100] | Loss 0.1424 | Metric 0.9464


  0%|          | 0/1091 [00:00<?, ?it/s]

Epoch valid [8/100] | Loss 0.3451 | Metric 0.8636


  0%|          | 0/3125 [00:00<?, ?it/s]

Epoch train [9/100] | Loss 0.1421 | Metric 0.9467


  0%|          | 0/1091 [00:00<?, ?it/s]

Epoch valid [9/100] | Loss 0.3212 | Metric 0.8735


  0%|          | 0/3125 [00:00<?, ?it/s]

Epoch train [10/100] | Loss 0.1410 | Metric 0.9472


  0%|          | 0/1091 [00:00<?, ?it/s]

Epoch valid [10/100] | Loss 0.3276 | Metric 0.8699


  0%|          | 0/3125 [00:00<?, ?it/s]

Epoch train [11/100] | Loss 0.1407 | Metric 0.9472


  0%|          | 0/1091 [00:00<?, ?it/s]

Epoch valid [11/100] | Loss 0.3447 | Metric 0.8659


  0%|          | 0/3125 [00:00<?, ?it/s]

Epoch train [12/100] | Loss 0.1410 | Metric 0.9467


  0%|          | 0/1091 [00:00<?, ?it/s]

Epoch valid [12/100] | Loss 0.3177 | Metric 0.8756


  0%|          | 0/3125 [00:00<?, ?it/s]

Epoch train [13/100] | Loss 0.1404 | Metric 0.9477


  0%|          | 0/1091 [00:00<?, ?it/s]

Epoch valid [13/100] | Loss 0.4696 | Metric 0.8245


  0%|          | 0/3125 [00:00<?, ?it/s]

Epoch train [14/100] | Loss 0.1405 | Metric 0.9477


  0%|          | 0/1091 [00:00<?, ?it/s]

Epoch valid [14/100] | Loss 0.3090 | Metric 0.8778
New best loss 0.3109 -> 0.3090


  0%|          | 0/3125 [00:00<?, ?it/s]

Epoch train [15/100] | Loss 0.1396 | Metric 0.9481


  0%|          | 0/1091 [00:00<?, ?it/s]

Epoch valid [15/100] | Loss 0.3310 | Metric 0.8739


  0%|          | 0/3125 [00:00<?, ?it/s]

Epoch train [16/100] | Loss 0.1405 | Metric 0.9476


  0%|          | 0/1091 [00:00<?, ?it/s]

Epoch valid [16/100] | Loss 0.3344 | Metric 0.8640


  0%|          | 0/3125 [00:00<?, ?it/s]

Epoch train [17/100] | Loss 0.1399 | Metric 0.9480


  0%|          | 0/1091 [00:00<?, ?it/s]

Epoch valid [17/100] | Loss 0.3121 | Metric 0.8784


  0%|          | 0/3125 [00:00<?, ?it/s]

Epoch train [18/100] | Loss 0.1394 | Metric 0.9478


  0%|          | 0/1091 [00:00<?, ?it/s]

Epoch valid [18/100] | Loss 0.3463 | Metric 0.8666


  0%|          | 0/3125 [00:00<?, ?it/s]

Epoch train [19/100] | Loss 0.1391 | Metric 0.9480


  0%|          | 0/1091 [00:00<?, ?it/s]

Epoch valid [19/100] | Loss 0.3402 | Metric 0.8699


  0%|          | 0/3125 [00:00<?, ?it/s]

Epoch train [20/100] | Loss 0.1400 | Metric 0.9477


  0%|          | 0/1091 [00:00<?, ?it/s]

Epoch valid [20/100] | Loss 0.3279 | Metric 0.8689


  0%|          | 0/3125 [00:00<?, ?it/s]

Epoch train [21/100] | Loss 0.1392 | Metric 0.9475


  0%|          | 0/1091 [00:00<?, ?it/s]

Epoch valid [21/100] | Loss 0.3153 | Metric 0.8757


  0%|          | 0/3125 [00:00<?, ?it/s]

Epoch train [22/100] | Loss 0.1392 | Metric 0.9479


  0%|          | 0/1091 [00:00<?, ?it/s]

Epoch valid [22/100] | Loss 0.3863 | Metric 0.8536


  0%|          | 0/3125 [00:00<?, ?it/s]

Epoch train [23/100] | Loss 0.1390 | Metric 0.9481


  0%|          | 0/1091 [00:00<?, ?it/s]

Epoch valid [23/100] | Loss 0.3341 | Metric 0.8737


  0%|          | 0/3125 [00:00<?, ?it/s]

Epoch train [24/100] | Loss 0.1393 | Metric 0.9481


  0%|          | 0/1091 [00:00<?, ?it/s]

Epoch valid [24/100] | Loss 0.3507 | Metric 0.8676
Early stopping after 10 epochs without improvement


## 3.b) Training with simply Imagenet normalization

In [None]:
class DINOv2Dataset(Dataset):
    def __init__(self, dataset_path, mode='train'):
        super(DINOv2Dataset, self).__init__()
        self.dataset_path = dataset_path
        self.mode = mode
        
        # Standard ImageNet normalization used by DINOv2
        self.transform = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.Normalize(
                mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225]
            )
        ])
        
        with h5py.File(self.dataset_path, 'r') as hdf:        
            self.image_ids = list(hdf.keys())

    def __len__(self):
        return len(self.image_ids)

    def __getitem__(self, idx):
        img_id = self.image_ids[idx]
        with h5py.File(self.dataset_path, 'r') as hdf:
            img = torch.tensor(np.array(hdf.get(img_id).get('img'))).float()
            label = float(np.array(hdf.get(img_id).get('label'))) if self.mode == 'train' else 0.0
        
        # Ensure image values are in [0, 1] range
        if img.max() > 1.0:
            img = img / 255.0
            
        # Apply ImageNet normalization for DINOv2
        normalized_img = self.transform(img)
        
        return normalized_img, torch.tensor(label)
    
def prepare_dinov2_dataloaders(batch_size=32):
    # Just ImageNet normalization
    train_dataset = DINOv2Dataset(TRAIN_IMAGES_PATH, mode='train')
    val_dataset = DINOv2Dataset(VAL_IMAGES_PATH, mode='train')
    test_dataset = DINOv2Dataset(TEST_IMAGES_PATH, mode='eval')

    train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=batch_size)
    val_dataloader = DataLoader(val_dataset, shuffle=False, batch_size=batch_size)
    test_dataloader = DataLoader(test_dataset, shuffle=False, batch_size=batch_size)
    
    return train_dataloader, val_dataloader, test_dataloader

In [39]:
def training_pipeline_dinov2(feature_extractor, device, train_dataloader, val_dataloader, num_epochs=100, patience=10):
    
    linear_probing = torch.nn.Sequential(
        torch.nn.Linear(feature_extractor.num_features, 1),
        torch.nn.Sigmoid()
    ).to(device)
    
    # Place feature extractor in eval mode since we're not training it
    feature_extractor.eval()
    
    optimizer = torch.optim.Adam(linear_probing.parameters(), lr=0.001)
    criterion = torch.nn.BCELoss()
    metric = torchmetrics.Accuracy('binary')
    
    min_loss, best_epoch = float('inf'), 0
    
    for epoch in range(num_epochs):
        linear_probing.train()
        train_metrics, train_losses = [], []
        
        for train_x, train_y in tqdm(train_dataloader, leave=False):
            # First extract features using DinoV2
            with torch.no_grad():
                train_x = train_x.to(device)
                features = feature_extractor(train_x)
            
            # Reshape labels and move to device
            train_y = train_y.float().unsqueeze(1).to(device)
            
            optimizer.zero_grad()
            train_pred = linear_probing(features)  # Use the extracted features
            loss = criterion(train_pred, train_y)
            loss.backward()
            optimizer.step()
            
            train_losses.append(loss.item())
            train_metric = metric(train_pred.detach().cpu(), train_y.int().cpu())
            train_metrics.append(train_metric.item())
        
        print(f'Epoch train [{epoch+1}/{num_epochs}] | Loss {np.mean(train_losses):.4f} | Metric {np.mean(train_metrics):.4f}')
        
        linear_probing.eval()
        val_metrics, val_losses = [], []
        
        for val_x, val_y in tqdm(val_dataloader, leave=False):
            with torch.no_grad():
                val_x = val_x.to(device)
                features = feature_extractor(val_x)
                
                # Reshape labels and move to device
                val_y = val_y.float().unsqueeze(1).to(device)
                val_pred = linear_probing(features)
            
            loss = criterion(val_pred, val_y)
            val_losses.append(loss.item())
            val_metric = metric(val_pred.cpu(), val_y.int().cpu())
            val_metrics.append(val_metric.item())
        
        mean_val_loss = np.mean(val_losses)
        print(f'Epoch valid [{epoch+1}/{num_epochs}] | Loss {mean_val_loss:.4f} | Metric {np.mean(val_metrics):.4f}')
        
        # Save best model
        if mean_val_loss < min_loss:
            print(f'New best loss {min_loss:.4f} -> {mean_val_loss:.4f}')
            min_loss = mean_val_loss
            best_epoch = epoch
            torch.save(linear_probing.state_dict(), 'best_model.pth')
        
        # Early stopping
        if epoch - best_epoch == patience:
            print(f'Early stopping after {patience} epochs without improvement')
            break
    
    # Load best model for final evaluation
    linear_probing.load_state_dict(torch.load('best_model.pth'))
    
    return linear_probing

In [None]:
train_dataloader, val_dataloader, test_dataloader = prepare_dinov2_dataloaders(batch_size=128)

# Train with center awareness
best_model = training_pipeline_dinov2(
    feature_extractor=feature_extractor,
    device=device,
    train_dataloader=train_dataloader,
    val_dataloader=val_dataloader,
    num_epochs=100,
    patience=10
)

  0%|          | 0/782 [00:00<?, ?it/s]



Epoch train [1/100] | Loss 0.1873 | Metric 0.9298


  0%|          | 0/273 [00:00<?, ?it/s]

Epoch valid [1/100] | Loss 0.3811 | Metric 0.8337
New best loss inf -> 0.3811


  0%|          | 0/782 [00:00<?, ?it/s]

Epoch train [2/100] | Loss 0.1348 | Metric 0.9504


  0%|          | 0/273 [00:00<?, ?it/s]

Epoch valid [2/100] | Loss 0.3166 | Metric 0.8643
New best loss 0.3811 -> 0.3166


  0%|          | 0/782 [00:00<?, ?it/s]

Epoch train [3/100] | Loss 0.1258 | Metric 0.9541


  0%|          | 0/273 [00:00<?, ?it/s]

Epoch valid [3/100] | Loss 0.3031 | Metric 0.8741
New best loss 0.3166 -> 0.3031


  0%|          | 0/782 [00:00<?, ?it/s]

Epoch train [4/100] | Loss 0.1208 | Metric 0.9555


  0%|          | 0/273 [00:00<?, ?it/s]

Epoch valid [4/100] | Loss 0.2987 | Metric 0.8795
New best loss 0.3031 -> 0.2987


  0%|          | 0/782 [00:00<?, ?it/s]

## 4. Making the final prediction with ensemble method

To create a solutions file, you need to generate a CSV with 2 columns.
- **ID**: containing the ID of the image
- **Pred**: with the predicted class (**threshold the prediction to get either 0 or 1**)
- **Probability**: The output probability of a sample being labeled as 1.

This part was the block that leaded to the greatest results improvements. THe idea is to run several inferences on augmented versions of test images and aggregate them for the final prediction.

We keep the assigned probability to later adapt the classification threshold to gain ~1% in prediction accuracy. This is done in `change_csv.py`.

In [21]:
linear_probing = torch.nn.Sequential(
    torch.nn.Linear(feature_extractor.num_features, 1),
    torch.nn.Sigmoid()
).to(device)
linear_probing.load_state_dict(torch.load('best_model.pth', weights_only=True))
linear_probing.eval()
linear_probing.to(device)
prediction_dict = {}

In [None]:
def apply_test_augmentation(img_tensor, num_augmentations=8):
    augmented_images = [img_tensor]
    
    # Define possible augmentations
    augmentations = [
        # Rotations
        lambda x: TF.rotate(x, 90),
        lambda x: TF.rotate(x, 180),
        lambda x: TF.rotate(x, 270),
        # Flips
        lambda x: TF.hflip(x),
        lambda x: TF.vflip(x),
        # Rots + Flips
        lambda x: TF.hflip(TF.rotate(x, 90)),
        lambda x: TF.vflip(TF.rotate(x, 90)),
        lambda x: TF.hflip(TF.rotate(x, 180)),
        lambda x: TF.vflip(TF.rotate(x, 180)),
        # Color jitter
        lambda x: TF.adjust_brightness(x, 1.2),
        lambda x: TF.adjust_brightness(x, 0.8),
        lambda x: TF.adjust_contrast(x, 1.2),
        lambda x: TF.adjust_contrast(x, 0.8),
    ]
    
    selected_augmentations = augmentations[:min(num_augmentations, len(augmentations))]
    
    for aug_func in selected_augmentations:
        augmented_images.append(aug_func(img_tensor))
    
    return torch.stack(augmented_images)

def normalize_for_dinov2(img_tensor):
    if img_tensor.max() > 1.0:
        img_tensor = img_tensor / 255.0
        
    normalize = transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    )
    return normalize(img_tensor)


In [None]:
def single_ensemble_prediction(model, feature_extractor, img_tensor, device, num_augmentations=5, aggregation_method='mean', threshold=0.5):
    """
    Make prediction on one sample by aggregating predictoins on test augmentation
    """
    augmented_images = apply_test_augmentation(img_tensor, num_augmentations)
    augmented_images = torch.stack([normalize_for_dinov2(img) for img in augmented_images])
    
    predictions = []
    with torch.no_grad():
        for aug_img in augmented_images:
            features = feature_extractor(aug_img.unsqueeze(0).to(device))
            pred = model(features).detach().cpu().item()
            predictions.append(pred)
    
    if aggregation_method == 'mean':
        final_prob = sum(predictions) / len(predictions)
    elif aggregation_method == 'median':
        final_prob = sorted(predictions)[len(predictions)//2]
    elif aggregation_method == 'vote':
        votes = [1 if p > threshold else 0 for p in predictions]
        final_prob = sum(votes) / len(votes)
    else:
        raise ValueError(f"Unknown aggregation method: {aggregation_method}")
    
    return int(final_prob > threshold), final_prob, predictions


def all_enemble_predictions(feature_extractor, linear_probing, test_path, device, num_augmentations=5, aggregation_method='mean'):
    """Generate predictions with ensemble method for all test images"""
    linear_probing.load_state_dict(torch.load('best_model.pth', map_location=device))
    linear_probing.eval()
    feature_extractor.eval()
    
    resize = transforms.Resize((224, 224))
    
    solutions_data = {'ID': [], 'Pred': [], 'Probability': []}
    
    with h5py.File(test_path, 'r') as hdf:
        test_ids = list(hdf.keys())
    
    with h5py.File(test_path, 'r') as hdf:
        for test_id in tqdm(test_ids):
            img = torch.tensor(np.array(hdf.get(test_id).get('img'))).float()
            img = resize(img)
            
            binary_pred, probability, all_predictions = single_ensemble_prediction(
                linear_probing, 
                feature_extractor, 
                img, 
                device, 
                num_augmentations=num_augmentations,
                aggregation_method=aggregation_method
            )
            
            solutions_data['ID'].append(int(test_id))
            solutions_data['Pred'].append(binary_pred)
            solutions_data['Probability'].append(probability)
            
    solutions_df = pd.DataFrame(solutions_data).set_index('ID')
    solutions_df.to_csv(f'tta_{aggregation_method}_{num_augmentations}.csv')
    
    return solutions_df

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
feature_extractor = torch.hub.load('facebookresearch/dinov2', 'dinov2_vits14').to(device)
linear_probing = torch.nn.Sequential(
    torch.nn.Linear(feature_extractor.num_features, 1),
    torch.nn.Sigmoid()
).to(device)

solutions_df, aug_details = all_enemble_predictions(
    feature_extractor=feature_extractor,
    linear_probing=linear_probing,
    test_path=TEST_IMAGES_PATH,
    device=device,
    num_augmentations=8,  # Use 8 different augmentations
    aggregation_method='mean'
)

print(f"Predictions complete, class distributions: {solutions_df['Pred'].value_counts()}")