In [1]:
import argparse
import os
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from models_mel import MultimodalVAE
from src.dataset import generate_datasets
from src.functions import Log
from src.config import config as default_config
from sklearn.preprocessing import StandardScaler
import sys 
from models_mel import MultimodalVAE
from sklearn.model_selection import KFold
from torch.utils.data import Subset

# Fonction d'entraînement du modèle 

In [7]:
import pandas as pd

def train_model_with_cross_validation(datasets, model, optimizer, criterion, num_epochs=10, n_splits=10):
    from torch.utils.data import Subset
    from torch.utils.data import DataLoader
    from sklearn.model_selection import KFold
    import torch
    
    kf = KFold(n_splits=n_splits) 
    all_fold_losses = []

    for fold, (train_index, test_index) in enumerate(kf.split(datasets[0])):
        print(f'Fold {fold+1}')
        fold_losses = []

        train_datasets = [Subset(dataset, train_index) for dataset in datasets]
        test_datasets = [Subset(dataset, test_index) for dataset in datasets]

        train_loaders = [DataLoader(dataset, batch_size=32, shuffle=True) for dataset in train_datasets]
        test_loaders = [DataLoader(dataset, batch_size=32, shuffle=False) for dataset in test_datasets]

        epoch_losses_per_fold = []
        for epoch in range(num_epochs):
            model.train()
            epoch_losses = []
            for (x1, _), (x2, _), (y, _) in zip(*train_loaders):
                if x1.size(0) != 32 or x2.size(0) != 32 or y.size(0) != 32:
                    continue
                outputs = model(x1, x2)
                loss = criterion(outputs, y)
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                epoch_losses.append(loss.item()) 

            avg_loss = sum(epoch_losses) / len(epoch_losses)  
            fold_losses.append(avg_loss)  

        all_fold_losses.append(fold_losses)  

    df = pd.DataFrame(all_fold_losses, columns=[f'Epoch {i+1}' for i in range(num_epochs)])

    df['Average Loss'] = df.mean(axis=1)
    df.index = [f'Fold {i+1}' for i in range(n_splits)]

    return df



## Générer modalité 3 à partir des modalités 1 et 2 


In [8]:
train_datasets = generate_datasets(suffix='5_diff', type='paired', train=True, test=False)
test_datasets = generate_datasets(suffix='5_diff', type='paired', train=False, test=True)


train_loaders = [DataLoader(dataset, batch_size=32, shuffle=True) for dataset in train_datasets]
test_loaders = [DataLoader(dataset, batch_size=32, shuffle=False) for dataset in test_datasets]


n_inputs1 = train_datasets[0][0][0].shape[0]  # La taille du vecteur de caractéristiques pour la modalité 1
n_inputs2 = train_datasets[1][0][0].shape[0]  # La taille du vecteur de caractéristiques pour la modalité 2
n_outputs = train_datasets[2][0][0].shape[0]  # La taille du vecteur de caractéristiques pour la modalité 3


datasets = generate_datasets(suffix='5_diff', type='paired', train=True, test=False)
datasets = [list(dataset) for dataset in datasets]

latent_dims = 10
n_hiddens = 256

model = MultimodalVAE(n_inputs1=n_inputs1, n_inputs2=n_inputs2, latent_dims=latent_dims, n_hiddens=n_hiddens, n_outputs=n_outputs)
optimizer = torch.optim.Adam(model.parameters())

criterion = nn.MSELoss()
train_model_with_cross_validation(datasets, model, optimizer, criterion, num_epochs=10, n_splits=10)

Loading paired dataset
Loading paired dataset


  if not hasattr(array, "sparse") and array.dtypes.apply(is_sparse).any():
  if is_sparse(pd_dtype):
  if is_sparse(pd_dtype) or not is_extension_array_dtype(pd_dtype):
  if not hasattr(array, "sparse") and array.dtypes.apply(is_sparse).any():
  if is_sparse(pd_dtype):
  if is_sparse(pd_dtype) or not is_extension_array_dtype(pd_dtype):
  if not hasattr(array, "sparse") and array.dtypes.apply(is_sparse).any():
  if is_sparse(pd_dtype):
  if is_sparse(pd_dtype) or not is_extension_array_dtype(pd_dtype):
  if not hasattr(array, "sparse") and array.dtypes.apply(is_sparse).any():
  if is_sparse(pd_dtype):
  if is_sparse(pd_dtype) or not is_extension_array_dtype(pd_dtype):
  if not hasattr(array, "sparse") and array.dtypes.apply(is_sparse).any():
  if is_sparse(pd_dtype):
  if is_sparse(pd_dtype) or not is_extension_array_dtype(pd_dtype):
  if not hasattr(array, "sparse") and array.dtypes.apply(is_sparse).any():
  if is_sparse(pd_dtype):
  if is_sparse(pd_dtype) or not is_extension_array_dty

