In [None]:
# Installing required Python packages for the project
# This command works inside Jupyter notebooks specifically using the %pip magic

%pip install numpy pandas librosa scikit-learn torch torchvision torchaudio tqdm matplotlib seaborn

# ------------------------
# Libraries included:
# numpy        - for numerical computing and matrix operations
# pandas       - for working with structured tabular data (DataFrames)
# librosa      - for audio analysis and feature extraction like mel spectrograms
# scikit-learn - for ML utilities (splits, metrics, normalization, etc.)
# torch        - PyTorch for building and training neural networks
# torchvision  - for computer vision utilities; may be unused if no vision features
# torchaudio   - for PyTorch-native audio handling
# tqdm         - for nice progress bars during long loops/training
# matplotlib   - for creating visualizations
# seaborn      - for prettier statistical plots

Note: you may need to restart the kernel to use updated packages.


In [None]:
import librosa       # for audio loading and spectrogram generation
import numpy as np   # for array manipulation and padding
import torch         # for tensor handling and model input formatting

def extract_features(file_path, sample_rate=16000, duration=3, n_mels=128):
    """
    Loads an audio file, ensures it has a fixed duration, extracts log-Mel spectrogram features,
    and converts the result into a PyTorch tensor.

    Parameters:
        file_path (str): Path to the audio file
        sample_rate (int): Target sampling rate (default: 16kHz)
        duration (int): Target duration in seconds (default: 3s)
        n_mels (int): Number of Mel filterbanks (default: 128)

    Returns:
        torch.Tensor: A tensor of shape [1, n_mels, TimeSteps]
    """
    try:
        # Load and resample the audio to the desired sample rate
        signal, sr = librosa.load(file_path, sr=sample_rate)

        # Ensure the audio is exactly `duration` seconds long
        required_len = sample_rate * duration
        if len(signal) > required_len:
            signal = signal[:required_len]  # Truncate
        else:
            signal = np.pad(signal, (0, required_len - len(signal)))  # Pad

        # Compute Mel spectrogram
        mel_spec = librosa.feature.melspectrogram(
            y=signal,
            sr=sr,
            n_mels=n_mels
        )

        # Convert spectrogram to log scale (dB)
        log_mel_spec = librosa.power_to_db(mel_spec)

        # Convert to PyTorch tensor and add channel dimension: [1, n_mels, time]
        return torch.tensor(log_mel_spec).unsqueeze(0).float()

    except Exception as e:
        print(f"Error processing {file_path}: {e}")
        # Return a zero tensor with approximate shape if anything goes wrong
        return torch.zeros((1, n_mels, int(sample_rate * duration / 512)))


In [None]:
import pandas as pd  # for structured data loading and manipulation

def load_protocol(file_path):
    """
    Loads a protocol file (e.g., en_train.txt) into a pandas DataFrame.

    Parameters:
        file_path (str): Full path to the protocol file.

    Returns:
        pd.DataFrame: A DataFrame with labeled columns:
                      ['speaker_id', 'file_name', 'unused', 'system_id', 'key']
    """
    columns = ['speaker_id', 'file_name', 'unused', 'system_id', 'key']
    return pd.read_csv(file_path, sep=' ', names=columns)

# Example usage: loading English training protocol
df_en_train = load_protocol("/home/kartik-dua/python_project/DECRO/inner/en_train.txt")
df_en_train.head()  # Displays the first 5 rows for inspection


Unnamed: 0,speaker_id,file_name,unused,system_id,key
0,4,4-727-124443-0055,-,baidu,spoof
1,5,5-2012-139356-0032,-,baidu,spoof
2,103,103-534-127537-0043,-,baidu,spoof
3,111,111-5266-34501-0003,-,baidu,spoof
4,111,111-203-132069-0025,-,baidu,spoof


In [None]:
import librosa  # for audio processing and spectrogram generation
import numpy as np  # for numerical ops like padding

