In [1]:
import os
os.chdir("../")

import warnings

warnings.filterwarnings("ignore")

In [2]:
import torch
from torch.utils.data import DataLoader
from torch.nn import TripletMarginLoss
import torch.optim as optim
from source.dataloader import RandomTripletLossDataset, collate_triplet_fn
from source.Model import SpeakerClassifier
from source.Frontend import MFCCTransform
from source.extraction_utils.get_label_files import get_label_files
from tqdm import tqdm


# Check if CUDA is available
if torch.cuda.is_available():
    print("CUDA is available! Training on GPU...")
    device = torch.device("cuda")
else:
    print("CUDA is not available. Training on CPU...")
    device = torch.device("cpu")

def train_model(epochs, dataloader, model, loss_function, optimizer, device):
    model.train()
    for epoch in range(epochs):
        running_loss = 0.0
        progress_bar = tqdm(dataloader, desc=f"Epoch {epoch+1}/{epochs}")
        for anchors, positives, negatives in progress_bar:
            anchors = anchors.to(device)
            positives = positives.to(device)
            negatives = negatives.to(device)

            optimizer.zero_grad()
            anchor_outputs = model(anchors)
            positive_outputs = model(positives)
            negative_outputs = model(negatives)
            
            loss = loss_function(anchor_outputs, positive_outputs, negative_outputs)
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item()
            progress_bar.set_postfix(loss=loss.item())
        
        avg_loss = running_loss / len(dataloader)
        print(f'Epoch {epoch+1}, Average Loss: {avg_loss:.4f}')

  from .autonotebook import tqdm as notebook_tqdm


CUDA is available! Training on GPU...


In [None]:
labels_text_path_list_train, labels_text_path_list_dev, labels_text_path_list_test, all_datasets_used = get_label_files(
    use_bsi_tts = False,
    use_bsi_vocoder = False,
    use_bsi_vc = False,
    use_bsi_genuine = True,
    use_bsi_ttsvctk = False,
    use_bsi_ttslj = False,
    use_bsi_ttsother = False,
    use_bsi_vocoderlj = False,
    use_wavefake = False,
    use_LibriSeVoc = False,
    use_lj = False,
    use_asv2019 = False,
)

In [None]:
audio_dataset = RandomTripletLossDataset(labels_text_path_list_train, frontend=MFCCTransform)
audio_dataloader = DataLoader(audio_dataset, batch_size=32, shuffle=True, collate_fn=collate_triplet_fn)
model = SpeakerClassifier(input_size=13, device=device)
model.to(device)

# Optimizer and loss function
optimizer = optim.Adam(model.parameters(), lr=0.001)
triplet_loss = TripletMarginLoss(margin=1.0, p=2)

In [None]:
# Run training
train_model(2, audio_dataloader, model, triplet_loss, optimizer, device)

In [None]:
torch.cuda.reset_peak_memory_stats()
torch.cuda.empty_cache()
torch.cuda.reset_max_memory_allocated()  # Resets the starting point for tracking