In [None]:
MODEL_NAME = "MFCC_ECAPA-TDNN_Genuine_Random"

import os
os.chdir("../")

import warnings
warnings.filterwarnings("ignore")

import logging
logging.basicConfig(filename=MODEL_NAME + '.log', 
                    level=logging.INFO, 
                    format='%(asctime)s - %(message)s')
logger = logging.getLogger()

In [None]:
import torch
from torch.utils.data import DataLoader
from torch.nn import TripletMarginLoss
import torch.optim as optim
from source.dataloader import RandomTripletLossDataset, collate_triplet_fn
from source.Model import SpeakerClassifier
from source.Frontend import MFCCTransform
from source.extraction_utils.get_label_files import get_label_files
from tqdm.notebook import tqdm
import time
import mlflow
import mlflow.pytorch


# Check if CUDA is available
if torch.cuda.is_available():
    logger.info("CUDA is available! Training on GPU...")
    device = torch.device("cuda")
else:
    logger.info("CUDA is not available. Training on CPU...")
    device = torch.device("cpu")

def train_model(epochs, dataloader, model, loss_function, optimizer, device):
    with mlflow.start_run(run_name=MODEL_NAME):
        mlflow.log_param("epochs", epochs)
        mlflow.log_param("batch_size", dataloader.batch_size)
        mlflow.log_param("model", model.__class__.__name__)
        mlflow.log_param("loss_function", loss_function.__class__.__name__)
        mlflow.log_param("optimizer", optimizer.__class__.__name__)
        model.train()
        total_start_time = time.time()  # Start timing the whole training process
        for epoch in range(epochs):
            epoch_start_time = time.time()  # Start timing the epoch
            running_loss = 0.0
            progress_bar = tqdm(dataloader, desc=f"Epoch {epoch+1}/{epochs}", leave=False)
            for anchors, positives, negatives in progress_bar:
                batch_start_time = time.time()  # Start timing the batch
                
                anchors = anchors.to(device)
                positives = positives.to(device)
                negatives = negatives.to(device)

                optimizer.zero_grad()
                
                # Time the forward passes
                forward_start_time = time.time()
                anchor_outputs = model(anchors)
                positive_outputs = model(positives)
                negative_outputs = model(negatives)
                forward_end_time = time.time()
                
                loss = loss_function(anchor_outputs, positive_outputs, negative_outputs)
                loss.backward()
                optimizer.step()
                
                running_loss += loss.item()
                progress_bar.set_postfix(loss=loss.item())
                
                batch_end_time = time.time()
                logger.info(f"Batch processed in {batch_end_time - batch_start_time:.4f} seconds.")
                logger.info(f"Forward pass took {forward_end_time - forward_start_time:.4f} seconds.")
            
            avg_loss = running_loss / len(dataloader)
            epoch_end_time = time.time()
            mlflow.log_metric("avg_loss", avg_loss, step=epoch)
            mlflow.log_metric("epoch_time", epoch_end_time - epoch_start_time, step=epoch)
            logger.info(f"Epoch {epoch+1} completed in {epoch_end_time - epoch_start_time:.4f} seconds. Average Loss: {avg_loss:.4f}")
            print(f'Epoch {epoch+1}, Average Loss: {avg_loss:.4f}', end='\r')
        total_end_time = time.time()
        mlflow.log_metric("total_training_time", total_end_time - total_start_time)
        logger.info(f"Training completed in {total_end_time - total_start_time:.4f} seconds.")
        print()

        # Log the model
        mlflow.pytorch.log_model(model, "model")

In [None]:
labels_text_path_list_train, labels_text_path_list_dev, labels_text_path_list_test, all_datasets_used = get_label_files(
    use_bsi_tts = False,
    use_bsi_vocoder = False,
    use_bsi_vc = False,
    use_bsi_genuine = True,
    use_bsi_ttsvctk = False,
    use_bsi_ttslj = False,
    use_bsi_ttsother = False,
    use_bsi_vocoderlj = False,
    use_wavefake = False,
    use_LibriSeVoc = False,
    use_lj = False,
    use_asv2019 = False,
)

In [None]:
audio_dataset = RandomTripletLossDataset(labels_text_path_list_train, frontend=MFCCTransform)
audio_dataloader = DataLoader(audio_dataset, batch_size=32, shuffle=True, collate_fn=collate_triplet_fn)
model = SpeakerClassifier(input_size=13, device=device)
model.to(device)

# Optimizer and loss function
optimizer = optim.Adam(model.parameters(), lr=0.001)
triplet_loss = TripletMarginLoss(margin=1.0, p=2)

In [None]:
# Run training
train_model(2, audio_dataloader, model, triplet_loss, optimizer, device)

In [None]:
torch.cuda.reset_peak_memory_stats()
torch.cuda.empty_cache()
torch.cuda.reset_max_memory_allocated()  # Resets the starting point for tracking