Loading paired dataset
Fold 1
torch.Size([32, 20])
torch.Size([32, 20])
torch.Size([32, 20])
torch.Size([32, 20])
torch.Size([32, 20])
torch.Size([32, 20])
torch.Size([32, 20])
torch.Size([32, 20])
torch.Size([32, 20])
torch.Size([32, 20])
torch.Size([32, 20])
torch.Size([32, 20])
torch.Size([32, 20])
torch.Size([32, 20])
torch.Size([32, 20])
torch.Size([32, 20])
torch.Size([32, 20])
torch.Size([32, 20])
torch.Size([32, 20])
torch.Size([32, 20])
torch.Size([32, 20])
torch.Size([32, 20])
torch.Size([32, 20])
torch.Size([32, 20])
torch.Size([32, 20])
torch.Size([32, 20])
torch.Size([32, 20])
torch.Size([32, 20])


  if not hasattr(array, "sparse") and array.dtypes.apply(is_sparse).any():
  if is_sparse(pd_dtype):
  if is_sparse(pd_dtype) or not is_extension_array_dtype(pd_dtype):
  if not hasattr(array, "sparse") and array.dtypes.apply(is_sparse).any():
  if is_sparse(pd_dtype):
  if is_sparse(pd_dtype) or not is_extension_array_dtype(pd_dtype):
  if not hasattr(array, "sparse") and array.dtypes.apply(is_sparse).any():
  if is_sparse(pd_dtype):
  if is_sparse(pd_dtype) or not is_extension_array_dtype(pd_dtype):


torch.Size([32, 20])
torch.Size([32, 20])
torch.Size([32, 20])
torch.Size([32, 20])
torch.Size([32, 20])
torch.Size([32, 20])
torch.Size([32, 20])
torch.Size([32, 20])
torch.Size([32, 20])
torch.Size([32, 20])
torch.Size([32, 20])
torch.Size([32, 20])
torch.Size([32, 20])
torch.Size([32, 20])
torch.Size([32, 20])
torch.Size([32, 20])
torch.Size([32, 20])
torch.Size([32, 20])
torch.Size([32, 20])
torch.Size([32, 20])
torch.Size([32, 20])
torch.Size([32, 20])
torch.Size([32, 20])
torch.Size([32, 20])
torch.Size([32, 20])
torch.Size([32, 20])
torch.Size([32, 20])
torch.Size([32, 20])
torch.Size([32, 20])
torch.Size([32, 20])
torch.Size([32, 20])
torch.Size([32, 20])
torch.Size([32, 20])
torch.Size([32, 20])
torch.Size([32, 20])
torch.Size([32, 20])
torch.Size([32, 20])
torch.Size([32, 20])
torch.Size([32, 20])
torch.Size([32, 20])
torch.Size([32, 20])
torch.Size([32, 20])
torch.Size([32, 20])
torch.Size([32, 20])
torch.Size([32, 20])
torch.Size([32, 20])
torch.Size([32, 20])
torch.Size([3

Unnamed: 0,Epoch 1,Epoch 2,Epoch 3,Epoch 4,Epoch 5,Epoch 6,Epoch 7,Epoch 8,Epoch 9,Epoch 10,Average Loss
Fold 1,1.005586,1.006708,1.005561,1.003758,1.009859,1.009421,1.005753,1.005727,1.004389,1.003259,1.006002
Fold 2,0.996793,0.998032,0.999938,0.998213,0.995895,0.997692,0.993966,0.994266,0.998354,0.994516,0.996767
Fold 3,1.006782,1.00726,1.003765,1.005733,1.001839,1.00446,1.00305,1.005137,1.00538,1.003214,1.004662
Fold 4,0.996721,0.997209,0.994675,0.997946,0.998129,0.994099,0.99706,0.99609,0.996266,0.995094,0.996329
Fold 5,0.999896,1.00029,1.003204,1.000383,1.000587,1.001495,1.002161,1.005601,1.000986,1.00194,1.001654
Fold 6,1.00159,1.000496,0.998171,1.003997,1.001756,1.003939,1.00146,1.001849,1.003667,1.004637,1.002156
Fold 7,1.000958,1.000187,1.001925,1.003968,1.002036,1.000417,0.999729,1.001722,0.998954,1.001705,1.00116
Fold 8,1.004017,1.002201,1.000098,1.000852,1.001633,1.00442,1.003026,1.003899,0.998751,1.002852,1.002175
Fold 9,0.998049,0.996637,0.996474,0.999592,0.997171,0.998183,0.996112,0.997452,0.995975,0.99815,0.99738
Fold 10,0.994772,0.998158,0.994873,0.998605,0.999543,0.997391,0.998363,0.998423,0.994189,0.996962,0.997128
