In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision import datasets, transforms, models
import time
import numpy as np
from sklearn.metrics import accuracy_score, f1_score, confusion_matrix
import matplotlib.pyplot as plt
import struct
from einops import rearrange, repeat
from einops.layers.torch import Rearrange

# ==================== DEVICE SETUP ====================
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# ==================== CUSTOM MNIST LOADER ====================
def load_mnist_images(filepath):
    """Charge les images MNIST depuis un fichier .ubyte"""
    with open(filepath, 'rb') as f:
        magic, num_images, rows, cols = struct.unpack('>4I', f.read(16))
        data = np.frombuffer(f.read(), dtype=np.uint8)
        return data.reshape(num_images, rows, cols)

def load_mnist_labels(filepath):
    """Charge les labels MNIST depuis un fichier .ubyte"""
    with open(filepath, 'rb') as f:
        magic, num_labels = struct.unpack('>2I', f.read(8))
        data = np.frombuffer(f.read(), dtype=np.uint8)
        return data

class MNISTDataset(Dataset):
    """Dataset personnalisé pour MNIST"""
    def __init__(self, images, labels, transform=None):
        self.images = torch.FloatTensor(images).unsqueeze(1) / 255.0
        self.labels = torch.LongTensor(labels)
        self.transform = transform
    
    def __len__(self):
        return len(self.labels)
    
    def __getitem__(self, idx):
        img = self.images[idx]
        label = self.labels[idx]
        if self.transform:
            img = self.transform(img)
        return img, label

# ==================== DATA LOADING ====================
print("\nChargement du dataset MNIST...")

transform = transforms.Compose([
    transforms.Normalize((0.1307,), (0.3081,))
])

# Essayer de charger depuis Kaggle
kaggle_path = '/kaggle/input/mnist-dataset/mnist'
local_path = './data'

try:
    train_images = load_mnist_images(f'{kaggle_path}/train-images-idx3-ubyte')
    train_labels = load_mnist_labels(f'{kaggle_path}/train-labels-idx1-ubyte')
    test_images = load_mnist_images(f'{kaggle_path}/t10k-images-idx3-ubyte')
    test_labels = load_mnist_labels(f'{kaggle_path}/t10k-labels-idx1-ubyte')
    
    train_dataset = MNISTDataset(train_images, train_labels, transform=transform)
    test_dataset = MNISTDataset(test_images, test_labels, transform=transform)
    print("✓ Dataset chargé depuis Kaggle")
except FileNotFoundError:
    try:
        train_images = load_mnist_images(f'{local_path}/train-images-idx3-ubyte')
        train_labels = load_mnist_labels(f'{local_path}/train-labels-idx1-ubyte')
        test_images = load_mnist_images(f'{local_path}/t10k-images-idx3-ubyte')
        test_labels = load_mnist_labels(f'{local_path}/t10k-labels-idx1-ubyte')
        
        train_dataset = MNISTDataset(train_images, train_labels, transform=transform)
        test_dataset = MNISTDataset(test_images, test_labels, transform=transform)
        print("✓ Dataset chargé depuis ./data")
    except FileNotFoundError:
        print("Téléchargement depuis torchvision...")
        train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transforms.ToTensor())
        test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transforms.ToTensor())
        print("✓ Dataset téléchargé")

train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False)

# ==================== PART 2: VISION TRANSFORMER FROM SCRATCH ====================
print("\n" + "="*70)
print("PART 2: VISION TRANSFORMER (ViT) FROM SCRATCH")
print("="*70)

class PatchEmbedding(nn.Module):
    """Convertit une image en patches et les embed"""
    def __init__(self, img_size=28, patch_size=4, in_channels=1, embed_dim=256):
        super().__init__()
        self.img_size = img_size
        self.patch_size = patch_size
        self.num_patches = (img_size // patch_size) ** 2
        self.proj = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size)
    
    def forward(self, x):
        # x: (B, C, H, W)
        x = self.proj(x)  # (B, embed_dim, H/patch_size, W/patch_size)
        x = rearrange(x, 'b e h w -> b (h w) e')  # (B, num_patches, embed_dim)
        return x