def extract_log_mel(file_path, sr=16000, n_mels=64, duration=3.0):
    """
    Extracts log-Mel spectrogram from an audio file.

    Parameters:
        file_path (str): Path to the audio file (.wav, .flac, etc.)
        sr (int): Target sample rate for resampling (default: 16000 Hz)
        n_mels (int): Number of Mel filter banks (default: 64)
        duration (float): Desired duration in seconds (default: 3.0)

    Returns:
        np.ndarray: 2D array of shape [n_mels, time_frames] representing log-Mel spectrogram
    """
    # Load audio file, resample, trim to specified duration
    audio, sr = librosa.load(file_path, sr=sr, duration=duration)

    # Zero-pad if shorter than required length
    if len(audio) < int(sr * duration):
        pad_width = int(sr * duration) - len(audio)
        audio = np.pad(audio, (0, pad_width))

    # Compute Mel spectrogram
    mel_spec = librosa.feature.melspectrogram(y=audio, sr=sr, n_mels=n_mels)

    # Convert to log scale (dB)
    log_mel_spec = librosa.power_to_db(mel_spec)

    return log_mel_spec


In [None]:
import torch
from torch.utils.data import Dataset
import os

class DECRODataset(Dataset):
    """
    PyTorch Dataset for loading and processing the DECRO audio dataset.
    
    Args:
        df (pd.DataFrame): DataFrame containing protocol metadata with 'file_name' and 'key' columns.
        base_path (str): Path to the directory containing the corresponding .wav audio files.
    """
    
    def __init__(self, df, base_path):
        self.df = df
        self.base_path = base_path

    def __len__(self):
        # Returns the total number of samples
        return len(self.df)

    def __getitem__(self, idx):
        """
        Loads and processes one sample given an index.

        Returns:
            feature (Tensor): Log-Mel spectrogram tensor of shape (1, n_mels, time_frames)
            label (int): Binary label (1 = spoof, 0 = real)
        """
        # Access the row in the DataFrame
        row = self.df.iloc[idx]

        # Construct full file path
        file_path = os.path.join(self.base_path, row['file_name'] + ".wav")

        # Assign binary label: 1 for spoof, 0 for real
        label = 1 if row['key'] == 'spoof' else 0

        # Extract log-Mel spectrogram features
        feature = extract_log_mel(file_path)

        # Convert to PyTorch tensor and add channel dimension
        feature = torch.tensor(feature).unsqueeze(0).float()

        return feature, label


In [None]:
# Core libraries for file handling and data operations
import os
import pandas as pd
import librosa
import numpy as np
import torch
from torch.utils.data import Dataset  # Utility to create custom datasets for DataLoader

# === Function to load the protocol metadata ===
def load_protocol(file_path):
    # Load the protocol file assuming 5 space-separated columns
    columns = ['speaker_id', 'file_name', 'system_id', 'dash', 'label']
    df = pd.read_csv(file_path, sep=" ", names=columns)
    # We only need filename and label for dataset use
    return df[['file_name', 'label']]

