# PANet - Few-Shot Segmentation pour CT-Scans Médicaux

**Date:** Janvier 2026  
**Score obtenu:** 0.32 (Top 3: 0.34)

## Description

Ce notebook implémente PANet (Prototype Alignment Network) pour la segmentation few-shot de CT-scans médicaux.

**Caractéristiques:**
- 105 classes de structures anatomiques
- Entraînement from scratch (sans modèle pré-entraîné)
- Configuration 5-way 5-shot
- ResNet-18 comme encodeur

**Référence:**  
Wang et al., "PANet: Few-Shot Image Semantic Segmentation with Prototype Alignment", ICCV 2019

## 1. Imports et Configuration

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, Sampler
from torch.optim.lr_scheduler import CosineAnnealingLR

import numpy as np
import pandas as pd
from PIL import Image
from pathlib import Path
from typing import Dict, List, Tuple, Optional
from dataclasses import dataclass, field
from tqdm import tqdm
import random
from datetime import datetime
import json

# Vérification GPU
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device: {DEVICE}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")

## 2. Configuration

In [None]:
@dataclass
class Config:
    """Configuration du projet PANet."""
    
    # Chemins (à adapter selon votre environnement)
    train_images: Path = Path("data/X_train/images")
    test_images: Path = Path("data/X_test/images")
    train_labels: Path = Path("data/Y_train.csv")
    output_dir: Path = Path("outputs")
    
    # Modèle
    encoder_name: str = "resnet18"
    feature_dim: int = 256
    use_multiscale: bool = True
    
    # Entraînement épisodique
    n_way: int = 5
    k_shot: int = 5  # 5-shot pour de meilleurs résultats
    n_query: int = 5
    
    # Optimisation
    learning_rate: float = 1e-3
    weight_decay: float = 1e-4
    num_epochs: int = 50
    train_episodes: int = 10000
    val_episodes: int = 500
    patience: int = 15  # Early stopping
    
    # Images
    input_size: Tuple[int, int] = (256, 256)
    original_size: Tuple[int, int] = (512, 512)
    num_classes: int = 105
    
    # Seed
    seed: int = 42

config = Config()
print(f"Configuration: {config.n_way}-way {config.k_shot}-shot")

## 3. Encodeur ResNet-18 (From Scratch)

In [None]:
class BasicBlock(nn.Module):
    """Bloc résiduel de base pour ResNet-18."""
    expansion = 1

    def __init__(self, in_channels, out_channels, stride=1, downsample=None):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3,
                               stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3,
                               stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.downsample = downsample

    def forward(self, x):
        identity = x
        out = self.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        if self.downsample is not None:
            identity = self.downsample(x)
        out += identity
        return self.relu(out)


