In [1]:
# import os
# os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'
import torch
import torch.nn as nn
import torch.nn.utils.rnn as rnn_utils
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence
import torch.nn.functional as F
from torch.optim.lr_scheduler import ReduceLROnPlateau
from tqdm.auto import tqdm
import pandas as pd
import numpy as np
import eval_metrics as em
import wandb
from imblearn.over_sampling import RandomOverSampler
from collections import Counter
from sklearn.metrics import classification_report

### Configurations

In [2]:
run = wandb.init(
    project = "teamlab_deepfake",
    name = "Training_08_CNN",     #Training_XX
    notes = None,
    tags = ["ALL_FEATURE", "COMBINED_MODEL", "HNR", "PITCH", "JITTER&SHIMMER", "MFCC"],
    config={
        #NOTE: set manually
        "model": "CNN_classifier",   #   SpoofEnsemble/LSTM_FFN_classifier/CNN_classifier/SpoofEnsemble_attention
        "dataset": "ASVSpoof19_LA",    
        "feature": "MFCC&Prosody",
        "attack_type": "all",   # all/A01/A02/A03/A04/A05/A06
        "loss_function": "weighted_CE",
        #
        "scheduler": False,
        "scheduler_factor": 0.5,
        "scheduler_patience": 4,
        "epochs": 70,
        "batch_size": 32,
        "oversampling": True,
        "learning_rate": 5e-4,
        "dropout_rate": 0.3,
        # lstm layer
        "lstm_input_dim": 2,
        "lstm_hidden_dim": 64,
        "bidirectional": False,
        "lstm_n_layers":1,
        # fnn layer
        "ffn_dims": [11, 64], # in, out -
        # cnn layer
        "cnn_channels": [1, 32, 64, 128],   #in, out -
        "conv_kernel": (3,3),
        "pool_kernel": (2,2),
        "cnn_padding": 1,
        # random seeds
        "seeds": [0,7,42]
    },
)

config = run.config

