# üè• MedViT-CAMIL: Mode PROXY (NoduleMNIST3D)

**Context-Aware Multiple Instance Learning for Medical Video Analysis**

Ce notebook ex√©cute le mode PROXY sur Google Colab avec GPU.

- **Dataset**: NoduleMNIST3D (volumes CT 3D ‚Üí s√©quences 2D)
- **T√¢che**: Classification binaire (nodule b√©nin/malin)
- **Comparaison**: Baseline (Average Pooling) vs MedViT-CAMIL (Gated Attention)

---
‚ö° **IMPORTANT**: Activez le GPU avant d'ex√©cuter: `Runtime > Change runtime type > T4 GPU`

üíæ **SAUVEGARDE**: Les r√©sultats sont automatiquement sauvegard√©s sur Google Drive

## 0Ô∏è‚É£ Montage Google Drive (PERSISTANCE)

In [None]:
# Monter Google Drive pour sauvegarder les r√©sultats
from google.colab import drive
drive.mount('/content/drive')

# Cr√©er le dossier de r√©sultats
import os
SAVE_DIR = '/content/drive/MyDrive/MedViT_Results/proxy'
os.makedirs(SAVE_DIR, exist_ok=True)
print(f"‚úÖ Dossier de sauvegarde: {SAVE_DIR}")
print("üìÅ Tous les fichiers seront automatiquement sauvegard√©s ici!")

In [None]:
# V√©rifier le GPU
!nvidia-smi

import torch
print(f"\n‚úÖ PyTorch {torch.__version__}")
print(f"‚úÖ CUDA disponible: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"‚úÖ GPU: {torch.cuda.get_device_name(0)}")
else:
    print("‚ö†Ô∏è PAS DE GPU! Activez: Runtime > Change runtime type > T4 GPU")

In [None]:
# Installation des d√©pendances
!pip install -q timm medmnist tqdm matplotlib seaborn einops

## 1Ô∏è‚É£ Configuration

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch.optim.lr_scheduler import CosineAnnealingLR
import numpy as np
import matplotlib.pyplot as plt
import json
from datetime import datetime
from torchvision import transforms
from tqdm.auto import tqdm
import medmnist
from medmnist import NoduleMNIST3D

# Configuration MODE PROXY
CONFIG = {
    'MODE': 'proxy',
    'DEVICE': 'cuda' if torch.cuda.is_available() else 'cpu',
    'SEQ_LEN': 28,          # Slices du volume 3D
    'IMG_SIZE': 224,
    'BATCH_SIZE': 16,
    'EPOCHS': 30,
    'LR': 1e-4,
    'WEIGHT_DECAY': 1e-5,
    'NUM_CLASSES': 2,
    'HIDDEN_DIM': 128,
    'SEED': 42,
    'SAVE_DIR': SAVE_DIR
}

# Reproductibilit√©
torch.manual_seed(CONFIG['SEED'])
np.random.seed(CONFIG['SEED'])
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(CONFIG['SEED'])

print("üìã CONFIGURATION PROXY")
for k, v in CONFIG.items():
    print(f"  {k}: {v}")

# Sauvegarder la config imm√©diatement
config_path = f"{SAVE_DIR}/config.json"
with open(config_path, 'w') as f:
    json.dump({k: str(v) for k, v in CONFIG.items()}, f, indent=2)
print(f"\nüíæ Config sauvegard√©e: {config_path}")

## 2Ô∏è‚É£ Dataset: NoduleMNIST3D

In [None]:
IMAGENET_MEAN = [0.485, 0.456, 0.406]
IMAGENET_STD = [0.229, 0.224, 0.225]

class ProxyDataset(Dataset):
    """
    Dataset proxy: NoduleMNIST3D transform√© en s√©quence 2D.
    Volume 28x28x28 ‚Üí S√©quence de 28 images 224x224.
    """
    
    def __init__(self, split='train', img_size=224):
        print(f"[INFO] Chargement NoduleMNIST3D ({split})...")
        self.data = NoduleMNIST3D(split=split, download=True, as_rgb=False)
        self.img_size = img_size
        
        self.transform = transforms.Compose([
            transforms.ToPILImage(),
            transforms.Resize((img_size, img_size)),
            transforms.ToTensor(),
            transforms.Lambda(lambda x: x.repeat(3, 1, 1)),  # Gray‚ÜíRGB
            transforms.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD)
        ])
        print(f"[INFO] {len(self.data)} volumes charg√©s")
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        volume, label = self.data[idx]
        volume = np.array(volume)
        
        if volume.ndim == 4:
            volume = volume[0]
        
        frames = []
        for z in range(volume.shape[0]):
            slice_2d = volume[z]
            slice_2d = ((slice_2d - slice_2d.min()) / (slice_2d.max() - slice_2d.min() + 1e-8) * 255).astype(np.uint8)
            frames.append(self.transform(slice_2d))
        
        video = torch.stack(frames)
        label = torch.tensor(int(label.item() > 0), dtype=torch.long)
        return video, label