class ResNetEncoder(nn.Module):
    """ResNet-18 encodeur adapté pour images grayscale."""

    def __init__(self, in_channels=1, feature_dim=256):
        super().__init__()
        self.in_planes = 64
        self.feature_dim = feature_dim

        # Couche initiale
        self.conv1 = nn.Conv2d(in_channels, 64, kernel_size=7, stride=2, padding=3, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

        # Couches résiduelles
        self.layer1 = self._make_layer(64, 2, stride=1)
        self.layer2 = self._make_layer(128, 2, stride=2)
        self.layer3 = self._make_layer(256, 2, stride=2)
        self.layer4 = self._make_layer(512, 2, stride=2)

        # Projections vers dimension commune
        self.proj2 = nn.Conv2d(128, feature_dim, kernel_size=1)
        self.proj3 = nn.Conv2d(256, feature_dim, kernel_size=1)
        self.proj4 = nn.Conv2d(512, feature_dim, kernel_size=1)

        self._init_weights()

    def _make_layer(self, planes, blocks, stride=1):
        downsample = None
        if stride != 1 or self.in_planes != planes:
            downsample = nn.Sequential(
                nn.Conv2d(self.in_planes, planes, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(planes),
            )
        layers = [BasicBlock(self.in_planes, planes, stride, downsample)]
        self.in_planes = planes
        for _ in range(1, blocks):
            layers.append(BasicBlock(self.in_planes, planes))
        return nn.Sequential(*layers)

    def _init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

    def forward(self, x, return_all_features=False):
        x = self.maxpool(self.relu(self.bn1(self.conv1(x))))
        x = self.layer1(x)
        f2 = self.proj2(self.layer2(x))
        f3 = self.proj3(self.layer3(self.layer2(x)))
        f4 = self.proj4(self.layer4(self.layer3(self.layer2(x))))

        if return_all_features:
            return {"layer2": f2, "layer3": f3, "layer4": f4}
        return f4


def get_encoder(in_channels=1, feature_dim=256):
    return ResNetEncoder(in_channels=in_channels, feature_dim=feature_dim)

## 4. PANet - Prototype Alignment Network

In [None]:
class PrototypeComputation(nn.Module):
    """Calcul des prototypes par masked average pooling."""

    def forward(self, support_features, support_masks, class_ids):
        N, C, H, W = support_features.shape
        _, H_mask, W_mask = support_masks.shape

        if (H, W) != (H_mask, W_mask):
            support_masks = F.interpolate(
                support_masks.unsqueeze(1).float(),
                size=(H, W), mode="nearest"
            ).squeeze(1).long()

        prototypes = {}
        for class_id in class_ids:
            binary_mask = (support_masks == class_id).float()
            masked_features = support_features * binary_mask.unsqueeze(1)
            sum_features = masked_features.sum(dim=(0, 2, 3))
            num_pixels = binary_mask.sum() + 1e-6
            prototypes[class_id] = sum_features / num_pixels

        return prototypes


class PrototypeMatching(nn.Module):
    """Classification par distance cosine aux prototypes."""

    def __init__(self, temperature=1.0):
        super().__init__()
        self.temperature = temperature

    def forward(self, query_features, prototypes):
        B, C, H, W = query_features.shape
        class_ids = sorted(prototypes.keys())

        proto_stack = torch.stack([prototypes[c] for c in class_ids], dim=0)
        features_flat = query_features.permute(0, 2, 3, 1).reshape(B, H * W, C)

        features_norm = F.normalize(features_flat, p=2, dim=-1)
        proto_norm = F.normalize(proto_stack, p=2, dim=-1)
        scores = torch.matmul(features_norm, proto_norm.t())
        scores = scores / self.temperature
        scores = scores.permute(0, 2, 1).reshape(B, len(class_ids), H, W)

        return scores, class_ids


class PANet(nn.Module):
    """PANet complet pour segmentation few-shot."""

    def __init__(self, encoder, feature_dim=256, use_multiscale=True):
        super().__init__()
        self.encoder = encoder
        self.feature_dim = feature_dim
        self.use_multiscale = use_multiscale

        self.prototype_computation = PrototypeComputation()
        self.prototype_matching = PrototypeMatching(temperature=1.0)

        if use_multiscale:
            self.fusion = nn.Sequential(
                nn.Conv2d(feature_dim * 3, feature_dim, kernel_size=1, bias=False),
                nn.BatchNorm2d(feature_dim),
                nn.ReLU(inplace=True),
            )

    def _extract_features(self, images):
        features = self.encoder(images, return_all_features=True)
        if self.use_multiscale:
            target_size = features["layer2"].shape[2:]
            upsampled = []
            for name in ["layer2", "layer3", "layer4"]:
                feat = features[name]
                if feat.shape[2:] != target_size:
                    feat = F.interpolate(feat, size=target_size, mode="bilinear", align_corners=False)
                upsampled.append(feat)
            return self.fusion(torch.cat(upsampled, dim=1))
        return features["layer4"]

    def forward(self, support_images, support_masks, query_images, query_masks, class_ids):
        support_features = self._extract_features(support_images)
        query_features = self._extract_features(query_images)

        prototypes = self.prototype_computation(support_features, support_masks, class_ids)
        scores, ordered_classes = self.prototype_matching(query_features, prototypes)

        # Loss de segmentation
        target_size = scores.shape[2:]
        query_masks_resized = F.interpolate(
            query_masks.unsqueeze(1).float(), size=target_size, mode="nearest"
        ).squeeze(1).long()

        class_to_idx = {c: i for i, c in enumerate(ordered_classes)}
        remapped_masks = torch.full_like(query_masks_resized, -1)
        for class_id in class_ids:
            if class_id in class_to_idx:
                remapped_masks[query_masks_resized == class_id] = class_to_idx[class_id]

        loss = F.cross_entropy(scores, remapped_masks, ignore_index=-1)

        # Prédictions
        pred_indices = scores.argmax(dim=1)
        idx_to_class = {i: c for i, c in enumerate(ordered_classes)}
        predictions = torch.zeros_like(pred_indices)
        for idx, class_id in idx_to_class.items():
            predictions[pred_indices == idx] = class_id

        if predictions.shape[1:] != query_masks.shape[1:]:
            predictions = F.interpolate(
                predictions.unsqueeze(1).float(), size=query_masks.shape[1:], mode="nearest"
            ).squeeze(1).long()

        return loss, predictions

    @torch.no_grad()
    def predict(self, support_images, support_masks, query_images, class_ids, output_size=None):
        self.eval()
        support_features = self._extract_features(support_images)
        query_features = self._extract_features(query_images)

        prototypes = self.prototype_computation(support_features, support_masks, class_ids)
        scores, ordered_classes = self.prototype_matching(query_features, prototypes)

        pred_indices = scores.argmax(dim=1)
        idx_to_class = {i: c for i, c in enumerate(ordered_classes)}
        predictions = torch.zeros_like(pred_indices)
        for idx, class_id in idx_to_class.items():
            predictions[pred_indices == idx] = class_id

        if output_size is not None:
            predictions = F.interpolate(
                predictions.unsqueeze(1).float(), size=output_size, mode="nearest"
            ).squeeze(1).long()

        return predictions

## 5. Dataset CT-Scan

In [None]:
def load_annotations(csv_path, limit=None):
    """Charge les annotations depuis le CSV."""
    print(f"Chargement des annotations depuis {csv_path}...")
    df = pd.read_csv(csv_path, index_col=0)
    
    if limit:
        df = df.iloc[:, :limit]
    
    masks = {}
    class_set = set()
    
    for col in df.columns:
        values = df[col].values
        h = w = int(np.sqrt(len(values)))
        mask = values.reshape(h, w).astype(np.int16)
        masks[col] = mask
        class_set.update(np.unique(mask))
    
    print(f"  {len(masks)} images chargées")
    print(f"  {len(class_set)} classes trouvées")
    
    return masks, sorted(class_set)


class CTScanDataset(Dataset):
    """Dataset pour CT-scans avec masques de segmentation."""

    def __init__(self, image_dir, masks, target_size=(256, 256)):
        self.image_dir = Path(image_dir)
        self.masks = masks
        self.target_size = target_size
        self.image_names = sorted(masks.keys())
        
        # Index classe -> images
        self.class_to_images = {}
        for name, mask in masks.items():
            for c in np.unique(mask):
                if c > 0:
                    self.class_to_images.setdefault(c, []).append(name)
        
        self.available_classes = sorted(self.class_to_images.keys())
        print(f"Dataset initialisé: {len(self.image_names)} images, {len(self.available_classes)} classes")

    def __len__(self):
        return len(self.image_names)

    def __getitem__(self, idx):
        name = self.image_names[idx]
        
        # Charger image
        img_path = self.image_dir / name
        image = Image.open(img_path).convert('L')
        image = image.resize(self.target_size, Image.BILINEAR)
        image_np = np.array(image, dtype=np.float32) / 255.0
        image_np = (image_np - 0.5) / 0.5
        
        # Charger masque
        mask = self.masks[name]
        mask_pil = Image.fromarray(mask.astype(np.int32))
        mask_pil = mask_pil.resize(self.target_size, Image.NEAREST)
        mask_np = np.array(mask_pil, dtype=np.int64)
        
        return {
            'image': torch.from_numpy(image_np).unsqueeze(0),
            'mask': torch.from_numpy(mask_np),
            'name': name,
        }

## 6. Sampler Épisodique

In [None]:
class EpisodicSampler(Sampler):
    """Sampler pour entraînement épisodique N-way K-shot."""

    def __init__(self, dataset, n_way=5, k_shot=5, n_query=5, episodes=1000):
        super().__init__()
        self.dataset = dataset
        self.n_way = n_way
        self.k_shot = k_shot
        self.n_query = n_query
        self.episodes = episodes
        
        # Filtrer classes avec assez d'images
        min_images = k_shot + n_query
        self.valid_classes = [
            c for c, imgs in dataset.class_to_images.items()
            if len(imgs) >= min_images
        ]
        print(f"EpisodicSampler: {len(self.valid_classes)} classes valides, {n_way}-way {k_shot}-shot")

    def __iter__(self):
        for _ in range(self.episodes):
            if len(self.valid_classes) < self.n_way:
                classes = random.choices(self.valid_classes, k=self.n_way)
            else:
                classes = random.sample(self.valid_classes, self.n_way)
            
            support_indices, query_indices = [], []
            
            for c in classes:
                images = self.dataset.class_to_images[c]
                selected = random.sample(images, min(len(images), self.k_shot + self.n_query))
                
                for img_name in selected[:self.k_shot]:
                    support_indices.append(self.dataset.image_names.index(img_name))
                for img_name in selected[self.k_shot:self.k_shot + self.n_query]:
                    query_indices.append(self.dataset.image_names.index(img_name))
            
            yield {'support': support_indices, 'query': query_indices, 'classes': classes}

    def __len__(self):
        return self.episodes

## 7. Métriques

In [None]:
def compute_iou(pred, target, num_classes=105):
    """Calcule le IoU moyen."""
    ious = []
    for c in range(num_classes):
        pred_c = (pred == c)
        target_c = (target == c)
        intersection = (pred_c & target_c).sum().float()
        union = (pred_c | target_c).sum().float()
        if union > 0:
            ious.append((intersection / union).item())
    return np.mean(ious) if ious else 0.0

## 8. Entraînement

In [None]:
def train_epoch(model, dataset, sampler, optimizer, device):
    """Entraîne le modèle pendant une epoch."""
    model.train()
    total_loss, total_iou = 0, 0
    
    for episode in tqdm(sampler, desc="Training"):
        support_data = [dataset[i] for i in episode['support']]
        query_data = [dataset[i] for i in episode['query']]
        
        support_images = torch.stack([d['image'] for d in support_data]).to(device)
        support_masks = torch.stack([d['mask'] for d in support_data]).to(device)
        query_images = torch.stack([d['image'] for d in query_data]).to(device)
        query_masks = torch.stack([d['mask'] for d in query_data]).to(device)
        
        optimizer.zero_grad()
        loss, predictions = model(support_images, support_masks, query_images, query_masks, episode['classes'])
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
        total_iou += compute_iou(predictions.cpu(), query_masks.cpu())
    
    n = len(sampler)
    return total_loss / n, total_iou / n


@torch.no_grad()
def validate(model, dataset, sampler, device):
    """Valide le modèle."""
    model.eval()
    total_loss, total_iou = 0, 0
    
    for episode in tqdm(sampler, desc="Validation"):
        support_data = [dataset[i] for i in episode['support']]
        query_data = [dataset[i] for i in episode['query']]
        
        support_images = torch.stack([d['image'] for d in support_data]).to(device)
        support_masks = torch.stack([d['mask'] for d in support_data]).to(device)
        query_images = torch.stack([d['image'] for d in query_data]).to(device)
        query_masks = torch.stack([d['mask'] for d in query_data]).to(device)
        
        loss, predictions = model(support_images, support_masks, query_images, query_masks, episode['classes'])
        
        total_loss += loss.item()
        total_iou += compute_iou(predictions.cpu(), query_masks.cpu())
    
    n = len(sampler)
    return total_loss / n, total_iou / n

## 9. Boucle d'Entraînement Principale

In [None]:
def train_model(config):
    """Entraîne le modèle PANet."""
    
    # Seed
    torch.manual_seed(config.seed)
    np.random.seed(config.seed)
    random.seed(config.seed)
    
    # Créer le modèle
    encoder = get_encoder(in_channels=1, feature_dim=config.feature_dim)
    model = PANet(encoder, feature_dim=config.feature_dim, use_multiscale=config.use_multiscale)
    model = model.to(DEVICE)
    
    print(f"Paramètres: {sum(p.numel() for p in model.parameters()):,}")
    
    # Charger les données
    masks, classes = load_annotations(config.train_labels)
    
    # Split train/val
    all_names = list(masks.keys())
    random.shuffle(all_names)
    split = int(0.8 * len(all_names))
    
    train_masks = {n: masks[n] for n in all_names[:split]}
    val_masks = {n: masks[n] for n in all_names[split:]}
    
    train_dataset = CTScanDataset(config.train_images, train_masks, config.input_size)
    val_dataset = CTScanDataset(config.train_images, val_masks, config.input_size)
    
    train_sampler = EpisodicSampler(train_dataset, config.n_way, config.k_shot, config.n_query, config.train_episodes)
    val_sampler = EpisodicSampler(val_dataset, config.n_way, config.k_shot, config.n_query, config.val_episodes)
    
    # Optimisation
    optimizer = torch.optim.Adam(model.parameters(), lr=config.learning_rate, weight_decay=config.weight_decay)
    scheduler = CosineAnnealingLR(optimizer, T_max=config.num_epochs)
    
    # Entraînement
    best_iou = 0
    patience_counter = 0
    history = {'train_loss': [], 'train_iou': [], 'val_loss': [], 'val_iou': []}
    
    output_dir = config.output_dir / f"run_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
    output_dir.mkdir(parents=True, exist_ok=True)
    
    for epoch in range(1, config.num_epochs + 1):
        print(f"\nEpoch {epoch}/{config.num_epochs}")
        
        train_loss, train_iou = train_epoch(model, train_dataset, train_sampler, optimizer, DEVICE)
        val_loss, val_iou = validate(model, val_dataset, val_sampler, DEVICE)
        scheduler.step()
        
        history['train_loss'].append(train_loss)
        history['train_iou'].append(train_iou)
        history['val_loss'].append(val_loss)
        history['val_iou'].append(val_iou)
        
        print(f"  Train Loss: {train_loss:.4f}, Train IoU: {train_iou*100:.2f}%")
        print(f"  Val Loss: {val_loss:.4f}, Val IoU: {val_iou*100:.2f}%")
        
        # Sauvegarder meilleur modèle
        if val_iou > best_iou:
            best_iou = val_iou
            patience_counter = 0
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'val_iou': val_iou,
            }, output_dir / 'best_model.pt')
            print(f"  -> Meilleur modèle sauvegardé (IoU: {val_iou*100:.2f}%)")
        else:
            patience_counter += 1
            if patience_counter >= config.patience:
                print(f"\nEarly stopping après {epoch} epochs")
                break
        
        # Sauvegarder historique
        with open(output_dir / 'history.json', 'w') as f:
            json.dump(history, f, indent=2)
    
    print(f"\nEntraînement terminé! Meilleur IoU: {best_iou*100:.2f}%")
    return model, output_dir


