In [None]:
import torch
import torchaudio
import torchaudio.transforms as T
import torchvision.transforms as transforms
from torchvision.transforms.functional import InterpolationMode
from torch.utils.data import Dataset, DataLoader
from torchsummary import summary
from torchvision.transforms import v2
import torch.nn as nn
import torch.optim as optim
from torchvision.transforms import ToTensor
from torch.utils.data import random_split
from torcheval.metrics import MulticlassAccuracy, MulticlassF1Score
from torch.optim.lr_scheduler import CosineAnnealingLR, ExponentialLR,CosineAnnealingWarmRestarts
import torch
import torch.nn as nn
import torchaudio
import torchaudio.functional as F
import random
from torchaudio_augmentations import (
    Compose,
    LowPassFilter,
    HighLowPass,
    Noise,
    PitchShift,
    RandomApply,
)
import numpy as np
from tqdm import tqdm
import os
import pandas as pd
import timm
import pickle
from sklearn.model_selection import StratifiedKFold


In [None]:

print(torch.__version__)
print(torchaudio.__version__)
device = "cuda" if torch.cuda.is_available() else "cpu"
#######################################################
#Audio hyper-param
# ESPEC = 'mfcc'
ESPEC = 'mel'
SAMPLE_RATE = 16000
MAX_DURATION = 60 # 20 seconds
N_FFT = 1024
PRETRAINED = True
N = 30

if N_FFT == 2048:
    TIME_MASK_PARAM =10
    FREQ_MASK_PARAM = 5
    HOP_LENGTH = 1024
else:
    TIME_MASK_PARAM = 10
    FREQ_MASK_PARAM = 5
    HOP_LENGTH = 512

        
N_MELS = 224 #128
N_MFCC = N_MELS
N_LFCC = N_MELS
RESIZE = False
IMG_SIZE = 224 #384, #300 , 240
NUM_CLASSES = 3
SAMPLE_RATE = 16000

num_samples = MAX_DURATION * SAMPLE_RATE
exp_num = '43_mel_mix_up_cv'

MODEL_NAME = 'convnextv2_femto.fcmae_ft_in1k' # tf_efficientnetv2_b0.in1k,tf_efficientnetv2_b3.in21k_ft_in1k (240x240),tf_efficientnetv2_s.in21k_ft_in1k (300x300),
#convnextv2_nano.fcmae_ft_in22k_in1k, convnextv2_femto.fcmae_ft_in1k, convnextv2_nano.fcmae_ft_in22k_in1k_384

########################################################
#Training hyper-param
batch_size = 32
learning_rate = 1E-4
scheduler_learning = True
weight_decay = 1E-2

num_epochs = 50
early_stopping_patience = 10
label_smoothing = 0.1
STEP_TO_DECAY = 1
GAMMA_TO_DECAY = 0.95
WEIGHT = True
f_weighted = 1.5
AUG= True

HIDDEN_UNITS=[256,128, 64]
DROPOUT_RATE = 0.5
SEED = 42
SEED_AUG = 1234
exp_name = f"CNNCustom_exp_{exp_num}_{MAX_DURATION}s_{ESPEC}_N_FFT_{N_FFT}_MODEL_{MODEL_NAME}_AUG_{AUG}_WEIGHT_{WEIGHT}_NMELS_{N_MELS}"
# ########################################################
torch.manual_seed(SEED_AUG)
# Establecer la semilla para Numpy
np.random.seed(SEED_AUG)
# Establecer la semilla para Python random
random.seed(SEED_AUG)
aug_transforms = [
                    RandomApply([Noise(min_snr=0.05, max_snr=0.1),],p=0.5,),
                    RandomApply([PitchShift(sample_rate=SAMPLE_RATE, n_samples=num_samples, pitch_shift_max=3, pitch_shift_min=-3),],p=0.3,),
                    #RandomApply([HighLowPass(sample_rate=SAMPLE_RATE) ],p=0.3,) 
                ]
###################################################################
metric_acc = MulticlassAccuracy(device=device)
metric_f1 = MulticlassF1Score(device=device, num_classes=NUM_CLASSES, average='macro')


# exp_name = f"from_scratch_exp_{exp_num}"
params = {'epochs': num_epochs, 'lr': learning_rate, 'weight_decay': weight_decay, 
          'batch_size': batch_size, 'early_stopping_patience': early_stopping_patience,
          'label_smoothing': label_smoothing, 'num_classes': NUM_CLASSES, 'n_fft': N_FFT,
          'hop_length': HOP_LENGTH,'n_mels': N_MELS, 'max_duration': MAX_DURATION, 
          'exp_name': exp_name, 'exp_num': exp_num, 'n_layers': N, 'augmented': AUG, 'hidden_units': HIDDEN_UNITS, 
          'audios_max_duration': MAX_DURATION, 'espectogram':ESPEC, 'Freq_Mask': FREQ_MASK_PARAM, 
          'Time_Mask': TIME_MASK_PARAM, 'scheduler': scheduler_learning, 'dropout_rate': DROPOUT_RATE, 'model_name': MODEL_NAME
          , 'weighted': WEIGHT, 'img_size': IMG_SIZE, 'f_weighted': f_weighted}

