In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision.models import vit_b_16
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms, models
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
import seaborn as sns


# **Étape 1 : Charger les données**
class SpectrogramDataset(Dataset):
    def __init__(self, data_dir, transform=None, exclude_classes=None):
        """
        Args:
            data_dir (str): Répertoire contenant les données organisées par classe.
            transform (callable, optional): Transformations à appliquer aux images.
            exclude_classes (list, optional): Liste des noms de classes à exclure.
        """
        self.data = []
        self.labels = []
        self.transform = transform
        self.classes = sorted([cls for cls in os.listdir(data_dir) if cls not in (exclude_classes or [])])

        for label, genre in enumerate(self.classes):
            genre_dir = os.path.join(data_dir, genre)
            for file in os.listdir(genre_dir):
                if file.endswith('.png'):
                    file_path = os.path.join(genre_dir, file)
                    self.data.append(file_path)
                    self.labels.append(label)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        img = Image.open(self.data[idx]).convert('RGB')
        label = self.labels[idx]
        if self.transform:
            img = self.transform(img)
        return img, label


# Transformations
image_size = 224
transform = transforms.Compose([
    transforms.Resize((image_size, image_size)),
    transforms.RandomHorizontalFlip(p=0.5),  # Flip horizontal aléatoire
    transforms.RandomRotation(degrees=15),   # Rotation aléatoire jusqu'à 15 degrés
    transforms.RandomResizedCrop(image_size, scale=(0.8, 1.0)),  # Zoom aléatoire
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

# Charger les données
data_dir = "/kaggle/input/music-genres/Music_Mel"
# Spécifiez la classe à exclure
exclude_classes = ["Unknown_MEL"]

# Initialiser le dataset avec la classe exclue
dataset = SpectrogramDataset(data_dir, transform=transform, exclude_classes=exclude_classes)

# Division en ensembles
train_data, test_data = train_test_split(range(len(dataset)), test_size=0.3, random_state=42)
val_data, test_data = train_test_split(test_data, test_size=0.5, random_state=42)

train_subset = torch.utils.data.Subset(dataset, train_data)
val_subset = torch.utils.data.Subset(dataset, val_data)
test_subset = torch.utils.data.Subset(dataset, test_data)

train_loader = DataLoader(train_subset, batch_size=16, shuffle=True)
val_loader = DataLoader(val_subset, batch_size=16, shuffle=False)
test_loader = DataLoader(test_subset, batch_size=16, shuffle=False)

print(f"Train size: {len(train_subset)}, Val size: {len(val_subset)}, Test size: {len(test_subset)}")

# **Étape 2 : Modèle préentraîné (ViT ou EfficientNet)**
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


# Charger un modèle ViT préentraîné
model = vit_b_16(pretrained=True)

# Modifier la tête de classification pour s'adapter au nombre de classes
model.heads = nn.Sequential(
    nn.Dropout(0.5),  # Dropout pour éviter le surapprentissage
    nn.Linear(model.heads.head.in_features, len(dataset.classes))  # Adapter le nombre de classes
)
model = model.to(device)

# Fonction de perte et optimiseur
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=1e-4)

def train(model, loader, criterion, optimizer, device):
    model.train()
    epoch_loss, correct, total = 0, 0, 0

    for imgs, labels in loader:
        imgs, labels = imgs.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(imgs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        epoch_loss += loss.item()
        _, predicted = outputs.max(1)
        correct += predicted.eq(labels).sum().item()
        total += labels.size(0)

    return epoch_loss / len(loader), correct / total

def evaluate(model, loader, criterion, device):
    model.eval()
    epoch_loss, correct, total = 0, 0, 0

    with torch.no_grad():
        for imgs, labels in loader:
            imgs, labels = imgs.to(device), labels.to(device)
            outputs = model(imgs)
            loss = criterion(outputs, labels)

            epoch_loss += loss.item()
            _, predicted = outputs.max(1)
            correct += predicted.eq(labels).sum().item()
            total += labels.size(0)

    return epoch_loss / len(loader), correct / total

class EarlyStopping:
    def __init__(self, patience=3, min_delta=0.0):
        """
        Args:
            patience (int): Nombre d'époques sans amélioration avant l'arrêt.
            min_delta (float): Amélioration minimale requise pour être considérée.
        """
        self.patience = patience
        self.min_delta = min_delta
        self.counter = 0
        self.best_loss = None
        self.early_stop = False

    def __call__(self, val_loss):
        if self.best_loss is None:
            self.best_loss = val_loss
        elif self.best_loss - val_loss > self.min_delta:
            self.best_loss = val_loss
            self.counter = 0  # Réinitialiser le compteur si amélioration
        else:
            self.counter += 1  # Incrémenter si pas d'amélioration
            if self.counter >= self.patience:
                self.early_stop = True


# Boucle d'entraînement
epochs = 250
train_losses, val_losses = [], []
train_accuracies, val_accuracies = [], []

# Instancier l'early stopping
early_stopping = EarlyStopping(patience=3, min_delta=0.01)

for epoch in range(epochs):
    train_loss, train_acc = train(model, train_loader, criterion, optimizer, device)
    val_loss, val_acc = evaluate(model, val_loader, criterion, device)

    # Stocker les métriques
    train_losses.append(train_loss)
    val_losses.append(val_loss)
    train_accuracies.append(train_acc)
    val_accuracies.append(val_acc)

    print(f"Epoch {epoch+1}/{epochs}")
    print(f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}")
    print(f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}")

    # Early stopping
    early_stopping(val_loss)
    if early_stopping.early_stop:
        print("Early stopping triggered. Stopping training...")
        break

# Évaluation finale sur l'ensemble de test
test_loss, test_acc = evaluate(model, test_loader, criterion, device)
print(f"Test Loss: {test_loss:.4f}, Test Acc: {test_acc:.4f}")

# Prédictions
all_preds, all_labels = [], []
model.eval()
with torch.no_grad():
    for imgs, labels in test_loader:
        imgs, labels = imgs.to(device), labels.to(device)
        outputs = model(imgs)
        _, preds = outputs.max(1)
        all_preds.extend(preds.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())

# Rapport de classification
print("\nClassification Report:\n", classification_report(all_labels, all_preds, target_names=dataset.classes))

# Matrice de confusion
conf_matrix = confusion_matrix(all_labels, all_preds)
plt.figure(figsize=(12, 8))
sns.heatmap(conf_matrix, annot=True, fmt='d', xticklabels=dataset.classes, yticklabels=dataset.classes, cmap='Blues')
plt.xlabel('Predicted')
plt.ylabel('True')
plt.title('Confusion Matrix')
plt.show()

# Tracer les courbes de performance
plt.figure(figsize=(12, 5))
plt.subplot(1, 2, 1)
plt.plot(range(1, len(train_losses) + 1), train_losses, label="Train Loss")
plt.plot(range(1, len(val_losses) + 1), val_losses, label="Validation Loss")
plt.title("Loss Curve")
plt.xlabel("Epochs")
plt.ylabel("Loss")
plt.legend()

plt.subplot(1, 2, 2)
plt.plot(range(1, len(train_accuracies) + 1), train_accuracies, label="Train Accuracy")
plt.plot(range(1, len(val_accuracies) + 1), val_accuracies, label="Validation Accuracy")
plt.title("Accuracy Curve")
plt.xlabel("Epochs")
plt.ylabel("Accuracy")
plt.legend()

plt.tight_layout()
plt.show()