# Décommenter pour lancer l'entraînement
# model, output_dir = train_model(config)

## 10. Prédiction et Soumission

In [None]:
class TestDataset(Dataset):
    """Dataset pour images de test."""
    
    def __init__(self, image_dir, target_size=(256, 256)):
        self.image_dir = Path(image_dir)
        self.target_size = target_size
        self.image_paths = sorted(
            list(self.image_dir.glob("*.png")) + list(self.image_dir.glob("*.jpg")),
            key=lambda p: int("".join(c for c in p.stem if c.isdigit()) or 0)
        )
        print(f"TestDataset: {len(self.image_paths)} images")

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        path = self.image_paths[idx]
        image = Image.open(path).convert('L')
        image = image.resize(self.target_size, Image.BILINEAR)
        image_np = np.array(image, dtype=np.float32) / 255.0
        image_np = (image_np - 0.5) / 0.5
        return {
            'image': torch.from_numpy(image_np).unsqueeze(0),
            'name': path.name,
        }


def generate_submission(model, test_dataset, support_images, support_masks, class_ids, output_path, device):
    """Génère le fichier de soumission."""
    model.eval()
    predictions = {}
    
    support_images = support_images.to(device)
    support_masks = support_masks.to(device)
    
    for idx in tqdm(range(len(test_dataset)), desc="Prédiction"):
        sample = test_dataset[idx]
        query_image = sample['image'].unsqueeze(0).to(device)
        
        pred = model.predict(support_images, support_masks, query_image, class_ids, output_size=(512, 512))
        predictions[sample['name']] = pred[0].cpu().numpy().astype(np.int16)
    
    # Créer CSV
    sorted_names = sorted(predictions.keys(), key=lambda n: int("".join(c for c in n if c.isdigit()) or 0))
    data = {name: predictions[name].flatten() for name in sorted_names}
    
    df = pd.DataFrame(data)
    df.index = [f"Pixel {i}" for i in range(512*512)]
    df.to_csv(output_path)
    
    print(f"Soumission sauvegardée: {output_path}")


# Exemple d'utilisation:
# test_dataset = TestDataset(config.test_images, config.input_size)
# generate_submission(model, test_dataset, support_images, support_masks, class_ids, 'submission.csv', DEVICE)

## 11. Chargement d'un Modèle Existant

In [None]:
def load_model(model_path, device):
    """Charge un modèle pré-entraîné."""
    encoder = get_encoder(in_channels=1, feature_dim=256)
    model = PANet(encoder, feature_dim=256, use_multiscale=True)
    
    checkpoint = torch.load(model_path, map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'])
    model = model.to(device)
    model.eval()
    
    print(f"Modèle chargé: {model_path}")
    print(f"  Epoch: {checkpoint.get('epoch', '?')}")
    print(f"  Val IoU: {checkpoint.get('val_iou', 0)*100:.2f}%")
    
    return model


# Exemple:
# model = load_model('outputs/best_model.pt', DEVICE)

---

## Résultats Obtenus

| Configuration | Val IoU | Score Test |
|---------------|---------|------------|
| 5-way 1-shot | 24.13% | 18.5% |
| **5-way 5-shot** | **38.30%** | **32%** |

**Top 3 du challenge: 34%**