# üè• MedViT-CAMIL: Mode REAL (HyperKvasir Videos)

**Context-Aware Multiple Instance Learning for Medical Video Analysis**

Ce notebook ex√©cute le mode REAL sur Google Colab avec GPU.

- **Dataset**: HyperKvasir Labeled Videos (~25 GB d'endoscopies)
- **T√¢che**: Classification binaire (normal/pathologique)
- **Comparaison**: Baseline (Average Pooling) vs MedViT-CAMIL (Gated Attention)

---
‚ö° **IMPORTANT**: 
1. Activez le GPU: `Runtime > Change runtime type > T4 GPU`
2. Le dataset est volumineux (~25 GB), pr√©voir du temps de t√©l√©chargement

üíæ **SAUVEGARDE**: Les r√©sultats sont automatiquement sauvegard√©s sur Google Drive

## 0Ô∏è‚É£ Montage Google Drive (PERSISTANCE)

In [None]:
# Monter Google Drive pour sauvegarder les r√©sultats
from google.colab import drive
drive.mount('/content/drive')

# Cr√©er le dossier de r√©sultats
import os
SAVE_DIR = '/content/drive/MyDrive/MedViT_Results/real'
os.makedirs(SAVE_DIR, exist_ok=True)
print(f"‚úÖ Dossier de sauvegarde: {SAVE_DIR}")
print("üìÅ Tous les fichiers seront automatiquement sauvegard√©s ici!")

In [None]:
# V√©rifier le GPU et l'espace disque
!nvidia-smi
!df -h /content

import torch
print(f"\n‚úÖ PyTorch {torch.__version__}")
print(f"‚úÖ CUDA disponible: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"‚úÖ GPU: {torch.cuda.get_device_name(0)}")
else:
    print("‚ö†Ô∏è PAS DE GPU! Activez: Runtime > Change runtime type > T4 GPU")

In [None]:
# Installation des d√©pendances
!pip install -q timm tqdm matplotlib seaborn einops opencv-python-headless

## 1Ô∏è‚É£ T√©l√©chargement du Dataset HyperKvasir Videos

**HyperKvasir** est un dataset d'endoscopies GI du Simula Research Lab.
- 374 vid√©os labellis√©es de proc√©dures gastro-intestinales
- Classes: findings normaux vs pathologiques (polypes, ulc√®res, etc.)

In [None]:
import os
import zipfile
import glob
import shutil
import json
from datetime import datetime

DATA_DIR = '/content/data/hyperkvasir'
ZIP_URL = 'https://datasets.simula.no/downloads/hyper-kvasir/hyper-kvasir-labeled-videos.zip'
ZIP_PATH = '/content/hyper-kvasir-labeled-videos.zip'

# Log de progression
def log_progress(msg):
    print(msg)
    with open(f"{SAVE_DIR}/download_log.txt", 'a') as f:
        f.write(f"{datetime.now()}: {msg}\n")

# T√©l√©charger si n√©cessaire
if not os.path.exists(DATA_DIR) or len(os.listdir(DATA_DIR)) < 2:
    log_progress("üì• T√©l√©chargement de HyperKvasir Videos (~25 GB)...")
    log_progress("‚è≥ Cela peut prendre 10-20 minutes selon la connexion...")
    !wget -q --show-progress -O $ZIP_PATH $ZIP_URL
    
    log_progress("üì¶ Extraction de l'archive...")
    os.makedirs(DATA_DIR, exist_ok=True)
    with zipfile.ZipFile(ZIP_PATH, 'r') as zip_ref:
        zip_ref.extractall('/content/data/extracted')
    
    # Organiser en normal/abnormal
    log_progress("üìÅ Organisation des donn√©es...")
    os.makedirs(f"{DATA_DIR}/normal", exist_ok=True)
    os.makedirs(f"{DATA_DIR}/abnormal", exist_ok=True)
    
    NORMAL_KEYWORDS = ['normal', 'cecum', 'pylorus', 'z-line', 'retroflex']
    ABNORMAL_KEYWORDS = ['polyp', 'ulcer', 'esophagitis', 'colitis', 'hemorrhoid', 'dyed']
    
    for video_path in glob.glob('/content/data/extracted/**/*.mp4', recursive=True):
        folder_name = os.path.dirname(video_path).lower()
        file_name = os.path.basename(video_path)
        
        is_normal = any(kw in folder_name for kw in NORMAL_KEYWORDS)
        is_abnormal = any(kw in folder_name for kw in ABNORMAL_KEYWORDS)
        
        if is_abnormal:
            dest = f"{DATA_DIR}/abnormal/{file_name}"
        elif is_normal:
            dest = f"{DATA_DIR}/normal/{file_name}"
        else:
            continue
        
        if not os.path.exists(dest):
            shutil.copy2(video_path, dest)
    
    # Nettoyage
    log_progress("üßπ Nettoyage...")
    !rm -rf /content/data/extracted
    !rm -f $ZIP_PATH

# Compter les fichiers
normal_count = len(glob.glob(f"{DATA_DIR}/normal/*.mp4"))
abnormal_count = len(glob.glob(f"{DATA_DIR}/abnormal/*.mp4"))

dataset_info = {
    'normal_videos': normal_count,
    'abnormal_videos': abnormal_count,
    'total': normal_count + abnormal_count
}

log_progress(f"‚úÖ Dataset pr√™t: Normal={normal_count}, Abnormal={abnormal_count}")

# Sauvegarder info dataset
with open(f"{SAVE_DIR}/dataset_info.json", 'w') as f:
    json.dump(dataset_info, f, indent=2)
print(f"üíæ Info dataset sauvegard√©e")

## 2Ô∏è‚É£ Configuration

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch.optim.lr_scheduler import CosineAnnealingLR
import numpy as np
import matplotlib.pyplot as plt
import cv2
from torchvision import transforms
from tqdm.auto import tqdm

# Configuration MODE REAL
CONFIG = {
    'MODE': 'real',
    'DEVICE': 'cuda' if torch.cuda.is_available() else 'cpu',
    'SEQ_LEN': 32,          # Frames extraites par vid√©o
    'IMG_SIZE': 224,
    'BATCH_SIZE': 8,        # R√©duit car vid√©os lourdes
    'EPOCHS': 30,
    'LR': 1e-4,
    'WEIGHT_DECAY': 1e-5,
    'NUM_CLASSES': 2,
    'HIDDEN_DIM': 128,
    'SEED': 42,
    'DATA_DIR': DATA_DIR,
    'SAVE_DIR': SAVE_DIR
}

# Reproductibilit√©
torch.manual_seed(CONFIG['SEED'])
np.random.seed(CONFIG['SEED'])
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(CONFIG['SEED'])

print("üìã CONFIGURATION REAL")
for k, v in CONFIG.items():
    print(f"  {k}: {v}")

# Sauvegarder la config imm√©diatement
config_path = f"{SAVE_DIR}/config.json"
with open(config_path, 'w') as f:
    json.dump({k: str(v) for k, v in CONFIG.items()}, f, indent=2)
print(f"\nüíæ Config sauvegard√©e: {config_path}")

## 3Ô∏è‚É£ Dataset: Real Video Loader (OpenCV)

In [None]:
IMAGENET_MEAN = [0.485, 0.456, 0.406]
IMAGENET_STD = [0.229, 0.224, 0.225]

class RealVideoDataset(Dataset):
    """
    Dataset pour vraies vid√©os m√©dicales via OpenCV.
    Extrait SEQ_LEN frames uniform√©ment r√©parties dans chaque vid√©o.
    """
    
    def __init__(self, data_dir, split='train', seq_len=32, img_size=224):
        self.seq_len = seq_len
        self.img_size = img_size
        
        self.transform = transforms.Compose([
            transforms.ToPILImage(),
            transforms.Resize((img_size, img_size)),
            transforms.ToTensor(),
            transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD)
        ])
        
        # Scanner les fichiers
        self.samples = []
        for label, cls in enumerate(['normal', 'abnormal']):
            cls_path = os.path.join(data_dir, cls)
            if os.path.exists(cls_path):
                for f in glob.glob(os.path.join(cls_path, '*.mp4')):
                    self.samples.append((f, label))
        
        # Shuffle et split 80/20
        np.random.shuffle(self.samples)
        cut = int(0.8 * len(self.samples))
        
        if split == 'train':
            self.samples = self.samples[:cut]
        else:
            self.samples = self.samples[cut:]
        
        print(f"[{split.upper()}] {len(self.samples)} vid√©os charg√©es")
    
    def __len__(self):
        return len(self.samples)
    
    def _extract_frames(self, video_path):
        """Extrait SEQ_LEN frames uniform√©ment de la vid√©o."""
        frames = []
        cap = cv2.VideoCapture(video_path)
        total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
        
        if total_frames == 0:
            cap.release()
            return None
        
        # Indices √† extraire (uniform√©ment r√©partis)
        indices = np.linspace(0, total_frames - 1, self.seq_len, dtype=int)
        
        frame_idx = 0
        while True:
            ret, frame = cap.read()
            if not ret:
                break
            
            if frame_idx in indices:
                frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
                frames.append(self.transform(frame))
            
            frame_idx += 1
            if len(frames) >= self.seq_len:
                break
        
        cap.release()
        
        # Padding si n√©cessaire
        if len(frames) == 0:
            return None
        
        while len(frames) < self.seq_len:
            frames.append(frames[-1])  # R√©p√©ter la derni√®re frame
        
        return torch.stack(frames[:self.seq_len])
    
    def __getitem__(self, idx):
        path, label = self.samples[idx]
        
        try:
            video = self._extract_frames(path)
            if video is None:
                video = torch.zeros(self.seq_len, 3, self.img_size, self.img_size)
        except Exception as e:
            print(f"[WARNING] Erreur {path}: {e}")
            video = torch.zeros(self.seq_len, 3, self.img_size, self.img_size)
        
        return video, torch.tensor(label, dtype=torch.long)

# Cr√©er les datasets
train_dataset = RealVideoDataset(CONFIG['DATA_DIR'], 'train', CONFIG['SEQ_LEN'], CONFIG['IMG_SIZE'])
val_dataset = RealVideoDataset(CONFIG['DATA_DIR'], 'val', CONFIG['SEQ_LEN'], CONFIG['IMG_SIZE'])

train_loader = DataLoader(train_dataset, batch_size=CONFIG['BATCH_SIZE'], shuffle=True, num_workers=2, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=CONFIG['BATCH_SIZE'], shuffle=False, num_workers=2, pin_memory=True)

dataset_split = f"Train: {len(train_dataset)} | Val: {len(val_dataset)}"
print(f"\nüìä {dataset_split}")

# Log
with open(f"{SAVE_DIR}/training_log.txt", 'w') as f:
    f.write(f"MedViT-CAMIL REAL Training Log\n")
    f.write(f"Started: {datetime.now()}\n")
    f.write(f"Dataset: {dataset_split}\n")
    f.write("="*60 + "\n")

## 4Ô∏è‚É£ Mod√®les: Backbone MobileViT + Aggregateurs

In [None]:
import timm

class MobileViTBackbone(nn.Module):
    """Backbone MobileViT pr√©-entra√Æn√© (GEL√â)."""
    
    def __init__(self, model_name='mobilevit_s', pretrained=True):
        super().__init__()
        print(f"[INFO] Chargement {model_name}...")
        self.backbone = timm.create_model(model_name, pretrained=pretrained, 
                                          num_classes=0, global_pool='avg')
        
        with torch.no_grad():
            dummy = torch.randn(1, 3, 224, 224)
            self.feature_dim = self.backbone(dummy).shape[-1]
        
        for param in self.backbone.parameters():
            param.requires_grad = False
        self.backbone.eval()
        print(f"[INFO] Feature dim: {self.feature_dim}, Backbone GEL√â")
    
    def forward(self, x):
        B, T, C, H, W = x.shape
        x = x.view(B * T, C, H, W)
        with torch.no_grad():
            features = self.backbone(x)
        return features.view(B, T, -1)
    
    def train(self, mode=True):
        super().train(mode)
        self.backbone.eval()
        return self


class BaselineAvgPooling(nn.Module):
    """Baseline: Moyenne temporelle simple."""
    
    def __init__(self, feature_dim, hidden_dim=128, num_classes=2, dropout=0.3):
        super().__init__()
        self.projection = nn.Sequential(
            nn.Linear(feature_dim, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout)
        )
        self.classifier = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim // 2, num_classes)
        )
    
    def forward(self, features):
        B, T, D = features.shape
        projected = self.projection(features)
        aggregated = projected.mean(dim=1)
        logits = self.classifier(aggregated)
        attention = torch.ones(B, T, device=features.device) / T
        return logits, attention


