# Representation Learning Assignment
This assignment covers key topics in representation learning using PyTorch datasets. You will implement tasks for Contrastive Learning, Energy-Based Models. Use torchvision.datasets.MNIST for Exercises 1, 2.

**Total Points: 12**
- Exercise 1: Contrastive Learning (8 points)
- Exercise 2: Energy-Based Models (4 points)

Import necessary libraries and load datasets as needed.


In [1]:
# --- Install dependencies (CPU-only build by default) ---
!pip install torch torchvision --extra-index-url https://download.pytorch.org/whl/cpu
!pip install -q torch_geometric

# --- Imports ---
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchvision import datasets, transforms
from torch.utils.data import Subset, DataLoader

print("Torch:", torch.__version__, "Torchvision:", torchvision.__version__)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

Looking in indexes: https://pypi.org/simple, https://download.pytorch.org/whl/cpu
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m63.1/63.1 kB[0m [31m5.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.1/1.1 MB[0m [31m71.9 MB/s[0m eta [36m0:00:00[0m
[?25hTorch: 2.8.0+cu126 Torchvision: 0.23.0+cu126
Using device: cuda


In [2]:
# --- MNIST (for Exercises 1, 2) ---
mnist_transform = transforms.ToTensor()

mnist_train_full = datasets.MNIST(
    root="./data", train=True, download=True, transform=mnist_transform
)
mnist_test_full = datasets.MNIST(
    root="./data", train=False, download=True, transform=mnist_transform
)

# Smaller subsets for quicker experiments
train_indices = list(range(1000))
test_indices = list(range(200))

mnist_train = Subset(mnist_train_full, train_indices)
mnist_test = Subset(mnist_test_full, test_indices)

print(f"MNIST Train Subset: {len(mnist_train)} samples")
print(f"MNIST Test Subset: {len(mnist_test)} samples")

100%|██████████| 9.91M/9.91M [00:02<00:00, 4.54MB/s]
100%|██████████| 28.9k/28.9k [00:00<00:00, 132kB/s]
100%|██████████| 1.65M/1.65M [00:01<00:00, 1.08MB/s]
100%|██████████| 4.54k/4.54k [00:00<00:00, 8.83MB/s]

MNIST Train Subset: 1000 samples
MNIST Test Subset: 200 samples





## Exercise 1: Contrastive Learning (8 points)

We’ll simulate contrastive learning on MNIST images by computing similarities, pairs, and losses.

### Task 1a (1 point): Cosine Similarity

Compute row-wise cosine similarity between two batches of embeddings.

In [3]:
### Ex-1-Task-1
import torch
import torch.nn.functional as F

def compute_similarity(emb1, emb2):
    """
    emb1, emb2: (B, D)
    Returns: (B,) cosine similarities
    """
    ### BEGIN SOLUTION
    sim = F.cosine_similarity(emb1, emb2, dim=1, eps=1e-8)
    return sim

    ### END SOLUTION

# Test
x = torch.randn(5, 10)
y = x.clone()
print(compute_similarity(x, y))

tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000])


In [None]:
# INTENTIONALLY LEFT BLANK

### Task 1b: Positive & Negative Pairs (2 points)

Generate positive pairs (same sample) and negative pairs (different sample) for a batch.
Simulate negatives by randomly shuffling batch.


In [4]:
### Ex-1-Task-2
def generate_pairs(batch):
    """
    batch: (B, D)
    Returns:
        pos: (B, D)
        neg: (B, D)
    """
    ### BEGIN SOLUTION
    B = batch.size(0)
    pos = batch.clone()
    indices = torch.randperm(B)
    neg = batch[indices]
    return pos, neg
    ### END SOLUTION

# Test
batch = torch.randn(5, 8)
pos, neg = generate_pairs(batch)
print("Positive pairs:\n", pos)
print("Negative pairs:\n", neg)

Positive pairs:
 tensor([[-1.7623, -0.9969,  0.3559, -2.3074,  0.7234,  0.3092,  0.2846, -1.5210],
        [-0.7185,  0.5643,  0.4987,  0.1448, -0.7507, -0.6327, -0.9062, -0.5515],
        [-0.9732,  1.7625,  1.5349, -0.7370, -1.2431,  0.4434, -0.2072,  0.3668],
        [ 0.4893, -2.4248,  1.1761, -0.5551,  1.4226,  1.1877, -1.2180, -0.2018],
        [-0.3001, -1.9517,  0.1267,  0.0627,  0.0985,  0.9352,  1.0425, -1.1303]])
Negative pairs:
 tensor([[-0.7185,  0.5643,  0.4987,  0.1448, -0.7507, -0.6327, -0.9062, -0.5515],
        [-0.9732,  1.7625,  1.5349, -0.7370, -1.2431,  0.4434, -0.2072,  0.3668],
        [-1.7623, -0.9969,  0.3559, -2.3074,  0.7234,  0.3092,  0.2846, -1.5210],
        [ 0.4893, -2.4248,  1.1761, -0.5551,  1.4226,  1.1877, -1.2180, -0.2018],
        [-0.3001, -1.9517,  0.1267,  0.0627,  0.0985,  0.9352,  1.0425, -1.1303]])


In [None]:
# INTENTIONALLY LEFT BLANK

### Task 1c:NT-Xent Loss (2 points)

Implement normalized temperature-scaled cross-entropy loss for a batch.

In [5]:
### Ex-1-Task-3
import torch
import torch.nn.functional as F

def nt_xent_loss(sims, temp=0.5):
    """
    sims: (B, B) similarity matrix
    temp: scalar temperature
    Returns: scalar loss
    """
    B = sims.size(0)
    device = sims.device


    mask = torch.eye(B, dtype=torch.bool, device=device)
    sims = sims.masked_fill(mask, -9e15)


    logits = sims / temp


    targets = torch.arange(B, device=device)
    targets = (targets + B//2) % B


    loss = F.cross_entropy(logits, targets)
    return loss

# Test
sims = torch.randn(3, 3)
print("NT-Xent loss:", nt_xent_loss(sims))

NT-Xent loss: tensor(0.7891)


In [6]:
# Quick visible test with identity similarity
sims = torch.eye(4)  # 4x4 identity matrix
loss = nt_xent_loss(sims)
print("NT-Xent loss:", loss.item())  # Should be > 0


NT-Xent loss: 1.0986123085021973


In [None]:
# INTENTIONALLY LEFT BLANK


### Task 1d: Augment MNIST Image (2 points)

Apply random small rotation ±10° to a single MNIST image.


In [11]:
### Ex-1-Task-4
from torchvision import transforms

augment = transforms.RandomRotation(degrees=10)

def augment_image(img):
    """
    img: (28, 28) tensor
    Returns: augmented image (28,28)
    """
    ### BEGIN SOLUTION

    if img.ndim == 2:
        img = img.unsqueeze(0)

    img = img.float() / 255.0

    aug_img = augment(img)
    return aug_img.squeeze(0)
    ### END SOLUTION

In [12]:
img, label = mnist_train[0]
aug_img = augment_image(img)
print("Original shape:", img.shape)
print("Augmented shape:", aug_img.shape)

Original shape: torch.Size([1, 28, 28])
Augmented shape: torch.Size([28, 28])


In [None]:
# INTENTIONALLY LEFT BLANK

### Task 1e: Contrastive Embedding Distance (1 point)

Compute Euclidean distance between two batches of embeddings row-wise.

In [None]:
### Ex-1-Task-5
import torch

def embedding_distance(emb1, emb2):
    """
    emb1, emb2: (B, D)
    Returns: (B,) Euclidean distances
    """
    ### BEGIN SOLUTION

    diff = emb1 - emb2             # (B, D)
    distances = torch.norm(diff, dim=1)
    return distances
    ### END SOLUTION

# --- Test Example ---
batch_size, dim = 4, 5
emb1 = torch.randn(batch_size, dim)
emb2 = torch.randn(batch_size, dim)

distances = embedding_distance(emb1, emb2)
print("Embeddings 1:\n", emb1)
print("Embeddings 2:\n", emb2)
print("Row-wise Euclidean distances:\n", distances)

Embeddings 1:
 tensor([[-0.0466, -1.7949, -0.2168,  1.0602,  0.5315],
        [-0.9347,  0.6330,  1.2127,  0.1068,  0.0831],
        [ 0.0344, -0.3057,  2.3740, -0.4420,  1.0593],
        [-1.0693, -1.3033,  1.2530,  0.0287,  0.2702]])
Embeddings 2:
 tensor([[-0.4994,  0.0062,  2.1370, -0.7603, -0.7722],
        [-1.2245, -0.6105,  1.3387,  0.1453,  0.7843],
        [-0.0856,  1.5888,  1.3050,  0.3964,  1.4354],
        [ 1.0166,  0.6969, -0.9011,  0.2643,  0.2895]])
Row-wise Euclidean distances:
 tensor([3.7421, 1.4626, 2.3645, 3.6122])


In [None]:
# INTENTIONALLY LEFT BLANK

## Exercise 2: Energy-Based Models (4 points)



### Task 2a (2 points): Define Energy Function

Energy function for binary classification: E(x,y;w)=−y⋅(w⋅x)

In [28]:
### Ex-2-Task-1
def energy_function(x, y, w):
    """
    x: (784,), y: +1/-1, w: (784,)
    Returns: scalar energy
    """
    ### BEGIN SOLUTION
    energy = -y * torch.dot(w, x)
    return energy
    ### END SOLUTION

# Test
x = torch.randn(784)
w = torch.randn(784)
y = 1
print("Energy:", energy_function(x, y, w))

Energy: tensor(-17.2632)


In [29]:
# INTENTIONALLY LEFT BLANK

### Task 2b (2 points): Perceptron Loss (2 point)

Compute perceptron loss: max(0, E(x, y_true) - E(x, y_pred))

In [31]:
### Ex-2-Task-2
def perceptron_loss(x, y_true, y_pred, w):
    """
    x: (784,), y_true, y_pred: +1/-1
    w: (784,)
    Returns: scalar loss
    """
    ### BEGIN SOLUTION

    e_true = energy_function(x, y_true, w)
    e_pred = energy_function(x, y_pred, w)

    loss = torch.relu(e_true - e_pred)
    return loss
    ### END SOLUTION

# Test
print("Perceptron loss:", perceptron_loss(x, 1, -1, w))

Perceptron loss: tensor(0.)


In [None]:
# INTENTIONALLY LEFT BLANK