# Cr√©er les datasets
train_dataset = ProxyDataset('train', CONFIG['IMG_SIZE'])
val_dataset = ProxyDataset('val', CONFIG['IMG_SIZE'])
test_dataset = ProxyDataset('test', CONFIG['IMG_SIZE'])

# num_workers=0 pour √©viter les erreurs multiprocessing sur Colab
train_loader = DataLoader(train_dataset, batch_size=CONFIG['BATCH_SIZE'], shuffle=True, num_workers=0, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=CONFIG['BATCH_SIZE'], shuffle=False, num_workers=0, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=CONFIG['BATCH_SIZE'], shuffle=False, num_workers=0, pin_memory=True)

dataset_info = f"Train: {len(train_dataset)} | Val: {len(val_dataset)} | Test: {len(test_dataset)}"
print(f"\nüìä {dataset_info}")

# Log
with open(f"{SAVE_DIR}/training_log.txt", 'w') as f:
    f.write(f"MedViT-CAMIL PROXY Training Log\n")
    f.write(f"Started: {datetime.now()}\n")
    f.write(f"Dataset: {dataset_info}\n")
    f.write("="*60 + "\n")

## 3Ô∏è‚É£ Mod√®les: Backbone MobileViT + Aggregateurs

In [None]:
import timm

class MobileViTBackbone(nn.Module):
    """Backbone MobileViT pr√©-entra√Æn√© (GEL√â)."""
    
    def __init__(self, model_name='mobilevit_s', pretrained=True):
        super().__init__()
        print(f"[INFO] Chargement {model_name}...")
        self.backbone = timm.create_model(model_name, pretrained=pretrained, 
                                          num_classes=0, global_pool='avg')
        
        with torch.no_grad():
            dummy = torch.randn(1, 3, 224, 224)
            self.feature_dim = self.backbone(dummy).shape[-1]
        
        for param in self.backbone.parameters():
            param.requires_grad = False
        self.backbone.eval()
        print(f"[INFO] Feature dim: {self.feature_dim}, Backbone GEL√â")
    
    def forward(self, x):
        B, T, C, H, W = x.shape
        x = x.view(B * T, C, H, W)
        with torch.no_grad():
            features = self.backbone(x)
        return features.view(B, T, -1)
    
    def train(self, mode=True):
        super().train(mode)
        self.backbone.eval()
        return self


class BaselineAvgPooling(nn.Module):
    """Baseline: Moyenne temporelle simple."""
    
    def __init__(self, feature_dim, hidden_dim=128, num_classes=2, dropout=0.3):
        super().__init__()
        self.projection = nn.Sequential(
            nn.Linear(feature_dim, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout)
        )
        self.classifier = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim // 2, num_classes)
        )
    
    def forward(self, features):
        B, T, D = features.shape
        projected = self.projection(features)
        aggregated = projected.mean(dim=1)  # Simple average
        logits = self.classifier(aggregated)
        attention = torch.ones(B, T, device=features.device) / T  # Uniform
        return logits, attention