class MultiHeadAttention(nn.Module):
    """Multi-Head Self-Attention"""
    def __init__(self, embed_dim=256, num_heads=8, attn_drop=0.0, proj_drop=0.0):
        super().__init__()
        self.num_heads = num_heads
        self.embed_dim = embed_dim
        self.head_dim = embed_dim // num_heads
        self.scale = self.head_dim ** -0.5
        
        self.qkv = nn.Linear(embed_dim, embed_dim * 3)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(embed_dim, embed_dim)
        self.proj_drop = nn.Dropout(proj_drop)
    
    def forward(self, x):
        B, N, C = x.shape
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]
        
        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)
        
        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x

class TransformerBlock(nn.Module):
    """Transformer Block = MHA + MLP"""
    def __init__(self, embed_dim=256, num_heads=8, mlp_dim=512, dropout=0.1):
        super().__init__()
        self.norm1 = nn.LayerNorm(embed_dim)
        self.attn = MultiHeadAttention(embed_dim, num_heads, dropout, dropout)
        self.norm2 = nn.LayerNorm(embed_dim)
        
        self.mlp = nn.Sequential(
            nn.Linear(embed_dim, mlp_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(mlp_dim, embed_dim),
            nn.Dropout(dropout)
        )
    
    def forward(self, x):
        x = x + self.attn(self.norm1(x))
        x = x + self.mlp(self.norm2(x))
        return x

class VisionTransformer(nn.Module):
    """Vision Transformer"""
    def __init__(self, img_size=28, patch_size=4, num_classes=10, 
                 embed_dim=256, num_heads=8, depth=12, mlp_dim=512, dropout=0.1):
        super().__init__()
        
        # Patch Embedding
        self.patch_embed = PatchEmbedding(img_size, patch_size, 1, embed_dim)
        num_patches = self.patch_embed.num_patches
        
        # Class token et position embedding
        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
        self.pos_drop = nn.Dropout(dropout)
        
        # Transformer blocks
        self.transformer = nn.Sequential(*[
            TransformerBlock(embed_dim, num_heads, mlp_dim, dropout)
            for _ in range(depth)
        ])
        
        # Classification head
        self.norm = nn.LayerNorm(embed_dim)
        self.head = nn.Linear(embed_dim, num_classes)
        
        # Initialize weights
        nn.init.trunc_normal_(self.pos_embed, std=0.02)
        nn.init.trunc_normal_(self.cls_token, std=0.02)
    
    def forward(self, x):
        B = x.shape[0]
        
        # Patch embedding
        x = self.patch_embed(x)  # (B, num_patches, embed_dim)
        
        # Add class token
        cls_token = repeat(self.cls_token, '1 1 d -> b 1 d', b=B)
        x = torch.cat([cls_token, x], dim=1)  # (B, num_patches+1, embed_dim)
        
        # Add position embedding
        x = x + self.pos_embed
        x = self.pos_drop(x)
        
        # Transformer blocks
        x = self.transformer(x)
        
        # Classification
        x = self.norm(x)
        x = x[:, 0]  # Take class token
        x = self.head(x)
        
        return x

def train_vit(model, train_loader, test_loader, num_epochs=10):
    """Entraîne le Vision Transformer"""
    model = model.to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=0.0001)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs)
    
    start_time = time.time()
    train_losses, test_losses, train_accs, test_accs = [], [], [], []
    
    for epoch in range(num_epochs):
        # Training
        model.train()
        running_loss = 0
        correct = 0
        total = 0
        
        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)
            optimizer.zero_grad()
            
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
        
        train_loss = running_loss / len(train_loader)
        train_acc = correct / total
        train_losses.append(train_loss)
        train_accs.append(train_acc)
        
        # Testing
        model.eval()
        test_loss = 0
        correct = 0
        total = 0
        
        with torch.no_grad():
            for images, labels in test_loader:
                images, labels = images.to(device), labels.to(device)
                outputs = model(images)
                loss = criterion(outputs, labels)
                test_loss += loss.item()
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
        
        test_loss = test_loss / len(test_loader)
        test_acc = correct / total
        test_losses.append(test_loss)
        test_accs.append(test_acc)
        
        scheduler.step()
        
        print(f"Epoch [{epoch+1}/{num_epochs}] - Train Loss: {train_loss:.4f}, Train Acc: {train_acc*100:.2f}%, "
              f"Test Loss: {test_loss:.4f}, Test Acc: {test_acc*100:.2f}%")
    
    training_time = time.time() - start_time
    return model, train_losses, test_losses, train_accs, test_accs, training_time