class CustomImageDataset(Dataset):
    def __init__(self, annotations_file, transform=None, target_transform=None):
        self.audio_df = annotations_file
        self.transform = transform
        self.target_transform = target_transform

    def __len__(self):
        return len(self.audio_df)

    def __getitem__(self, idx):
        audio = torchaudio.load(self.audio_df.iloc[idx, 0])
        label = self.audio_df.iloc[idx, 1]
        if self.transform:
            audio = self.transform(audio)
        if self.target_transform:
            label = self.target_transform(label)
        return audio, label


def preprocess_mfcc_aug(audio):

     # Load the audio file
    WAVEFORM, SAMPLE_RATE = audio
    c_transform = Compose(aug_transforms)
    transformed_audio = c_transform(WAVEFORM)
    padding_needed = (SAMPLE_RATE * MAX_DURATION) - transformed_audio.shape[1]
    
    padding = max(padding_needed, 0)
    waveform_padding = torch.nn.functional.pad(transformed_audio, (0, padding))

    mfcc_transform = T.MFCC(
        sample_rate=SAMPLE_RATE,
        n_mfcc=N_MFCC,
        log_mels=False,
        melkwargs={
        'n_fft':N_FFT,
        'win_length':None,
        'hop_length':HOP_LENGTH,
        'center':True,
        'pad_mode':"reflect",
        'power':2.0,
        'norm':"slaney",
        'n_mels':N_MELS,
        'mel_scale':"htk" 
        })

        
    mfcc_spectrogram = mfcc_transform(waveform_padding)
    
    time_masking = T.TimeMasking(time_mask_param=TIME_MASK_PARAM)
    freq_masking = T.FrequencyMasking(freq_mask_param=FREQ_MASK_PARAM)

    time_masked = time_masking(mfcc_spectrogram)
    freq_masked = freq_masking(time_masked)
    
    # log_mel_spectrogram= torchaudio.transforms.AmplitudeToDB(top_db=80)(melspec)
    log_mel_mfcc_spectrogram= torchaudio.transforms.AmplitudeToDB()(freq_masked)
    if log_mel_mfcc_spectrogram.max() - log_mel_mfcc_spectrogram.min() != 0:
        mfcc_spectrogram_norm = (log_mel_mfcc_spectrogram - log_mel_mfcc_spectrogram.min()) / (log_mel_mfcc_spectrogram.max() - log_mel_mfcc_spectrogram.min())
    else:
        mfcc_spectrogram_norm = (log_mel_mfcc_spectrogram - log_mel_mfcc_spectrogram.min())  
        
    mfcc_spectrogram_norm =mfcc_spectrogram_norm[0]*255
    # log_mel_spectrogram_norm = torch.unsqueeze(log_mel_spectrogram_norm, 0)
    mfcc_spectrogram_norm_rgb = mfcc_spectrogram_norm.repeat(3, 1, 1)  # Repite el canal en las dimensiones de los canales
    resize_image = transforms.Resize((IMG_SIZE, IMG_SIZE), interpolation=InterpolationMode.BICUBIC, max_size=None)
    mfcc_spectrogram_norm_rgb_resize = resize_image(mfcc_spectrogram_norm_rgb)/255
    mfcc_normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])
    mfcc_imagenet = mfcc_normalize(mfcc_spectrogram_norm_rgb_resize)
    return mfcc_imagenet

def preprocess_mfcc(audio):
    # Load the audio file
    WAVEFORM, SAMPLE_RATE = audio
    padding_needed = (SAMPLE_RATE * MAX_DURATION) - WAVEFORM.shape[1]
    padding = max(padding_needed, 0)
    waveform_padding = torch.nn.functional.pad(WAVEFORM, (0, padding))

    mfcc_transform = T.MFCC(
        sample_rate=SAMPLE_RATE,
        n_mfcc=N_MFCC,
        log_mels=False,
        melkwargs={
        'n_fft':N_FFT,
        'win_length':None,
        'hop_length':HOP_LENGTH,
        'center':True,
        'pad_mode':"reflect",
        'power':2.0,
        'norm':"slaney",
        'n_mels':N_MELS,
        'mel_scale':"htk" 
        })

        
    mfcc_spectrogram = mfcc_transform(waveform_padding)
    # log_mel_spectrogram= torchaudio.transforms.AmplitudeToDB(top_db=80)(melspec)
    log_mel_mfcc_spectrogram= torchaudio.transforms.AmplitudeToDB()(mfcc_spectrogram)
    if log_mel_mfcc_spectrogram.max() - log_mel_mfcc_spectrogram.min() != 0:
        mfcc_spectrogram_norm = (log_mel_mfcc_spectrogram - log_mel_mfcc_spectrogram.min()) / (log_mel_mfcc_spectrogram.max() - log_mel_mfcc_spectrogram.min())
    else:
        mfcc_spectrogram_norm = (log_mel_mfcc_spectrogram - log_mel_mfcc_spectrogram.min())  
        
    mfcc_spectrogram_norm =mfcc_spectrogram_norm[0]*255
    # log_mel_spectrogram_norm = torch.unsqueeze(log_mel_spectrogram_norm, 0)
    mfcc_spectrogram_norm_rgb = mfcc_spectrogram_norm.repeat(3, 1, 1)  # Repite el canal en las dimensiones de los canales
    resize_image = transforms.Resize((IMG_SIZE, IMG_SIZE), interpolation=InterpolationMode.BICUBIC, max_size=None)
    mfcc_spectrogram_norm_rgb_resize = resize_image(mfcc_spectrogram_norm_rgb)/255
    mfcc_normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])
    mfcc_imagenet = mfcc_normalize(mfcc_spectrogram_norm_rgb_resize)
    return mfcc_imagenet
    
    

