# Rakuten Image Classification - Production ViT Training

**Objectif:** Entra√Æner le mod√®le Vision Transformer (google/vit-base-patch16-224) sur l'ensemble complet des donn√©es.

**Mat√©riel:** NVIDIA RTX 3060 Ti (8GB VRAM)

**Strat√©gie:**
- Transfer learning avec ViT pr√©-entra√Æn√©
- Optimiseur AdamW avec warmup scheduler
- Early stopping bas√© sur la validation accuracy
- Monitoring en temps r√©el avec graphiques

**R√©sultats attendus:**
- Temps d'entra√Ænement: ~3-4 heures pour 20 epochs
- Val accuracy: ~50-60%
- Checkpoints sauvegard√©s dans: `/workspace/checkpoints/vit_production/`

## 1. Configuration

In [None]:
# Imports standards
import sys
import os
from pathlib import Path
import pandas as pd
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import transforms
import numpy as np
from sklearn.metrics import classification_report, accuracy_score, f1_score
from tqdm import tqdm
import json
from datetime import datetime
import matplotlib.pyplot as plt
from IPython.display import display, clear_output
from torch.cuda.amp import GradScaler
import wandb

# Imports transformers
from transformers import ViTForImageClassification, get_scheduler
from torch.optim import AdamW

# Add project root and scripts to path
project_root = Path.cwd().parent.parent
scripts_dir = project_root / "scripts"
if str(project_root) not in sys.path:
    sys.path.insert(0, str(project_root))
if str(scripts_dir) not in sys.path:
    sys.path.insert(0, str(scripts_dir))

# Import rakuten modules
from src.rakuten_image.datasets import RakutenImageDataset
from load_data import split_data

print("‚úì All modules imported successfully")
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")

wandb.login()

In [2]:
# ============================================================================
# CONFIGURATION D'ENTRA√éNEMENT
# ============================================================================

CONFIG = {
    # Chemins des donn√©es
    "data_dir": Path("/workspace/data"),
    "img_dir": Path("/workspace/data/images/image_train"),
    "checkpoint_dir": Path("/workspace/checkpoints/vit_production"),

    # Configuration du mod√®le
    "model_name": "google/vit-base-patch16-224",
    "img_size": 224,

    # Hyperparam√®tres d'entra√Ænement
    "batch_size": 32,  # Augment√© √† 32 avec AMP pour RTX 3060 Ti 8GB
    "num_epochs": 20,
    "learning_rate": 2e-5,
    "weight_decay": 0.05,
    "warmup_ratio": 0.1,

    # Split train/validation (85% train, 15% validation)
    "val_split": 0.15,
    "random_state": 42,

    # Early stopping
    "early_stopping_patience": 3,

    # Configuration hardware
    "device": "cuda" if torch.cuda.is_available() else "cpu",
    "num_workers": 2,
    "use_amp": True,  # Utiliser Automatic Mixed Precision
}

print("=" * 80)
print("üöÄ CONFIGURATION ViT PRODUCTION")
print("=" * 80)
print(f"Device: {CONFIG['device']}")
print(f"Mod√®le: {CONFIG['model_name']}")
print(f"Taille d'image: {CONFIG['img_size']}x{CONFIG['img_size']}")
print(f"Batch size: {CONFIG['batch_size']}")
print(f"Epochs max: {CONFIG['num_epochs']}")
print(f"Learning rate: {CONFIG['learning_rate']}")
print(f"Warmup ratio: {CONFIG['warmup_ratio']}")
print(f"Val split: {CONFIG['val_split']} ({int(CONFIG['val_split']*100)}%)")
print(f"Early stopping patience: {CONFIG['early_stopping_patience']}")
print(f"AMP activ√©: {CONFIG['use_amp']}")
print(f"R√©pertoire checkpoints: {CONFIG['checkpoint_dir']}")
print("=" * 80)

üöÄ CONFIGURATION ViT PRODUCTION
Device: cuda
Mod√®le: google/vit-base-patch16-224
Taille d'image: 224x224
Batch size: 32
Epochs max: 20
Learning rate: 2e-05
Warmup ratio: 0.1
Val split: 0.15 (15%)
Early stopping patience: 3
AMP activ√©: True
R√©pertoire checkpoints: /workspace/checkpoints/vit_production


## 2. Chargement et Pr√©paration des Donn√©es

In [3]:
print("\nüìÇ Loading data...")

# Use unified split function (SAME as text notebooks!)
X_dev, X_holdout, y_dev, y_holdout = split_data()

# Create full dataframes
df_dev = X_dev.copy()
df_dev['prdtypecode'] = y_dev

df_holdout = X_holdout.copy()
df_holdout['prdtypecode'] = y_holdout

print(f"‚úì Data loaded: {len(df_dev) + len(df_holdout):,} total samples")
print(f"  Development: {len(df_dev):,} samples (85%)")
print(f"  Hold-out:    {len(df_holdout):,} samples (15%)")
print(f"  Unique classes: {df_dev['prdtypecode'].nunique()}")

# Global Label Encoding
print("\nüîß Label encoding...")
from sklearn.preprocessing import LabelEncoder

