In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
from sklearn.decomposition import PCA
import matplotlib.pyplot as plt
import random

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


In [2]:
def load_mnist():
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ])
    dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
    data = torch.stack([d[0] for d in dataset])  # [N, 1, 28, 28]
    labels = torch.tensor([d[1] for d in dataset])
    return data.to(device), labels.to(device)

def contrastive_loss(x1, x2, label, margin=1.0):
    dist = F.pairwise_distance(x1, x2)
    return (label * dist.pow(2) + (1 - label) * F.relu(margin - dist).pow(2)).mean()

def triplet_loss(anchor, pos, neg, margin=1.0):
    d_pos = F.pairwise_distance(anchor, pos)
    d_neg = F.pairwise_distance(anchor, neg)
    return F.relu(d_pos - d_neg + margin).mean()



In [4]:
class ConvNetEmbedder(nn.Module):
    """
    A convolutional neural network for learning embeddings.
    """
    def __init__(self, input_channels=3, input_height=32, input_width=32):
        super().__init__()
        self.features = nn.Sequential(
            nn.Conv2d(input_channels, 32, 3, 1, 1),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 64, 3, 1, 1),
            nn.ReLU(),
            nn.MaxPool2d(2),
        )

        # Compute correct flattened size using actual height & width
        with torch.no_grad():
            # Create a dummy tensor to pass through the feature extractor
            dummy = torch.zeros(1, input_channels, input_height, input_width)
            out = self.features(dummy)
            self.flattened_size = out.view(1, -1).size(1)

        self.fc = nn.Sequential(
            nn.Linear(self.flattened_size, 128),
            nn.ReLU(),
            nn.Linear(128, 64), # Output embedding dimension is 64
        )

    def forward(self, x):
        """
        Forward pass through the network.

        Args:
            x (torch.Tensor): Input tensor.

        Returns:
            torch.Tensor: The normalized embedding.
        """
        x = self.features(x)
        x = x.view(x.size(0), -1)  # Flatten the tensor
        x = self.fc(x)
        return F.normalize(x, dim=-1) # L2 normalize the embeddings

# ----------------------
# Losses
# ----------------------
def contrastive_loss(x1, x2, label, margin=1.0):
    """
    Computes the contrastive loss.

    Args:
        x1 (torch.Tensor): Embedding of the first sample.
        x2 (torch.Tensor): Embedding of the second sample.
        label (torch.Tensor): Label indicating if the pair is similar (1) or dissimilar (0).
        margin (float): Margin for dissimilar pairs.

    Returns:
        torch.Tensor: The computed contrastive loss.
    """
    dist = F.pairwise_distance(x1, x2)
    # loss = label * dist^2 + (1 - label) * max(0, margin - dist)^2
    return (label * dist.pow(2) + (1 - label) * F.relu(margin - dist).pow(2)).mean()

def triplet_loss(anchor, pos, neg, margin=1.0):
    """
    Computes the triplet loss.

    Args:
        anchor (torch.Tensor): Embedding of the anchor sample.
        pos (torch.Tensor): Embedding of the positive sample.
        neg (torch.Tensor): Embedding of the negative sample.
        margin (float): Margin for the triplet loss.

    Returns:
        torch.Tensor: The computed triplet loss.
    """
    d_pos = F.pairwise_distance(anchor, pos)
    d_neg = F.pairwise_distance(anchor, neg)
    # loss = max(0, d_pos - d_neg + margin)
    return F.relu(d_pos - d_neg + margin).mean()