def preprocess_data_train_aug(audio):
    # Load the audio file
    WAVEFORM, SAMPLE_RATE = audio
    
    c_transform = Compose(aug_transforms)
    transformed_audio = c_transform(WAVEFORM)
    
    padding_needed = (SAMPLE_RATE * MAX_DURATION) - transformed_audio.shape[1]
    padding = max(padding_needed, 0)
    waveform_padding = torch.nn.functional.pad(transformed_audio, (0, padding))
    
    n_fft = N_FFT
    win_length = None
    hop_length = HOP_LENGTH
    n_mels = N_MELS
    
    mel_spectrogram = T.MelSpectrogram(
        sample_rate=SAMPLE_RATE,
        n_fft=n_fft,
        win_length=win_length,
        hop_length=hop_length,
        center=True,
        pad_mode="reflect",
        power=2.0,
        norm="slaney",
        n_mels=n_mels,
        mel_scale="htk",
    )
    
    melspec = mel_spectrogram(waveform_padding)
    #T.TimeStretch()

    time_masking = T.TimeMasking(time_mask_param=TIME_MASK_PARAM)
    freq_masking = T.FrequencyMasking(freq_mask_param=FREQ_MASK_PARAM)
    

    time_masked = time_masking(melspec)
    freq_masked = freq_masking(time_masked)
        
    # log_mel_spectrogram= torchaudio.transforms.AmplitudeToDB(top_db=80)(freq_masked)
    log_mel_spectrogram= torchaudio.transforms.AmplitudeToDB(top_db=80)(freq_masked)
    if log_mel_spectrogram.max() - log_mel_spectrogram.min() != 0:
        log_mel_spectrogram_norm = (log_mel_spectrogram - log_mel_spectrogram.min()) / (log_mel_spectrogram.max() - log_mel_spectrogram.min())
    else:
        log_mel_spectrogram_norm = (log_mel_spectrogram - log_mel_spectrogram.min())  
    log_mel_spectrogram_norm =log_mel_spectrogram_norm[0]*255
    # log_mel_spectrogram_norm = torch.unsqueeze(log_mel_spectrogram_norm, 0)
    log_mel_spectrogram_norm_rgb = log_mel_spectrogram_norm.repeat(3, 1, 1)  # Repite el canal en las dimensiones de los canales
    if  RESIZE: 
        resize_image = transforms.Resize((IMG_SIZE, IMG_SIZE), interpolation=InterpolationMode.BICUBIC, max_size=None)

        log_mel_spectrogram_norm_rgb = resize_image(log_mel_spectrogram_norm_rgb)/255
    else:
        log_mel_spectrogram_norm_rgb = log_mel_spectrogram_norm_rgb/255
    log_mel_normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])
    log_mel_imagenet = log_mel_normalize(log_mel_spectrogram_norm_rgb)
    return log_mel_imagenet

def preprocess_data(audio):
     # Load the audio file
  
    WAVEFORM, SAMPLE_RATE = audio
    padding_needed = (SAMPLE_RATE * MAX_DURATION) - WAVEFORM.shape[1]
    padding = max(padding_needed, 0)
    waveform_padding = torch.nn.functional.pad(WAVEFORM, (0, padding))
    
    n_fft = N_FFT
    win_length = None
    hop_length = HOP_LENGTH
    n_mels = N_MELS
    
    mel_spectrogram = T.MelSpectrogram(
        sample_rate=SAMPLE_RATE,
        n_fft=n_fft,
        win_length=win_length,
        hop_length=hop_length,
        center=True,
        pad_mode="reflect",
        power=2.0,
        norm="slaney",
        n_mels=n_mels,
        mel_scale="htk",
    )
    
    melspec = mel_spectrogram(waveform_padding)
    
    # log_mel_spectrogram= torchaudio.transforms.AmplitudeToDB(top_db=80)(melspec)
    log_mel_spectrogram= torchaudio.transforms.AmplitudeToDB(top_db=80)(melspec)
    if log_mel_spectrogram.max() - log_mel_spectrogram.min() != 0:
        log_mel_spectrogram_norm = (log_mel_spectrogram - log_mel_spectrogram.min()) / (log_mel_spectrogram.max() - log_mel_spectrogram.min())
    else:
        log_mel_spectrogram_norm = (log_mel_spectrogram - log_mel_spectrogram.min())  
    log_mel_spectrogram_norm =log_mel_spectrogram_norm[0]*255
    # log_mel_spectrogram_norm = torch.unsqueeze(log_mel_spectrogram_norm, 0)
    log_mel_spectrogram_norm_rgb = log_mel_spectrogram_norm.repeat(3, 1, 1)  # Repite el canal en las dimensiones de los canales
    resize_image = transforms.Resize((IMG_SIZE, IMG_SIZE), interpolation=InterpolationMode.BICUBIC, max_size=None)
    if  RESIZE: 
        log_mel_spectrogram_norm_rgb = resize_image(log_mel_spectrogram_norm_rgb)/255
    else:
        log_mel_spectrogram_norm_rgb = log_mel_spectrogram_norm_rgb/255
    log_mel_normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])
    log_mel_imagenet = log_mel_normalize(log_mel_spectrogram_norm_rgb)
    return log_mel_imagenet


