In [1]:
import torch
import torch.nn as nn
import torch.nn.utils.rnn as rnn_utils
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence
from tqdm.auto import tqdm
import pandas as pd
import numpy as np
import eval_metrics as em
import wandb

### Configurations

In [2]:
run = wandb.init(
    project = "teamlab_deepfake",
    config={
        "learning_rate": 0.01,
        "model": "LSTM",    #NOTE: set manually
        "dataset": "ASVSpoof19_LA_original",    #NOTE: set manually
        "feature": "PITCH_HNR_scaled",   #NOTE: set manually
        "epochs": 5,
        "batch_size": 32,
        "input_dim": 2,
        "bidirectional": False,
        "hidden_dim": 128,
        "n_layers":1,
        "dropout_rate": 0,
        "loss_function": "weighted_CE",     #NOTE: set manually
    },
    name = "test-run",     #NOTE: set manually
    notes = None,
    tags = ["multi_feature","HNR", "PITCH", "scaling"],
)

config = run.config

[34m[1mwandb[0m: Currently logged in as: [33mqianyue[0m ([33mqianyue-university-of-stuttgart[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


### Padding and Data Loader

In [3]:
PITCH_COLUMN = 'PITCH'
HNR_COLUMN = 'HNR'
AUDIO_ID_COLUMN = 'AUDIO_ID'
NAN_REPLACEMENT_VALUE = 0.0  
PADDING_VALUE = 0.0         
LABEL_BONAFIDE = 1
LABEL_SPOOF = 0


# --- Load Labels from Text File ---
def load_labels_from_file(label_file_path):
    """
    Returns:
        dict: A dictionary mapping AUDIO_ID (str) to numerical label (int).
              e.g., {'LA_T_9351820': 1, 'LA_T_1004644': 0}
    """
    labels_map = {}
    try:
        with open(label_file_path, 'r') as f:
            for line in f:
                parts = line.strip().split()
                audio_id = parts[1]
                label_str = parts[-1].lower() # Get the last part as label
                
                if label_str == 'bonafide':
                    labels_map[audio_id] = LABEL_BONAFIDE
                elif label_str == 'spoof':
                    labels_map[audio_id] = LABEL_SPOOF

    except FileNotFoundError:
        print(f"Error: Label file not found at {label_file_path}")
    return labels_map

# --- Process pitch sequences and match them with labels ---
class PitchHNRDataset(Dataset):
    def __init__(self, dataframe, pitch_col, hnr_col, audio_id_col, label_file_path, nan_replacement=NAN_REPLACEMENT_VALUE):
        
        labels_map = load_labels_from_file(label_file_path)
        
        self.processed_features = []
        self.labels = []
        
        print(f"Attempting to match {len(dataframe)} entries from DataFrame with labels from '{label_file_path}'...")
        found_count = 0
        # Iterate through the DataFrame and match with loaded labels
        for index, row in dataframe.iterrows():    
            audio_id = row[audio_id_col]
            if audio_id in labels_map:
                pitch_sequence_raw = row[pitch_col]
                processed_pitch = np.nan_to_num(pitch_sequence_raw, nan=nan_replacement)
                
                hnr_sequence_raw = row[hnr_col]
                processed_hnr = np.nan_to_num(hnr_sequence_raw, nan=nan_replacement)

                ### NOTE:need to pad the two sequences to the same length
                max_length = max(len(processed_pitch), len(processed_hnr))
                if len(processed_pitch) > len(processed_hnr):
                    padding = np.zeros(max_length - len(processed_hnr), dtype=processed_hnr.dtype)
                    processed_hnr = np.concatenate((processed_hnr, padding))
                else:
                    padding = np.zeros(max_length - len(processed_pitch), dtype=processed_pitch.dtype)
                    processed_pitch = np.concatenate((processed_pitch, padding))

                combined_features = np.stack((processed_pitch, processed_hnr), axis=-1) 
                self.processed_features.append(torch.tensor(combined_features, dtype=torch.float32))

                self.labels.append(labels_map[audio_id])
                
                found_count += 1
        
        if not self.processed_features:
            raise ValueError("No samples were successfully matched and processed. Check your AUDIO_IDs and label file.")

        self.labels = torch.tensor(self.labels, dtype=torch.long) # Assuming labels are integers for classification
        print(f"Successfully matched and processed {found_count} samples out of {len(dataframe)} DataFrame entries.")


    def __len__(self):
        """Returns the total number of matched samples in the dataset."""
        return len(self.processed_features)

    def __getitem__(self, idx):
        """
        Returns one sample from the dataset: a preprocessed pitch sequence and its label.
        """
        feature_sequence = self.processed_features[idx]
        label = self.labels[idx]
        return feature_sequence, label

# --- Custom Collate Function for Dynamic Padding  ---
def collate_fn(batch, padding_value=PADDING_VALUE):
    """
    Pads sequences within a batch to the same length.
    """
    sequences = [item[0] for item in batch]
    labels = [item[1] for item in batch]
    lengths = torch.tensor([len(seq) for seq in sequences], dtype=torch.long)
    padded_sequences = pad_sequence(sequences, batch_first=True, padding_value=padding_value)
    labels = torch.stack(labels)
    if padded_sequences.ndim == 2:
        padded_sequences = padded_sequences.unsqueeze(2)
    return padded_sequences, lengths, labels


In [4]:
### For remote server
train_features_path = '/home/users1/liqe/TeamLab_phonetics/prosody_features_train_scaled.parquet'
dev_features_path = '/home/users1/liqe/TeamLab_phonetics/prosody_features_dev_scaled.parquet'

df_train = pd.read_parquet(train_features_path, engine='pyarrow')
df_dev = pd.read_parquet(dev_features_path, engine='pyarrow')

### For local
# train_features_path = r'C:\Users\ivyap\Desktop\25SU\TEAMLAB\prosody_features\prosody_features_train.parquet'
# dev_features_path = r'C:\Users\ivyap\Desktop\25SU\TEAMLAB\prosody_features\prosody_features_dev.parquet'

# df_train = pd.read_pickle(train_features_path)
# df_dev = pd.read_pickle(dev_features_path)

In [5]:
### for remote
labels_train = '/home/users1/liqe/TeamLab_phonetics/ASVspoof2019_LA_cm_protocols/ASVspoof2019.LA.cm.train.trn.txt'
labels_dev = '/home/users1/liqe/TeamLab_phonetics/ASVspoof2019_LA_cm_protocols/ASVspoof2019.LA.cm.dev.trl.txt'

### For local 
# labels_train = r'C:\Users\ivyap\Desktop\25SU\TEAMLAB\LA\ASVspoof2019_LA_cm_protocols\ASVspoof2019.LA.cm.train.trn.txt'
# labels_dev = r'C:\Users\ivyap\Desktop\25SU\TEAMLAB\LA\ASVspoof2019_LA_cm_protocols\ASVspoof2019.LA.cm.dev.trl.txt'


# print(df.loc[0, PITCH_COLUMN])
# print(type(df.loc[0, PITCH_COLUMN]))

In [6]:
pitch_dataset_train = PitchHNRDataset(dataframe=df_train,
                                     pitch_col=PITCH_COLUMN,
                                     hnr_col=HNR_COLUMN,
                                     audio_id_col=AUDIO_ID_COLUMN,
                                     label_file_path=labels_train,
                                     nan_replacement=NAN_REPLACEMENT_VALUE)

train_dataloader = DataLoader(
    pitch_dataset_train, batch_size=config.batch_size, shuffle=True, collate_fn=collate_fn
)

pitch_dataset_dev = PitchHNRDataset(dataframe=df_dev,
                                     pitch_col=PITCH_COLUMN,
                                     hnr_col=HNR_COLUMN,
                                     audio_id_col=AUDIO_ID_COLUMN,
                                     label_file_path=labels_dev,
                                     nan_replacement=NAN_REPLACEMENT_VALUE)

dev_dataloader = DataLoader(
    pitch_dataset_dev, batch_size=config.batch_size, shuffle=True, collate_fn=collate_fn
)

## For inspection
for i, batch_data in enumerate(train_dataloader):
    # batch_data is a tuple: (padded_sequences, lengths, labels)
    batch_sequences, batch_lengths, batch_labels = batch_data
    print(f"\n--- Batch {i+1} ---")
    print(f"  Padded Sequences Shape: {batch_sequences.shape}")
    print(f"  Original Lengths (first 5): {batch_lengths[:5]}")
    print(f"  Labels (first 5): {batch_labels[:5]}")
    

    if i == 0: # Break after the first batch for inspection
        break


Attempting to match 25379 entries from DataFrame with labels from '/home/users1/liqe/TeamLab_phonetics/ASVspoof2019_LA_cm_protocols/ASVspoof2019.LA.cm.train.trn.txt'...
Successfully matched and processed 25379 samples out of 25379 DataFrame entries.
Attempting to match 24986 entries from DataFrame with labels from '/home/users1/liqe/TeamLab_phonetics/ASVspoof2019_LA_cm_protocols/ASVspoof2019.LA.cm.dev.trl.txt'...
Successfully matched and processed 24844 samples out of 24986 DataFrame entries.

--- Batch 1 ---
  Padded Sequences Shape: torch.Size([32, 792, 2])
  Original Lengths (first 5): tensor([652, 380, 193, 210, 232])
  Labels (first 5): tensor([0, 0, 0, 0, 1])


### Finding the weight (for weighted cross entropy)

is there different ways calculating weitghs?

In [7]:
labels = load_labels_from_file(labels_train)
total = len(labels)
count_bonafide = class_count = sum(1 for value in labels.values() if value == LABEL_BONAFIDE)
count_spoof =  total - count_bonafide
weight_bonafide = total / (count_bonafide * 2)
weight_spoof = total / (count_spoof * 2)

### LSTM classifier

In [8]:
class LSTM_Classifier(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, n_layers,
                 bidirectional, dropout):

        super().__init__()
        
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.output_dim = output_dim
        self.n_layers = n_layers
        self.bidirectional = bidirectional
        
        # 1. LSTM Layer
        self.lstm = nn.LSTM(input_dim, 
                           hidden_dim, 
                           num_layers=n_layers, 
                           bidirectional=bidirectional, 
                           dropout=dropout if n_layers > 1 else 0,
                           batch_first=True) # Input/output tensors are (batch, seq, feature)
        
        # 2. Fully Connected Layer (Linear Layer)
        self.fc = nn.Linear(hidden_dim * 2 if bidirectional else hidden_dim, output_dim)
        
        # 3. Dropout Layer (for regularization on the output of LSTM)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, sequences, sequence_lengths):
      
        # 1. Pack sequence
        packed_input = rnn_utils.pack_padded_sequence(sequences, sequence_lengths.cpu(), batch_first=True, enforce_sorted=False)
        
        # 2. Pass packed sequence through LSTM
        packed_output, (hidden, cell) = self.lstm(packed_input)
        
        # 3. Concatenate the final forward and backward hidden states (if bidirectional)
        if self.lstm.bidirectional:
            hidden = self.dropout(torch.cat((hidden[-2,:,:], hidden[-1,:,:]), dim=1))
        else:
            hidden = self.dropout(hidden[-1,:,:])
         
        # 4. Pass the processed hidden state through the fully connected layer
        output = self.fc(hidden)
        
        return output

### Initiate the model

In [9]:
print(f"CUDA available: {torch.cuda.is_available()}")
print(f"CUDA device count: {torch.cuda.device_count()}")
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

CUDA available: True
CUDA device count: 4


In [10]:
class_weights = torch.tensor([weight_bonafide, weight_spoof], dtype=torch.float32).to(DEVICE)

In [12]:
model = LSTM_Classifier(input_dim=config.input_dim, hidden_dim=config.hidden_dim, output_dim=2, n_layers=config.n_layers,
                 bidirectional=config.bidirectional, dropout=config.dropout_rate).to(DEVICE)
criterion = torch.nn.CrossEntropyLoss(reduction='mean', weight=class_weights)
optimizer = torch.optim.Adam(model.parameters(), lr=config.learning_rate)

### Evaluation

In [13]:
def evaluate_classifier(data_loader, model, criterion):

    model.eval()  # Set the model to evaluation mode (disables dropout, etc.)
    
    total_loss = 0.0
    correct_predictions = 0
    total_samples = 0

    scores_bonafide = []
    scores_spoof = []

    with torch.no_grad():  # Disable gradient calculations during evaluation
        for batch_sequences, batch_lengths, batch_labels in data_loader:
            
            batch_sequences = batch_sequences.to(DEVICE)
            batch_labels = batch_labels.to(DEVICE)

            # Forward pass: Get model outputs (logits)
            logits = model(batch_sequences, batch_lengths)
            
            # Calculate loss for the current batch
            loss = criterion(logits, batch_labels)
            total_loss += loss.item() * batch_sequences.size(0) # Accumulate loss, weighted by batch size

            # for EER
            probabilities = torch.softmax(logits, dim=1)
            
            for i in range(len(batch_labels)):
                current_label = batch_labels[i]
                current_score = probabilities[i]

                if current_label == LABEL_BONAFIDE:
                    scores_bonafide.append(current_score[LABEL_BONAFIDE].cpu())     ## numpy is cpu only, need to move tensor from gpu
                elif current_label == LABEL_SPOOF:
                    scores_spoof.append(current_score[LABEL_BONAFIDE].cpu())
            
            # Compare predictions with true labels
            total_samples += batch_labels.size(0) # Count number of samples in this batch

    average_loss = total_loss / total_samples if total_samples > 0 else 0.0

    scores_bonafide_np = np.array(scores_bonafide)    
    scores_spoof_np = np.array(scores_spoof)
    eer, threshold = em.compute_eer(scores_bonafide_np, scores_spoof_np)
    
    return average_loss, eer, threshold

### The training loop

In [14]:
def train_model(model, train_dataloader, dev_dataloader, criterion, optimizer, num_epochs, device,
                min_eer, best_model_filename):

    print(f"Training started on device: {device}")
    model.to(device) # Ensure model is on the correct device

    # Initial metric dictionary for the progress bar
    metric_dict = {'train_loss': 'N/A', 'val_loss': 'N/A', 'val_eer': 'N/A', 'val_threshold': 'N/A'}

    # Evaluate on validation set first to get a baseline
    print("Evaluating on validation set before training...")
    model.eval() # Set model to evaluation mode
    val_loss_initial, val_eer_initial, threshold_initial = evaluate_classifier(dev_dataloader, model, criterion)
    metric_dict.update({'val_loss': f'{val_loss_initial:.3f}', 'val_eer': f'{val_eer_initial*100:.2f}%', 'val_threshold': f'{threshold_initial*100:.2f}%'})
    print(f"Initial Validation - Loss: {val_loss_initial:.4f}, EER: {val_eer_initial*100:.2f}%, Threshold: {threshold_initial*100:.2f}%")

    # Progress bar setup
    total_steps = num_epochs * len(train_dataloader)
    pbar = tqdm(total=total_steps, initial=0, postfix=metric_dict, unit="batch")

    for epoch in range(num_epochs):
        model.train()  # Set the model to training mode (enables dropout, etc.)
        pbar.set_description(f"Epoch {epoch + 1}/{num_epochs}")
        
        running_train_loss = 0.0
        num_train_batches = 0

        for batch_sequences, batch_lengths, batch_labels in train_dataloader:
            # Move data to the specified device
            batch_sequences = batch_sequences.to(device)
            # batch_lengths are used by pack_padded_sequence which expects them on CPU
            batch_labels = batch_labels.to(device)

            # Zero the parameter gradients
            optimizer.zero_grad()

            # Forward pass: Get model outputs (logits)
            logits = model(batch_sequences, batch_lengths)
            
            # Calculate loss
            loss = criterion(logits, batch_labels)
            
            # Backward pass and optimize
            loss.backward()
            optimizer.step()

            # Update statistics for progress bar and logging
            running_train_loss += loss.item()
            num_train_batches += 1
            
            pbar.update(1) # Increment progress bar by one batch
            metric_dict.update({'train_loss': f'{loss.item():.3f}'}) # Current batch loss
            pbar.set_postfix(metric_dict)
        
        # Calculate average training loss for the epoch
        avg_epoch_train_loss = running_train_loss / num_train_batches if num_train_batches > 0 else 0.0
        metric_dict.update({'train_loss': f'{avg_epoch_train_loss:.3f}'}) # Average epoch loss
        
        # Evaluate on validation set after each epoch
        avg_val_loss, val_eer, val_threshold = evaluate_classifier(dev_dataloader, model, criterion)
        
        metric_dict.update({'val_loss': f'{avg_val_loss:.3f}', 'val_eer': f'{val_eer*100:.2f}%', 'val_threshold': f'{val_threshold*100:.2f}%'})
        pbar.set_postfix(metric_dict) # Update with latest validation metrics
        
        # Optional: Print epoch summary
        print(f"\nEpoch {epoch+1} Summary: Avg Train Loss: {avg_epoch_train_loss:.4f}, Val Loss: {avg_val_loss:.4f}, EER: {val_eer*100:.2f}%, Threshold: {val_threshold*100:.2f}%")

        # log the train\dev loss and the eer & threshold
        run.log({"train_loss": avg_epoch_train_loss, "dev_loss": avg_val_loss, 
                   "dev_eer": val_eer, "dev_threshold":val_threshold, "epoch": epoch + 1})
        
        ### update min eer and optimal model
        if val_eer < min_eer:
            min_eer = val_eer
            torch.save(model.state_dict(), best_model_filename)
            print(f"Epoch {epoch+1}: New best model saved to '{best_model_filename}' with EER: {min_eer:.4f}")

            run.summary['best_validation_eer'] = min_eer
            run.summary['best_eer_epoch'] = epoch + 1
            run.summary['validation_loss_at_best_eer'] = avg_val_loss

    pbar.close()
    print("Training finished.")
    return min_eer

In [15]:
NUM_EPOCHS = config.epochs
min_eer = float('inf')
best_model_filename = 'best_model'  #tbc

min_eer = train_model(model, train_dataloader, dev_dataloader, criterion, optimizer, NUM_EPOCHS, DEVICE,
            min_eer, best_model_filename)

Training started on device: cuda
Evaluating on validation set before training...
Initial Validation - Loss: 0.6642, EER: 77.90%, Threshold: 48.50%


  0%|          | 0/3970 [00:00<?, ?batch/s, train_loss=N/A, val_eer=77.90%, val_loss=0.664, val_threshold=48.5…


Epoch 1 Summary: Avg Train Loss: 0.0727, Val Loss: 0.0702, EER: 91.76%, Threshold: 1.23%
Epoch 1: New best model saved to 'best_model' with EER: 0.9176

Epoch 2 Summary: Avg Train Loss: 0.0704, Val Loss: 0.0704, EER: 55.46%, Threshold: 1.55%
Epoch 2: New best model saved to 'best_model' with EER: 0.5546

Epoch 3 Summary: Avg Train Loss: 0.0701, Val Loss: 0.0703, EER: 63.90%, Threshold: 1.09%

Epoch 4 Summary: Avg Train Loss: 0.0705, Val Loss: 0.0700, EER: 57.80%, Threshold: 1.11%

Epoch 5 Summary: Avg Train Loss: 0.0702, Val Loss: 0.0707, EER: 51.14%, Threshold: 1.71%
Epoch 5: New best model saved to 'best_model' with EER: 0.5114
Training finished.


### Save the model

In [16]:
if min_eer != float('inf'):
    print(f"Logging the best model ({best_model_filename}) to W&B Artifacts...")
    best_model_artifact = wandb.Artifact(
        name=f"{run.id}-best-model", # Using run ID for uniqueness
        type="model",
        description=f"Best model according to EER ({min_eer:.4f}) achieved at epoch {run.summary.get('best_eer_epoch', 'N/A')}.",
        metadata={"best_eer": min_eer, "epoch_of_best_eer": run.summary.get('best_eer_epoch', 'N/A')}
    )
    best_model_artifact.add_file(best_model_filename) # Add the saved file
    wandb.run.log_artifact(best_model_artifact, aliases=["best_eer_model"]) # Add an alias
    print("Best model logged as W&B Artifact.")
else:
    print("No model was saved as best_eer did not improve from its initial value.")

run.finish()

print("W&B run finished.")

Logging the best model (best_model) to W&B Artifacts...
Best model logged as W&B Artifact.


0,1
dev_eer,█▂▃▂▁
dev_loss,▃▅▄▁█
dev_threshold,▃▆▁▁█
epoch,▁▃▅▆█
train_loss,█▂▁▂▁

0,1
best_eer_epoch,5.0
best_validation_eer,0.51139
dev_eer,0.51139
dev_loss,0.07067
dev_threshold,0.01712
epoch,5.0
train_loss,0.07022
validation_loss_at_best_eer,0.07067


W&B run finished.
