In [None]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms, models
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
import seaborn as sns

# **Étape 1 : Charger les données**
class SpectrogramDataset(Dataset):
    def __init__(self, data_dir, transform=None, exclude_classes=None):
        """
        Args:
            data_dir (str): Répertoire contenant les données organisées par classe.
            transform (callable, optional): Transformations à appliquer aux images.
            exclude_classes (list, optional): Liste des noms de classes à exclure.
        """
        self.data = []
        self.labels = []
        self.transform = transform
        self.classes = sorted([cls for cls in os.listdir(data_dir) if cls not in (exclude_classes or [])])

        for label, genre in enumerate(self.classes):
            genre_dir = os.path.join(data_dir, genre)
            for file in os.listdir(genre_dir):
                if file.endswith('.png'):
                    file_path = os.path.join(genre_dir, file)
                    self.data.append(file_path)
                    self.labels.append(label)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        img = Image.open(self.data[idx]).convert('RGB')
        label = self.labels[idx]
        if self.transform:
            img = self.transform(img)
        return img, label


# Transformations
image_size = 128
transform = transforms.Compose([
    transforms.Resize((image_size, image_size)),
    transforms.RandomHorizontalFlip(p=0.5),  # Flip horizontal aléatoire
    #transforms.RandomRotation(degrees=15),   # Rotation aléatoire jusqu'à 15 degrés
    #transforms.RandomResizedCrop(image_size, scale=(0.8, 1.0)),  # Zoom aléatoire
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

# Charger les données
data_dir = "/kaggle/input/music-genres/Music_Mel"
# Spécifiez la classe à exclure
exclude_classes = ["Unknown_MEL"]

# Initialiser le dataset avec la classe exclue
dataset = SpectrogramDataset(data_dir, transform=transform, exclude_classes=exclude_classes)


In [None]:

# Division en ensembles
train_data, test_data = train_test_split(range(len(dataset)), test_size=0.3, random_state=42)
val_data, test_data = train_test_split(test_data, test_size=0.5, random_state=42)

train_subset = torch.utils.data.Subset(dataset, train_data)
val_subset = torch.utils.data.Subset(dataset, val_data)
test_subset = torch.utils.data.Subset(dataset, test_data)

train_loader = DataLoader(train_subset, batch_size=16, shuffle=True)
val_loader = DataLoader(val_subset, batch_size=16, shuffle=False)
test_loader = DataLoader(test_subset, batch_size=16, shuffle=False)

print(f"Train size: {len(train_subset)}, Val size: {len(val_subset)}, Test size: {len(test_subset)}")

# **Étape 2 : Modèle préentraîné (ViT ou EfficientNet)**
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Modèle préentraîné
model = models.efficientnet_b0(pretrained=True)
# Modifier la tête de classification pour inclure Dropout
model.classifier[1] = nn.Sequential(
    nn.Dropout(0.5),  # Ajout de Dropout avec un taux de 50%
    nn.Linear(model.classifier[1].in_features, len(dataset.classes))
)
model = model.to(device)

# **Étape 3 : Optimiseur et fonction de perte**
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=1e-4)



In [None]:
# **Étape 4 : Entraînement**
def train(model, loader, criterion, optimizer, device):
    model.train()
    epoch_loss, correct, total = 0, 0, 0

    for imgs, labels in loader:
        imgs, labels = imgs.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(imgs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        epoch_loss += loss.item()
        _, predicted = outputs.max(1)
        correct += predicted.eq(labels).sum().item()
        total += labels.size(0)

    return epoch_loss / len(loader), correct / total

def evaluate_topk(model, loader, criterion, device, k_values=[1, 3, 5]):
    """
    Évalue le modèle et calcule les métriques Top-1, Top-3 et Top-5.
    Args:
        model: Modèle PyTorch.
        loader: DataLoader pour les données.
        criterion: Fonction de perte.
        device: Appareil (CPU/GPU).
        k_values: Liste des valeurs de k pour les métriques top-k.
    Returns:
        avg_loss: Perte moyenne.
        topk_accuracies: Dictionnaire contenant les précisions pour chaque k.
    """
    model.eval()
    epoch_loss, correct_topk = 0, {k: 0 for k in k_values}
    total = 0

    with torch.no_grad():
        for imgs, labels in loader:
            imgs, labels = imgs.to(device), labels.to(device)
            outputs = model(imgs)
            loss = criterion(outputs, labels)

            epoch_loss += loss.item()
            total += labels.size(0)

            # Obtenir les indices des k meilleures prédictions
            for k in k_values:
                _, topk_preds = torch.topk(outputs, k, dim=1)
                correct_topk[k] += (topk_preds == labels.view(-1, 1)).sum().item()

    # Moyenne des pertes et précisions Top-k
    avg_loss = epoch_loss / len(loader)
    topk_accuracies = {k: correct_topk[k] / total for k in k_values}

    return avg_loss, topk_accuracies

class EarlyStopping:
    def __init__(self, patience=3, min_delta=0.0):
        """
        Args:
            patience (int): Nombre d'époques sans amélioration avant l'arrêt.
            min_delta (float): Amélioration minimale requise pour être considérée.
        """
        self.patience = patience
        self.min_delta = min_delta
        self.counter = 0
        self.best_loss = None
        self.early_stop = False

    def __call__(self, val_loss):
        if self.best_loss is None:
            self.best_loss = val_loss
        elif self.best_loss - val_loss > self.min_delta:
            self.best_loss = val_loss
            self.counter = 0  # Réinitialiser le compteur si amélioration
        else:
            self.counter += 1  # Incrémenter si pas d'amélioration
            if self.counter >= self.patience:
                self.early_stop = True

# Initialisation des listes pour stocker les métriques
train_losses, val_losses = [], []
train_top1, val_top1 = [], []
train_top3, val_top3 = [], []
train_top5, val_top5 = [], []

# Instancier l'early stopping
early_stopping = EarlyStopping(patience=3, min_delta=0.01)

# Boucle d'entraînement avec enregistrement des métriques
epochs = 50
for epoch in range(epochs):
    train_loss, train_acc_topk = evaluate_topk(model, train_loader, criterion, device, k_values=[1, 3, 5])
    val_loss, val_acc_topk = evaluate_topk(model, val_loader, criterion, device, k_values=[1, 3, 5])

    # Stocker les pertes et précisions Top-k
    train_losses.append(train_loss)
    val_losses.append(val_loss)
    train_top1.append(train_acc_topk[1])
    train_top3.append(train_acc_topk[3])
    train_top5.append(train_acc_topk[5])
    val_top1.append(val_acc_topk[1])
    val_top3.append(val_acc_topk[3])
    val_top5.append(val_acc_topk[5])

    print(f"Epoch {epoch+1}/{epochs}")
    print(f"Train Loss: {train_loss:.4f}, Train Top-1: {train_acc_topk[1]:.4f}, Top-3: {train_acc_topk[3]:.4f}, Top-5: {train_acc_topk[5]:.4f}")
    print(f"Val Loss: {val_loss:.4f}, Val Top-1: {val_acc_topk[1]:.4f}, Top-3: {val_acc_topk[3]:.4f}, Top-5: {val_acc_topk[5]:.4f}")

    # Vérifier l'early stopping
    early_stopping(val_loss)
    if early_stopping.early_stop:
        print("Early stopping triggered. Stopping training...")
        break

# **Étape 5 : Évaluation finale sur l'ensemble de test**
test_loss, topk_accuracies = evaluate_topk(model, test_loader, criterion, device, k_values=[1, 3, 5])

print(f"Test Loss: {test_loss:.4f}")
print(f"Top-1 Accuracy: {topk_accuracies[1]:.4f}")
print(f"Top-3 Accuracy: {topk_accuracies[3]:.4f}")
print(f"Top-5 Accuracy: {topk_accuracies[5]:.4f}")

# Prédictions pour la matrice de confusion
all_preds, all_labels = [], []
model.eval()
with torch.no_grad():
    for imgs, labels in test_loader:
        imgs, labels = imgs.to(device), labels.to(device)
        outputs = model(imgs)
        _, preds = outputs.max(1)
        all_preds.extend(preds.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())

# Rapport de classification
print("\nClassification Report:\n", classification_report(all_labels, all_preds, target_names=dataset.classes))

# Matrice de confusion
conf_matrix = confusion_matrix(all_labels, all_preds)
plt.figure(figsize=(12, 8))
sns.heatmap(conf_matrix, annot=True, fmt='d', xticklabels=dataset.classes, yticklabels=dataset.classes, cmap='Blues')
plt.xlabel('Predicted')
plt.ylabel('True')
plt.title('Confusion Matrix')
plt.show()

In [None]:
import matplotlib.pyplot as plt

# **Tracer les courbes de perte et des précisions Top-k**
plt.figure(figsize=(15, 8))

# Courbe de perte
plt.subplot(2, 1, 1)
plt.plot(range(1, len(train_losses) + 1), train_losses, label='Train Loss')
plt.plot(range(1, len(val_losses) + 1), val_losses, label='Validation Loss')
plt.title('Loss Curve')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()

# Courbes des précisions Top-k
plt.subplot(2, 1, 2)
plt.plot(range(1, len(train_top1) + 1), train_top1, label='Train Top-1 Accuracy')
plt.plot(range(1, len(val_top1) + 1), val_top1, label='Validation Top-1 Accuracy')

plt.plot(range(1, len(train_top3) + 1), train_top3, label='Train Top-3 Accuracy')
plt.plot(range(1, len(val_top3) + 1), val_top3, label='Validation Top-3 Accuracy')

plt.plot(range(1, len(train_top5) + 1), train_top5, label='Train Top-5 Accuracy')
plt.plot(range(1, len(val_top5) + 1), val_top5, label='Validation Top-5 Accuracy')

plt.title('Top-k Accuracy Curve')
plt.xlabel('Epochs')
plt.ylabel('Accuracy')
plt.legend()

plt.tight_layout()
plt.show()


In [None]:
# Sauvegarde complète du modèle
full_model_save_path = "full_model_efficientnet_b0.pth"
torch.save(model, full_model_save_path)
print(f"Modèle complet sauvegardé à : {full_model_save_path}")