class MLPModel(torch.nn.Module):
    def __init__(self, input_size, hidden_units, dropout_rate):
        super(MLPModel, self).__init__()
        self.hidden_units = hidden_units
        self.dropout_rate = dropout_rate
        self.layers = torch.nn.ModuleList()
        
        # Agregar la primera capa oculta con la entrada original
        self.layers.append(torch.nn.Linear(input_size, hidden_units[0]))
        self.layers.append(torch.nn.BatchNorm1d(hidden_units[0]))
        self.layers.append(torch.nn.ReLU())
        self.layers.append(torch.nn.Dropout(dropout_rate))
        
        # Agregar el resto de las capas ocultas
        for i in range(len(hidden_units) - 1):
            self.layers.append(torch.nn.Linear(hidden_units[i], hidden_units[i+1]))
            self.layers.append(torch.nn.BatchNorm1d(hidden_units[i+1]))
            self.layers.append(torch.nn.ReLU())
            self.layers.append(torch.nn.Dropout(dropout_rate))
        self.layers.append(torch.nn.Linear(hidden_units[-1], NUM_CLASSES))
    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        return x

class CustomConvNeXt(nn.Module):
    def __init__(self, N=0):
        super(CustomConvNeXt, self).__init__()
        
        # Cargar el modelo preentrenado
        self.pretrained_model  = timm.create_model(MODEL_NAME, pretrained=PRETRAINED, num_classes=0, global_pool='avg') #convnextv2_femto.fcmae_ft_in1k.fcmae_ft_in22k_in1k_384 
        #timm.create_model('convnextv2_nano.fcmae_ft_in22k_in1k', pretrained=True, num_classes=0)
        self.n_layers = N
        # Congelar todas las capas primero
        for name, param in self.pretrained_model.named_parameters():
           param.requires_grad = False

        # Descongelar las últimas 20 capas que no son BatchNormalization
        unfrozen_count = 0
        for name, param in reversed(list(self.pretrained_model.named_parameters())):
            if 'bn' not in name and unfrozen_count < self.n_layers:
                param.requires_grad = True
                unfrozen_count += 1
        
        
        #self.add_vit = self.pretrained_model.num_features
        self.additional_layer = MLPModel(self.pretrained_model.num_features, hidden_units=HIDDEN_UNITS, dropout_rate= DROPOUT_RATE)
        #self.additional_layer = torch.nn.Linear(self.pretrained_model.num_features, NUM_CLASSES)
    def forward(self, x):
        x = self.pretrained_model(x)
        #x = self.avgpool(x)
        #x = torch.flatten(x, 1)
        #x = self.fc(x)
        x = self.additional_layer(x)
        return x


def convert_labels_to_binary(labels):
    # Aplicar un umbral de 0.5
    binary_labels = torch.argmax(labels, dim=1)
    return binary_labels

def train_epoch(model,epoch,  train_loader, criterion, optimizer, scheduler=None):
    model.train() # Set the model to training mode
    running_loss_train = 0.0
    # Iterate over the training data (forward pass and backward pass)
    tepoch = tqdm(train_loader, desc=f"Training {epoch+1}/{num_epochs}")
    for inputs, labels in tepoch:
        inputs, labels = inputs.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(inputs) #feed forward
        loss = criterion(outputs, labels) #calculate the loss
        loss.backward()#backward pass to calculate the gradients
        optimizer.step()#update the weights
        running_loss_train += loss.item() * inputs.size(0) #accumulate the loss
        # metric.update() updates the metric state with new data
        
        binary_labels = convert_labels_to_binary(labels)
        metric_acc.update(outputs, binary_labels)
        metric_f1.update(outputs, binary_labels)
        tepoch.set_postfix(loss=loss.item(), acc=metric_acc.compute().item(), f1=metric_f1.compute().item())
    if scheduler and (epoch+1) % STEP_TO_DECAY == 0:
        scheduler.step()
    train_epoch_loss = running_loss_train / len(train_loader.dataset)
    train_acc = metric_acc.compute().item()
    train_f1 = metric_f1.compute().item()
    return train_epoch_loss, train_acc, train_f1


