# **This is the complete code I used to implement the contrastive loss for the ASV including the code for the heatmap for the weights. It yielded an EER = 0.34% but almost all the weights were concentrated in layer 10.**

# **Adding the contrastive loss required a change in the ResNetClassifier.forward(), Train(), and Test() functions. I used the MemmapDataset class for loading the dataset from .npy files. You can use the LayerFeatureDataset class on your pkl file.**

In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
from torch.utils.data import DataLoader, random_split, Dataset
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_curve, auc
import seaborn as sns
import matplotlib.pyplot as plt


class LayerFeatureDataset(Dataset):
    def __init__(self, pkl_path):
        data = pickle.load(open(pkl_path, "rb"))
        self.features = torch.tensor(data["features"], dtype=torch.float32)
        self.labels   = torch.tensor(data["labels"],   dtype=torch.long)
    def __len__(self):
        return len(self.labels)
    def __getitem__(self, idx):
        return self.features[idx], self.labels[idx]

class MemmapDataset(Dataset):
    def __init__(self, feat_path, lab_path):
        self.X = np.load(feat_path, mmap_mode="r")  # (N, L, D)
        self.y = np.load(lab_path,   mmap_mode="r")  # (N,)

    def __len__(self):
        return len(self.y)

    def __getitem__(self, idx):
        x = torch.tensor(self.X[idx], dtype=torch.float32)  # (L, D)
        y = int(self.y[idx])
        return x, y

# Models
class LayerWeightedAggregator(nn.Module):
    def __init__(self, num_layers):
        super().__init__()
        self.w = nn.Parameter(torch.ones(num_layers) / num_layers)
    def forward(self, x):
        # x: (batch, L, D)
        w = torch.softmax(self.w, dim=0)               # (L,)
        return (x * w[None, :, None]).sum(dim=1)       # (batch, D)

class ResNetClassifier(nn.Module):
    def __init__(self, num_layers, hidden_dim, num_classes):
        super().__init__()
        self.agg = LayerWeightedAggregator(num_layers)
        H = W = int(np.sqrt(hidden_dim))
        assert H * W == hidden_dim
        self.H, self.W = H, W

        self.resnet = models.resnet18(pretrained=True)
        self.resnet.conv1 = nn.Conv2d(1,
            self.resnet.conv1.out_channels,
            kernel_size=self.resnet.conv1.kernel_size,
            stride=self.resnet.conv1.stride,
            padding=self.resnet.conv1.padding,
            bias=False)
        self.resnet.fc = nn.Linear(self.resnet.fc.in_features, num_classes)

    def forward(self, x):
        z = self.agg(x)                        # (batch, D)
        b = z.size(0)
        img = z.view(b, 1, self.H, self.W)
        logits = self.resnet(img)
        return z, logits

# Contrastive Loss
def contrastive_loss(z1, z2, pair_labels, margin=1.0):
    # z1, z2: (batch, D); pair_labels: 0=same, 1=different
    z1 = F.normalize(z1, dim=1)
    z2 = F.normalize(z2, dim=1)
    dist = F.pairwise_distance(z1, z2)              # (batch,)
    pos = (1 - pair_labels) * 0.5 * dist**2
    neg = (    pair_labels) * 0.5 * torch.clamp(margin - dist, min=0.0)**2
    return (pos + neg).mean()


# Joint Training Loop
def Train(model, optimizer, ce_criterion, train_loader, val_loader,
          num_epochs=50, alpha=0.5, margin=1.0):
    eer_per_epoch = []
    model.to(device)

    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0

        for feats, labels in train_loader:
            feats, labels = feats.to(device), labels.to(device)

            # Forward
            z, logits = model(feats)
            loss_ce = ce_criterion(logits, labels)

            # Sample paired batch via random permutation
            idx = torch.randperm(feats.size(0))
            z2, labels2 = z[idx], labels[idx]
            pair_lbl = (labels != labels2).float()

            # Contrastive loss on embeddings
            loss_con = contrastive_loss(z, z2, pair_lbl, margin)

            # Combined loss
            loss = loss_ce + alpha * loss_con

            # Backprop
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            running_loss += loss.item()

        avg_loss = running_loss / len(train_loader)

        # Validation: purely classification EER
        model.eval()
        all_labels, all_probs = [], []
        with torch.no_grad():
            for feats, labels in val_loader:
                feats, labels = feats.to(device), labels.to(device)
                z, logits = model(feats)
                probs = torch.softmax(logits, dim=1)[:, 1]
                all_labels.extend(labels.cpu().numpy())
                all_probs.extend(probs.cpu().numpy())

        all_labels = np.array(all_labels)
        all_probs  = np.array(all_probs)
        fpr, tpr, _ = roc_curve(all_labels, all_probs)
        fnr = 1 - tpr
        eer = fpr[np.nanargmin(np.abs(fpr - fnr))]
        eer_per_epoch.append(eer)

        print(f"Epoch {epoch+1}/{num_epochs}, Loss: {avg_loss:.4f}, Val EER: {eer:.4f}")

    return eer_per_epoch

