In [1]:
import argparse
import os
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from models_mel import MultimodalAE
from src.dataset import generate_datasets
from src.functions import Log
from src.config import config as default_config
from sklearn.preprocessing import StandardScaler
import sys 
from sklearn.model_selection import KFold
from torch.utils.data import Subset

# Fonction d'entraînement du modèle 

In [3]:
import pandas as pd

def train_model_with_cross_validation(datasets, model, optimizer, criterion, num_epochs=10, n_splits=10):
    from torch.utils.data import Subset
    from torch.utils.data import DataLoader
    from sklearn.model_selection import KFold
    import torch
    
    kf = KFold(n_splits=n_splits) 
    all_fold_losses = []

    for fold, (train_index, test_index) in enumerate(kf.split(datasets[0])):
        print(f'Fold {fold+1}')
        fold_losses = []

        train_datasets = [Subset(dataset, train_index) for dataset in datasets]
        test_datasets = [Subset(dataset, test_index) for dataset in datasets]

        train_loaders = [DataLoader(dataset, batch_size=32, shuffle=True) for dataset in train_datasets]
        test_loaders = [DataLoader(dataset, batch_size=32, shuffle=False) for dataset in test_datasets]

        epoch_losses_per_fold = []
        for epoch in range(num_epochs):
            model.train()
            epoch_losses = []
            for (x1, _), (x2, _), (y, _) in zip(*train_loaders):
                if x1.size(0) != 32 or x2.size(0) != 32 or y.size(0) != 32:
                    continue
                outputs = model(x1, x2)
                loss = criterion(outputs, y)
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                epoch_losses.append(loss.item()) 

            avg_loss = sum(epoch_losses) / len(epoch_losses)  
            fold_losses.append(avg_loss)  

        all_fold_losses.append(fold_losses)  

    df = pd.DataFrame(all_fold_losses, columns=[f'Epoch {i+1}' for i in range(num_epochs)])

    df['Average Loss'] = df.mean(axis=1)
    df.index = [f'Fold {i+1}' for i in range(n_splits)]

    return df



## Générer modalité 3 à partir des modalités 1 et 2 


In [4]:
train_datasets = generate_datasets(suffix='5_diff', type='paired', train=True, test=False)
test_datasets = generate_datasets(suffix='5_diff', type='paired', train=False, test=True)


train_loaders = [DataLoader(dataset, batch_size=32, shuffle=True) for dataset in train_datasets]
test_loaders = [DataLoader(dataset, batch_size=32, shuffle=False) for dataset in test_datasets]


n_inputs1 = train_datasets[0][0][0].shape[0]  # La taille du vecteur de caractéristiques pour la modalité 1
n_inputs2 = train_datasets[1][0][0].shape[0]  # La taille du vecteur de caractéristiques pour la modalité 2
n_outputs = train_datasets[2][0][0].shape[0]  # La taille du vecteur de caractéristiques pour la modalité 3


datasets = generate_datasets(suffix='5_diff', type='paired', train=True, test=False)
datasets = [list(dataset) for dataset in datasets]

latent_dims = 10
n_hiddens = 256

model = MultimodalAE(n_inputs1=n_inputs1, n_inputs2=n_inputs2, latent_dims=latent_dims, n_hiddens=n_hiddens, n_outputs=n_outputs)
optimizer = torch.optim.Adam(model.parameters())

criterion = nn.MSELoss()
train_model_with_cross_validation(datasets, model, optimizer, criterion, num_epochs=10, n_splits=10)

Loading paired dataset


  if not hasattr(array, "sparse") and array.dtypes.apply(is_sparse).any():
  if is_sparse(pd_dtype):
  if is_sparse(pd_dtype) or not is_extension_array_dtype(pd_dtype):
  if not hasattr(array, "sparse") and array.dtypes.apply(is_sparse).any():
  if is_sparse(pd_dtype):
  if is_sparse(pd_dtype) or not is_extension_array_dtype(pd_dtype):
  if not hasattr(array, "sparse") and array.dtypes.apply(is_sparse).any():
  if is_sparse(pd_dtype):
  if is_sparse(pd_dtype) or not is_extension_array_dtype(pd_dtype):
  if not hasattr(array, "sparse") and array.dtypes.apply(is_sparse).any():
  if is_sparse(pd_dtype):
  if is_sparse(pd_dtype) or not is_extension_array_dtype(pd_dtype):
  if not hasattr(array, "sparse") and array.dtypes.apply(is_sparse).any():
  if is_sparse(pd_dtype):
  if is_sparse(pd_dtype) or not is_extension_array_dtype(pd_dtype):
  if not hasattr(array, "sparse") and array.dtypes.apply(is_sparse).any():
  if is_sparse(pd_dtype):
  if is_sparse(pd_dtype) or not is_extension_array_dty