df_audio_train = pd.read_parquet('/home/gass/audio-sensitive-content-detection/data/dfTrainVideo.parquet')
df_audio_test = pd.read_parquet('/home/gass/audio-sensitive-content-detection/data/dfTestVideo.parquet')
print(f'Training shape: {df_audio_train.shape}, Testing shape: {df_audio_test.shape}')

# Función para entrenamiento
def train_model(fold, model, train_loader, val_loader, criterion, criterion_val,optimizer, scheduler=None, num_epochs=10, save_path=None, early_stopping_patience=3, min_delta = 0.001, val = 'val'):
    best_acc = 0.0
    best_f1 = 0.0
    patience = early_stopping_patience

    #for epoch in range(num_epochs):
    for epoch in range(params["epochs"]):
        # Call the function
        train_epoch_loss, train_acc,train_f1 =  train_epoch(model, epoch, train_loader, criterion, optimizer, scheduler)
        print(f"Epoch [{epoch+1}/{num_epochs}, Fold: {fold}, Loss: {train_epoch_loss:.4f},  Acc: {train_acc:.4f}, F1-Macro: {train_f1:.4f}")

        # metric.reset() reset metric states. It's typically called after the epoch completes.
        metric_acc.reset()
        metric_f1.reset()
        
        # evaluation step
        model.eval() # Set the model to evaluation mode
        running_loss_val= 0.0
     
        tepoch = tqdm(val_loader, desc=f"Val {epoch+1}/{num_epochs}")
        for inputs, labels in tepoch:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs) #feed forward
            loss = criterion_val(outputs, labels) #calculate the loss
            running_loss_val += loss.item() * inputs.size(0) #accumulate the loss

            # metric.update() updates the metric state with new data
            metric_acc.update(outputs, labels)
            metric_f1.update(outputs, labels)
            tepoch.set_postfix(val_loss=loss.item(), val_acc=metric_acc.compute().item(), val_f1=metric_f1.compute().item())

        val_epoch_loss = running_loss_val / len(val_loader.dataset)
        val_acc = metric_acc.compute().item()
        val_f1 = metric_f1.compute().item()
        print(f"Epoch [{epoch+1}/{num_epochs}], Fold: {fold}, Val_loss: {val_epoch_loss:.4f},  Val_acc: {val_acc:.4f}, Val_f1-Macro: {val_f1:.4f}")
          # metric.reset() reset metric states. It's typically called after the epoch completes.
        

        metric_acc.reset()
        metric_f1.reset()
        
        # Guardar el modelo si la pérdida es la mejor hasta ahora
        if (val_f1 - best_f1) > min_delta:
            print("Improved val F1 from {:.4f} to {:.4f}. Saving model...".format(best_f1, val_f1))
            torch.save({
                'epoch': epoch + 1,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'train_loss': train_epoch_loss,
                'val_loss': val_epoch_loss,
                'train_acc': train_acc,
                'val_acc': val_acc,
                'val_f1': val_f1,
                'train_f1': train_f1,
            }, save_path)
            best_acc = val_acc
            best_f1 = val_f1
            patience = early_stopping_patience
            
        else:
            patience -= 1
            if patience == 0:
                print("Early stopping.")
                break
    print('Finished Training')
    dict_model = torch.load(save_path)
    dict_results = {'epoch':dict_model['epoch'],'train_loss': dict_model['train_loss'], 'val_loss':dict_model['val_loss'], 'train_acc':dict_model['train_acc'],
                    'val_acc':dict_model['val_acc'], 'val_f1':dict_model['val_f1'], 'train_f1':dict_model['train_f1']}
   # model = model.load_state_dict(dict_model['model_state_dict'])
    return  dict_results
    
    
    