def Test(model, test_loader, name="Model"):
    model.eval()
    all_labels = []
    all_preds  = []
    all_probs  = []

    with torch.no_grad():
        for features, labels in test_loader:
            features, labels = features.to(device), labels.to(device)
            emb, outputs = model(features)

            probs = torch.softmax(outputs, dim=1)[:, 1]
            preds = torch.argmax(outputs, dim=1)

            all_labels.extend(labels.cpu().numpy())
            all_preds.extend(preds.cpu().numpy())
            all_probs.extend(probs.cpu().numpy())

    all_labels = np.array(all_labels)
    all_preds  = np.array(all_preds)
    all_probs  = np.array(all_probs)

    # Classification metrics
    accuracy  = accuracy_score(all_labels, all_preds)
    precision = precision_score(all_labels, all_preds)
    recall    = recall_score(all_labels, all_preds)
    f1        = f1_score(all_labels, all_preds)

    # Manual FPR/FNR
    TP = np.sum((all_preds == 1) & (all_labels == 1))
    TN = np.sum((all_preds == 0) & (all_labels == 0))
    FP = np.sum((all_preds == 1) & (all_labels == 0))
    FN = np.sum((all_preds == 0) & (all_labels == 1))

    fpr_manual = FP / (FP + TN) if (FP + TN) > 0 else 0.0
    fnr_manual = FN / (FN + TP) if (FN + TP) > 0 else 0.0

    # ROC & EER
    fpr, tpr, thresholds = roc_curve(all_labels, all_probs)
    fnr = 1 - tpr
    eer_index = np.nanargmin(np.abs(fpr - fnr))
    eer = fpr[eer_index]
    roc_auc = auc(fpr, tpr)

    # Print results
    print(f"=== Evaluation Metrics: {name} ===")
    print(f"Accuracy : {accuracy:.4f}")
    print(f"Precision: {precision:.4f}")
    print(f"Recall   : {recall:.4f}")
    print(f"F1 Score : {f1:.4f}")
    print(f"FPR      : {fpr_manual:.4f}")
    print(f"FNR      : {fnr_manual:.4f}")
    print(f"EER      : {eer:.4f}")
    print(f"AUC      : {roc_auc:.4f}")

    # Plot ROC
    plt.figure(figsize=(8,6))
    plt.plot(fpr, tpr, label=f'ROC (AUC = {roc_auc:.4f})')
    plt.plot([0,1], [0,1], '--', label='Random')
    plt.scatter(fpr[eer_index], tpr[eer_index], color='red',
                label=f'EER = {eer:.4f}')
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title(f'{name} ROC Curve')
    plt.legend()
    plt.grid(True)
    plt.show()

###################################################################

# Begin

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load dataset
# use the LayerFeatureDataset class on a pkl features file: ds = LayerFeatureDataset(features_pkl)
feat_path = "/content/features.npy"
lab_path  = "/content/labels.npy"
ds = MemmapDataset(feat_path, lab_path)



# Train/Val/Test split (80/10/10)
n = len(ds)
train_len = int(0.8 * n)
val_len   = int(0.1 * n)
test_len  = n - train_len - val_len
train_ds, val_ds, test_ds = random_split(ds, [train_len, val_len, test_len])

# DataLoaders
train_loader = DataLoader(train_ds, batch_size=64, shuffle=True)
val_loader   = DataLoader(val_ds,   batch_size=64, shuffle=False)
test_loader  = DataLoader(test_ds,  batch_size=64, shuffle=False)

# Info for model config
num_layers  = ds.X.shape[1]
hidden_dim  = ds.X.shape[2]
num_classes = len(np.unique(ds.y))


model = ResNetClassifier(num_layers, hidden_dim, num_classes)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
ce_criterion = nn.CrossEntropyLoss()

eer_history = Train(
    model, optimizer, ce_criterion,
    train_loader, val_loader,
    num_epochs=50,
    alpha=0.5,       # weight for contrastive loss
    margin=1.0       # contrastive margin
)

####################################################################

# Testing

Test(model, test_loader, name="ResNet18 Classifier")

###################################################################

# Weights Heatmap
w = model.agg.w.detach().cpu().numpy()

# Apply softmax to the data to convert to probabilities
def softmax(x):
    exp_x = np.exp(x - np.max(x))  # Subtract max(x) for numerical stability
    return exp_x / np.sum(exp_x)

# Convert the raw data to probabilities via softmax
data_softmax = softmax(np.array(w))

# Reshape the data into a 1x24 matrix
data_reshaped = data_softmax.reshape(1, -1)

# Plot heatmap with the blue-to-red color scale and softmax probabilities
plt.figure(figsize=(24, 1))  # Make the plot wide and short to fit the line of squares
sns.heatmap(data_reshaped, annot=False, cmap='coolwarm', cbar=True, square=True, linewidths=0.5)

# Adjust x-axis to start from 1 and remove the default labels
plt.xticks(np.arange(0.5, 24.5), np.arange(1, 25), rotation=0)

plt.show()