class ContextAwareGatedMIL(nn.Module):
    """CAMIL: Context-Aware Gated Attention MIL."""
    
    def __init__(self, feature_dim, hidden_dim=128, num_classes=2, dropout=0.3):
        super().__init__()
        
        self.input_projection = nn.Sequential(
            nn.Linear(feature_dim, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.GELU()
        )
        
        self.context_conv = nn.Sequential(
            nn.Conv1d(hidden_dim, hidden_dim, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm1d(hidden_dim),
            nn.GELU()
        )
        
        self.attention_V = nn.Sequential(nn.Linear(hidden_dim, hidden_dim), nn.Tanh())
        self.attention_U = nn.Sequential(nn.Linear(hidden_dim, hidden_dim), nn.Sigmoid())
        self.attention_w = nn.Linear(hidden_dim, 1)
        
        self.dropout = nn.Dropout(dropout)
        
        self.classifier = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.LayerNorm(hidden_dim // 2),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim // 2, num_classes)
        )
    
    def forward(self, features):
        B, T, D = features.shape
        
        h = self.input_projection(features)
        h_conv = self.context_conv(h.permute(0, 2, 1)).permute(0, 2, 1)
        h = h + h_conv
        h = self.dropout(h)
        
        v = self.attention_V(h)
        u = self.attention_U(h)
        gated = v * u
        
        attention_scores = self.attention_w(gated).squeeze(-1)
        attention_weights = F.softmax(attention_scores, dim=1)
        
        aggregated = torch.bmm(attention_weights.unsqueeze(1), h).squeeze(1)
        logits = self.classifier(aggregated)
        
        return logits, attention_weights


class MedViTModel(nn.Module):
    """Mod√®le complet: Backbone + Aggregateur."""
    
    def __init__(self, use_camil=True, hidden_dim=128, num_classes=2):
        super().__init__()
        self.backbone = MobileViTBackbone()
        feature_dim = self.backbone.feature_dim
        
        if use_camil:
            self.aggregator = ContextAwareGatedMIL(feature_dim, hidden_dim, num_classes)
            self.name = "MedViT-CAMIL"
        else:
            self.aggregator = BaselineAvgPooling(feature_dim, hidden_dim, num_classes)
            self.name = "Baseline-AvgPool"
    
    def forward(self, video):
        features = self.backbone(video)
        return self.aggregator(features)

def count_params(model):
    total = sum(p.numel() for p in model.parameters())
    trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
    return total, trainable

## 5Ô∏è‚É£ Fonctions d'entra√Ænement (avec sauvegarde automatique)

In [None]:
def log_message(msg, log_file=f"{SAVE_DIR}/training_log.txt"):
    """Affiche et sauvegarde un message."""
    print(msg)
    with open(log_file, 'a') as f:
        f.write(msg + "\n")

def train_epoch(model, loader, criterion, optimizer, device):
    model.train()
    total_loss, correct, total = 0, 0, 0
    
    for videos, labels in tqdm(loader, desc="Training", leave=False):
        videos, labels = videos.to(device), labels.to(device)
        
        optimizer.zero_grad()
        logits, _ = model(videos)
        loss = criterion(logits, labels)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        
        total_loss += loss.item()
        _, pred = logits.max(1)
        correct += (pred == labels).sum().item()
        total += labels.size(0)
    
    return total_loss / len(loader), correct / total

@torch.no_grad()
def evaluate(model, loader, criterion, device):
    model.eval()
    total_loss, correct, total = 0, 0, 0
    all_attention = []
    
    for videos, labels in loader:
        videos, labels = videos.to(device), labels.to(device)
        logits, attention = model(videos)
        loss = criterion(logits, labels)
        
        total_loss += loss.item()
        _, pred = logits.max(1)
        correct += (pred == labels).sum().item()
        total += labels.size(0)
        all_attention.append(attention.cpu().numpy())
    
    return total_loss / len(loader), correct / total, np.concatenate(all_attention)

def train_model(model, train_loader, val_loader, config):
    device = config['DEVICE']
    save_dir = config['SAVE_DIR']
    model = model.to(device)
    
    total, trainable = count_params(model)
    log_message(f"\n{'='*60}")
    log_message(f"üîß {model.name}")
    log_message(f"   Params: {total:,} total, {trainable:,} entra√Ænables")
    
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.AdamW(
        [p for p in model.parameters() if p.requires_grad],
        lr=config['LR'], weight_decay=config['WEIGHT_DECAY']
    )
    scheduler = CosineAnnealingLR(optimizer, T_max=config['EPOCHS'])
    
    history = {'train_loss': [], 'train_acc': [], 'val_loss': [], 'val_acc': []}
    best_acc = 0
    
    for epoch in range(config['EPOCHS']):
        train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer, device)
        val_loss, val_acc, _ = evaluate(model, val_loader, criterion, device)
        scheduler.step()
        
        history['train_loss'].append(train_loss)
        history['train_acc'].append(train_acc)
        history['val_loss'].append(val_loss)
        history['val_acc'].append(val_acc)
        
        msg = f"Epoch {epoch+1}/{config['EPOCHS']} | Train: {train_loss:.4f} / {train_acc*100:.1f}% | Val: {val_loss:.4f} / {val_acc*100:.1f}%"
        log_message(msg)
        
        # Sauvegarder le meilleur mod√®le
        if val_acc > best_acc:
            best_acc = val_acc
            model_path = f"{save_dir}/{model.name}_best.pth"
            torch.save(model.state_dict(), model_path)
            log_message(f"   üíæ Nouveau meilleur mod√®le: {val_acc*100:.2f}%")
        
        # Sauvegarder l'historique √† chaque epoch (protection d√©connexion)
        history_path = f"{save_dir}/{model.name}_history.json"
        with open(history_path, 'w') as f:
            json.dump(history, f)
    
    log_message(f"‚úÖ {model.name} Best Val Accuracy: {best_acc*100:.2f}%")
    return model, history, best_acc

## 6Ô∏è‚É£ Entra√Ænement BASELINE

In [None]:
log_message(f"\n{'='*60}")
log_message("üèÉ ENTRA√éNEMENT BASELINE (Average Pooling)")
log_message(f"{'='*60}")

model_baseline = MedViTModel(use_camil=False, hidden_dim=CONFIG['HIDDEN_DIM'], num_classes=CONFIG['NUM_CLASSES'])
model_baseline, history_baseline, best_baseline = train_model(model_baseline, train_loader, val_loader, CONFIG)

print(f"\nüíæ Mod√®le baseline sauvegard√© dans {SAVE_DIR}")

## 7Ô∏è‚É£ Entra√Ænement MedViT-CAMIL

In [None]:
log_message(f"\n{'='*60}")
log_message("üèÉ ENTRA√éNEMENT MedViT-CAMIL (Gated Attention)")
log_message(f"{'='*60}")

model_camil = MedViTModel(use_camil=True, hidden_dim=CONFIG['HIDDEN_DIM'], num_classes=CONFIG['NUM_CLASSES'])
model_camil, history_camil, best_camil = train_model(model_camil, train_loader, val_loader, CONFIG)

print(f"\nüíæ Mod√®le CAMIL sauvegard√© dans {SAVE_DIR}")

## 8Ô∏è‚É£ R√©sultats Finaux

In [None]:
criterion = nn.CrossEntropyLoss()
device = CONFIG['DEVICE']

# Charger les meilleurs mod√®les
model_baseline.load_state_dict(torch.load(f"{SAVE_DIR}/Baseline-AvgPool_best.pth"))
model_camil.load_state_dict(torch.load(f"{SAVE_DIR}/MedViT-CAMIL_best.pth"))

# √âvaluer
_, val_acc_baseline, attention_baseline = evaluate(model_baseline, val_loader, criterion, device)
_, val_acc_camil, attention_camil = evaluate(model_camil, val_loader, criterion, device)

improvement = (val_acc_camil - val_acc_baseline) * 100

log_message(f"\n{'='*60}")
log_message("üìä R√âSULTATS FINAUX (HyperKvasir Videos)")
log_message(f"{'='*60}")
log_message(f"Baseline (Avg Pool):  {val_acc_baseline*100:.2f}%")
log_message(f"MedViT-CAMIL:         {val_acc_camil*100:.2f}%")
log_message(f"Am√©lioration:         {improvement:+.2f}%")
log_message(f"{'='*60}")

## 9Ô∏è‚É£ Sauvegarde des R√©sultats

In [None]:
# R√©sultats complets
results = {
    'mode': 'real',
    'dataset': 'HyperKvasir-Videos',
    'timestamp': str(datetime.now()),
    'config': {k: str(v) for k, v in CONFIG.items()},
    'dataset_info': dataset_info,
    'results': {
        'baseline': {
            'val_accuracy': val_acc_baseline,
            'best_val_accuracy': best_baseline
        },
        'camil': {
            'val_accuracy': val_acc_camil,
            'best_val_accuracy': best_camil
        },
        'improvement': val_acc_camil - val_acc_baseline
    },
    'history': {
        'baseline': history_baseline,
        'camil': history_camil
    }
}

# Sauvegarder JSON
results_path = f"{SAVE_DIR}/results_real.json"
with open(results_path, 'w') as f:
    json.dump(results, f, indent=2)

log_message(f"\nüíæ R√©sultats sauvegard√©s: {results_path}")
print("\n" + "="*60)
print(json.dumps(results['results'], indent=2))
print("="*60)

## üîü Visualisations

In [None]:
# Courbes d'entra√Ænement
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

ax = axes[0]
ax.plot(history_baseline['train_loss'], 'r-', label='Baseline Train', linewidth=2)
ax.plot(history_baseline['val_loss'], 'r--', label='Baseline Val', linewidth=2)
ax.plot(history_camil['train_loss'], 'g-', label='CAMIL Train', linewidth=2)
ax.plot(history_camil['val_loss'], 'g--', label='CAMIL Val', linewidth=2)
ax.set_xlabel('Epoch', fontsize=12)
ax.set_ylabel('Loss', fontsize=12)
ax.set_title('Training & Validation Loss', fontsize=14)
ax.legend()
ax.grid(True, alpha=0.3)

ax = axes[1]
ax.plot([a*100 for a in history_baseline['train_acc']], 'r-', label='Baseline Train', linewidth=2)
ax.plot([a*100 for a in history_baseline['val_acc']], 'r--', label='Baseline Val', linewidth=2)
ax.plot([a*100 for a in history_camil['train_acc']], 'g-', label='CAMIL Train', linewidth=2)
ax.plot([a*100 for a in history_camil['val_acc']], 'g--', label='CAMIL Val', linewidth=2)
ax.set_xlabel('Epoch', fontsize=12)
ax.set_ylabel('Accuracy (%)', fontsize=12)
ax.set_title('Training & Validation Accuracy', fontsize=14)
ax.legend()
ax.grid(True, alpha=0.3)

plt.tight_layout()
curves_path = f"{SAVE_DIR}/training_curves_real.png"
plt.savefig(curves_path, dpi=150, bbox_inches='tight')
plt.show()
print(f"üíæ Sauvegard√©: {curves_path}")

In [None]:
# Heatmap d'attention
fig, axes = plt.subplots(2, 1, figsize=(14, 8))

n_samples = min(15, len(attention_baseline))

ax = axes[0]
im = ax.imshow(attention_baseline[:n_samples], aspect='auto', cmap='Reds')
ax.set_ylabel('Vid√©o', fontsize=12)
ax.set_title('Baseline - Attention UNIFORME', fontsize=14)
plt.colorbar(im, ax=ax, label='Poids')

ax = axes[1]
im = ax.imshow(attention_camil[:n_samples], aspect='auto', cmap='Greens')
ax.set_xlabel('Frame (temps)', fontsize=12)
ax.set_ylabel('Vid√©o', fontsize=12)
ax.set_title('MedViT-CAMIL - Attention APPRISE (pics sur moments cl√©s)', fontsize=14)
plt.colorbar(im, ax=ax, label='Poids')

plt.tight_layout()
heatmap_path = f"{SAVE_DIR}/attention_heatmap_real.png"
plt.savefig(heatmap_path, dpi=150, bbox_inches='tight')
plt.show()
print(f"üíæ Sauvegard√©: {heatmap_path}")

## üìã R√©sum√© Final

In [None]:
log_message(f"\n{'='*60}")
log_message("üìã R√âSUM√â FINAL")
log_message(f"{'='*60}")
log_message(f"Termin√©: {datetime.now()}")
log_message(f"")
log_message(f"M√âTRIQUES:")
log_message(f"  Baseline Val Accuracy: {val_acc_baseline*100:.2f}%")
log_message(f"  CAMIL Val Accuracy:    {val_acc_camil*100:.2f}%")
log_message(f"  Am√©lioration:          {improvement:+.2f}%")
log_message(f"")
log_message(f"FICHIERS SAUVEGARD√âS dans {SAVE_DIR}:")
log_message(f"  üìÑ config.json")
log_message(f"  üìÑ results_real.json")
log_message(f"  üìÑ training_log.txt")
log_message(f"  üìÑ dataset_info.json")
log_message(f"  üñºÔ∏è training_curves_real.png")
log_message(f"  üñºÔ∏è attention_heatmap_real.png")
log_message(f"  ü§ñ Baseline-AvgPool_best.pth")
log_message(f"  ü§ñ MedViT-CAMIL_best.pth")
log_message(f"{'='*60}")

# Lister les fichiers
print("\nüìÅ Fichiers dans Google Drive:")
!ls -la $SAVE_DIR

## üì• Instructions pour r√©cup√©rer les r√©sultats

Les fichiers sont automatiquement sauvegard√©s dans **Google Drive** :
`/MyDrive/MedViT_Results/real/`

### Pour analyser les r√©sultats avec l'assistant:
1. Va dans Google Drive ‚Üí MedViT_Results ‚Üí real
2. T√©l√©charge les fichiers:
   - `results_real.json` (m√©triques)
   - `training_curves_real.png` (courbes)
   - `attention_heatmap_real.png` (attention)
3. D√©pose-les dans le dossier `results/` de ton projet local
4. Demande √† l'assistant d'analyser les r√©sultats!