def train_model_for_testing( live, model, train_loader, test_loader, criterion, criterion_test,optimizer, scheduler=None, num_epochs=10, save_path=None, 
                early_stopping_patience=3, min_delta = 0.003, val = 'val'):
    best_acc = 0.0
    best_f1 = 0.0
    patience = early_stopping_patience

    #for epoch in range(num_epochs):
    for epoch in range(params["epochs"]):
        # Call the function
        train_epoch_loss, train_acc,train_f1 =  train_epoch(model, epoch, train_loader, criterion, optimizer, scheduler)
        print(f"Epoch [{epoch+1}/{num_epochs}, Loss: {train_epoch_loss:.4f},  Acc: {train_acc:.4f}, F1-Macro: {train_f1:.4f}")

        live.log_metric(f"train/loss", train_epoch_loss)
        live.log_metric(f"train/acc", train_acc)
        live.log_metric(f"train/f1", train_f1)

    
        # metric.reset() reset metric states. It's typically called after the epoch completes.
        metric_acc.reset()
        metric_f1.reset()
        
        # evaluation step
        model.eval() # Set the model to evaluation mode
        running_loss_test= 0.0
        tepoch = tqdm(test_loader, desc=f"Test {epoch+1}/{num_epochs}")
        for inputs, labels in tepoch:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs) #feed forward
            loss = criterion_test(outputs, labels) #calculate the loss
            running_loss_test += loss.item() * inputs.size(0) #accumulate the loss

            # metric.update() updates the metric state with new data
            metric_acc.update(outputs, labels)
            metric_f1.update(outputs, labels)
            tepoch.set_postfix(test_loss=loss.item(), test_acc=metric_acc.compute().item(), test_f1=metric_f1.compute().item())

        test_epoch_loss = running_loss_test / len(test_loader.dataset)
        test_acc = metric_acc.compute().item()
        test_f1 = metric_f1.compute().item()
        print(f"Epoch [{epoch+1}/{num_epochs}], test_loss: {test_epoch_loss:.4f},  test_acc: {test_acc:.4f}, test_f1-Macro: {test_f1:.4f}")
          # metric.reset() reset metric states. It's typically called after the epoch completes.
        
        live.log_metric(f"test/loss", test_epoch_loss)
        live.log_metric(f"test/acc", test_acc)
        live.log_metric(f"test/f1", test_f1)
        
        metric_acc.reset()
        metric_f1.reset()
        
        # Guardar el modelo si la pérdida es la mejor hasta ahora
        if (test_f1 - best_f1) > min_delta:
            print("Improved test F1 from {:.4f} to {:.4f}. Saving model...".format(best_f1, test_f1))
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'train_loss': train_epoch_loss,
                'test_loss': test_epoch_loss,
                'train_acc': train_acc,
                'test_acc': test_acc,
                'test_f1': test_f1,
                'train_f1': train_f1,
            }, save_path)
            best_acc = test_acc
            best_f1 = test_f1
            patience = early_stopping_patience
            
        else:
            patience -= 1
            if patience == 0:
                print("Early stopping.")
                break
        live.next_step() #finish epoch
    print('Finished Training')
    dict_model = torch.load(save_path)
    dict_results = {'epoch':dict_model['epoch'],'train_loss': dict_model['train_loss'], 'test_loss':dict_model['test_loss'], 'train_acc':dict_model['train_acc'],
                    'test_acc':dict_model['test_acc'], 'test_f1':dict_model['test_f1'], 'train_f1':dict_model['train_f1']}
   # model = model.load_state_dict(dict_model['model_state_dict'])
    return  dict_results
# save_path = f'/home/gass/audio-sensitive-content-detection/models/reprocess_data_3_chans_history_torch_baseline_convnext_femto_{IMG_SIZE}_custom_exp_{exp_num}.pth'
# save_dict_results = f'/home/gass/audio-sensitive-content-detection/models/reprocess_data_3_chans_history_torch_baseline_convnext_femto_{IMG_SIZE}_custom_exp_{exp_num}.pkl'

save_path = f'/home/gass/audio-sensitive-content-detection/models/{ESPEC}_preprocess_data_3_chans_history_torch_baseline_convnext_femto_{IMG_SIZE}_custom_exp_{exp_num}.pth'
save_dict_results = f'/home/gass/audio-sensitive-content-detection/models/{ESPEC}_preprocess_data_3_chans_history_torch_baseline_convnext_femto_{IMG_SIZE}_custom_exp_{exp_num}.pkl'

skf= StratifiedKFold(n_splits=5, random_state=SEED, shuffle=True)
history_list = []
val_acc = []
val_f1 = []

test_acc = []
test_f1 = []

from torch.utils.data import default_collate
mixup = v2.MixUp(alpha=0.2, num_classes=NUM_CLASSES)

def collate_fn(batch):
    return mixup(*default_collate(batch))