# === Function to extract log-Mel spectrogram features from an audio file ===
def extract_features(file_path, sample_rate=16000, duration=3, n_mels=128):
    try:
        # Load audio at given sample rate
        signal, sr = librosa.load(file_path, sr=sample_rate)

        # Pad or truncate to get fixed-duration audio
        required_length = sample_rate * duration
        if len(signal) > required_length:
            signal = signal[:required_length]
        else:
            signal = np.pad(signal, (0, required_length - len(signal)))

        # Compute Mel spectrogram
        mel_spec = librosa.feature.melspectrogram(y=signal, sr=sr, n_mels=n_mels)

        # Convert power spectrogram to decibel scale (log-Mel)
        log_mel_spec = librosa.power_to_db(mel_spec)

        # Convert to 3D tensor [1, n_mels, time] suitable for CNN input
        return torch.tensor(log_mel_spec).unsqueeze(0).float()

    except Exception as e:
        # Return a zero tensor in case of error (e.g., corrupted file)
        print(f"Error processing {file_path}: {e}")
        return torch.zeros((1, n_mels, int(required_length//512)))  # fallback shape

# === Custom PyTorch Dataset class for DECRO Audio Deepfake detection ===
class DeepfakeDataset(Dataset):
    def __init__(self, protocol_file, audio_dir, sample_rate=16000, duration=3, n_mels=128):
        # Load protocol DataFrame with file names and labels
        self.df = load_protocol(protocol_file)
        self.audio_dir = audio_dir

        # Audio preprocessing parameters
        self.sample_rate = sample_rate
        self.duration = duration
        self.n_mels = n_mels

        # Convert label strings to numeric: bonafide → 0, spoof → 1
        self.label_map = {'bonafide': 0, 'spoof': 1}

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        # Get one row of metadata
        row = self.df.iloc[idx]

        # Build full path to the .wav file
        wav_path = os.path.join(self.audio_dir, row['file_name'] + ".wav")

        # Convert string label to numeric
        label = self.label_map.get(row['label'], 0)

        # Extract log-Mel spectrogram features
        features = extract_features(
            wav_path,
            sample_rate=self.sample_rate,
            duration=self.duration,
            n_mels=self.n_mels
        )

        return features, torch.tensor(label)


In [None]:
# === Creating PyTorch Dataset objects for the English subset of the DECRO dataset ===

# Initialize training dataset using the protocol file and corresponding audio directory
train_dataset_english = DeepfakeDataset(
    "/home/kartik-dua/python_project/DECRO/inner/en_train.txt",       # Path to training protocol file
    "/home/kartik-dua/python_project/DECRO/inner/wav/en_train"        # Path to training audio (.wav) files
)

# Initialize development (validation) dataset
development_dataset_english = DeepfakeDataset(
    "/home/kartik-dua/python_project/DECRO/inner/en_dev.txt",         # Path to dev protocol file
    "/home/kartik-dua/python_project/DECRO/inner/wav/en_dev"          # Path to dev audio files
)

# Initialize evaluation (test) dataset
evaluation_dataset_english = DeepfakeDataset(
    "/home/kartik-dua/python_project/DECRO/inner/en_eval.txt",        # Path to eval protocol file
    "/home/kartik-dua/python_project/DECRO/inner/wav/en_eval"         # Path to eval audio files
)


In [None]:
# === Testing the Dataset with a single sample ===

# Retrieve the first sample (index 0) from the training dataset
sample_feature, sample_label = train_dataset_english[0]

# Print the shape of the extracted features
# Expected shape: [1, 128, time]
#   - 1: Channel dimension (used for CNNs, treated like grayscale image)
#   - 128: Number of Mel frequency bins
#   - time: Number of time steps (depends on duration and hop size used by librosa)
print("Feature shape:", sample_feature.shape)

# Print the label for this sample
# Expected: 0 for 'bonafide', 1 for 'spoof'
print("Label:", sample_label)


Feature shape: torch.Size([1, 128, 94])
Label: tensor(1)


In [None]:
# === SiameseDeepfakeDataset: Prepares pairs of audio samples for Siamese Network training ===

import numpy as np
import torch
from torch.utils.data import Dataset

class SiameseDeepfakeDataset(Dataset):
    def __init__(self, protocol_file, audio_dir, sample_rate=16000, duration=3, n_mels=128):
        # Initialize the base DeepfakeDataset for feature extraction
        self.base_dataset = DeepfakeDataset(
            protocol_file=protocol_file,
            audio_dir=audio_dir,
            sample_rate=sample_rate,
            duration=duration,
            n_mels=n_mels
        )

        # Precompute indices for each label to efficiently sample pairs
        # Dictionary: {0: indices of bonafide samples, 1: indices of spoofed samples}
        self.indices_by_label = {0: [], 1: []}
        for idx in range(len(self.base_dataset)):
            _, label = self.base_dataset[idx]
            self.indices_by_label[label.item()].append(idx)

    def __len__(self):
        # Return total number of samples in the base dataset
        return len(self.base_dataset)

    def __getitem__(self, idx):
        # Get the first sample (anchor)
        feature1, label1 = self.base_dataset[idx]

        # Decide randomly whether to create a pair from the same class or different class
        same_class = np.random.rand() < 0.5

        if same_class:
            # Choose a second index with the same label (same class pair)
            indices = self.indices_by_label[label1.item()]
        else:
            # Choose a second index from the opposite label (different class pair)
            other_label = 1 - label1.item()
            indices = self.indices_by_label[other_label]

        # Safety check: if selected label set is empty (edge case), fallback to random sample
        if len(indices) == 0:
            idx2 = np.random.randint(0, len(self.base_dataset))
        else:
            # Randomly choose another sample index from the selected class
            idx2 = np.random.choice(indices)

        # Get the second sample (positive or negative depending on pair type)
        feature2, label2 = self.base_dataset[idx2]

        # Define the similarity target
        #   0 → same class (positive pair)
        #   1 → different class (negative pair)
        target = 0.0 if label1.item() == label2.item() else 1.0

        return feature1, feature2, torch.tensor([target], dtype=torch.float32)


In [None]:
from torch.utils.data import DataLoader

# === Prepare the Siamese Dataset DataLoader ===

# Create the dataset using the protocol and audio path for the English training set
siamese_train_dataset = SiameseDeepfakeDataset(
    "/home/kartik-dua/python_project/DECRO/inner/en_train.txt",  # Protocol file with labels
    "/home/kartik-dua/python_project/DECRO/inner/wav/en_train"   # Corresponding .wav files directory
)

# Wrap the dataset in a DataLoader for batching and shuffling during training
train_loader = DataLoader(
    siamese_train_dataset,  # The dataset to sample from
    batch_size=16,          # Number of pairs per batch
    shuffle=True            # Shuffle the dataset each epoch for better generalization
)


In [None]:
import torch                        # Core PyTorch library
import torch.nn as nn               # Neural Network Modules
import torch.nn.functional as F     # Functional Utilities(ReLU,pooling,distance,etc.)
import torch.optim as optim         # Optimizers like Adam
from torch.utils.data import DataLoader, Dataset    # Data Loading Utilities
import numpy as np                  # For random choice and numerical operations

# ----------------------------
# Hyperparameters & Settings
# ----------------------------
BATCH_SIZE = 32             # Number of pairs per batch
NUM_EPOCHS = 40             # Total training cycles
LEARNING_RATE = 0.0005      # Learning Rate for Optimizer
MARGIN = 2.0                # Margin for Contrastive Loss
INPUT_CHANNELS = 1          # Spectograms are 1-channel
INPUT_HEIGHT = 128          # Height of Spectograms (mel bands)
INPUT_WIDTH = 128           # Width (time dimension , should be fixed by user)
EMBEDDING_SIZE = 256        # Size of output vector from each image
# ----------------------------
# Residual Block Definition
# ----------------------------
class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1):
        super(ResidualBlock, self).__init__()
        # First convolution
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3,
                               stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        # Second convolution
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3,
                               stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)
        # Shortcut connection to match dimensions if needed
        self.shortcut = nn.Sequential()
        if stride != 1 or in_channels != out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1,
                          stride=stride, bias=False),
                nn.BatchNorm2d(out_channels)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        return F.relu(out)
