# Entraînement Fader Network

In [None]:
import os
import sys
import torch
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm
import matplotlib.pyplot as plt
import numpy as np
import random
from torch.utils.tensorboard import SummaryWriter


# Ajouter le chemin racine au PYTHONPATH pour pouvoir importer les modules personnalisés
# sys.path.append(os.path.abspath('.'))

from Models.fader_network import FaderNetwork
from Models.discriminator import Discriminator
from Data.preprocess import create_data_file
from Training.train import Evaluator, Trainer
from Training.losses import attr_loss

## Configuration des paramètres

In [None]:
selected_attrs = ['Smiling', 'Male', 'Eyeglasses', 'Young', 'Mouth_Slightly_Open']
num_attributes = len(selected_attrs)
latent_dim = 512
batch_size = 32
epochs = 100
lr = 0.002
beta1 = 0.5
step = 202599 # On selectionne toute les images de la BD
n_img = 202599 # Toute les images de la base de données 
generator = create_data_file(selected_attrs, step = step, batch_size = 32, n_img = n_img)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Utilisation du device : {device}")

# Calcul du scheduler_steps en fonction de l'ensemble d'entraînement
total_iterations = ((step/2) // batch_size) * epochs
scheduler_steps = int(total_iterations*0.3)
print(f"Total iterations for scheduler: {scheduler_steps}")

# Initialiser les modèles
auto_encoder = FaderNetwork(attribute_dim=num_attributes, attributes=selected_attrs).to(device)
discriminator = Discriminator(n_attr=num_attributes, attributes=selected_attrs).to(device)

print("Modèles initialisés.")

# Trainer and Evaluator
trainer = Trainer(auto_encoder, discriminator, scheduler_steps)
evaluator = Evaluator(auto_encoder, discriminator)

print("Trainer/Evaluator initialisés.")

#les optimisateurs
optimizer_enc_dec = optim.Adam(auto_encoder.parameters(), lr=lr, betas=(beta1, 0.999))
optimizer_dis = optim.Adam(discriminator.parameters(), lr=lr, betas=(beta1, 0.999))

#fonctions de perte
mse_loss_fn = nn.MSELoss()
ce_loss_fn = attr_loss

## Début de l'entraînement

In [None]:
writer = SummaryWriter(log_dir='./logs/fader_networks')

# best_val_loss = float('inf')  # Pour suivre la meilleure perte de validation

for i, base_loader in enumerate(generator):
    (train_loader, valid_loader, test_loader) = base_loader
    for epoch in range(1, epochs+1):

        trainer.init_parameters()
        evaluator.init_parameters()

        for (X, y) in tqdm(train_loader, desc=f"Epoch : {epoch}"):
            X, y = X.to(device), y.to(device)
            trainer.latent_step(X=X, y=y, criterion=ce_loss_fn, optim=optimizer_dis)
            trainer.ae_step(X=X, y=y, criterion_ae=mse_loss_fn, criterion_latent=ce_loss_fn, optim=optimizer_enc_dec)
        
        avg_ae_loss = trainer.ae_train_loss / len(train_loader)
        avg_dis_loss = trainer.latent_train_loss / len(train_loader)
        
       
        # Log dans TensorBoard
        writer.add_scalar('Loss/AE_Train', avg_ae_loss, epoch)
        writer.add_scalar('Loss/Discriminator_Train', avg_dis_loss, epoch)
        
        # Afficher les pertes de l'époque
        print(f"Epoch [{epoch}/{epochs}] | AE_Train: {avg_ae_loss:.4f} | Discriminator_Train: {avg_dis_loss:.4f}")

        with torch.no_grad():
            for l, (X, y) in enumerate(valid_loader):
                X, y = X.to(device), y.to(device)
                evaluator.latent_step(X=X, y=y, criterion=ce_loss_fn)
                _ = evaluator.ae_step(X=X, y=y, criterion=mse_loss_fn)
            
            avg_val_ae_loss = evaluator.ae_evaluate_loss / len(valid_loader)
            avg_val_dis_loss = evaluator.latent_evaluate_loss / len(valid_loader)
            accuracy_dis = evaluator.latent_evaluate_accuracy / (len(valid_loader.dataset) * num_attributes)
        
            # Log dans TensorBoard
            writer.add_scalar('Validation/AE_Validation', avg_val_ae_loss, epoch)
            writer.add_scalar('Validation/Discriminator_Validation', avg_val_dis_loss, epoch)
            writer.add_scalar('Validation/Accuracy_Discriminator', accuracy_dis, epoch)
            
        print(f"Epoch Val [{epoch}/{epochs}] | AE_Validation: {avg_val_ae_loss:.4f} | Discriminator_Validation: {avg_val_dis_loss:.4f} | accuracy_dis: {accuracy_dis:.6f}")
                        
        # Enregistrement des modèles
        # os.makedirs('./Models/trained_model/', exist_ok=True)
        # torch.save(auto_encoder.state_dict(), f'./Models/trained_model/TrainBD_50_100_{i}/auto_encoder_epoch_{epoch}.pth')
        # torch.save(discriminator.state_dict(), f'./Models/trained_model/TrainBD_50_100_{i}/discriminator_epoch_{epoch}.pth')
        # print(f"Modèles enregistrés pour l'époque {epoch}")