def cross_validation():

        for fold, (train_index, val_index) in enumerate(skf.split(df_video_train, df_video_train['video_label'])):
                
                # Crear DataFrames para conjuntos de entrenamiento y validación
                df_train_audios = df_video_train.iloc[train_index]
                df_val_audios = df_video_train.iloc[val_index]
                
                if ESPEC == 'mfcc':
                    if AUG:
                        print('MFCC ESPECTOGRAM EXPERIMENT AUGMENTED')
                        ds_train_audios = CustomImageDataset(df_train_audios, transform=preprocess_mfcc_aug)
                        ds_val_audios =  CustomImageDataset(df_val_audios, transform=preprocess_mfcc)
                    else:
                        print('MFCC ESPECTOGRAM EXPERIMENT NO AUGMENTED')
                        ds_train_audios = CustomImageDataset(df_train_audios, transform=preprocess_mfcc)
                        ds_val_audios =  CustomImageDataset(df_val_audios, transform=preprocess_mfcc)
                elif ESPEC == 'mel':
                    if AUG:
                        print('MEL ESPECTOGRAM EXPERIMENT AUGMENTED')
                        ds_train_audios = CustomImageDataset(df_train_audios, transform=preprocess_data_train_aug)
                        ds_val_audios =  CustomImageDataset(df_val_audios, transform=preprocess_data)
                    else:
                        print('MEL ESPECTOGRAM EXPERIMENT NO AUGMENTED')
                        ds_train_audios = CustomImageDataset(df_train_audios, transform=preprocess_data)
                        ds_val_audios =  CustomImageDataset(df_val_audios, transform=preprocess_data)

                
                dataloader_train = DataLoader(ds_train_audios, batch_size=batch_size, shuffle=True, num_workers=16, collate_fn=collate_fn)
                dataloader_val = DataLoader(ds_val_audios, batch_size=batch_size, shuffle=True, num_workers=16)
                
                
                # Crear una instancia del nuevo modelo
                model = CustomConvNeXt(N=N)
                # model = NewCustomCNN()
                # Obtener el número total de parámetros entrenables
                total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
                print(f'Params: {params}')
                print(f"Num. of trainable parameters: {total_params}")
                # Definir la función de pérdida y el optimizador
                optimizer = optim.AdamW(model.parameters(), lr=learning_rate, weight_decay= weight_decay)
                #optimizer = optim.SGD(model.parameters(), lr=learning_rate, weight_decay= weight_decay, momentum=0.9)
                if scheduler_learning:
                    #scheduler = CosineAnnealingLR(optimizer, T_max=num_epochs, eta_min= 1E-4)
                    scheduler = ExponentialLR(optimizer, gamma=GAMMA_TO_DECAY)
                    #scheduler = None#CosineAnnealingWarmRestarts(optimizer, T_0=5, T_mult=2, eta_min=0.000001)
                else:
                    scheduler = None

                model.to(device)
                label_counts = df_train_audios.value_counts('label')
                label_counts = label_counts.reindex(label_counts.index[::-1])
                # Cambiar las etiquetas de 0 y 1
                label_counts.index = ['Sensible' if label == 1 else 'Seguro' if label == 0 else label for label in label_counts.index]
                class_counts = [label_counts['Seguro'], label_counts['Sensible'], ...]  # Lista con el número de muestras de cada clase
                total_samples = sum(label_counts)
                weights = [total_samples / count for count in label_counts]

                # Multiplica los pesos por el factor de aumento
                weights_list = [weight * f_weighted for weight in weights]
                print(weights_list)
                if WEIGHT:
                    print('WEIGHTED LOSS')
                    weight_classes = torch.tensor(weights_list).float()
                    weight_classes = weight_classes.to(device)
                    criterion = torch.nn.CrossEntropyLoss(weight=weight_classes, label_smoothing=label_smoothing)
                    criterion_val = torch.nn.CrossEntropyLoss()
                else:
                    criterion = torch.nn.CrossEntropyLoss(label_smoothing=label_smoothing)
                    criterion_val = torch.nn.CrossEntropyLoss()
                    
                history = train_model(fold, model, dataloader_train, dataloader_val, criterion, criterion_val,optimizer, scheduler, num_epochs=num_epochs, 
                                    early_stopping_patience=early_stopping_patience,
                                    save_path=save_path)

                history_list.append(history)
                value_val = history['val_acc']
                value_val_f1 = history['val_f1']
                print(f'Fold {fold} finished, Val_accuracy: {value_val}, Val_f1: {value_val_f1}')
                val_acc.append(history['val_acc'])
                val_f1.append(history['val_f1'])

                # Guarda history_list en un archivo con Pickle
        print('Mean val_acc:', np.mean(val_acc), 'Mean val_f1:', np.mean(val_f1))
        with open(save_dict_results, 'wb') as f:
                pickle.dump(history_list, f)        
                
        print(history_list)    
            
        
