In [1]:
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import time
import matplotlib.cm as cm

# Vérifier la disponibilité du GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda


In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torch.optim as optim
from torch.utils.data import DataLoader
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns

# Définir le device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Transformations
train_transforms = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

test_transforms = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Charger les données
trainset = datasets.CIFAR10(root='./data', train=True, download=True, transform=train_transforms)
testset = datasets.CIFAR10(root='./data', train=False, download=True, transform=test_transforms)

trainloader = DataLoader(trainset, batch_size=64, shuffle=True)
testloader = DataLoader(testset, batch_size=64, shuffle=False)

class_names = trainset.classes

# Modules d'attention
class SEBlock(nn.Module):
    def __init__(self, channels, reduction=16):
        super(SEBlock, self).__init__()
        self.global_avg_pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Sequential(
            nn.Linear(channels, channels // reduction, bias=False),
            nn.ReLU(inplace=True),
            nn.Linear(channels // reduction, channels, bias=False),
            nn.Sigmoid()
        )
    
    def forward(self, x):
        b, c, _, _ = x.size()
        y = self.global_avg_pool(x).view(b, c)
        y = self.fc(y).view(b, c, 1, 1)
        return x * y.expand_as(x)

class CBAMBlock(nn.Module):
    def __init__(self, channels, reduction=16):
        super(CBAMBlock, self).__init__()
        self.se_block = SEBlock(channels, reduction)
        self.spatial_att = nn.Sequential(
            nn.Conv2d(2, 1, kernel_size=7, padding=3, bias=False),
            nn.Sigmoid()
        )
    
    def forward(self, x):
        x = self.se_block(x)
        max_pool = torch.max(x, dim=1, keepdim=True)[0]
        avg_pool = torch.mean(x, dim=1, keepdim=True)
        spatial_att = torch.cat([max_pool, avg_pool], dim=1)
        spatial_att = self.spatial_att(spatial_att)
        return x * spatial_att.expand_as(x)

# Modification des modèles
def modify_model(base_model, model_name, use_senet=False, use_cbam=False):
    if model_name == "VGG16":
        features = list(base_model.features.children())
        channels = 512  # Nombre de canaux dans la dernière couche de VGG16
    else:
        raise ValueError("Modèle non supporté")
    
    new_layers = []
    for layer in features:
        new_layers.append(layer)
        if isinstance(layer, nn.Conv2d):
            if use_senet:
                new_layers.append(SEBlock(layer.out_channels))
            if use_cbam:
                new_layers.append(CBAMBlock(layer.out_channels))
    
    return nn.Sequential(
        nn.Sequential(*new_layers),
        nn.AdaptiveAvgPool2d((1, 1)),
        nn.Flatten(),
        nn.Linear(channels, 10)  # 10 classes pour CIFAR-10
    ).to(device)

# Initialisation des modèles
models_dict = {
    "VGG16": modify_model(models.vgg16(weights=models.VGG16_Weights.DEFAULT), "VGG16"),
    "VGG16+SENet": modify_model(models.vgg16(weights=models.VGG16_Weights.DEFAULT), "VGG16", use_senet=True),
    "VGG16+CBAM": modify_model(models.vgg16(weights=models.VGG16_Weights.DEFAULT), "VGG16", use_cbam=True),
    "VGG16+CBAM+SENet": modify_model(models.vgg16(weights=models.VGG16_Weights.DEFAULT), "VGG16", use_senet=True, use_cbam=True),
}

# Fonction d'entraînement avec validation et suivi des métriques
def train_model_with_validation(model, trainloader, testloader, criterion, optimizer, scheduler, epochs=30, patience=5):
    model.train()
    train_losses = []
    val_losses = []
    train_accuracies = []
    val_accuracies = []

    best_val_loss = float('inf')
    trigger_times = 0  # Compteur pour Early Stopping
    
    for epoch in range(epochs):
        # Phase d'entraînement
        model.train()
        running_loss = 0.0
        correct_train = 0
        total_train = 0
        
        for images, labels in trainloader:
            images, labels = images.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total_train += labels.size(0)
            correct_train += (predicted == labels).sum().item()
        
        train_loss = running_loss / len(trainloader)
        train_accuracy = correct_train / total_train
        train_losses.append(train_loss)
        train_accuracies.append(train_accuracy)
        
        # Phase de validation
        model.eval()
        val_loss = 0.0
        correct_val = 0
        total_val = 0
        all_preds = []
        all_labels = []
        
        with torch.no_grad():
            for images, labels in testloader:
                images, labels = images.to(device), labels.to(device)
                outputs = model(images)
                loss = criterion(outputs, labels)
                val_loss += loss.item()
                
                _, predicted = torch.max(outputs.data, 1)
                total_val += labels.size(0)
                correct_val += (predicted == labels).sum().item()
                all_preds.extend(predicted.cpu().numpy())
                all_labels.extend(labels.cpu().numpy())
        
        val_loss /= len(testloader)
        val_accuracy = correct_val / total_val
        val_losses.append(val_loss)
        val_accuracies.append(val_accuracy)
        
        # Affichage des métriques
        print(f"Epoch [{epoch+1}/{epochs}]")
        print(f"Train Loss: {train_loss:.4f}, Train Acc: {train_accuracy:.4f}")
        print(f"Val Loss: {val_loss:.4f}, Val Acc: {val_accuracy:.4f}")
        print("-" * 50)

        # Early Stopping
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            trigger_times = 0  # Réinitialiser le compteur
        else:
            trigger_times += 1
            if trigger_times >= patience:
                print("Early stopping triggered!")
                break  # Arrêter l'entraînement
        
        # Réduction du taux d'apprentissage
        scheduler.step(val_loss)
    
    return train_losses, val_losses, train_accuracies, val_accuracies, all_preds, all_labels

# Fonction pour afficher les métriques de classification
def print_classification_metrics(all_labels, all_preds, class_names):
    print("\nClassification Metrics:")
    print(f"Accuracy: {accuracy_score(all_labels, all_preds):.4f}")
    print(f"Precision: {precision_score(all_labels, all_preds, average='macro'):.4f}")
    print(f"Recall: {recall_score(all_labels, all_preds, average='macro'):.4f}")
    print(f"F1 Score: {f1_score(all_labels, all_preds, average='macro'):.4f}")
    
    # Matrice de confusion
    cm = confusion_matrix(all_labels, all_preds)
    plt.figure(figsize=(10, 8))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', 
                xticklabels=class_names, yticklabels=class_names)
    plt.title('Confusion Matrix')
    plt.xlabel('Predicted')
    plt.ylabel('True')
    plt.show()

# Fonction pour tracer les courbes d'apprentissage
def plot_learning_curves(train_losses, val_losses, train_accuracies, val_accuracies, model_name):
    plt.figure(figsize=(12, 5))
    
    # Courbes de perte
    plt.subplot(1, 2, 1)
    plt.plot(train_losses, label='Train Loss')
    plt.plot(val_losses, label='Validation Loss')
    plt.title(f'{model_name} - Loss Curves')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.legend()
    
    # Courbes d'accuracy
    plt.subplot(1, 2, 2)
    plt.plot(train_accuracies, label='Train Accuracy')
    plt.plot(val_accuracies, label='Validation Accuracy')
    plt.title(f'{model_name} - Accuracy Curves')
    plt.xlabel('Epochs')
    plt.ylabel('Accuracy')
    plt.legend()
    
    plt.tight_layout()
    plt.show()

# Entraînement et évaluation des modèles
criterion = nn.CrossEntropyLoss()
epochs = 30
patience = 5
results = {}

for name, model in models_dict.items():
    print(f"\nEntraînement et évaluation de {name}...")

    # Optimiseur avec Weight Decay
    optimizer = optim.Adam(model.parameters(), lr=0.0001, weight_decay=1e-4)  # Taux d'apprentissage réduit

    # Scheduler pour réduire le taux d'apprentissage
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=3)
    
    # Entraînement avec validation
    train_losses, val_losses, train_accuracies, val_accuracies, all_preds, all_labels = train_model_with_validation(
        model, trainloader, testloader, criterion, optimizer, scheduler, epochs, patience)
    
    # Stocker les résultats
    results[name] = {
        'train_losses': train_losses,
        'val_losses': val_losses,
        'train_accuracies': train_accuracies,
        'val_accuracies': val_accuracies,
        'predictions': all_preds,
        'labels': all_labels
    }
    
    # Affichage des métriques de classification
    print_classification_metrics(all_labels, all_preds, class_names)
    
    # Tracer les courbes d'apprentissage
    plot_learning_curves(train_losses, val_losses, train_accuracies, val_accuracies, name)


Entraînement et évaluation de VGG16...


KeyboardInterrupt: 