label_encoder = LabelEncoder()
# Fit on combined data to ensure all classes are seen
all_labels = pd.concat([df_dev['prdtypecode'], df_holdout['prdtypecode']])
label_encoder.fit(all_labels)

df_dev['encoded_label'] = label_encoder.transform(df_dev['prdtypecode'])
df_holdout['encoded_label'] = label_encoder.transform(df_holdout['prdtypecode'])

num_classes = len(label_encoder.classes_)
print(f"‚úì Encoding complete")
print(f"  Encoded range: 0 to {num_classes - 1}")
print(f"  Total classes: {num_classes}")
assert num_classes == 27, f"Expected 27 classes, got {num_classes}"

print("\n‚ö†Ô∏è  IMPORTANT: Using SAME split as text notebooks!")
print("‚ö†Ô∏è  This ensures consistent evaluation across modalities")


üìÇ Loading data...
‚úì Data loaded: 84,916 total samples
  Development: 72,178 samples (85%)
  Hold-out:    12,738 samples (15%)
  Unique classes: 27

üîß Label encoding...
‚úì Encoding complete
  Encoded range: 0 to 26
  Total classes: 27

‚ö†Ô∏è  IMPORTANT: Using SAME split as text notebooks!
‚ö†Ô∏è  This ensures consistent evaluation across modalities


In [4]:
print("\n" + "=" * 80)
print("Data Cleaning & Train/Val Split")
print("=" * 80)

# Data cleaning (development set)
print("\nüîß Data cleaning (development set)...")
original_dev_size = len(df_dev)

missing_images_dev = []
for idx, row in df_dev.iterrows():
    imageid = int(row['imageid'])
    productid = int(row['productid'])
    img_path = CONFIG["img_dir"] / f"image_{imageid}_product_{productid}.jpg"
    if not img_path.exists():
        missing_images_dev.append(idx)

if missing_images_dev:
    df_dev = df_dev.drop(missing_images_dev)
    print(f"  Removed {len(missing_images_dev)} samples with missing images")

print(f"‚úì Development set after cleaning: {len(df_dev):,} samples")

# Data cleaning (hold-out set)
print("\nüîß Data cleaning (hold-out set)...")
original_holdout_size = len(df_holdout)

missing_images_holdout = []
for idx, row in df_holdout.iterrows():
    imageid = int(row['imageid'])
    productid = int(row['productid'])
    img_path = CONFIG["img_dir"] / f"image_{imageid}_product_{productid}.jpg"
    if not img_path.exists():
        missing_images_holdout.append(idx)

if missing_images_holdout:
    df_holdout = df_holdout.drop(missing_images_holdout)
    print(f"  Removed {len(missing_images_holdout)} samples with missing images")

print(f"‚úì Hold-out set after cleaning: {len(df_holdout):,} samples")

# Split development set into train/val for hyperparameter tuning
print("\n" + "=" * 80)
print("Development Split (Train/Val for Hyperparameter Tuning)")
print("=" * 80)

from sklearn.model_selection import train_test_split

train_indices, val_indices, _, _ = train_test_split(
    df_dev.index,
    df_dev['encoded_label'],
    test_size=0.15,
    random_state=CONFIG["random_state"],
    stratify=df_dev['encoded_label']
)

df_train = df_dev.loc[train_indices].reset_index(drop=True)
df_val = df_dev.loc[val_indices].reset_index(drop=True)
df_holdout = df_holdout.reset_index(drop=True)

total_samples = len(df_dev) + len(df_holdout)
print(f"‚úì Development split complete:")
print(f"  Training:   {len(df_train):,} samples (~{len(df_train)/total_samples*100:.1f}%)")
print(f"  Validation: {len(df_val):,} samples (~{len(df_val)/total_samples*100:.1f}%)")
print(f"  Hold-out:   {len(df_holdout):,} samples (15.0%)")

print("\n" + "=" * 80)
print("DATA SPLIT SUMMARY")
print("=" * 80)
print(f"Total: {total_samples:,} samples")
print(f"  1. Training:   {len(df_train):,} (for model training)")
print(f"  2. Validation: {len(df_val):,} (for hyperparameter tuning)")
print(f"  3. Hold-out:   {len(df_holdout):,} (for final evaluation)")
print()
print("‚ö†Ô∏è  CRITICAL: This split is IDENTICAL to text notebooks!")
print("‚ö†Ô∏è  Image and text models evaluated on same hold-out samples")
print("=" * 80)


Data Cleaning & Train/Val Split

üîß Data cleaning (development set)...
‚úì Development set after cleaning: 72,178 samples

üîß Data cleaning (hold-out set)...
‚úì Hold-out set after cleaning: 12,738 samples

Development Split (Train/Val for Hyperparameter Tuning)
‚úì Development split complete:
  Training:   61,351 samples (~72.2%)
  Validation: 10,827 samples (~12.8%)
  Hold-out:   12,738 samples (15.0%)

DATA SPLIT SUMMARY
Total: 84,916 samples
  1. Training:   61,351 (for model training)
  2. Validation: 10,827 (for hyperparameter tuning)
  3. Hold-out:   12,738 (for final evaluation)