def testing():
    with Live(report="notebook", exp_name=exp_name) as live:
        df_video_train = pd.read_parquet('/home/gass/audio-sensitive-content-detection/data/dfTrainVideo.parquet')
        df_video_test = pd.read_parquet('/home/gass/audio-sensitive-content-detection/data/dfTestVideo.parquet')
        live.log_params(params)
        
  
        # Crear DataFrames para conjuntos de entrenamiento y validación
        df_train_videos = df_video_train
        df_video_test = df_video_test
        
        # Obtener nombres de videos para los conjuntos de entrenamiento y validación
        train_video_names = df_train_videos['video_name'].tolist()
        val_video_names = df_video_test['video_name'].tolist()
        
        # Filtrar df_audio para obtener solo las filas correspondientes a los conjuntos de entrenamiento y validación
        df_train_audios = df_audio[df_audio['filename'].str.contains('|'.join(train_video_names))]
        df_test_audios = df_audio[df_audio['filename'].str.contains('|'.join(val_video_names))]
        if ESPEC == 'mfcc':
            if AUG:
                print('MFCC ESPECTOGRAM EXPERIMENT AUGMENTED')
                ds_train_audios = CustomImageDataset(df_train_audios, transform=preprocess_mfcc_aug)
                ds_test_audios =  CustomImageDataset(df_test_audios, transform=preprocess_mfcc)
            else:
                print('MFCC ESPECTOGRAM EXPERIMENT NO AUGMENTED')
                ds_train_audios = CustomImageDataset(df_train_audios, transform=preprocess_mfcc)
                ds_test_audios =  CustomImageDataset(df_test_audios, transform=preprocess_mfcc)
        elif ESPEC == 'mel':
            if AUG:
                print('MEL ESPECTOGRAM EXPERIMENT AUGMENTED')
                ds_train_audios = CustomImageDataset(df_train_audios, transform=preprocess_data_train_aug)
                ds_test_audios =  CustomImageDataset(df_test_audios, transform=preprocess_data)
            else:
                print('MEL ESPECTOGRAM EXPERIMENT NO AUGMENTED')
                ds_train_audios = CustomImageDataset(df_train_audios, transform=preprocess_data)
                ds_test_audios =  CustomImageDataset(df_test_audios, transform=preprocess_data)
        else:
            if AUG:
                print('MEL/MFCC/LFCC ESPECTOGRAM EXPERIMENT AUGMENTED')
                ds_train_audios = CustomImageDataset(df_train_audios, transform=preprocess_mel_mfcc_lfcc_aug)
                ds_test_audios =  CustomImageDataset(df_test_audios, transform=preprocess_mel_mfcc_lfcc)
            else:
                print('MEL/MFCC/LFCC ESPECTOGRAM EXPERIMENT NO AUGMENTED')
                ds_train_audios = CustomImageDataset(df_train_audios, transform=preprocess_mel_mfcc_lfcc)
                ds_test_audios =  CustomImageDataset(df_test_audios, transform=preprocess_mel_mfcc_lfcc)
                
                
        dataloader_train = DataLoader(ds_train_audios, batch_size=batch_size, shuffle=True, num_workers=16, collate_fn=collate_fn)
        dataloader_val = DataLoader(ds_test_audios, batch_size=batch_size, shuffle=True, num_workers=16)
        
        
        # Crear una instancia del nuevo modelo
        model = CustomConvNeXt(N=N)
        # model = NewCustomCNN()
        # Obtener el número total de parámetros entrenables
        total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

        print(f"Num. of trainable parameters: {total_params}")
        print(params)
        # Definir la función de pérdida y el optimizador
        optimizer = optim.AdamW(model.parameters(), lr=learning_rate, weight_decay= weight_decay)
        #optimizer = optim.SGD(model.parameters(), lr=learning_rate, weight_decay= weight_decay, momentum=0.9)
        if scheduler_learning:
            #scheduler = CosineAnnealingLR(optimizer, T_max=num_epochs, eta_min= 1E-4)
            scheduler = ExponentialLR(optimizer, gamma=GAMMA_TO_DECAY)
            #scheduler = None#CosineAnnealingWarmRestarts(optimizer, T_0=5, T_mult=2, eta_min=0.000001)
        else:
            scheduler = None

        model.to(device)
        label_counts = df_train_audios.value_counts('label')
        label_counts = label_counts.reindex(label_counts.index[::-1])
        # Cambiar las etiquetas de 0 y 1
        label_counts.index = ['Sensible' if label == 1 else 'Seguro' if label == 0 else label for label in label_counts.index]
        class_counts = [label_counts['Seguro'], label_counts['Sensible'], ...]  # Lista con el número de muestras de cada clase
        total_samples = sum(label_counts)
        weights = [total_samples / count for count in label_counts]

        # Multiplica los pesos por el factor de aumento
        weights_list = [weight * f_weighted for weight in weights]
        print(weights_list)
        if WEIGHT:
            print('WEIGHTED LOSS')
            weight_classes = torch.tensor(weights_list).float()
            weight_classes = weight_classes.to(device)
            criterion = torch.nn.CrossEntropyLoss(weight=weight_classes, label_smoothing=label_smoothing)
            criterion_test = torch.nn.CrossEntropyLoss()
        else:
            criterion = torch.nn.CrossEntropyLoss(label_smoothing=label_smoothing)
            criterion_test = torch.nn.CrossEntropyLoss()
            
        history = train_model_for_testing(live, model, dataloader_train, dataloader_val, criterion, criterion_test,optimizer, scheduler, num_epochs=num_epochs, 
                            early_stopping_patience=early_stopping_patience,
                            save_path=save_path)

        history_list.append(history)
        value_test = history['test_acc']
        value_test_f1 = history['test_f1']
        print(f'Finished, test_f1: {value_test_f1}, test_f1: {value_test_f1}')
        test_acc.append(history['test_acc'])
        test_f1.append(history['test_f1'])
        live.log_metric(f"test/mean_test_acc", np.mean(test_acc))
        live.log_metric(f"test/mean_test_f1", np.mean(test_f1))
        # Guarda history_list en un archivo con Pickle
        print('Mean val_acc:', np.mean(test_acc), 'Mean test_f1:', np.mean(test_f1))
        with open(save_dict_results, 'wb') as f:
                pickle.dump(history_list, f)        
                
        print(history_list)    
                
                
#testing()
cross_validation()