[34m[1mwandb[0m: Currently logged in as: [33mqianyue[0m ([33mqianyue-university-of-stuttgart[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


>NOTE: attack types are evenly distributed in training and dev dataset, and each has higher number than genuine voices, so no further balancing is needed>

In [3]:
PITCH_COLUMN = 'PITCH'
HNR_COLUMN = 'HNR'
JITTER_COLUMN = 'JITTER'
SHIMMER_COLUMN = 'SHIMMER'
MFCC_COLUMN = 'MFCC'
LABEL_COLUMN = 'LABEL'      
                           
NAN_REPLACEMENT_VALUE = 0.0  
PADDING_VALUE = 0.0         
LABEL_BONAFIDE = 1
LABEL_SPOOF = 0

train_features_path = '/home/users1/liqe/TeamLab_phonetics/merged_train_com.pkl'
dev_features_path = '/home/users1/liqe/TeamLab_phonetics/merged_dev_com.pkl'

df_train = pd.read_pickle(train_features_path)
df_dev = pd.read_pickle(dev_features_path)

# NOTE: if training on a specific attack type
if config.attack_type != "all":
    df_train = df_train[df_train['ATTACK_TYPE'].isin([config.attack_type,'-'])]
    df_dev = df_dev[df_dev['ATTACK_TYPE'].isin([config.attack_type,'-'])]
elif config.attack_type == "all":
    pass
elif config.attack_type != ("A01" or "A02" or "A03" or "A04" or "A05" or "A06"):
    print("WARNING: invalid attack type.")

# # inspect
# print(df_train.head())
# print(df_train.groupby('ATTACK_TYPE').count())

# print("\n")
# print(df_dev.head())
# print(df_dev.groupby('ATTACK_TYPE').count())

#### Set the random seeds for replicability

In [4]:
# def set_seed(seed):
#     torch.manual_seed(seed)
#     torch.cuda.manual_seed_all(seed)
#     np.random.seed(seed)
#     random.seed(seed)
#     torch.backends.cudnn.deterministic = True
#     torch.backends.cudnn.benchmark = False

### Training Data Oversampling

>NOTE: training audio & labels are matched, dev are not (Solved: excessive rows are deleted beforehand)

In [5]:
if config.oversampling:
    X = df_train.drop('LABEL', axis=1)
    y = df_train['LABEL']

    over = RandomOverSampler(random_state=config.seeds[0])
    X_resampled_np, y_resampled_np = over.fit_resample(X, y) 

    X_resampled_df = pd.DataFrame(X_resampled_np, columns=X.columns)
    y_resampled_series = pd.Series(y_resampled_np, name=y.name)

    print("\nResampled X (DataFrame) head:")
    print(X_resampled_df.head())
    print("\nResampled y (Series) head:")
    print(y_resampled_series.head())
    print("\nResampled class distribution (from y_resampled_series):")
    print(Counter(y_resampled_series))

    df_train = pd.concat([X_resampled_df, y_resampled_series], axis=1)

    print("\nCombined Resampled DataFrame head:")
    print(df_train.head())
    print("\nCombined Resampled DataFrame info:")
    df_train.info()
    print("\nCombined Resampled DataFrame class distribution:")
    print(Counter(df_train['LABEL'])) # Verify target column in the new DataFrame


Resampled X (DataFrame) head:
       AUDIO_ID ATTACK_TYPE  \
0  LA_T_1000137         A04   
1  LA_T_1000406           -   
2  LA_T_1000648         A01   
3  LA_T_1000824         A04   
4  LA_T_1001074         A03   

                                               PITCH  \
0  [nan, nan, nan, nan, nan, nan, nan, nan, nan, ...   
1  [nan, nan, nan, nan, nan, nan, nan, nan, nan, ...   
2  [nan, nan, nan, nan, nan, 0.35835335, 0.350411...   
3  [nan, nan, nan, nan, nan, nan, nan, nan, nan, ...   
4  [nan, nan, nan, nan, nan, nan, nan, nan, nan, ...   

                                                 HNR  \
0  [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...   
1  [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...   
2  [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.7987432, 0.79...   
3  [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...   
4  [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...   

                                              JITTER  \
0  [0.30073947, 0.23022015, 0.26707897, 0.2538646...

### Padding and Data Loader

In [6]:
class ASVDataset(Dataset):
    def __init__(self, dataframe, pitch_col, hnr_col, jitter_col, shimmer_col, mfcc_col, label_col, nan_replacement=NAN_REPLACEMENT_VALUE):
        
        self.labels = []
        self.processed_pitchhnr = []
        self.global_features = []
        self.processed_mfcc = []
        
        print(f"Attempting to process {len(dataframe)} entries from DataFrame")
        found_count = 0
        # Iterate through the DataFrame, process and pad the features
        for index, row in dataframe.iterrows():  
            if not np.isnan(row[label_col]):
                self.labels.append(row[label_col])

                pitch_sequence_raw = row[pitch_col]
                processed_pitch = np.nan_to_num(pitch_sequence_raw, nan=nan_replacement)
                
                hnr_sequence_raw = row[hnr_col]
                processed_hnr = np.nan_to_num(hnr_sequence_raw, nan=nan_replacement)

                ### NOTE:need to pad the two sequences to the same length
                max_length = max(len(processed_pitch), len(processed_hnr))
                if len(processed_pitch) > len(processed_hnr):
                    padding = np.zeros(max_length - len(processed_hnr), dtype=processed_hnr.dtype)
                    processed_hnr = np.concatenate((processed_hnr, padding))
                else:
                    padding = np.zeros(max_length - len(processed_pitch), dtype=processed_pitch.dtype)
                    processed_pitch = np.concatenate((processed_pitch, padding))

                combined_features = np.stack((processed_pitch, processed_hnr), axis=-1) 
                self.processed_pitchhnr.append(torch.tensor(combined_features, dtype=torch.float32))

                # process and combine jitter and shimmer to one sequence
                processed_jitter = np.nan_to_num(row[jitter_col], nan=nan_replacement)
                processed_shimmer = np.nan_to_num(row[shimmer_col], nan=nan_replacement)
                jitter_shimmer = np.concatenate((processed_jitter, processed_shimmer))
                self.global_features.append(torch.tensor(jitter_shimmer, dtype=torch.float32))
                
                # process mfcc
                mfcc = row[mfcc_col]
                # NOTE: need transpose for padding (time, feature_dim)
                self.processed_mfcc.append(torch.tensor(mfcc, dtype=torch.float32).T)

                found_count += 1
        
        self.labels = torch.tensor(self.labels, dtype=torch.long) 
        print(f"Successfully processed {found_count} samples out of {len(dataframe)} DataFrame entries.")


    def __len__(self):
        """Returns the total number of matched samples in the dataset."""
        return len(self.labels)

    def __getitem__(self, idx):
        """
        Returns one sample from the dataset: a preprocessed pitch sequence and its label.
        """
        label = self.labels[idx]
        pitch_hnr = self.processed_pitchhnr[idx]
        global_feature = self.global_features[idx]
        mfcc = self.processed_mfcc[idx]
        return label, pitch_hnr, global_feature, mfcc

In [7]:
# --- for Dynamic Padding  ---
def collate_fn(batch, padding_value=PADDING_VALUE):
    """
    Pads sequences within a batch to the same length.
    """
    labels = [item[0] for item in batch]
    pitch_hnrs = [item[1] for item in batch]
    global_features = [item[2] for item in batch]
    mfccs = [item[3] for item in batch]

    labels = torch.stack(labels)

    pitchhnr_lengths = torch.tensor([len(seq) for seq in pitch_hnrs], dtype=torch.long)
    padded_pitchhnrs = pad_sequence(pitch_hnrs, batch_first=True, padding_value=padding_value)
    if padded_pitchhnrs.ndim == 2:     # lstm expects: [batch_size, sequence_length, feature_size]
        padded_pitchhnrs = padded_pitchhnrs.unsqueeze(2)

    global_features = torch.stack(global_features)

    padded_mfccs = pad_sequence(mfccs, batch_first=True, padding_value=padding_value)

    return labels, pitchhnr_lengths, padded_pitchhnrs, global_features, padded_mfccs

In [8]:
pitch_dataset_train = ASVDataset(dataframe=df_train,   
                                    pitch_col=PITCH_COLUMN,
                                    hnr_col=HNR_COLUMN,
                                    jitter_col=JITTER_COLUMN,
                                    shimmer_col=SHIMMER_COLUMN,
                                    mfcc_col=MFCC_COLUMN,
                                    label_col=LABEL_COLUMN,
                                    nan_replacement=NAN_REPLACEMENT_VALUE)

train_dataloader = DataLoader(
    pitch_dataset_train, batch_size=config.batch_size, shuffle=True, collate_fn=collate_fn, num_workers=8
)

pitch_dataset_dev = ASVDataset(dataframe=df_dev,   
                                    pitch_col=PITCH_COLUMN,
                                    hnr_col=HNR_COLUMN,
                                    jitter_col=JITTER_COLUMN,
                                    shimmer_col=SHIMMER_COLUMN,
                                    mfcc_col=MFCC_COLUMN,
                                    label_col=LABEL_COLUMN,
                                    nan_replacement=NAN_REPLACEMENT_VALUE)

dev_dataloader = DataLoader(
    pitch_dataset_dev, batch_size=config.batch_size, shuffle=True, collate_fn=collate_fn, num_workers=8
)

## For inspection
for i, batch_data in enumerate(train_dataloader):
    # batch_data is a tuple
    batch_labels, batch_lengths, batch_pitchhnr, batch_global, batch_mfcc = batch_data
    print(f"\n--- Batch {i+1} ---")
    print(f"  Labels (first 5): {batch_labels[:5]}")
    print(f"  Padded Sequences Shape: {batch_pitchhnr.shape}")
    print(f"  Original Lengths (first 5): {batch_lengths[:5]}")
    print(f"  Global Shape: {batch_global.shape}")
    print(f"  MFCC Shape: {batch_mfcc.shape}")
    

    if i == 0: # Break after the first batch for inspection
        break


Attempting to process 45598 entries from DataFrame
Successfully processed 45598 samples out of 45598 DataFrame entries.
Attempting to process 24844 entries from DataFrame
Successfully processed 24844 samples out of 24844 DataFrame entries.

--- Batch 1 ---
  Labels (first 5): tensor([0, 1, 1, 0, 1])
  Padded Sequences Shape: torch.Size([32, 568, 2])
  Original Lengths (first 5): tensor([152, 270, 351, 568, 249])
  Global Shape: torch.Size([32, 11])
  MFCC Shape: torch.Size([32, 179, 60])


### Finding the weight (for weighted cross entropy)

is there different ways calculating weitghs?

In [9]:

labels = df_train['LABEL']   
total = len(labels)
count_bonafide = labels.value_counts().get(LABEL_BONAFIDE, 0)
count_spoof =  total - count_bonafide
weight_bonafide = total / (count_bonafide * 2)
weight_spoof = total / (count_spoof * 2)

### Classifier

#### LSTM&FFN

In [10]:
class LSTM_FFN_branch(nn.Module):
    def __init__(self, lstm_input_dim, lstm_hidden_dim, lstm_n_layers, bidirectional, 
                 ffn_dims):

        super().__init__()

        self.lstm_ffn_dim = (lstm_hidden_dim * 2 if bidirectional else lstm_hidden_dim) + ffn_dims[-1]
        self.ffn_layers = nn.ModuleList()

        # 1. lstm layer
        self.lstm = nn.LSTM(lstm_input_dim, 
                            lstm_hidden_dim, 
                            num_layers=lstm_n_layers, 
                            bidirectional=bidirectional, 
                            batch_first=True) # Input/output tensors are (batch, seq, feature)
        # BN layer for stabalization
        self.bn_lstm = nn.BatchNorm1d(lstm_hidden_dim * 2 if bidirectional else lstm_hidden_dim)
        
        # 2. ffn layer
        for i in range(len(ffn_dims) -1):
            ffn_input_dim = ffn_dims[i]
            ffn_hidden_dim = ffn_dims[i+1]
            ffn_block = nn.Sequential(
                nn.Linear(ffn_input_dim, ffn_hidden_dim),
                nn.BatchNorm1d(ffn_hidden_dim),    # BN layer for stabalization
                nn.ReLU())
            self.ffn_layers.append(ffn_block)
        
        
    def forward(self, pitch_hnrs, pitchhnr_lengths, global_features):
      
        # 1. Pack sequence
        ### Compute actual data and ignore the padded values
        packed_input = rnn_utils.pack_padded_sequence(pitch_hnrs, pitchhnr_lengths.cpu(), batch_first=True, enforce_sorted=False)
        
        # 2. Pass packed sequence through LSTM
        ### packed_output: Hidden states for every time step.
        ### hidden: The final hidden state (summary) of the entire sequence.
        ### cell: The final cell state (long-term memory) of the entire sequence.
        packed_output, (lstm_hidden, cell) = self.lstm(packed_input)
        
        # 3. Concatenate the final forward and backward hidden states (if bidirectional)
        if self.lstm.bidirectional:
            lstm_hidden = torch.cat((lstm_hidden[-2,:,:], lstm_hidden[-1,:,:]), dim=1)
        else:
            lstm_hidden = lstm_hidden[-1,:,:]
        lstm_hidden = self.bn_lstm(lstm_hidden)

        # 4. Pass global features (jitter and shimmer) through the FFN
        for layer in self.ffn_layers:
            global_features = layer(global_features)
        ffn_output = global_features

        # 5. Concatenate the outputs from lstm and fnn
        combined_output = torch.cat((lstm_hidden,ffn_output), dim=1)

        return combined_output

In [11]:
# for LSTM_FFN training alone 
class LSTM_FFN_classifer(nn.Module):
    def __init__(self, lstm_ffn_out, output_dim, dropout):
        super().__init__()

        self.lstm_ffn_layer = lstm_ffn_out
        self.fc = nn.Linear(lstm_ffn_out.lstm_ffn_dim, output_dim)
        
        self.dropout = nn.Dropout(dropout)

    def forward(self, pitch_hnrs, pitchhnr_lengths, global_features, _a):

        lstm_ffn_out = self.lstm_ffn_layer(pitch_hnrs, pitchhnr_lengths, global_features)
        lstm_ffn_out = self.dropout(lstm_ffn_out)
        output = self.fc(lstm_ffn_out)

        return output

#### CNN

In [12]:
class CNN_branch(nn.Module):
    def __init__(self, cnn_channels, conv_kernel, pool_kernel, cnn_padding):

        super().__init__()

        self.cnn_dim = cnn_channels[-1]

        self.conv_layers = nn.ModuleList()

        for i in range(len(cnn_channels)-2):
            cnn_in = cnn_channels[i]
            cnn_out = cnn_channels[i+1]
            conv_block = nn.Sequential(
                nn.Conv2d(in_channels=cnn_in, out_channels=cnn_out, kernel_size=conv_kernel, padding=cnn_padding),
                nn.BatchNorm2d(cnn_out),
                nn.ReLU(),
                nn.MaxPool2d(kernel_size=pool_kernel))
            self.conv_layers.append(conv_block)

        # final layer of CNN
        final_in = cnn_channels[-2]
        final_out = cnn_channels[-1]

        conv_final = nn.Sequential(
            nn.Conv2d(in_channels=final_in, out_channels=final_out, kernel_size=conv_kernel, padding=cnn_padding),
            nn.BatchNorm2d(final_out),
            nn.ReLU(),
            nn.AdaptiveMaxPool2d((1, 1))  # Output size: [batch, 64, 1, 1]
        )
        self.conv_layers.append(conv_final)
        
    def forward(self, mfccs):

        # expected shape (batch_size, in_channel, height, width) -> unsqeeze
        mfccs = mfccs.unsqueeze(1)

        for layer in self.conv_layers:
            mfccs = layer(mfccs)
        cnn_out = mfccs.view(mfccs.size(0), -1)
        
        return cnn_out

In [13]:
# for CNN training alone
class CNN_classifer(nn.Module):
    def __init__(self, cnn_out, output_dim, dropout):
        super().__init__()

        self.cnn_layer = cnn_out
        self.fc = nn.Linear(cnn_out.cnn_dim, output_dim)

        self.dropout = nn.Dropout(dropout)

    def forward(self, _a, _b, _c, mfccs):

        cnn_out = self.cnn_layer(mfccs)
        cnn_out = self.dropout(cnn_out)
        output = self.fc(cnn_out)

        return output

#### Attention

In [14]:
class BranchAttention(nn.Module):
    def __init__(self, branch1_dim, branch2_dim):
        super().__init__()
        self.attention_net = nn.Linear(branch1_dim + branch2_dim, 2)

        with torch.no_grad():
            self.attention_net.bias.fill_(0)

    def forward(self, branch1_out, branch2_out):
        # 1. concatenate the raw outputs from both branches
        combined_out = torch.cat((branch1_out, branch2_out), dim=1)
        
        # 2. Predict the score for each branch's importance
        attention_scores = self.attention_net(combined_out)
        
        # 3. turn scores into weights that sum to 1 (e.g., [0.7, 0.3])
        attention_weights = F.softmax(attention_scores, dim=1)
        
        # 4. Get the individual weight for each branch
        # .unsqueeze(1) is needed to make the dimensions compatible for multiplication
        branch1_weight = attention_weights[:, 0].unsqueeze(1)
        branch2_weight = attention_weights[:, 1].unsqueeze(1)
        
        # 5. Scale each branch's output by its learned weight
        branch1_weighted = branch1_out * branch1_weight
        branch2_weighted = branch2_out * branch2_weight
        
        # 6. Concatenate the *weighted* features to pass to the final classifier
        weighted_combined_features = torch.cat((branch1_weighted, branch2_weighted), dim=1)
        
        # Return the combined features and the weights for inspection
        return weighted_combined_features, attention_weights

#### Emsemble

In [15]:
class SpoofEnsemble(nn.Module):
    def __init__(self, lstm_ffn_branch, cnn_branch, output_dim, dropout):

        super().__init__()

        self.lstm_ffn_branch = lstm_ffn_branch
        self.cnn_branch = cnn_branch

        lstm_ffn_dim = lstm_ffn_branch.lstm_ffn_dim
        cnn_dim = cnn_branch.cnn_dim
        self.fc = nn.Linear(lstm_ffn_dim + cnn_dim, output_dim)

        self.dropout = nn.Dropout(dropout)
        
    def forward(self, pitch_hnrs, pitchhnr_lengths, global_features, mfccs):
      
        lstm_ffn_out = self.lstm_ffn_branch(pitch_hnrs, pitchhnr_lengths, global_features)
        
        # Get the output from the second branch
        cnn_out = self.cnn_branch(mfccs)
        
        # Concatenate all features
        combined_features = torch.cat((lstm_ffn_out, cnn_out), dim=1)

        # Apply dropout
        combined_features = self.dropout(combined_features)
        
        # Final classification
        output = self.fc(combined_features)
        
        return output

#### Ensemble with attention

In [16]:
class SpoofEnsemble_attention(nn.Module):
    def __init__(self, lstm_ffn_branch, cnn_branch, output_dim, dropout):

        super().__init__()

        self.lstm_ffn_branch = lstm_ffn_branch
        self.cnn_branch = cnn_branch

        lstm_ffn_dim = lstm_ffn_branch.lstm_ffn_dim
        cnn_dim = cnn_branch.cnn_dim

        # Instantiate the attention module
        self.attention = BranchAttention(lstm_ffn_dim, cnn_dim)
        
        # This attribute will store the weights from the last forward pass
        # for later analysis and interpretation.
        self.attention_weights = None

        self.fc = nn.Linear(lstm_ffn_dim + cnn_dim, output_dim)

        self.dropout = nn.Dropout(dropout)
        
    def forward(self, pitch_hnrs, pitchhnr_lengths, global_features, mfccs):
      
        lstm_ffn_out = self.lstm_ffn_branch(pitch_hnrs, pitchhnr_lengths, global_features)
        
        # Get the output from the second branch
        cnn_out = self.cnn_branch(mfccs)
        
        # Pass the raw outputs through the attention mechanism
        combined_features, self.attention_weights = self.attention(lstm_ffn_out, cnn_out)
        
        # Apply dropout
        combined_features = self.dropout(combined_features)
        
        # Final classification
        output = self.fc(combined_features)
        
        return output

### Initiate the model

#### find the device

In [17]:
if torch.cuda.is_available():
    device_index = 1
    torch.cuda.set_device(device_index)
    DEVICE = torch.device('cuda')
    print(f"Using CUDA device: {torch.cuda.get_device_name(DEVICE)}")
else:
    print("CUDA is not available. Using CPU.")
    DEVICE = torch.device('cpu')

Using CUDA device: NVIDIA GeForce GTX TITAN X


#### find the class weights for WCE & set the criterion

In [18]:
class_weights = torch.tensor([weight_bonafide, weight_spoof], dtype=torch.float32).to(DEVICE)
criterion = torch.nn.CrossEntropyLoss(reduction='mean', weight=class_weights)

#### Initiation

In [19]:
def initiate_model():
    lstm_ffn_out= LSTM_FFN_branch(lstm_input_dim=config.lstm_input_dim, lstm_hidden_dim=config.lstm_hidden_dim, lstm_n_layers=config.lstm_n_layers, bidirectional=config.bidirectional,
                    ffn_dims=config.ffn_dims).to(DEVICE)
    cnn_out = CNN_branch(cnn_channels=config.cnn_channels, conv_kernel=config.conv_kernel, pool_kernel=config.pool_kernel, cnn_padding=config.cnn_padding).to(DEVICE)

    if config.model=="SpoofEnsemble":
        model = SpoofEnsemble(lstm_ffn_branch=lstm_ffn_out, cnn_branch=cnn_out, output_dim=2, dropout=config.dropout_rate).to(DEVICE)
    elif config.model=="LSTM_FFN_classifier":
        model = LSTM_FFN_classifer(lstm_ffn_out=lstm_ffn_out, output_dim=2, dropout=config.dropout_rate).to(DEVICE)
    elif config.model=="CNN_classifier":
        model = CNN_classifer(cnn_out=cnn_out, output_dim=2, dropout=config.dropout_rate).to(DEVICE)
    elif config.model=="SpoofEnsemble_attention":
        model = SpoofEnsemble_attention(lstm_ffn_branch=lstm_ffn_out, cnn_branch=cnn_out, output_dim=2, dropout=config.dropout_rate).to(DEVICE)
    else:
        print("WARNING: invalid model name.")
    return model


# print(f"DEBUG: Initial FFN_Linear WEIGHTS:\n{model.ffn_linear.weight.detach().cpu().numpy()}")
# print(f"DEBUG: Initial FFN_Linear BIAS:\n{model.ffn_linear.bias.detach().cpu().numpy()}")

### Evaluation

In [20]:
def evaluate_classifier(data_loader, model, criterion):

    model.eval()  # Set the model to evaluation mode (disables dropout, etc.)
    
    total_loss = 0.0
    total_samples = 0

    scores_bonafide = []
    scores_spoof = []

    with torch.no_grad():  # Disable gradient calculations during evaluation
        for batch_labels, batch_lengths, batch_pitchhnr, batch_global, batch_mfcc in data_loader:
            
            batch_labels = batch_labels.to(DEVICE)
            batch_pitchhnr = batch_pitchhnr.to(DEVICE)
            batch_global = batch_global.to(DEVICE)
            batch_mfcc = batch_mfcc.to(DEVICE)

            # Forward pass: Get model outputs (logits)
            logits = model(batch_pitchhnr, batch_lengths, batch_global, batch_mfcc)
            
            # Calculate loss for the current batch
            loss = criterion(logits, batch_labels)
            total_loss += loss.item() * batch_labels.size(0) # Accumulate loss, weighted by batch size

            # for EER
            probabilities = torch.softmax(logits, dim=1)
            
            for i in range(len(batch_labels)):
                current_label = batch_labels[i]
                current_score = probabilities[i]

                if current_label == LABEL_BONAFIDE:
                    scores_bonafide.append(current_score[LABEL_BONAFIDE].cpu())     # numpy is cpu only, need to move tensor from gpu
                elif current_label == LABEL_SPOOF:
                    scores_spoof.append(current_score[LABEL_BONAFIDE].cpu())
            
            total_samples += batch_labels.size(0) # Count number of samples in this batch

    average_loss = total_loss / total_samples if total_samples > 0 else 0.0

    scores_bonafide_np = np.array(scores_bonafide)    
    scores_spoof_np = np.array(scores_spoof)
    eer, threshold = em.compute_eer(scores_bonafide_np, scores_spoof_np)

    all_scores = np.concatenate((scores_bonafide_np, scores_spoof_np))
    labels_true = np.concatenate((np.ones_like(scores_bonafide_np), np.zeros_like(scores_spoof_np)))
    labels_pred = (all_scores >= threshold).astype(int)
    
    return average_loss, eer, threshold, labels_true, labels_pred

### The training loop

>note: in wandb, scalers logs for every epoch, plots get overwritten (but still saved in artifacts?)

In [21]:
def train_model(criterion, train_dataloader, dev_dataloader, num_epochs,
                min_eer, best_model_filename):

    model = initiate_model()
    optimizer = torch.optim.Adam(model.parameters(), lr=config.learning_rate)
    scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=config.scheduler_factor, patience=config.scheduler_patience)
    print(f"Training started on device: {DEVICE}")
    model.to(DEVICE) 

    # Initial metric dictionary for the progress bar
    metric_dict = {'train_loss': 'N/A', 'val_loss': 'N/A', 'val_eer': 'N/A', 'val_threshold': 'N/A'}

    # Evaluate on validation set first to get a baseline
    print("Evaluating on validation set before training...")
    model.eval() # Set model to evaluation mode
    val_loss_initial, val_eer_initial, threshold_initial, labels_true, labels_pred = evaluate_classifier(dev_dataloader, model, criterion)
    metric_dict.update({'val_loss': f'{val_loss_initial:.3f}', 'val_eer': f'{val_eer_initial*100:.2f}%', 'val_threshold': f'{threshold_initial*100:.2f}%'})
    print(f"Initial Validation - Loss: {val_loss_initial:.4f}, EER: {val_eer_initial*100:.2f}%, Threshold: {threshold_initial*100:.2f}%")

    # Progress bar setup
    total_steps = num_epochs * len(train_dataloader)
    pbar = tqdm(total=total_steps, initial=0, postfix=metric_dict, unit="batch")

    for epoch in range(num_epochs):
        model.train()  # Set the model to training mode (enables dropout, etc.)
        pbar.set_description(f"Epoch {epoch + 1}/{num_epochs}")
        
        running_train_loss = 0.0
        num_train_batches = 0

        for batch_labels, batch_lengths, batch_pitchhnr, batch_global, batch_mfcc in train_dataloader:
            # Move data to the specified device
            # batch_lengths are used by pack_padded_sequence which expects them on CPU
            batch_labels = batch_labels.to(DEVICE)
            batch_pitchhnr = batch_pitchhnr.to(DEVICE)
            batch_global = batch_global.to(DEVICE)
            batch_mfcc = batch_mfcc.to(DEVICE)

            # Zero the parameter gradients
            optimizer.zero_grad()
            
            # Forward pass: Get model outputs (logits)
            logits = model(batch_pitchhnr, batch_lengths, batch_global, batch_mfcc)
            
            # Calculate loss
            loss = criterion(logits, batch_labels)
            
            # Backward pass and optimize
            loss.backward()
            # --- FOR GRADIENT CLIPPING ---
            # torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

            optimizer.step()

            # Update statistics for progress bar and logging
            running_train_loss += loss.item()
            num_train_batches += 1
            
            pbar.update(1) # Increment progress bar by one batch
            metric_dict.update({'train_loss': f'{loss.item():.3f}'}) # Current batch loss
            pbar.set_postfix(metric_dict)
        
        # Calculate average training loss for the epoch
        avg_epoch_train_loss = running_train_loss / num_train_batches if num_train_batches > 0 else 0.0
        metric_dict.update({'train_loss': f'{avg_epoch_train_loss:.3f}'}) # Average epoch loss
        
        # Evaluate on validation set after each epoch
        avg_val_loss, val_eer, val_threshold, labels_true, labels_pred = evaluate_classifier(dev_dataloader, model, criterion)
        
        # for reduce on plateau
        if config.scheduler:
            scheduler.step(val_eer)

        # Update with latest validation metrics
        metric_dict.update({'val_loss': f'{avg_val_loss:.3f}', 'val_eer': f'{val_eer*100:.2f}%', 'val_threshold': f'{val_threshold*100:.2f}%'})
        pbar.set_postfix(metric_dict)
        
        # Optional: Print epoch summary
        print(f"\nEpoch {epoch+1} Summary: Avg Train Loss: {avg_epoch_train_loss:.4f}, Val Loss: {avg_val_loss:.4f}, EER: {val_eer*100:.2f}%, Threshold: {val_threshold*100:.2f}%")

        # log the train\dev loss and the eer & threshold
        run.log({"train_loss": avg_epoch_train_loss, "dev_loss": avg_val_loss, 
                   "dev_eer": val_eer, "dev_threshold":val_threshold, "epoch": epoch + 1})
        
        # update min eer and optimal model
        if val_eer < min_eer:
            min_eer = val_eer
            torch.save(model.state_dict(), best_model_filename)
            print(f"Epoch {epoch+1}: New best model saved to '{best_model_filename}' with EER: {min_eer:.4f}")

            run.summary['best_validation_eer'] = min_eer
            run.summary['best_eer_epoch'] = epoch + 1
            run.summary['validation_loss_at_best_eer'] = avg_val_loss

            # log the report and confusion matrix
            class_names = ['SPOOF', 'BONAFIDE']     #NOTE: the order matters, need to match labels
            report_columns =  ["Class", "Precision", "Recall", "F1-score", "Support"]
            class_report = classification_report(labels_true, labels_pred, labels=[0, 1],
                                        target_names=class_names).splitlines()
            report_table = []
            for line in class_report[2:(len(class_names)+2)]:
                report_table.append(line.split())
            run.log({"Confusion Matix": wandb.plot.confusion_matrix(y_true=labels_true, preds=labels_pred, class_names=class_names),
                    "Classification Report": wandb.Table(data=report_table, columns=report_columns)})

    pbar.close()
    print("Training finished.")
    return min_eer

#### Start the training

>note: only partially deterministic for adaptivemaxpooling does not support the feature yet

In [22]:
NUM_EPOCHS = config.epochs
min_eer = float('inf')
best_model_filename = 'best_model'

for seed in config.seeds:
    print(f"\n--- Starting Trial with Seed: {seed} ---")
    # set_seed(seed)
    # torch.use_deterministic_algorithms(True, warn_only=True)
    min_eer = train_model(criterion, train_dataloader, dev_dataloader, NUM_EPOCHS, min_eer, best_model_filename)


--- Starting Trial with Seed: 0 ---
Training started on device: cuda
Evaluating on validation set before training...
Initial Validation - Loss: 0.7381, EER: 59.81%, Threshold: 52.69%


  0%|          | 0/99750 [00:00<?, ?batch/s, train_loss=N/A, val_eer=59.81%, val_loss=0.738, val_threshold=52.…


Epoch 1 Summary: Avg Train Loss: 0.2975, Val Loss: 0.1130, EER: 5.65%, Threshold: 36.42%
Epoch 1: New best model saved to 'best_model' with EER: 0.0565

Epoch 2 Summary: Avg Train Loss: 0.1129, Val Loss: 0.0630, EER: 3.72%, Threshold: 21.34%
Epoch 2: New best model saved to 'best_model' with EER: 0.0372

Epoch 3 Summary: Avg Train Loss: 0.0655, Val Loss: 0.0476, EER: 2.43%, Threshold: 3.63%
Epoch 3: New best model saved to 'best_model' with EER: 0.0243

Epoch 4 Summary: Avg Train Loss: 0.0460, Val Loss: 0.0337, EER: 2.12%, Threshold: 20.08%
Epoch 4: New best model saved to 'best_model' with EER: 0.0212

Epoch 5 Summary: Avg Train Loss: 0.0291, Val Loss: 0.0552, EER: 1.92%, Threshold: 56.95%
Epoch 5: New best model saved to 'best_model' with EER: 0.0192

Epoch 6 Summary: Avg Train Loss: 0.0241, Val Loss: 0.4135, EER: 2.24%, Threshold: 99.11%

Epoch 7 Summary: Avg Train Loss: 0.0195, Val Loss: 0.0271, EER: 1.45%, Threshold: 1.93%
Epoch 7: New best model saved to 'best_model' with EER: 0

  0%|          | 0/99750 [00:00<?, ?batch/s, train_loss=N/A, val_eer=31.01%, val_loss=0.775, val_threshold=54.…


Epoch 1 Summary: Avg Train Loss: 0.3307, Val Loss: 0.0888, EER: 5.45%, Threshold: 13.42%

Epoch 2 Summary: Avg Train Loss: 0.1260, Val Loss: 0.0483, EER: 2.94%, Threshold: 12.49%

Epoch 3 Summary: Avg Train Loss: 0.0642, Val Loss: 0.0352, EER: 2.36%, Threshold: 15.35%

Epoch 4 Summary: Avg Train Loss: 0.0452, Val Loss: 0.0283, EER: 1.88%, Threshold: 6.64%

Epoch 5 Summary: Avg Train Loss: 0.0287, Val Loss: 0.0294, EER: 1.53%, Threshold: 32.84%

Epoch 6 Summary: Avg Train Loss: 0.0241, Val Loss: 0.0304, EER: 1.37%, Threshold: 1.27%

Epoch 7 Summary: Avg Train Loss: 0.0183, Val Loss: 0.0730, EER: 1.57%, Threshold: 82.23%

Epoch 8 Summary: Avg Train Loss: 0.0144, Val Loss: 0.0432, EER: 1.34%, Threshold: 0.14%

Epoch 9 Summary: Avg Train Loss: 0.0119, Val Loss: 0.0307, EER: 1.22%, Threshold: 0.39%

Epoch 10 Summary: Avg Train Loss: 0.0114, Val Loss: 0.0339, EER: 1.53%, Threshold: 35.75%

Epoch 11 Summary: Avg Train Loss: 0.0095, Val Loss: 0.0279, EER: 1.42%, Threshold: 0.61%

Epoch 12 Sum

  0%|          | 0/99750 [00:00<?, ?batch/s, train_loss=N/A, val_eer=38.74%, val_loss=0.692, val_threshold=49.…


Epoch 1 Summary: Avg Train Loss: 0.2619, Val Loss: 0.0776, EER: 5.14%, Threshold: 14.00%

Epoch 2 Summary: Avg Train Loss: 0.0826, Val Loss: 0.0435, EER: 2.68%, Threshold: 21.80%

Epoch 3 Summary: Avg Train Loss: 0.0407, Val Loss: 0.0339, EER: 1.96%, Threshold: 22.09%

Epoch 4 Summary: Avg Train Loss: 0.0326, Val Loss: 0.1064, EER: 1.57%, Threshold: 87.66%

Epoch 5 Summary: Avg Train Loss: 0.0241, Val Loss: 0.0312, EER: 1.73%, Threshold: 1.62%

Epoch 6 Summary: Avg Train Loss: 0.0216, Val Loss: 0.0247, EER: 1.38%, Threshold: 1.69%

Epoch 7 Summary: Avg Train Loss: 0.0152, Val Loss: 0.0245, EER: 1.57%, Threshold: 3.50%

Epoch 8 Summary: Avg Train Loss: 0.0116, Val Loss: 0.0267, EER: 1.02%, Threshold: 0.52%

Epoch 9 Summary: Avg Train Loss: 0.0089, Val Loss: 0.0227, EER: 1.14%, Threshold: 22.60%

Epoch 10 Summary: Avg Train Loss: 0.0103, Val Loss: 0.0291, EER: 1.45%, Threshold: 1.44%

Epoch 11 Summary: Avg Train Loss: 0.0058, Val Loss: 0.0288, EER: 1.06%, Threshold: 0.33%

Epoch 12 Summ

### Save the model

In [23]:
if min_eer != float('inf'):
    print(f"Logging the best model ({best_model_filename}) to W&B Artifacts...")
    best_model_artifact = wandb.Artifact(
        name=f"{run.id}-best-model", # Using run ID for uniqueness
        type="model",
        description=f"Best model according to EER ({min_eer:.4f}) achieved at epoch {run.summary.get('best_eer_epoch', 'N/A')}.",
        metadata={"best_eer": min_eer, "epoch_of_best_eer": run.summary.get('best_eer_epoch', 'N/A')}
    )
    best_model_artifact.add_file(best_model_filename) # Add the saved file
    wandb.run.log_artifact(best_model_artifact, aliases=["best_eer_model"]) # Add an alias
    print("Best model logged as W&B Artifact.")
else:
    print("No model was saved as best_eer did not improve from its initial value.")

run.finish()

print("W&B run finished.")

Logging the best model (best_model) to W&B Artifacts...
Best model logged as W&B Artifact.


0,1
dev_eer,▃▂▂▂▂▁▂▁▁▂▁█▄▃▂▂▂▂▂▂▂▁▁▁▂▁▄▃▂▂▂▃▂▁▂▁▁▁▁▂
dev_loss,▂▁▂█▁▁▂▁▂▁▃▁▁▁▂▁▂▁▁▂▁▁▁▁▁▁▁▁▁▁▁▂▁▂▂▁▁▂▁▂
dev_threshold,█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▇▁▁▁▁▁▁▁▁▁▁▁▁▇▁▁▁▁▁▁▁▁▁
epoch,▃▄▄▅▅▆▆▇▇▇██▁▂▂▃▄▄▄▄▆▇▇██▁▂▂▃▃▄▄▅▅▅▆▆▆▇█
train_loss,▃▁▂▁▁▁▁▁▁▁▁▁▁█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
best_eer_epoch,58.0
best_validation_eer,0.00668
dev_eer,0.01373
dev_loss,0.06477
dev_threshold,0.0
epoch,70.0
train_loss,0.00101
validation_loss_at_best_eer,0.02216


W&B run finished.