‚ö†Ô∏è  CRITICAL: This split is IDENTICAL to text notebooks!
‚ö†Ô∏è  Image and text models evaluated on same hold-out samples


In [5]:
print("\nüîß Cr√©ation des datasets et dataloaders...")

# Transformations pour l'entra√Ænement (avec augmentation)
train_transform = transforms.Compose([
    transforms.Resize((CONFIG["img_size"], CONFIG["img_size"])),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandAugment(num_ops=2, magnitude=15),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

# Transformations pour validation/test (sans augmentation)
val_transform = transforms.Compose([
    transforms.Resize((CONFIG["img_size"], CONFIG["img_size"])),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

# Cr√©er les datasets (utilisant encoded_label au lieu de prdtypecode)
train_dataset = RakutenImageDataset(
    dataframe=df_train,
    image_dir=CONFIG["img_dir"],
    transform=train_transform,
    label_col="encoded_label"
)

val_dataset = RakutenImageDataset(
    dataframe=df_val,
    image_dir=CONFIG["img_dir"],
    transform=val_transform,
    label_col="encoded_label"
)

test_dataset = RakutenImageDataset(
    dataframe=df_holdout,
    image_dir=CONFIG["img_dir"],
    transform=val_transform,
    label_col="encoded_label"
)

print(f"‚úì Datasets cr√©√©s")
print(f"  Training:   {len(train_dataset):,} √©chantillons")
print(f"  Validation: {len(val_dataset):,} √©chantillons")
print(f"  Hold-out:   {len(test_dataset):,} √©chantillons")
print(f"  Classes:    {train_dataset.num_classes}")

# Cr√©er les dataloaders
train_loader = DataLoader(
    train_dataset,
    batch_size=CONFIG["batch_size"],
    shuffle=True,
    num_workers=CONFIG["num_workers"],
    pin_memory=True,
    drop_last=True  # Pour stabilit√© BatchNorm
)

val_loader = DataLoader(
    val_dataset,
    batch_size=CONFIG["batch_size"],
    shuffle=False,
    num_workers=CONFIG["num_workers"],
    pin_memory=True
)

test_loader = DataLoader(
    test_dataset,
    batch_size=CONFIG["batch_size"],
    shuffle=False,
    num_workers=CONFIG["num_workers"],
    pin_memory=True
)

print(f"\n‚úì DataLoaders cr√©√©s (batch_size={CONFIG['batch_size']})")
print(f"  Train:      {len(train_loader):,} batches")
print(f"  Validation: {len(val_loader):,} batches")
print(f"  Hold-out:   {len(test_loader):,} batches")

# Sanity check
try:
    images, labels = next(iter(train_loader))
    print(f"\n‚úì Sanity check: Batch shape {images.shape}, Labels {labels.shape}")
    print(f"‚úÖ All DataLoaders working correctly!")
except Exception as e:
    print(f"‚ùå Error: {e}")


üîß Cr√©ation des datasets et dataloaders...
Pre-loading paths into memory...
Dataset initialized with 61351 samples.
Pre-loading paths into memory...
Dataset initialized with 10827 samples.
Pre-loading paths into memory...
Dataset initialized with 12738 samples.
‚úì Datasets cr√©√©s
  Training:   61,351 √©chantillons
  Validation: 10,827 √©chantillons
  Hold-out:   12,738 √©chantillons
  Classes:    27

‚úì DataLoaders cr√©√©s (batch_size=32)
  Train:      1,917 batches
  Validation: 339 batches
  Hold-out:   399 batches

‚úì Sanity check: Batch shape torch.Size([32, 3, 224, 224]), Labels torch.Size([32])
‚úÖ All DataLoaders working correctly!


## 3. Initialisation du Mod√®le

In [6]:
print("\nüèóÔ∏è Initialisation du mod√®le ViT...")

# Charger le mod√®le pr√©-entra√Æn√©
model = ViTForImageClassification.from_pretrained(
    CONFIG["model_name"],
    num_labels=num_classes,  # Utiliser num_classes de l'encodage global
    ignore_mismatched_sizes=True,
    hidden_dropout_prob=0.1,
    attention_probs_dropout_prob=0.1
)
model.to(CONFIG["device"])

print(f"‚úì Mod√®le charg√©: {CONFIG['model_name']}")
print(f"  Nombre de classes: {num_classes}")

# Compter les param√®tres
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"  Param√®tres totaux: {total_params:,}")
print(f"  Param√®tres entra√Ænables: {trainable_params:,}")


üèóÔ∏è Initialisation du mod√®le ViT...


config.json: 0.00B [00:00, ?B/s]

model.safetensors:   0%|          | 0.00/346M [00:00<?, ?B/s]

Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224 and are newly initialized because the shapes did not match:
- classifier.bias: found shape torch.Size([1000]) in the checkpoint and torch.Size([27]) in the model instantiated
- classifier.weight: found shape torch.Size([1000, 768]) in the checkpoint and torch.Size([27, 768]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


‚úì Mod√®le charg√©: google/vit-base-patch16-224
  Nombre de classes: 27
  Param√®tres totaux: 85,819,419
  Param√®tres entra√Ænables: 85,819,419


In [7]:
print("\n‚öôÔ∏è Configuration de l'optimiseur et du scheduler...")

# Optimiseur AdamW
optimizer = AdamW(
    model.parameters(),
    lr=CONFIG["learning_rate"],
    weight_decay=CONFIG["weight_decay"]
)

# Scheduler avec warmup lin√©aire
total_steps = len(train_loader) * CONFIG["num_epochs"]
num_warmup_steps = int(total_steps * CONFIG["warmup_ratio"])

scheduler = get_scheduler(
    "linear",
    optimizer=optimizer,
    num_warmup_steps=num_warmup_steps,
    num_training_steps=total_steps
)

# Initialiser GradScaler pour AMP
scaler = torch.amp.GradScaler('cuda') if CONFIG["use_amp"] else None

print(f"‚úì Optimiseur: AdamW")
print(f"  Learning rate: {CONFIG['learning_rate']}")
print(f"  Weight decay: {CONFIG['weight_decay']}")
print(f"‚úì Scheduler: Linear warmup")
print(f"  Warmup steps: {num_warmup_steps:,}/{total_steps:,}")
if CONFIG["use_amp"]:
    print(f"‚úì GradScaler initialis√© pour AMP")

# Cr√©er le r√©pertoire de checkpoints
CONFIG["checkpoint_dir"].mkdir(parents=True, exist_ok=True)
print(f"‚úì R√©pertoire checkpoints: {CONFIG['checkpoint_dir']}")

# Sauvegarder la configuration
with open(CONFIG["checkpoint_dir"] / "config.json", "w") as f:
    json.dump({k: str(v) for k, v in CONFIG.items()}, f, indent=2)
print(f"‚úì Configuration sauvegard√©e")


‚öôÔ∏è Configuration de l'optimiseur et du scheduler...
‚úì Optimiseur: AdamW
  Learning rate: 2e-05
  Weight decay: 0.05
‚úì Scheduler: Linear warmup
  Warmup steps: 3,834/38,340
‚úì GradScaler initialis√© pour AMP
‚úì R√©pertoire checkpoints: /workspace/checkpoints/vit_production
‚úì Configuration sauvegard√©e


## 4. Entra√Ænement avec Monitoring en Temps R√©el

In [None]:
print("\n" + "=" * 80)
print("üöÄ D√âMARRAGE DE L'ENTRA√éNEMENT")
print("=" * 80)
print(f"Temps estim√©: ~{len(train_loader) * CONFIG['num_epochs'] * 0.8 / 60:.0f} minutes")
print(f"AMP activ√©: {CONFIG['use_amp']}")
print("=" * 80 + "\n")

# Initialize WandB
wandb.init(
    project="rakuten-classification",
    entity="xiaosong-dev-formation-data-science",
    config=CONFIG,
    name=f"vit_run_{pd.Timestamp.now().strftime('%Y%m%d_%H%M')}"
)

# Variables pour le tracking
best_val_acc = 0.0
best_val_loss = float('inf')
best_val_f1 = 0.0
patience_counter = 0
history = {
    "train_loss": [],
    "train_acc": [],
    "val_loss": [],
    "val_acc": [],
    "val_f1": []
}

# Fonction pour mettre √† jour les graphiques
def update_plots():
    clear_output(wait=True)
    fig, axes = plt.subplots(1, 2, figsize=(15, 4))
    
    # Plot Loss
    axes[0].plot(history["train_loss"], label='Train Loss', marker='o', linewidth=2)
    axes[0].plot(history["val_loss"], label='Val Loss', marker='s', linewidth=2)
    axes[0].set_xlabel('Epoch')
    axes[0].set_ylabel('Loss')
    axes[0].set_title('Training et Validation Loss')
    axes[0].legend()
    axes[0].grid(True, alpha=0.3)
    
    # Plot Accuracy
    axes[1].plot(history["train_acc"], label='Train Acc', marker='o', linewidth=2)
    axes[1].plot(history["val_acc"], label='Val Acc', marker='s', linewidth=2)
    axes[1].set_xlabel('Epoch')
    axes[1].set_ylabel('Accuracy (%)')
    axes[1].set_title('Training et Validation Accuracy')
    axes[1].legend()
    axes[1].grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()
    
    # Afficher les statistiques
    print(f"\nüìä Statistiques actuelles:")
    print(f"  Meilleure Val Acc: {best_val_acc:.2f}% | F1: {best_val_f1:.4f} (Patience: {patience_counter}/{CONFIG['early_stopping_patience']})")
    print(f"  Meilleure Val Loss: {best_val_loss:.4f}")

# Boucle d'entra√Ænement principale
for epoch in range(CONFIG["num_epochs"]):
    print(f"\nEpoch {epoch + 1}/{CONFIG['num_epochs']}")
    print("=" * 80)
    
    # -------------------- Phase d'entra√Ænement --------------------
    model.train()
    train_loss = 0.0
    train_correct = 0
    train_total = 0
    
    train_pbar = tqdm(train_loader, desc=f"Training", unit="batch")
    for images, labels in train_pbar:
        images, labels = images.to(CONFIG["device"]), labels.to(CONFIG["device"])
        
        # Zero gradients
        optimizer.zero_grad()
        
        # Forward pass avec AMP
        if CONFIG["use_amp"]:
            with torch.amp.autocast(device_type="cuda"):
                outputs = model(pixel_values=images, labels=labels)
                loss = outputs.loss
            
            # Backward pass avec scaler
            scaler.scale(loss).backward()
            
            # Unscale avant gradient clipping
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            
            # Optimizer step avec scaler
            scaler.step(optimizer)
            scaler.update()
        else:
            # Forward pass sans AMP
            outputs = model(pixel_values=images, labels=labels)
            loss = outputs.loss
            
            # Backward pass standard
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
        
        scheduler.step()
        
        # M√©triques
        train_loss += loss.item()
        predictions = torch.argmax(outputs.logits, dim=-1)
        train_correct += (predictions == labels).sum().item()
        train_total += labels.size(0)
        
        # Mise √† jour de la barre de progression
        train_pbar.set_postfix({
            'loss': f'{loss.item():.4f}',
            'acc': f'{100.0 * train_correct / train_total:.2f}%'
        })
    
    avg_train_loss = train_loss / len(train_loader)
    train_accuracy = 100.0 * train_correct / train_total
    
    # -------------------- Phase de validation --------------------
    model.eval()
    val_loss = 0.0
    val_correct = 0
    val_total = 0
    
    # Collect all predictions and labels for F1 score
    all_val_preds = []
    all_val_labels = []
    
    val_pbar = tqdm(val_loader, desc=f"Validation", unit="batch")
    with torch.no_grad():
        for images, labels in val_pbar:
            images, labels = images.to(CONFIG["device"]), labels.to(CONFIG["device"])
            
            # Forward pass validation (avec AMP si activ√©)
            if CONFIG["use_amp"]:
                with torch.amp.autocast(device_type="cuda"):
                    outputs = model(pixel_values=images, labels=labels)
                    loss = outputs.loss
            else:
                outputs = model(pixel_values=images, labels=labels)
                loss = outputs.loss
            
            val_loss += loss.item()
            predictions = torch.argmax(outputs.logits, dim=-1)
            val_correct += (predictions == labels).sum().item()
            val_total += labels.size(0)
            
            # Collect for F1 score
            all_val_preds.extend(predictions.cpu().numpy())
            all_val_labels.extend(labels.cpu().numpy())
            
            val_pbar.set_postfix({
                'loss': f'{loss.item():.4f}',
                'acc': f'{100.0 * val_correct / val_total:.2f}%'
            })
    
    avg_val_loss = val_loss / len(val_loader)
    val_accuracy = 100.0 * val_correct / val_total
    
    # Calculate F1 score (weighted average for multi-class)
    val_f1 = f1_score(all_val_labels, all_val_preds, average='weighted')
    
    # -------------------- Mise √† jour de l'historique --------------------
    history["train_loss"].append(avg_train_loss)
    history["train_acc"].append(train_accuracy)
    history["val_loss"].append(avg_val_loss)
    history["val_acc"].append(val_accuracy)
    history["val_f1"].append(val_f1)
    
    # Log metrics to WandB (including F1 score)
    wandb.log({
        "train_loss": avg_train_loss,
        "train_acc": train_accuracy,
        "val_loss": avg_val_loss,
        "val_acc": val_accuracy,
        "val_f1": val_f1,
        "epoch": epoch + 1,
        "learning_rate": optimizer.param_groups[0]['lr']
    })
    
    print(f"\nüìä R√©sultats Epoch {epoch + 1}:")
    print(f"  Train Loss: {avg_train_loss:.4f} | Train Acc: {train_accuracy:.2f}%")
    print(f"  Val Loss:   {avg_val_loss:.4f} | Val Acc:   {val_accuracy:.2f}% | F1: {val_f1:.4f}")
    
    # -------------------- Sauvegarde du meilleur mod√®le --------------------
    if val_accuracy > best_val_acc:
        best_val_acc = val_accuracy
        best_val_loss = avg_val_loss
        best_val_f1 = val_f1
        patience_counter = 0
        
        checkpoint_path = CONFIG["checkpoint_dir"] / "best_model.pth"
        torch.save({
            'epoch': epoch + 1,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict(),
            'scaler_state_dict': scaler.state_dict() if CONFIG["use_amp"] else None,
            'val_acc': val_accuracy,
            'val_loss': avg_val_loss,
            'val_f1': val_f1,
            'train_acc': train_accuracy,
            'train_loss': avg_train_loss,
        }, checkpoint_path)
        
        print(f"  ‚úÖ Meilleur mod√®le sauvegard√©! (Val Acc: {val_accuracy:.2f}%, F1: {val_f1:.4f})")
    else:
        patience_counter += 1
        print(f"  ‚è≥ Pas d'am√©lioration ({patience_counter}/{CONFIG['early_stopping_patience']})")
    
    # -------------------- Mise √† jour des graphiques --------------------
    update_plots()
    
    # -------------------- Early Stopping --------------------
    if patience_counter >= CONFIG["early_stopping_patience"]:
        print(f"\n‚ö†Ô∏è Early stopping d√©clench√© apr√®s {epoch + 1} epochs")
        break

print("\n" + "=" * 80)
print("‚úÖ ENTRA√éNEMENT TERMIN√â")
print("=" * 80)

## 5. √âvaluation Finale

In [None]:
print("\n" + "=" * 80)
print("üìä √âVALUATION FINALE")
print("=" * 80)

# Charger le meilleur mod√®le
checkpoint = torch.load(CONFIG["checkpoint_dir"] / "best_model.pth", weights_only=False)
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()

print(f"\n‚úÖ Statistiques du meilleur mod√®le:")
print(f"  Epoch: {checkpoint['epoch']}")
print(f"  Val Accuracy: {checkpoint['val_acc']:.2f}%")
print(f"  Val Loss: {checkpoint.get('val_loss', 'N/A')}")

# √âvaluation sur le set de validation
print("\n" + "="*80)
print("√âvaluation sur Validation Set")
print("="*80)
all_preds_val = []
all_labels_val = []

with torch.no_grad():
    for images, labels in tqdm(val_loader, desc="Validation"):
        images = images.to(CONFIG["device"])
        
        if CONFIG["use_amp"]:
            with torch.amp.autocast(device_type="cuda"):
                outputs = model(pixel_values=images)
        else:
            outputs = model(pixel_values=images)
            
        predictions = torch.argmax(outputs.logits, dim=-1)
        all_preds_val.extend(predictions.cpu().numpy())
        all_labels_val.extend(labels.numpy())

val_acc_final = accuracy_score(all_labels_val, all_preds_val)
print(f"\n‚úì Validation Accuracy: {val_acc_final*100:.2f}%")

# √âvaluation sur le hold-out test set
print("\n" + "="*80)
print("√âvaluation sur Hold-out Test Set (Final Benchmark)")
print("="*80)
all_preds_test = []
all_labels_test = []

with torch.no_grad():
    for images, labels in tqdm(test_loader, desc="Hold-out Test"):
        images = images.to(CONFIG["device"])
        
        if CONFIG["use_amp"]:
            with torch.amp.autocast(device_type="cuda"):
                outputs = model(pixel_values=images)
        else:
            outputs = model(pixel_values=images)
            
        predictions = torch.argmax(outputs.logits, dim=-1)
        all_preds_test.extend(predictions.cpu().numpy())
        all_labels_test.extend(labels.numpy())

test_acc_final = accuracy_score(all_labels_test, all_preds_test)
print(f"\n‚úì Hold-out Test Accuracy: {test_acc_final*100:.2f}%")

# Rapport de classification d√©taill√© (sur hold-out test)
print("\nüìã Rapport de Classification (Hold-out Test):")
print("=" * 80)
print(classification_report(all_labels_test, all_preds_test, digits=4, zero_division=0))

# Sauvegarder l'historique
history_path = CONFIG["checkpoint_dir"] / "training_history.json"
with open(history_path, "w") as f:
    json.dump(history, f, indent=2)

# Sauvegarder les r√©sultats finaux
results = {
    "best_epoch": int(checkpoint['epoch']),
    "val_acc": float(checkpoint['val_acc']),
    "final_val_acc": float(val_acc_final * 100),
    "final_test_acc": float(test_acc_final * 100),
    "num_classes": int(num_classes),
    "model": CONFIG["model_name"]
}

results_path = CONFIG["checkpoint_dir"] / "final_results.json"
with open(results_path, "w") as f:
    json.dump(results, f, indent=2)

print(f"\n‚úÖ Historique sauvegard√©: {history_path}")
print(f"‚úÖ R√©sultats finaux sauvegard√©s: {results_path}")
print(f"‚úÖ Meilleur mod√®le: {CONFIG['checkpoint_dir'] / 'best_model.pth'}")
print("\n" + "=" * 80)
print("üéâ √âVALUATION COMPL√àTE")
print("=" * 80)

## 6. Sauvegarder les Graphiques Finaux

In [None]:
# Cr√©er et sauvegarder les graphiques finaux
fig, axes = plt.subplots(1, 2, figsize=(15, 5))

# Plot Loss
axes[0].plot(history["train_loss"], label='Train Loss', marker='o', linewidth=2)
axes[0].plot(history["val_loss"], label='Val Loss', marker='s', linewidth=2)
axes[0].set_xlabel('Epoch', fontsize=12)
axes[0].set_ylabel('Loss', fontsize=12)
axes[0].set_title('Training et Validation Loss', fontsize=14, fontweight='bold')
axes[0].legend(fontsize=11)
axes[0].grid(True, alpha=0.3)

# Plot Accuracy
axes[1].plot(history["train_acc"], label='Train Acc', marker='o', linewidth=2)
axes[1].plot(history["val_acc"], label='Val Acc', marker='s', linewidth=2)
axes[1].set_xlabel('Epoch', fontsize=12)
axes[1].set_ylabel('Accuracy (%)', fontsize=12)
axes[1].set_title('Training et Validation Accuracy', fontsize=14, fontweight='bold')
axes[1].legend(fontsize=11)
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plot_path = CONFIG["checkpoint_dir"] / 'training_curves.png'
plt.savefig(plot_path, dpi=150, bbox_inches='tight')
plt.show()

print(f"‚úì Graphiques sauvegard√©s: {plot_path}")

## 7. R√©sum√© Final

In [None]:
print("=" * 80)
print("üìä R√âSUM√â DE L'ENTRA√éNEMENT")
print("=" * 80)

total_samples = len(df_dev) + len(df_holdout)

print(f"\nDonn√©es:")
print(f"  Total √©chantillons: {total_samples:,}")
print(f"  Training: {len(train_dataset):,} (~{len(train_dataset)/total_samples*100:.1f}%)")
print(f"  Validation: {len(val_dataset):,} (~{len(val_dataset)/total_samples*100:.1f}%)")
print(f"  Hold-out: {len(test_dataset):,} (15.0%)")
print(f"  Classes: {num_classes}")

print(f"\nMod√®le:")
print(f"  Architecture: {CONFIG['model_name']}")
print(f"  Param√®tres entra√Ænables: {trainable_params:,}")

print(f"\nEntra√Ænement:")
print(f"  Epochs compl√©t√©s: {len(history['train_loss'])}")
print(f"  Batch size: {CONFIG['batch_size']}")
print(f"  Learning rate: {CONFIG['learning_rate']}")

print(f"\nMeilleurs R√©sultats (sur Validation):")
print(f"  Best epoch: {checkpoint['epoch']}")
print(f"  Best val accuracy: {checkpoint['val_acc']:.2f}%")

print(f"\nR√©sultats Finaux:")
print(f"  Final val accuracy: {val_acc_final*100:.2f}%")
print(f"  Final hold-out test accuracy: {test_acc_final*100:.2f}%")

print(f"\nFichiers sauvegard√©s:")
print(f"  Checkpoint: {CONFIG['checkpoint_dir'] / 'best_model.pth'}")
print(f"  Historique: {CONFIG['checkpoint_dir'] / 'training_history.json'}")
print(f"  R√©sultats: {CONFIG['checkpoint_dir'] / 'final_results.json'}")
print(f"  Configuration: {CONFIG['checkpoint_dir'] / 'config.json'}")
print(f"  Graphiques: {CONFIG['checkpoint_dir'] / 'training_curves.png'}")

print("\n" + "=" * 80)
print("‚úÖ ENTRA√éNEMENT PRODUCTION ViT TERMIN√â")
print("=" * 80)

# Finish WandB logging
wandb.finish()

In [None]:
import torch.nn as nn
import torch.optim as optim
from timm.data.mixup import Mixup
from timm.loss import SoftTargetCrossEntropy
from transformers import get_linear_schedule_with_warmup

# 1. Update Checkpoint Directory
CONFIG["checkpoint_dir"] = Path("/workspace/checkpoints/vit_mixup_v3")
CONFIG["checkpoint_dir"].mkdir(parents=True, exist_ok=True)
print(f"Checkpoint directory set to: {CONFIG['checkpoint_dir']}")

# 2. Configure Mixup/CutMix
mixup_args = {
    'mixup_alpha': 0.8,       # Mixup alpha value
    'cutmix_alpha': 1.0,      # CutMix alpha value
    'cutmix_minmax': None,
    'prob': 1.0,              # Probability of applying mixup or cutmix
    'switch_prob': 0.5,       # Probability of switching to cutmix instead of mixup
    'mode': 'batch',
    'label_smoothing': 0.1,
    'num_classes': 27
}

# Initialize Mixup
mixup_fn = Mixup(**mixup_args)
print("Mixup & CutMix initialized")

# 3. Define Loss Functions
# Training: SoftTargetCrossEntropy (for mixed labels)
criterion_train = SoftTargetCrossEntropy()
# Validation: Standard CrossEntropy (for integer labels)
criterion_val = nn.CrossEntropyLoss()

# 4. Reset Optimizer and Scheduler
# Note: Increased weight_decay to 0.05 for regularization
optimizer = optim.AdamW(model.parameters(), lr=CONFIG["learning_rate"], weight_decay=0.05)

# Recalculate steps
num_training_steps = len(train_loader) * CONFIG["num_epochs"]
num_warmup_steps = int(num_training_steps * CONFIG["warmup_ratio"])

scheduler = get_linear_schedule_with_warmup(
    optimizer,
    num_warmup_steps=num_warmup_steps,
    num_training_steps=num_training_steps
)
print("Optimizer and Scheduler reset (Weight Decay = 0.05)")

In [None]:
from tqdm import tqdm  # Using standard text-based progress bar to avoid IProgress errors
import matplotlib.pyplot as plt
from IPython.display import clear_output
import torch

print("\n" + "=" * 80)
print(" STARTING TRAINING (V3 - MIXUP/CUTMIX)")
print("=" * 80)

# Reset tracking variables
best_val_acc = 0.0
patience_counter = 0
history = {"train_loss": [], "train_acc": [], "val_loss": [], "val_acc": []}

def update_plots_v3():
    """Real-time plotting of training curves."""
    clear_output(wait=True)
    fig, axes = plt.subplots(1, 2, figsize=(15, 4))
    
    # Plot Loss
    axes[0].plot(history["train_loss"], label='Train Loss', marker='o')
    axes[0].plot(history["val_loss"], label='Val Loss', marker='s')
    axes[0].set_title('Loss')
    axes[0].set_xlabel('Epoch')
    axes[0].set_ylabel('Loss')
    axes[0].legend()
    axes[0].grid(True, alpha=0.3)
    
    # Plot Accuracy
    # Note: Train Acc is dashed because it is not fully representative under Mixup
    axes[1].plot(history["train_acc"], label='Train Acc (Mixup)', marker='o', linestyle='--', alpha=0.5)
    axes[1].plot(history["val_acc"], label='Val Acc', marker='s', linewidth=2)
    axes[1].set_title('Accuracy')
    axes[1].set_xlabel('Epoch')
    axes[1].set_ylabel('Accuracy (%)')
    axes[1].legend()
    axes[1].grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()
    print(f" Current Best Val Acc: {best_val_acc:.2f}%")

# ============================================================================
# Main Training Loop
# ============================================================================
for epoch in range(CONFIG["num_epochs"]):
    print(f"\nEpoch {epoch + 1}/{CONFIG['num_epochs']}")
    
    # --- 1. Training Phase ---
    model.train()
    train_loss = 0.0
    
    # Use ascii=True for compatibility with all terminals
    train_pbar = tqdm(train_loader, desc="Training", ascii=True)
    
    for images, labels in train_pbar:
        images, labels = images.to(CONFIG["device"]), labels.to(CONFIG["device"])
        
        # ‚û§ Apply Mixup / CutMix
        if mixup_fn is not None:
            images, labels = mixup_fn(images, labels)
        
        optimizer.zero_grad()
        
        # ‚û§ Forward Pass & Loss Calculation
        if CONFIG["use_amp"]:
            # Mixed Precision Context
            with torch.cuda.amp.autocast(): 
                # Note: We do not pass labels to the model here
                # We calculate the Mixup Loss manually
                outputs = model(pixel_values=images)
                loss = criterion_train(outputs.logits, labels)
            
            # Backward Pass
            scaler.scale(loss).backward()
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            scaler.step(optimizer)
            scaler.update()
        else:
            # Standard Precision (FP32)
            outputs = model(pixel_values=images)
            loss = criterion_train(outputs.logits, labels)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            
        scheduler.step()
        train_loss += loss.item()
        
        # Update progress bar
        train_pbar.set_postfix({'loss': f'{loss.item():.4f}'})

    avg_train_loss = train_loss / len(train_loader)
    
    # --- 2. Validation Phase ---
    model.eval()
    val_loss = 0.0
    val_correct = 0
    val_total = 0
    
    # Validation does not use Mixup and does not need gradients
    with torch.no_grad():
        val_pbar = tqdm(val_loader, desc="Validation", ascii=True)
        for images, labels in val_pbar:
            images, labels = images.to(CONFIG["device"]), labels.to(CONFIG["device"])
            
            if CONFIG["use_amp"]:
                with torch.cuda.amp.autocast():
                    outputs = model(pixel_values=images)
                    # Validation uses standard CrossEntropyLoss
                    loss = criterion_val(outputs.logits, labels)
            else:
                outputs = model(pixel_values=images)
                loss = criterion_val(outputs.logits, labels)
                
            val_loss += loss.item()
            # Calculate Accuracy
            preds = torch.argmax(outputs.logits, dim=-1)
            val_correct += (preds == labels).sum().item()
            val_total += labels.size(0)
            
    val_acc = 100.0 * val_correct / val_total
    avg_val_loss = val_loss / len(val_loader)
    
    # --- 3. Logging & Saving ---
    # We record 0 for train_acc as it is not meaningful under Mixup
    history["train_loss"].append(avg_train_loss)
    history["train_acc"].append(0) 
    history["val_loss"].append(avg_val_loss)
    history["val_acc"].append(val_acc)
    
    # Save Best Model
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        patience_counter = 0
        torch.save({
            'model_state_dict': model.state_dict(),
            'val_acc': val_acc,
            'epoch': epoch + 1,
            'optimizer_state_dict': optimizer.state_dict()
        }, CONFIG["checkpoint_dir"] / "best_model.pth")
        print(f"‚úÖ New best model saved! Val Acc: {val_acc:.2f}%")
    else:
        patience_counter += 1
        print(f"‚è≥ No improvement ({patience_counter}/3). Best: {best_val_acc:.2f}%")
        
    # Update plots
    update_plots_v3()
    
    # Early Stopping Check
    if patience_counter >= 3:
        print("\n‚ö†Ô∏è Early Stopping Triggered.")
        break

print("\n" + "=" * 80)
print("‚úÖ TRAINING COMPLETE")
print("=" * 80)