# ----------------------------
# Squeeze-and-Excitation Block
# ----------------------------
class SEBlock(nn.Module):
    def __init__(self, channels, reduction=16):
        super(SEBlock, self).__init__()
        self.fc1 = nn.Linear(channels, channels // reduction, bias=False)
        self.fc2 = nn.Linear(channels // reduction, channels, bias=False)

    def forward(self, x):
        b, c, _, _ = x.size()
        # Squeeze: Global Average Pooling
        y = F.adaptive_avg_pool2d(x, 1).view(b, c)
        # Excitation: Fully connected layers with ReLU and Sigmoid activations
        y = F.relu(self.fc1(y))
        y = torch.sigmoid(self.fc2(y)).view(b, c, 1, 1)
        return x * y
# ----------------------------
# Advanced Siamese CNN Model
# ----------------------------
class RefinedSiameseCNN(nn.Module):
    def __init__(self, embedding_size=EMBEDDING_SIZE):
        super(RefinedSiameseCNN, self).__init__()

        self.initial_conv = nn.Sequential(
            nn.Conv2d(INPUT_CHANNELS, 32, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU()
        )

        self.layer1 = self._make_layer(32, 64, num_blocks=2, stride=2)
        self.layer2 = self._make_layer(64, 128, num_blocks=2, stride=2)
        self.layer3 = self._make_layer(128, 256, num_blocks=2, stride=2)
        self.se = SEBlock(256)

        # Delay FC layer creation — will define it after calculating shape
        self.fc = None
        self.embedding_size = embedding_size

    def _make_layer(self, in_channels, out_channels, num_blocks, stride):
        layers = [ResidualBlock(in_channels, out_channels, stride)]
        for _ in range(1, num_blocks):
            layers.append(ResidualBlock(out_channels, out_channels))
        return nn.Sequential(*layers)

    def forward_features(self, x):
        out = self.initial_conv(x)
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.se(out)
        return out

    def forward_once(self, x):
        out = self.forward_features(x)
        out = out.view(out.size(0), -1)

        # Create FC layer on first forward pass
        if self.fc is None:
            conv_output_size = out.size(1)
            self.fc = nn.Sequential(
                nn.Linear(conv_output_size, 512),
                nn.ReLU(),
                nn.Linear(512, self.embedding_size)
            )
            # Move to same device
            self.fc.to(out.device)

        embedding = self.fc(out)
        return embedding

    def forward(self, input1, input2):
        output1 = self.forward_once(input1)
        output2 = self.forward_once(input2)
        return output1, output2
# ----------------------------
# Contrastive Loss Definition
# ----------------------------
import torch.nn as nn
class ContrastiveLoss(nn.Module):
    def __init__(self, margin=MARGIN):
        super(ContrastiveLoss, self).__init__()
        self.margin = margin

    def forward(self, output1, output2, label):
        euclidean_distance = F.pairwise_distance(output1, output2)
        loss = torch.mean((1 - label) * torch.pow(euclidean_distance, 2) +
                          (label) * torch.pow(torch.clamp(self.margin - euclidean_distance, min=0.0), 2))
        return loss
# ----------------------------
# Training Loop Function
# ----------------------------
def train(model, dataloader, criterion, optimizer, device, num_epochs=NUM_EPOCHS):
    model.train()
    for epoch in range(num_epochs):
        running_loss = 0.0
        for batch_idx, (img1, img2, label) in enumerate(dataloader):
            img1, img2, label = img1.to(device), img2.to(device), label.to(device)
            optimizer.zero_grad()
            output1, output2 = model(img1, img2)
            loss = criterion(output1, output2, label)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
        avg_loss = running_loss / len(dataloader)
        print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {avg_loss:.4f}")
        
# ----------------------------
# Example Main Function (Dataset part to be added later)
# ----------------------------
def main():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Use your real data loader here
    dataloader = train_loader  # Make sure train_loader is defined above

    # Instantiate the refined Siamese model, loss, and optimizer.
    model = RefinedSiameseCNN(embedding_size=EMBEDDING_SIZE).to(device)
    criterion = ContrastiveLoss(margin=MARGIN)
    optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

    # Train the model
    train(model, dataloader, criterion, optimizer, device, num_epochs=NUM_EPOCHS)


In [None]:
# Entry point for the script.
# This conditional ensures that the main training loop is only executed
# when this script is run directly, and not when it is imported as a module.
if __name__ == "__main__":
    main()

Epoch [1/20], Loss: 1.0131
Epoch [2/20], Loss: 1.0050
Epoch [3/20], Loss: 1.0023
Epoch [4/20], Loss: 0.9950
Epoch [5/20], Loss: 0.9922
Epoch [6/20], Loss: 0.9875
Epoch [7/20], Loss: 0.9877
Epoch [8/20], Loss: 0.9902
Epoch [9/20], Loss: 0.9874
Epoch [10/20], Loss: 0.9865
Epoch [11/20], Loss: 0.9846
Epoch [12/20], Loss: 0.9886
Epoch [13/20], Loss: 0.9856
Epoch [14/20], Loss: 0.9864
Epoch [15/20], Loss: 0.9851
Epoch [16/20], Loss: 0.9827
Epoch [17/20], Loss: 0.9818
Epoch [18/20], Loss: 0.9832
Epoch [19/20], Loss: 0.9849
Epoch [20/20], Loss: 0.9836


In [None]:
# Instantiate the development (validation) dataset and dataloader.
# This will be used to monitor model performance during training (optional).
siamese_dev_dataset = SiameseDeepfakeDataset(
    "/home/kartik-dua/python_project/DECRO/inner/en_dev.txt",     # Path to development label file
    "/home/kartik-dua/python_project/DECRO/inner/wav/en_dev"      # Path to corresponding WAV files
)
dev_loader = DataLoader(
    siamese_dev_dataset,
    batch_size=16,    # Number of sample pairs per batch
    shuffle=True      # Enable shuffling for generalization
)

# Instantiate the evaluation (test) dataset and dataloader.
# This will be used for final model evaluation after training.
siamese_eval_dataset = SiameseDeepfakeDataset(
    "/home/kartik-dua/python_project/DECRO/inner/en_eval.txt",    # Path to evaluation label file
    "/home/kartik-dua/python_project/DECRO/inner/wav/en_eval"     # Path to corresponding WAV files
)
eval_loader = DataLoader(
    siamese_eval_dataset,
    batch_size=16,    # Number of sample pairs per batch
    shuffle=True      # Shuffle may be used here for batch variability
)


In [None]:
# ----------------------------
# Device Setup and Model Initialization
# ----------------------------

# Set the computation device — GPU if available, otherwise fallback to CPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Instantiate the refined Siamese CNN model and move it to the selected device
model = RefinedSiameseCNN(embedding_size=EMBEDDING_SIZE).to(device)

# Define the contrastive loss function with a specified margin
criterion = ContrastiveLoss(margin=MARGIN)

# Use the Adam optimizer to update model parameters
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)


# ----------------------------
# Validation Function
# ----------------------------

def validate(model, dataloader, criterion, device, threshold=1.0):
    """
    Evaluate the Siamese model on a validation or evaluation dataset.

    Args:
        model (nn.Module): Trained Siamese network.
        dataloader (DataLoader): DataLoader providing pairs of inputs and labels.
        criterion (nn.Module): Contrastive loss function.
        device (torch.device): Device on which to perform computation.
        threshold (float): Distance threshold to classify a pair as "different" (1) or "same" (0).

    Returns:
        Tuple[float, float]: Average validation loss and accuracy.
    """
    model.eval()  # Set model to evaluation mode (disables dropout/batchnorm updates)
    total_loss = 0.0
    total_correct = 0
    total_samples = 0

    with torch.no_grad():  # Disable gradient tracking for validation
        for x1, x2, labels in dataloader:
            # Move inputs and labels to the target device
            x1, x2, labels = x1.to(device), x2.to(device), labels.to(device)

            # Flatten labels from shape [batch_size, 1] to [batch_size]
            labels = labels.view(-1)

            # Forward pass through the model
            output1, output2 = model(x1, x2)

            # Compute contrastive loss
            loss = criterion(output1, output2, labels)
            total_loss += loss.item()

            # Compute Euclidean distances between embeddings
            distances = F.pairwise_distance(output1, output2)

            # Predict: same (0) if distance < threshold, different (1) if >= threshold
            preds = (distances >= threshold).float()

            # Compare predictions with actual labels
            total_correct += (preds == labels).float().sum().item()
            total_samples += labels.size(0)

    # Calculate average loss and accuracy
    avg_loss = total_loss / len(dataloader)
    accuracy = total_correct / total_samples if total_samples > 0 else 0

    print(f"Validation Loss: {avg_loss:.4f}, Accuracy: {accuracy * 100:.2f}%")
    return avg_loss, accuracy

In [None]:
# ----------------------------
# Run Evaluation on the Evaluation Set
# ----------------------------

# Evaluate the trained Siamese model using the evaluation dataset.
# A lower threshold (e.g., 0.8) can make the model more sensitive to subtle differences.
validate(
    model=model,
    dataloader=eval_loader,
    criterion=criterion,
    device=device,
    threshold=0.8  # Distance threshold for classification: same vs different
)

Validation Loss: 0.9843, Accuracy: 62.36%


(0.984286413192749, 0.6235539343408025)