def evaluate_model(model, test_loader):
    """Évalue le modèle"""
    model.eval()
    all_preds, all_labels = [], []
    
    with torch.no_grad():
        for images, labels in test_loader:
            images = images.to(device)
            outputs = model(images)
            _, preds = torch.max(outputs, 1)
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.numpy())
    
    accuracy = accuracy_score(all_labels, all_preds)
    f1 = f1_score(all_labels, all_preds, average='weighted')
    
    return accuracy, f1, all_preds, all_labels

# Train ViT
print("\n--- Vision Transformer Training ---")
vit_model = VisionTransformer(
    img_size=28,
    patch_size=4,
    num_classes=10,
    embed_dim=256,
    num_heads=8,
    depth=12,
    mlp_dim=512,
    dropout=0.1
)

vit_model, vit_train_loss, vit_test_loss, vit_train_acc, vit_test_acc, vit_time = train_vit(
    vit_model, train_loader, test_loader, num_epochs=10
)

vit_final_acc, vit_f1, vit_preds, vit_labels = evaluate_model(vit_model, test_loader)

print(f"\nVision Transformer Final Results:")
print(f"Accuracy: {vit_final_acc:.4f}")
print(f"F1 Score: {vit_f1:.4f}")
print(f"Training Time: {vit_time:.2f}s")
print(f"Final Test Loss: {vit_test_loss[-1]:.4f}")

# ==================== PART 3: COMPARISON ALL MODELS ====================
print("\n" + "="*70)
print("PART 3: COMPREHENSIVE COMPARISON - ALL MODELS")
print("="*70)

# Résultats de la Part 1 (vous pouvez les charger ou les entrer manuellement)
# Pour cette démo, nous allons créer des modèles simples de Part 1

# CNN Simple
class CNN(nn.Module):
    def __init__(self, in_channels=1, num_classes=10):
        super(CNN, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2)
        )
        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(64 * 7 * 7, 128),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(128, num_classes)
        )

    def forward(self, x):
        x = self.features(x)
        x = self.classifier(x)
        return x

def quick_train(model, train_loader, test_loader, epochs=5, lr=0.001):
    """Entraînement rapide pour comparaison"""
    model = model.to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=lr)
    
    start_time = time.time()
    
    for epoch in range(epochs):
        model.train()
        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
        
        if (epoch + 1) % epochs == 0:
            model.eval()
            with torch.no_grad():
                for images, labels in test_loader:
                    break
    
    training_time = time.time() - start_time
    
    model.eval()
    all_preds, all_labels = [], []
    with torch.no_grad():
        for images, labels in test_loader:
            images = images.to(device)
            outputs = model(images)
            _, preds = torch.max(outputs, 1)
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.numpy())
    
    accuracy = accuracy_score(all_labels, all_preds)
    f1 = f1_score(all_labels, all_preds, average='weighted')
    
    return accuracy, f1, training_time

print("\nEntraînement des modèles de Part 1 pour comparaison...")
cnn_model = CNN()
cnn_acc, cnn_f1, cnn_time = quick_train(cnn_model, train_loader, test_loader, epochs=5)
print(f"CNN - Accuracy: {cnn_acc:.4f}, F1: {cnn_f1:.4f}, Time: {cnn_time:.2f}s")

# ==================== FINAL COMPARISON TABLE ====================
print("\n" + "="*70)
print("COMPARISON RESULTS")
print("="*70)

comparison_results = {
    'CNN': {
        'Accuracy': cnn_acc,
        'F1 Score': cnn_f1,
        'Training Time': cnn_time,
        'Parameters': sum(p.numel() for p in cnn_model.parameters())
    },
    'Vision Transformer': {
        'Accuracy': vit_final_acc,
        'F1 Score': vit_f1,
        'Training Time': vit_time,
        'Parameters': sum(p.numel() for p in vit_model.parameters())
    }
}

print(f"\n{'Model':<20} {'Accuracy':<15} {'F1 Score':<15} {'Time(s)':<15} {'Parameters':<15}")
print("-" * 85)
for model_name, metrics in comparison_results.items():
    print(f"{model_name:<20} {metrics['Accuracy']:.4f}         {metrics['F1 Score']:.4f}         "
          f"{metrics['Training Time']:.2f}         {metrics['Parameters']:<15}")

# ==================== INTERPRETATION & ANALYSIS ====================
print("\n" + "="*70)
print("INTERPRETATION & ANALYSIS")
print("="*70)