# ----------------------
# Training Loop
# ----------------------
def train_embedding_model(data, labels, loss_type='triplet', epochs=20, batch_size=256):
    """
    Trains the embedding model using either contrastive or triplet loss.

    Args:
        data (torch.Tensor): The input data.
        labels (torch.Tensor): The corresponding labels.
        loss_type (str): Type of loss to use ('triplet' or 'contrastive').
        epochs (int): Number of training epochs.
        batch_size (int): Batch size for training.

    Returns:
        nn.Module: The trained embedding model in evaluation mode.
    """
    channels, height, width = data.size(1), data.size(2), data.size(3)
    # Fix: Pass input_height and input_width explicitly
    model = ConvNetEmbedder(input_channels=channels, input_height=height, input_width=width).to(device).train()
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
    n = len(data)

    print(f"Starting training with {loss_type} loss for {epochs} epochs...")
    for epoch in range(epochs):
        perm = torch.randperm(n)
        total_loss = 0.0
        num_batches = 0
        for i in range(0, n, batch_size):
            batch_idx = perm[i:i + batch_size]
            x_batch, y_batch = data[batch_idx], labels[batch_idx]
            
            # Ensure the model is in training mode during this loop
            model.train() 
            x_embed = model(x_batch)

            optimizer.zero_grad()
            current_loss = torch.tensor(0.0).to(device) # Initialize current_loss for the batch

            if loss_type == 'triplet':
                anchors, positives, negatives = [], [], []
                # Collect triplets within the current batch
                for j in range(len(x_batch)):
                    anchor_label = y_batch[j].item()
                    
                    # Find potential positive samples within the batch
                    pos_candidates_in_batch = (y_batch == anchor_label).nonzero(as_tuple=True)[0]
                    # Exclude the anchor itself
                    pos_candidates_in_batch = pos_candidates_in_batch[pos_candidates_in_batch != j]

                    # Find potential negative samples within the batch
                    neg_candidates_in_batch = (y_batch != anchor_label).nonzero(as_tuple=True)[0]

                    if len(pos_candidates_in_batch) > 0 and len(neg_candidates_in_batch) > 0:
                        pos_j = random.choice(pos_candidates_in_batch).item()
                        neg_j = random.choice(neg_candidates_in_batch).item()

                        anchors.append(x_embed[j])
                        positives.append(x_embed[pos_j])
                        negatives.append(x_embed[neg_j])
                
                if anchors:
                    current_loss = triplet_loss(torch.stack(anchors), torch.stack(positives), torch.stack(negatives))
                else:
                    # If no valid triplets found in batch, skip this batch for loss calculation
                    continue

            elif loss_type == 'contrastive':
                a, b, l = [], [], []
                for j in range(len(x_batch)):
                    current_label = y_batch[j].item()
                    
                    # 50% chance for positive pair, 50% for negative
                    if random.random() < 0.5: # Positive pair
                        pos_candidates_in_batch = (y_batch == current_label).nonzero(as_tuple=True)[0]
                        pos_candidates_in_batch = pos_candidates_in_batch[pos_candidates_in_batch != j]
                        if len(pos_candidates_in_batch) > 0:
                            k = random.choice(pos_candidates_in_batch).item()
                            a.append(x_embed[j])
                            b.append(x_embed[k])
                            l.append(1.0)
                    else: # Negative pair
                        neg_candidates_in_batch = (y_batch != current_label).nonzero(as_tuple=True)[0]
                        if len(neg_candidates_in_batch) > 0:
                            k = random.choice(neg_candidates_in_batch).item()
                            a.append(x_embed[j])
                            b.append(x_embed[k])
                            l.append(0.0)
                
                if a: # If pairs were generated
                    current_loss = contrastive_loss(torch.stack(a), torch.stack(b), torch.tensor(l).to(device))
                else:
                    # If no valid pairs found in batch, skip this batch for loss calculation
                    continue
            
            if current_loss > 0: # Only backpropagate if there's an actual loss
                current_loss.backward()
                optimizer.step()
                total_loss += current_loss.item()
                num_batches += 1
        
        avg_loss = total_loss / num_batches if num_batches > 0 else 0
        print(f"Epoch {epoch+1}/{epochs}, Average Loss: {avg_loss:.4f}")

    return model.eval() # Return model in evaluation mode


In [9]:
data1, labels1 = load_mnist()

model_con1 = train_embedding_model(data1, labels1, loss_type='contrastive')
data_c1_list = []
with torch.no_grad():
    for i in range(0, len(data1), 256): 
        batch_data = data1[i:i+256]
        data_c1_list.append(model_con1(batch_data).detach().cpu())
data_c1 = torch.cat(data_c1_list).to(device) 

model_tri1 = train_embedding_model(data1, labels1, loss_type='triplet')
data_t1_list = []
with torch.no_grad():
    for i in range(0, len(data1), 256):
        batch_data = data1[i:i+256]
        data_t1_list.append(model_tri1(batch_data).detach().cpu())
data_t1 = torch.cat(data_t1_list).to(device)