class ContextAwareGatedMIL(nn.Module):
    """CAMIL: Context-Aware Gated Attention MIL."""
    
    def __init__(self, feature_dim, hidden_dim=128, num_classes=2, dropout=0.3):
        super().__init__()
        
        self.input_projection = nn.Sequential(
            nn.Linear(feature_dim, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.GELU()
        )
        
        # Conv1D pour contexte temporel (KERNEL=3)
        self.context_conv = nn.Sequential(
            nn.Conv1d(hidden_dim, hidden_dim, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm1d(hidden_dim),
            nn.GELU()
        )
        
        # Gated Attention branches
        self.attention_V = nn.Sequential(nn.Linear(hidden_dim, hidden_dim), nn.Tanh())
        self.attention_U = nn.Sequential(nn.Linear(hidden_dim, hidden_dim), nn.Sigmoid())
        self.attention_w = nn.Linear(hidden_dim, 1)
        
        self.dropout = nn.Dropout(dropout)
        
        self.classifier = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.LayerNorm(hidden_dim // 2),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim // 2, num_classes)
        )
    
    def forward(self, features):
        B, T, D = features.shape
        
        # Projection
        h = self.input_projection(features)
        
        # Contexte temporel (Conv1D)
        h_conv = self.context_conv(h.permute(0, 2, 1)).permute(0, 2, 1)
        h = h + h_conv  # Residual
        h = self.dropout(h)
        
        # Gated Attention
        v = self.attention_V(h)  # Tanh branch
        u = self.attention_U(h)  # Sigmoid branch
        gated = v * u            # Element-wise gating
        
        attention_scores = self.attention_w(gated).squeeze(-1)
        attention_weights = F.softmax(attention_scores, dim=1)
        
        # Weighted aggregation
        aggregated = torch.bmm(attention_weights.unsqueeze(1), h).squeeze(1)
        logits = self.classifier(aggregated)
        
        return logits, attention_weights


class MedViTModel(nn.Module):
    """Mod√®le complet: Backbone + Aggregateur."""
    
    def __init__(self, use_camil=True, hidden_dim=128, num_classes=2):
        super().__init__()
        self.backbone = MobileViTBackbone()
        feature_dim = self.backbone.feature_dim
        
        if use_camil:
            self.aggregator = ContextAwareGatedMIL(feature_dim, hidden_dim, num_classes)
            self.name = "MedViT-CAMIL"
        else:
            self.aggregator = BaselineAvgPooling(feature_dim, hidden_dim, num_classes)
            self.name = "Baseline-AvgPool"
    
    def forward(self, video):
        features = self.backbone(video)
        return self.aggregator(features)

def count_params(model):
    total = sum(p.numel() for p in model.parameters())
    trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
    return total, trainable

## 4Ô∏è‚É£ Fonctions d'entra√Ænement (avec sauvegarde automatique)

In [None]:
def log_message(msg, log_file=f"{SAVE_DIR}/training_log.txt"):
    """Affiche et sauvegarde un message."""
    print(msg)
    with open(log_file, 'a') as f:
        f.write(msg + "\n")

def train_epoch(model, loader, criterion, optimizer, device):
    model.train()
    total_loss, correct, total = 0, 0, 0
    
    for videos, labels in tqdm(loader, desc="Training", leave=False):
        videos, labels = videos.to(device), labels.to(device)
        
        optimizer.zero_grad()
        logits, _ = model(videos)
        loss = criterion(logits, labels)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        
        total_loss += loss.item()
        _, pred = logits.max(1)
        correct += (pred == labels).sum().item()
        total += labels.size(0)
    
    return total_loss / len(loader), correct / total

@torch.no_grad()
def evaluate(model, loader, criterion, device):
    model.eval()
    total_loss, correct, total = 0, 0, 0
    all_attention = []
    
    for videos, labels in loader:
        videos, labels = videos.to(device), labels.to(device)
        logits, attention = model(videos)
        loss = criterion(logits, labels)
        
        total_loss += loss.item()
        _, pred = logits.max(1)
        correct += (pred == labels).sum().item()
        total += labels.size(0)
        all_attention.append(attention.cpu().numpy())
    
    return total_loss / len(loader), correct / total, np.concatenate(all_attention)

def train_model(model, train_loader, val_loader, config):
    device = config['DEVICE']
    save_dir = config['SAVE_DIR']
    model = model.to(device)
    
    total, trainable = count_params(model)
    log_message(f"\n{'='*60}")
    log_message(f"üîß {model.name}")
    log_message(f"   Params: {total:,} total, {trainable:,} entra√Ænables")
    
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.AdamW(
        [p for p in model.parameters() if p.requires_grad],
        lr=config['LR'], weight_decay=config['WEIGHT_DECAY']
    )
    scheduler = CosineAnnealingLR(optimizer, T_max=config['EPOCHS'])
    
    history = {'train_loss': [], 'train_acc': [], 'val_loss': [], 'val_acc': []}
    best_acc = 0
    
    for epoch in range(config['EPOCHS']):
        train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer, device)
        val_loss, val_acc, _ = evaluate(model, val_loader, criterion, device)
        scheduler.step()
        
        history['train_loss'].append(train_loss)
        history['train_acc'].append(train_acc)
        history['val_loss'].append(val_loss)
        history['val_acc'].append(val_acc)
        
        msg = f"Epoch {epoch+1}/{config['EPOCHS']} | Train: {train_loss:.4f} / {train_acc*100:.1f}% | Val: {val_loss:.4f} / {val_acc*100:.1f}%"
        log_message(msg)
        
        # Sauvegarder le meilleur mod√®le
        if val_acc > best_acc:
            best_acc = val_acc
            model_path = f"{save_dir}/{model.name}_best.pth"
            torch.save(model.state_dict(), model_path)
            log_message(f"   üíæ Nouveau meilleur mod√®le: {val_acc*100:.2f}%")
        
        # Sauvegarder l'historique √† chaque epoch (protection d√©connexion)
        history_path = f"{save_dir}/{model.name}_history.json"
        with open(history_path, 'w') as f:
            json.dump(history, f)
    
    log_message(f"‚úÖ {model.name} Best Val Accuracy: {best_acc*100:.2f}%")
    return model, history, best_acc

## 5Ô∏è‚É£ Entra√Ænement BASELINE (Average Pooling)

In [None]:
log_message(f"\n{'='*60}")
log_message("üèÉ ENTRA√éNEMENT BASELINE (Average Pooling)")
log_message(f"{'='*60}")

model_baseline = MedViTModel(use_camil=False, hidden_dim=CONFIG['HIDDEN_DIM'], num_classes=CONFIG['NUM_CLASSES'])
model_baseline, history_baseline, best_baseline = train_model(model_baseline, train_loader, val_loader, CONFIG)

print(f"\nüíæ Mod√®le baseline sauvegard√© dans {SAVE_DIR}")

## 6Ô∏è‚É£ Entra√Ænement MedViT-CAMIL (Gated Attention)

In [None]:
log_message(f"\n{'='*60}")
log_message("üèÉ ENTRA√éNEMENT MedViT-CAMIL (Gated Attention)")
log_message(f"{'='*60}")

model_camil = MedViTModel(use_camil=True, hidden_dim=CONFIG['HIDDEN_DIM'], num_classes=CONFIG['NUM_CLASSES'])
model_camil, history_camil, best_camil = train_model(model_camil, train_loader, val_loader, CONFIG)

print(f"\nüíæ Mod√®le CAMIL sauvegard√© dans {SAVE_DIR}")

## 7Ô∏è‚É£ √âvaluation Finale sur Test Set

In [None]:
criterion = nn.CrossEntropyLoss()
device = CONFIG['DEVICE']

# Charger les meilleurs mod√®les
model_baseline.load_state_dict(torch.load(f"{SAVE_DIR}/Baseline-AvgPool_best.pth"))
model_camil.load_state_dict(torch.load(f"{SAVE_DIR}/MedViT-CAMIL_best.pth"))

# √âvaluer sur TEST set
log_message(f"\n{'='*60}")
log_message("üìä √âVALUATION FINALE SUR TEST SET")
log_message(f"{'='*60}")

_, test_acc_baseline, attention_baseline = evaluate(model_baseline, test_loader, criterion, device)
_, test_acc_camil, attention_camil = evaluate(model_camil, test_loader, criterion, device)

# Calculer aussi les accuracies de validation finales
_, val_acc_baseline, _ = evaluate(model_baseline, val_loader, criterion, device)
_, val_acc_camil, _ = evaluate(model_camil, val_loader, criterion, device)

improvement = (test_acc_camil - test_acc_baseline) * 100

log_message(f"\nR√âSULTATS (NoduleMNIST3D):")
log_message(f"  Baseline (Avg Pool) - Val: {val_acc_baseline*100:.2f}% | Test: {test_acc_baseline*100:.2f}%")
log_message(f"  MedViT-CAMIL        - Val: {val_acc_camil*100:.2f}% | Test: {test_acc_camil*100:.2f}%")
log_message(f"  Am√©lioration Test: {improvement:+.2f}%")

## 8Ô∏è‚É£ Sauvegarde des R√©sultats Finaux

In [None]:
# R√©sultats complets
results = {
    'mode': 'proxy',
    'dataset': 'NoduleMNIST3D',
    'timestamp': str(datetime.now()),
    'config': {k: str(v) for k, v in CONFIG.items()},
    'results': {
        'baseline': {
            'val_accuracy': val_acc_baseline,
            'test_accuracy': test_acc_baseline,
            'best_val_accuracy': best_baseline
        },
        'camil': {
            'val_accuracy': val_acc_camil,
            'test_accuracy': test_acc_camil,
            'best_val_accuracy': best_camil
        },
        'improvement_test': test_acc_camil - test_acc_baseline,
        'improvement_val': val_acc_camil - val_acc_baseline
    },
    'history': {
        'baseline': history_baseline,
        'camil': history_camil
    }
}

# Sauvegarder JSON
results_path = f"{SAVE_DIR}/results_proxy.json"
with open(results_path, 'w') as f:
    json.dump(results, f, indent=2)

log_message(f"\nüíæ R√©sultats sauvegard√©s: {results_path}")
print("\n" + "="*60)
print(json.dumps(results['results'], indent=2))
print("="*60)

## 9Ô∏è‚É£ Visualisations (sauvegard√©es automatiquement)

In [None]:
# Courbes d'entra√Ænement
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Loss
ax = axes[0]
ax.plot(history_baseline['train_loss'], 'r-', label='Baseline Train', linewidth=2)
ax.plot(history_baseline['val_loss'], 'r--', label='Baseline Val', linewidth=2)
ax.plot(history_camil['train_loss'], 'g-', label='CAMIL Train', linewidth=2)
ax.plot(history_camil['val_loss'], 'g--', label='CAMIL Val', linewidth=2)
ax.set_xlabel('Epoch', fontsize=12)
ax.set_ylabel('Loss', fontsize=12)
ax.set_title('Training & Validation Loss', fontsize=14)
ax.legend()
ax.grid(True, alpha=0.3)

# Accuracy
ax = axes[1]
ax.plot([a*100 for a in history_baseline['train_acc']], 'r-', label='Baseline Train', linewidth=2)
ax.plot([a*100 for a in history_baseline['val_acc']], 'r--', label='Baseline Val', linewidth=2)
ax.plot([a*100 for a in history_camil['train_acc']], 'g-', label='CAMIL Train', linewidth=2)
ax.plot([a*100 for a in history_camil['val_acc']], 'g--', label='CAMIL Val', linewidth=2)
ax.set_xlabel('Epoch', fontsize=12)
ax.set_ylabel('Accuracy (%)', fontsize=12)
ax.set_title('Training & Validation Accuracy', fontsize=14)
ax.legend()
ax.grid(True, alpha=0.3)

plt.tight_layout()
curves_path = f"{SAVE_DIR}/training_curves_proxy.png"
plt.savefig(curves_path, dpi=150, bbox_inches='tight')
plt.show()
print(f"üíæ Sauvegard√©: {curves_path}")

In [None]:
# Heatmap d'attention
fig, axes = plt.subplots(2, 1, figsize=(14, 8))

n_samples = min(20, len(attention_baseline))

ax = axes[0]
im = ax.imshow(attention_baseline[:n_samples], aspect='auto', cmap='Reds')
ax.set_ylabel('Sample', fontsize=12)
ax.set_title('Baseline - Attention UNIFORME (1/T pour chaque frame)', fontsize=14)
plt.colorbar(im, ax=ax, label='Poids')

ax = axes[1]
im = ax.imshow(attention_camil[:n_samples], aspect='auto', cmap='Greens')
ax.set_xlabel('Slice (axe Z du volume)', fontsize=12)
ax.set_ylabel('Sample', fontsize=12)
ax.set_title('MedViT-CAMIL - Attention APPRISE (focus sur slices informatives)', fontsize=14)
plt.colorbar(im, ax=ax, label='Poids')

plt.tight_layout()
heatmap_path = f"{SAVE_DIR}/attention_heatmap_proxy.png"
plt.savefig(heatmap_path, dpi=150, bbox_inches='tight')
plt.show()
print(f"üíæ Sauvegard√©: {heatmap_path}")

In [None]:
# Distribution des poids d'attention
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

ax = axes[0]
mean_baseline = attention_baseline.mean(axis=0)
ax.bar(range(len(mean_baseline)), mean_baseline, color='red', alpha=0.7)
ax.axhline(y=1/len(mean_baseline), color='black', linestyle='--', label='Uniforme')
ax.set_xlabel('Slice', fontsize=12)
ax.set_ylabel('Poids moyen', fontsize=12)
ax.set_title('Baseline: Distribution uniforme', fontsize=14)
ax.legend()

ax = axes[1]
mean_camil = attention_camil.mean(axis=0)
ax.bar(range(len(mean_camil)), mean_camil, color='green', alpha=0.7)
ax.axhline(y=1/len(mean_camil), color='black', linestyle='--', label='Uniforme')
ax.set_xlabel('Slice', fontsize=12)
ax.set_ylabel('Poids moyen', fontsize=12)
ax.set_title('CAMIL: Focus sur slices centrales (r√©gion nodulaire)', fontsize=14)
ax.legend()

plt.tight_layout()
dist_path = f"{SAVE_DIR}/attention_distribution_proxy.png"
plt.savefig(dist_path, dpi=150, bbox_inches='tight')
plt.show()
print(f"üíæ Sauvegard√©: {dist_path}")

## üìã R√©sum√© Final

In [None]:
log_message(f"\n{'='*60}")
log_message("üìã R√âSUM√â FINAL")
log_message(f"{'='*60}")
log_message(f"Termin√©: {datetime.now()}")
log_message(f"")
log_message(f"M√âTRIQUES:")
log_message(f"  Baseline Test Accuracy: {test_acc_baseline*100:.2f}%")
log_message(f"  CAMIL Test Accuracy:    {test_acc_camil*100:.2f}%")
log_message(f"  Am√©lioration:           {improvement:+.2f}%")
log_message(f"")
log_message(f"FICHIERS SAUVEGARD√âS dans {SAVE_DIR}:")
log_message(f"  üìÑ config.json")
log_message(f"  üìÑ results_proxy.json")
log_message(f"  üìÑ training_log.txt")
log_message(f"  üñºÔ∏è training_curves_proxy.png")
log_message(f"  üñºÔ∏è attention_heatmap_proxy.png")
log_message(f"  üñºÔ∏è attention_distribution_proxy.png")
log_message(f"  ü§ñ Baseline-AvgPool_best.pth")
log_message(f"  ü§ñ MedViT-CAMIL_best.pth")
log_message(f"  üìÑ Baseline-AvgPool_history.json")
log_message(f"  üìÑ MedViT-CAMIL_history.json")
log_message(f"{'='*60}")

# Lister les fichiers
print("\nüìÅ Fichiers dans Google Drive:")
!ls -la $SAVE_DIR

## üì• Instructions pour r√©cup√©rer les r√©sultats

Les fichiers sont automatiquement sauvegard√©s dans **Google Drive** :
`/MyDrive/MedViT_Results/proxy/`

### Pour analyser les r√©sultats avec l'assistant:
1. Va dans Google Drive ‚Üí MedViT_Results ‚Üí proxy
2. T√©l√©charge les fichiers:
   - `results_proxy.json` (m√©triques)
   - `training_curves_proxy.png` (courbes)
   - `attention_heatmap_proxy.png` (attention)
3. D√©pose-les dans le dossier `results/` de ton projet local
4. Demande √† l'assistant d'analyser les r√©sultats!