In [15]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torchvision.models import mobilenet_v2
from torch.utils.data import DataLoader, random_split
from sqlalchemy.orm import Session, sessionmaker
from sqlalchemy import func, create_engine
from collections import Counter
from PIL import Image
from io import BytesIO
from db_setup import BirdImage

# Configuration du dispositif
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# Charger la base de données
DATABASE_URL = "sqlite:///./bird_data.db"
engine = create_engine(DATABASE_URL, connect_args={"check_same_thread": False})
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
db = SessionLocal()

# Définir la fonction pour obtenir les 10 classes les plus représentées
def get_top_classes(db: Session, top_n=10):
    class_counts = db.query(BirdImage.class_label, func.count(BirdImage.class_label)).group_by(BirdImage.class_label).all()
    sorted_classes = sorted(class_counts, key=lambda x: x[1], reverse=True)[:top_n]
    top_classes = [cls for cls, _ in sorted_classes]
    return top_classes

# Filtrer les classes avec le plus d'images
top_classes = get_top_classes(db, top_n=10)

# Créer un Dataset personnalisé pour les images
class FilteredDataset(torch.utils.data.Dataset):
    def __init__(self, samples, class_to_idx, transform=None):
        self.samples = samples
        self.class_to_idx = class_to_idx
        self.transform = transform

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        image_data, label = self.samples[idx]
        
        # Charger l'image depuis les données binaires
        image = Image.open(BytesIO(image_data)).convert('RGB')
        
        # Appliquer les transformations si elles sont définies
        if self.transform:
            image = self.transform(image)
        
        # Convertir l'étiquette en un indice numérique
        label_idx = self.class_to_idx[label]
        
        # Retourner l'image et l'étiquette sous forme de Tensor
        return image, label_idx

# Créer un mappage des étiquettes (chaînes) vers des indices numériques
class_to_idx = {cls: idx for idx, cls in enumerate(top_classes)}

# Transformations avec augmentation des données
transform = transforms.Compose([
    transforms.RandomResizedCrop((128, 128)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(15),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

# Charger les samples depuis la base de données
samples = []
for bird_image in db.query(BirdImage).filter(BirdImage.class_label.in_(top_classes)).all():
    samples.append((bird_image.image, bird_image.class_label))

# Créer le dataset filtré
filtered_dataset = FilteredDataset(samples, class_to_idx, transform)

# Division des données en ensembles d'entraînement, validation et test
train_size = int(0.7 * len(filtered_dataset))
val_size = int(0.15 * len(filtered_dataset))
test_size = len(filtered_dataset) - train_size - val_size
train_dataset, val_dataset, test_dataset = random_split(filtered_dataset, [train_size, val_size, test_size])

dataloaders = {
    'train': DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=0),
    'val': DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=0),
    'test': DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers=0)
}

# Charger MobileNetV2 avec fine-tuning
model = mobilenet_v2(weights='IMAGENET1K_V1')
model.classifier[1] = nn.Linear(model.last_channel, len(top_classes))
model = model.to(device)

# Optimisation
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.0005)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)

# Fonction d'entraînement avec diagnostic
def train_model(model, dataloaders, criterion, optimizer, scheduler, num_epochs=10):
    for epoch in range(num_epochs):
        print(f"Epoch {epoch+1}/{num_epochs}")
        print('-' * 10)
        model.train()

        running_loss = 0.0
        running_corrects = 0

        # Phase d'entraînement
        for inputs, labels in dataloaders['train']:
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()

            outputs = model(inputs)
            loss = criterion(outputs, labels)

            loss.backward()
            optimizer.step()

            running_loss += loss.item() * inputs.size(0)
            running_corrects += (outputs.argmax(1) == labels).sum().item()

        epoch_loss = running_loss / len(dataloaders['train'].dataset)
        epoch_acc = running_corrects / len(dataloaders['train'].dataset)
        print(f"Train Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}")

        # Phase de validation après chaque époque
        model.eval()
        val_loss = 0.0
        val_corrects = 0
        with torch.no_grad():
            for inputs, labels in dataloaders['val']:
                inputs, labels = inputs.to(device), labels.to(device)
                outputs = model(inputs)
                loss = criterion(outputs, labels)

                val_loss += loss.item() * inputs.size(0)
                val_corrects += (outputs.argmax(1) == labels).sum().item()

        val_loss = val_loss / len(dataloaders['val'].dataset)
        val_acc = val_corrects / len(dataloaders['val'].dataset)
        print(f"Val Loss: {val_loss:.4f} Acc: {val_acc:.4f}")

        scheduler.step()
    return model

# Fonction d'évaluation
def evaluate_model(model, dataloader):
    model.eval()
    correct = 0
    total = 0

    with torch.no_grad():
        for inputs, labels in dataloader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            preds = outputs.argmax(1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)

    return correct / total

# Entraîner le modèle
model = train_model(model, dataloaders, criterion, optimizer, scheduler, num_epochs=10)

# Évaluer sur le test
test_acc = evaluate_model(model, dataloaders['test'])
print(f"Test Accuracy: {test_acc:.4f}")


Epoch 1/10
----------
Train Loss: 1.6834 Acc: 0.4214
Val Loss: 1.1555 Acc: 0.6111
Epoch 2/10
----------
Train Loss: 1.0666 Acc: 0.6310
Val Loss: 1.2158 Acc: 0.6222
Epoch 3/10
----------
Train Loss: 0.8978 Acc: 0.6881
Val Loss: 1.2810 Acc: 0.5222
Epoch 4/10
----------
Train Loss: 0.7754 Acc: 0.7500
Val Loss: 1.0303 Acc: 0.6000
Epoch 5/10
----------
Train Loss: 0.8210 Acc: 0.7119
Val Loss: 1.2698 Acc: 0.5889
Epoch 6/10
----------
Train Loss: 0.8117 Acc: 0.7167
Val Loss: 1.5306 Acc: 0.5667
Epoch 7/10
----------
Train Loss: 0.7670 Acc: 0.7310
Val Loss: 1.0461 Acc: 0.7111
Epoch 8/10
----------
Train Loss: 0.6754 Acc: 0.7738
Val Loss: 1.0331 Acc: 0.6667
Epoch 9/10
----------
Train Loss: 0.5965 Acc: 0.7976
Val Loss: 0.8875 Acc: 0.7111
Epoch 10/10
----------
Train Loss: 0.5505 Acc: 0.8119
Val Loss: 0.8119 Acc: 0.7667
Test Accuracy: 0.7000