Starting training with contrastive loss for 20 epochs...
Epoch 1/20, Average Loss: 0.0575
Epoch 2/20, Average Loss: 0.0191
Epoch 3/20, Average Loss: 0.0136
Epoch 4/20, Average Loss: 0.0099
Epoch 5/20, Average Loss: 0.0087
Epoch 6/20, Average Loss: 0.0075
Epoch 7/20, Average Loss: 0.0068
Epoch 8/20, Average Loss: 0.0063
Epoch 9/20, Average Loss: 0.0057
Epoch 10/20, Average Loss: 0.0053
Epoch 11/20, Average Loss: 0.0049
Epoch 12/20, Average Loss: 0.0047
Epoch 13/20, Average Loss: 0.0034
Epoch 14/20, Average Loss: 0.0043
Epoch 15/20, Average Loss: 0.0035
Epoch 16/20, Average Loss: 0.0029
Epoch 17/20, Average Loss: 0.0031
Epoch 18/20, Average Loss: 0.0024
Epoch 19/20, Average Loss: 0.0024
Epoch 20/20, Average Loss: 0.0044
Starting training with triplet loss for 20 epochs...
Epoch 1/20, Average Loss: 0.1056
Epoch 2/20, Average Loss: 0.0283
Epoch 3/20, Average Loss: 0.0200
Epoch 4/20, Average Loss: 0.0154
Epoch 5/20, Average Loss: 0.0131
Epoch 6/20, Average Loss: 0.0110
Epoch 7/20, Average L

In [10]:
from metric import compute_intra_class_variance, compute_inter_class_variance

variances_original = compute_intra_class_variance(data1, labels1)
print(variances_original)

variances_contrastive = compute_intra_class_variance(data_c1, labels)
print(variances_contrastive)

variances_triplet = compute_intra_class_variance(data_t1, labels)
print(variances_triplet)


{0: 0.3494914770126343, 1: 0.16978716850280762, 2: 0.3682921826839447, 3: 0.33302924036979675, 4: 0.3112839162349701, 5: 0.36207902431488037, 6: 0.3122705817222595, 7: 0.2790134847164154, 8: 0.3301646411418915, 9: 0.2867642045021057}
{0: 0.000799543980974704, 1: 0.0015676108887419105, 2: 0.003525828942656517, 3: 0.0022698622196912766, 4: 0.002072482369840145, 5: 0.0033548111096024513, 6: 0.0036530648358166218, 7: 0.003849779022857547, 8: 0.0026101875118911266, 9: 0.0059341671876609325}
{0: 0.006985311396420002, 1: 0.00323380995541811, 2: 0.010066665709018707, 3: 0.004289098549634218, 4: 0.005846464075148106, 5: 0.0072196065448224545, 6: 0.005104381125420332, 7: 0.0041129146702587605, 8: 0.003503313986584544, 9: 0.00861962977796793}


In [12]:
import numpy as np

intra_values = list(variances_contrastive.values())
intra_mean = np.mean(intra_values)
intra_std = np.std(intra_values)
print(f"Intra-class variance - Mean (contrastive): {intra_mean:.4f}, Std: {intra_std:.4f}")
intra_values = list(variances_triplet.values())
intra_mean = np.mean(intra_values)
intra_std = np.std(intra_values)
print(f"Intra-class variance - Mean (triplet): {intra_mean:.4f}, Std: {intra_std:.4f}")

Intra-class variance - Mean (contrastive): 0.0030, Std: 0.0014
Intra-class variance - Mean (triplet): 0.0059, Std: 0.0022


In [16]:
inter_mean, inter_std = compute_inter_class_variance(data_c1, labels)
print(f"Inter-class variance - Mean (Contrastive): {inter_mean:.4f}, Std: {inter_std:.4f}")
inter_mean, inter_std = compute_inter_class_variance(data_t1, labels)
print(f"Inter-class variance - Mean (Triplet): {inter_mean:.4f}, Std: {inter_std:.4f}")


Inter-class variance - Mean (Contrastive): 1.0347, Std: 0.0800
Inter-class variance - Mean (Triplet): 1.4840, Std: 0.0685


In [17]:
0.0800*0.0800, 0.0685*0.0685

(0.0064, 0.004692250000000001)