In [None]:
MODEL = "WavLM-Base-frozen_ECAPA-TDNN_Genuine_Random"

LEARNING_RATE = 0.001
MARGIN = 1
NORM = 2

BATCH_SIZE = 8
EPOCHS = 4
VALIDATION_RATE = 1

In [None]:
# Handle warnings
import warnings
warnings.filterwarnings("ignore")

# Logging to log file
import logging
logging.basicConfig(filename=f'../logs/{MODEL}.log',
                    level=logging.INFO,
                    format='%(asctime)s - %(message)s')
logger = logging.getLogger()

# MLFlow configuration
import mlflow
mlflow.set_tracking_uri("../mlruns")
logging.getLogger('mlflow.utils.requirements_utils').setLevel(logging.ERROR)

# Get device
from utils import get_device, load_genuine_dataset, ModelTrainer
device = get_device(logger)

# Imports
import torch.optim as optim
from torch.nn import TripletMarginLoss
from torch.utils.data import DataLoader
from dataloader import ValidationDataset, RandomTripletLossDataset, collate_triplet_wav_fn, collate_valid_fn
from models import WavLM_Base_frozen_ECAPA_TDNN

In [None]:
train_labels, dev_labels, test_labels = load_genuine_dataset()

audio_dataset = RandomTripletLossDataset(train_labels, frontend=lambda x: x, logger=logger)
audio_dataloader = DataLoader(audio_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_triplet_wav_fn)

validation_dataset = ValidationDataset(dev_labels, frontend=lambda x: x, logger=logger)
validation_dataloader = DataLoader(validation_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_valid_fn)

In [None]:
model = WavLM_Base_frozen_ECAPA_TDNN(device=device)
model.to(device)

# Optimizer and loss function
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
triplet_loss = TripletMarginLoss(margin=MARGIN, p=NORM)

In [None]:
# Run training
trainer = ModelTrainer(model, audio_dataloader, validation_dataloader, device, triplet_loss, optimizer, logger, MODEL, validation_rate=VALIDATION_RATE)
trainer.train_model(EPOCHS)