Loading paired dataset
Loading paired dataset


  if not hasattr(array, "sparse") and array.dtypes.apply(is_sparse).any():
  if is_sparse(pd_dtype):
  if is_sparse(pd_dtype) or not is_extension_array_dtype(pd_dtype):
  if not hasattr(array, "sparse") and array.dtypes.apply(is_sparse).any():
  if is_sparse(pd_dtype):
  if is_sparse(pd_dtype) or not is_extension_array_dtype(pd_dtype):
  if not hasattr(array, "sparse") and array.dtypes.apply(is_sparse).any():
  if is_sparse(pd_dtype):
  if is_sparse(pd_dtype) or not is_extension_array_dtype(pd_dtype):
  if not hasattr(array, "sparse") and array.dtypes.apply(is_sparse).any():
  if is_sparse(pd_dtype):
  if is_sparse(pd_dtype) or not is_extension_array_dtype(pd_dtype):
  if not hasattr(array, "sparse") and array.dtypes.apply(is_sparse).any():
  if is_sparse(pd_dtype):
  if is_sparse(pd_dtype) or not is_extension_array_dtype(pd_dtype):
  if not hasattr(array, "sparse") and array.dtypes.apply(is_sparse).any():
  if is_sparse(pd_dtype):
  if is_sparse(pd_dtype) or not is_extension_array_dty

Fold 1
torch.Size([32, 20])
torch.Size([32, 20])
torch.Size([32, 20])
torch.Size([32, 20])
torch.Size([32, 20])
torch.Size([32, 20])
torch.Size([32, 20])
torch.Size([32, 20])
torch.Size([32, 20])
torch.Size([32, 20])
torch.Size([32, 20])
torch.Size([32, 20])
torch.Size([32, 20])
torch.Size([32, 20])
torch.Size([32, 20])
torch.Size([32, 20])
torch.Size([32, 20])
torch.Size([32, 20])
torch.Size([32, 20])
torch.Size([32, 20])
torch.Size([32, 20])
torch.Size([32, 20])
torch.Size([32, 20])
torch.Size([32, 20])
torch.Size([32, 20])
torch.Size([32, 20])
torch.Size([32, 20])
torch.Size([32, 20])
torch.Size([32, 20])
torch.Size([32, 20])
torch.Size([32, 20])
torch.Size([32, 20])
torch.Size([32, 20])
torch.Size([32, 20])
torch.Size([32, 20])
torch.Size([32, 20])
torch.Size([32, 20])
torch.Size([32, 20])
torch.Size([32, 20])
torch.Size([32, 20])
torch.Size([32, 20])
torch.Size([32, 20])
torch.Size([32, 20])
torch.Size([32, 20])
torch.Size([32, 20])
torch.Size([32, 20])
torch.Size([32, 20])
torch.

Unnamed: 0,Epoch 1,Epoch 2,Epoch 3,Epoch 4,Epoch 5,Epoch 6,Epoch 7,Epoch 8,Epoch 9,Epoch 10,Average Loss
Fold 1,1.008084,1.007875,1.008509,1.00853,1.00439,1.006975,1.002489,1.004539,1.003448,1.005935,1.006077
Fold 2,0.9986,0.997994,0.998802,0.995365,0.997784,0.997005,1.000262,0.996126,0.997623,0.9971,0.997666
Fold 3,1.006599,1.003246,1.00271,1.003802,1.003206,1.006387,1.004979,1.003515,1.002321,1.003857,1.004062
Fold 4,0.995712,0.996517,0.994195,0.998349,0.993626,0.995288,0.996418,0.994674,0.997599,0.994885,0.995726
Fold 5,1.004244,1.001482,0.999295,0.997376,1.002717,1.004984,1.003289,1.001962,1.001731,1.000728,1.001781
Fold 6,1.00146,1.003643,1.001713,1.00065,1.00102,1.000723,1.000821,0.999005,1.003066,1.003363,1.001546
Fold 7,1.002718,0.998814,1.001231,1.000253,1.000948,0.998363,0.999185,1.000908,1.001293,1.002558,1.000627
Fold 8,1.001451,0.999735,1.000564,1.00019,1.001789,1.00261,1.002564,1.003178,1.002534,1.002208,1.001682
Fold 9,0.99806,0.999288,0.998654,0.996985,0.997403,0.998446,0.994599,0.995597,0.996237,0.99811,0.997338
Fold 10,0.997753,0.997379,0.998382,0.997545,0.996901,0.997412,0.997639,0.997597,0.999721,0.994792,0.997512