print("\n1. ACCURACY COMPARISON:")
acc_diff = vit_final_acc - cnn_acc
print(f"   Vision Transformer: {vit_final_acc:.4f}")
print(f"   CNN: {cnn_acc:.4f}")
print(f"   Difference: {acc_diff:+.4f} {'(ViT Better)' if acc_diff > 0 else '(CNN Better)'}")

print("\n2. COMPUTATIONAL EFFICIENCY:")
time_diff = vit_time - cnn_time
print(f"   Vision Transformer: {vit_time:.2f}s")
print(f"   CNN: {cnn_time:.2f}s")
print(f"   Time Difference: {time_diff:+.2f}s")

print("\n3. MODEL COMPLEXITY:")
vit_params = sum(p.numel() for p in vit_model.parameters())
cnn_params = sum(p.numel() for p in cnn_model.parameters())
print(f"   Vision Transformer: {vit_params:,} parameters")
print(f"   CNN: {cnn_params:,} parameters")
print(f"   Ratio: {vit_params/cnn_params:.2f}x")

print("\n4. KEY INSIGHTS:")
print("   • Vision Transformers capture global dependencies via self-attention")
print("   • CNNs are more efficient for small images like MNIST (28x28)")
print("   • ViT requires more data and computation, benefits more from large datasets")
print("   • For MNIST: CNN likely performs better due to task simplicity")
print("   • ViT architecture is more versatile for complex vision tasks")

# ==================== VISUALIZATION ====================
fig, axes = plt.subplots(2, 2, figsize=(15, 10))

# Loss Comparison
axes[0, 0].plot(range(1, len(vit_train_loss)+1), vit_train_loss, 'b-o', label='ViT Train', linewidth=2)
axes[0, 0].plot(range(1, len(vit_test_loss)+1), vit_test_loss, 'b--s', label='ViT Test', linewidth=2)
axes[0, 0].set_title('Vision Transformer - Loss', fontsize=12, fontweight='bold')
axes[0, 0].set_xlabel('Epoch')
axes[0, 0].set_ylabel('Loss')
axes[0, 0].legend()
axes[0, 0].grid(True, alpha=0.3)

# Accuracy Comparison
axes[0, 1].plot(range(1, len(vit_train_acc)+1), [a*100 for a in vit_train_acc], 'g-o', label='ViT Train', linewidth=2)
axes[0, 1].plot(range(1, len(vit_test_acc)+1), [a*100 for a in vit_test_acc], 'g--s', label='ViT Test', linewidth=2)
axes[0, 1].set_title('Vision Transformer - Accuracy', fontsize=12, fontweight='bold')
axes[0, 1].set_xlabel('Epoch')
axes[0, 1].set_ylabel('Accuracy (%)')
axes[0, 1].legend()
axes[0, 1].grid(True, alpha=0.3)

# Model Comparison Bar Chart
models = list(comparison_results.keys())
accuracies = [comparison_results[m]['Accuracy'] for m in models]
colors = ['#1f77b4', '#ff7f0e']
axes[1, 0].bar(models, accuracies, color=colors, alpha=0.7, edgecolor='black', linewidth=2)
axes[1, 0].set_title('Accuracy Comparison', fontsize=12, fontweight='bold')
axes[1, 0].set_ylabel('Accuracy')
axes[1, 0].set_ylim([0, 1])
for i, v in enumerate(accuracies):
    axes[1, 0].text(i, v + 0.02, f'{v:.4f}', ha='center', fontweight='bold')
axes[1, 0].grid(True, alpha=0.3, axis='y')

# Training Time Comparison
times = [comparison_results[m]['Training Time'] for m in models]
axes[1, 1].bar(models, times, color=colors, alpha=0.7, edgecolor='black', linewidth=2)
axes[1, 1].set_title('Training Time Comparison', fontsize=12, fontweight='bold')
axes[1, 1].set_ylabel('Time (seconds)')
for i, v in enumerate(times):
    axes[1, 1].text(i, v + max(times)*0.02, f'{v:.2f}s', ha='center', fontweight='bold')
axes[1, 1].grid(True, alpha=0.3, axis='y')

plt.tight_layout()
plt.savefig('vit_comparison.png', dpi=100, bbox_inches='tight')
plt.show()

print("\n✓ Vision Transformer Analysis Complete!")
print("✓ Comparison visualization saved as 'vit_comparison.png'")