# 🧠Self-Supervised Learning Exercise


This notebook is based on a self-supervised learning exercise using MNIST and PyTorch.

You will:
- Build and train an **autoencoder** for representation learning.
- Use **2% labeled data** to train a classifier on the learned representations.
- Evaluate the model and visualize latent space.

📝 Some code cells include `# TODO` comments. Complete them as you go!


## 1. Setup

In [None]:
# Import necessary libraries
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader, Subset, random_split
import torchvision
import torchvision.transforms as transforms
from torchvision.datasets import MNIST

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
from sklearn.manifold import TSNE
import warnings
warnings.filterwarnings('ignore')

# Set device and random seeds
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

torch.manual_seed(42)
np.random.seed(42)

# MNIST digit classes
MNIST_CLASSES = list(range(10))

## 2. Prepare and Visualize Dataset

In [None]:
def prepare_ssl_datasets(labeled_percentage=0.02):
    """Prepare datasets for self-supervised learning"""

    # TODO: Define transforms for MNIST
    # For autoencoders, we typically normalize to [0, 1] or [-1, 1]
    # YOUR CODE HERE:


    # Load MNIST datasets
    train_dataset = MNIST(root='./data', train=True, download=True, transform=transform)
    test_dataset = MNIST(root='./data', train=False, download=True, transform=transform)

    # Create unlabeled dataset (all training data, labels ignored)
    unlabeled_dataset = train_dataset

    # Create small labeled dataset (2% of training data)
    num_labeled = int(len(train_dataset) * labeled_percentage)

    # Get indices for each class
    class_indices = {i: [] for i in range(10)}
    for idx, (_, label) in enumerate(train_dataset):
        class_indices[label].append(idx)

    # Sample equally from each class
    samples_per_class = num_labeled // 10
    labeled_indices = []

    for class_idx in range(10):
        class_samples = np.random.choice(
            class_indices[class_idx],
            samples_per_class,
            replace=False
        )
        labeled_indices.extend(class_samples)

    # Create labeled subset
    labeled_dataset = Subset(train_dataset, labeled_indices)

    print(f"Unlabeled dataset size: {len(unlabeled_dataset)} (labels ignored)")
    print(f"Labeled dataset size: {len(labeled_dataset)} ({labeled_percentage*100:.1f}%)")
    print(f"Test dataset size: {len(test_dataset)}")
    print(f"Samples per class in labeled set: {samples_per_class}")

    return unlabeled_dataset, labeled_dataset, test_dataset

def visualize_mnist_samples(dataset, title="MNIST Samples", num_samples=20):
    """Visualize sample images from MNIST"""

    fig, axes = plt.subplots(2, 10, figsize=(15, 4))
    axes = axes.ravel()

    # Get random samples
    indices = np.random.choice(len(dataset), num_samples, replace=False)

    for i, idx in enumerate(indices):
        if hasattr(dataset, 'dataset'):  # Handle Subset
            image, label = dataset.dataset[dataset.indices[idx]]
        else:
            image, label = dataset[idx]

        axes[i].imshow(image.squeeze(), cmap='gray')
        axes[i].set_title(f'Label: {label}')
        axes[i].axis('off')

    plt.suptitle(title)
    plt.tight_layout()
    plt.show()

def analyze_class_distribution(dataset, title="Class Distribution"):
    """Analyze class distribution in the dataset"""

    class_counts = [0] * 10

    # Count classes
    if hasattr(dataset, 'dataset'):  # Handle Subset
        for idx in dataset.indices:
            _, label = dataset.dataset[idx]
            class_counts[label] += 1
    else:
        for _, label in dataset:
            class_counts[label] += 1

    # Plot distribution
    plt.figure(figsize=(10, 6))
    bars = plt.bar(range(10), class_counts, color='skyblue', alpha=0.7)
    plt.title(title)
    plt.xlabel('Digit Class')
    plt.ylabel('Number of Samples')
    plt.xticks(range(10))

    # Add value labels on bars
    for bar, count in zip(bars, class_counts):
        plt.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 5,
                str(count), ha='center', va='bottom')

    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.show()

    print("Class distribution:")
    for i, count in enumerate(class_counts):
        print(f"Digit {i}: {count} samples")

# Prepare datasets
unlabeled_dataset, labeled_dataset, test_dataset = prepare_ssl_datasets(labeled_percentage=0.02)

# Visualize data
visualize_mnist_samples(unlabeled_dataset, "Unlabeled MNIST Samples")
visualize_mnist_samples(labeled_dataset, "Labeled MNIST Samples (2%)")

# Analyze distributions
analyze_class_distribution(labeled_dataset, "Labeled Dataset Distribution")

## 3. Define Autoencoder

In [None]:
class Encoder(nn.Module):
    """Encoder network that compresses input to latent representation"""

    def __init__(self, input_dim=784, latent_dim=128):
        super(Encoder, self).__init__()

        # TODO: Define encoder architecture
        # Progressively reduce dimensionality: 784 -> ... -> latent_dim
        # Try different cnn tecniques (dropout,batchnorm)
        # YOUR CODE HERE:

        self.encoder = nn.Sequential(
            # Note: No activation on final layer - let it learn any range
        )

    def forward(self, x):
        # TODO: Implement encoder forward pass
        # YOUR CODE HERE:

        # Flatten input if needed
        if len(x.shape) > 2:
            x = x.view(x.size(0), -1)

        return

class Decoder(nn.Module):
    """Decoder network that reconstructs input from latent representation"""

    def __init__(self, latent_dim=128, output_dim=784):
        super(Decoder, self).__init__()

        # TODO: Define decoder architecture
        # Progressively increase dimensionality: latent_dim -> ... -> 784
        # YOUR CODE HERE:

        self.decoder = nn.Sequential(
            # First layer: expand from latent_dim to 128
            # Use batchnorm1d after activation function

        )

    def forward(self, x):
        # TODO: Implement decoder forward pass
        # YOUR CODE HERE:

        return

class Autoencoder(nn.Module):
    """Complete autoencoder combining encoder and decoder"""

    def __init__(self, input_dim=784, latent_dim=128):
        super(Autoencoder, self).__init__()

        # TODO: Combine encoder and decoder
        # YOUR CODE HERE:

    def forward(self, x):
        # TODO: Implement autoencoder forward pass
        # YOUR CODE HERE:

        return reconstructed, latent

    def encode(self, x):
        """Get latent representation only"""
        return self.encoder(x)

# Test autoencoder
latent_dim = 128
autoencoder = Autoencoder(input_dim=784, latent_dim=latent_dim)

# Test with dummy data
test_input = torch.randn(32, 784)  # Batch of flattened MNIST images
reconstructed, latent = autoencoder(test_input)

print(f"Input shape: {test_input.shape}")
print(f"Latent representation shape: {latent.shape}")
print(f"Reconstructed output shape: {reconstructed.shape}")
print(f"Compression ratio: {784/latent_dim:.1f}x")

# Count parameters
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"Autoencoder parameters: {count_parameters(autoencoder):,}")
print(f"Encoder parameters: {count_parameters(autoencoder.encoder):,}")
print(f"Decoder parameters: {count_parameters(autoencoder.decoder):,}")

## 4. Train Autoencoder

In [None]:
def train_autoencoder(model, unlabeled_loader, num_epochs=50, learning_rate=1e-3):
    """Train autoencoder on unlabeled data"""

    # Setup training components for autoencoder

    criterion = nn.MSELoss()  # Reconstruction loss
    optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=1e-5)
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=15, gamma=0.5)

    model.to(device)
    model.train()

    train_losses = []

    print("Starting Autoencoder pretraining...")

    for epoch in range(num_epochs):
        epoch_loss = 0
        num_batches = 0

        for batch_idx, (images, _) in enumerate(unlabeled_loader):  # Ignore labels!

            images = images.to(device)
            batch_size = images.size(0)

            # Flatten images
            images_flat = images.view(batch_size, -1)

            optimizer.zero_grad()

            # Forward pass
            reconstructed, latent =

            # Calculate reconstruction loss
            loss =

            # Backward pass
            loss.backward()
            optimizer.step()

            epoch_loss += loss.item()
            num_batches += 1

            # Print progress
            if batch_idx % 200 == 0:
                print(f'Epoch {epoch+1}/{num_epochs}, Batch {batch_idx}, Loss: {loss.item():.6f}')

        # Calculate average loss
        avg_loss = epoch_loss / num_batches
        train_losses.append(avg_loss)

        # Update learning rate
        scheduler.step()

        print(f'Epoch [{epoch+1}/{num_epochs}], Average Loss: {avg_loss:.6f}, LR: {scheduler.get_last_lr()[0]:.6f}')
        print('-' * 60)

    print("Autoencoder pretraining completed!")
    return train_losses

def visualize_reconstructions(model, dataset, num_samples=10):
    """Visualize original images and their reconstructions"""

    model.eval()

    # Get random samples
    indices = np.random.choice(len(dataset), num_samples, replace=False)

    fig, axes = plt.subplots(2, num_samples, figsize=(15, 4))

    with torch.no_grad():
        for i, idx in enumerate(indices):
            # Get original image
            if hasattr(dataset, 'dataset'):
                original, label = dataset.dataset[dataset.indices[idx]]
            else:
                original, label = dataset[idx]

            # Reconstruct image
            original_flat = original.view(1, -1).to(device)
            reconstructed, _ = model(original_flat)
            reconstructed = reconstructed.view(28, 28).cpu()

            # Plot original
            axes[0, i].imshow(original.squeeze(), cmap='gray')
            axes[0, i].set_title(f'Original: {label}')
            axes[0, i].axis('off')

            # Plot reconstruction
            axes[1, i].imshow(reconstructed, cmap='gray')
            axes[1, i].set_title('Reconstructed')
            axes[1, i].axis('off')

    plt.suptitle('Original vs Reconstructed Images')
    plt.tight_layout()
    plt.show()

def plot_training_progress(train_losses):
    """Plot autoencoder training loss"""

    plt.figure(figsize=(10, 6))
    plt.plot(train_losses, color='blue', linewidth=2)
    plt.title('Autoencoder Training Loss')
    plt.xlabel('Epoch')
    plt.ylabel('MSE Loss')
    plt.grid(True, alpha=0.3)
    plt.show()

    print(f"Final training loss: {train_losses[-1]:.6f}")
    print(f"Loss reduction: {train_losses[0]/train_losses[-1]:.2f}x")

# Create data loaders
batch_size = 256

unlabeled_loader = DataLoader(
    unlabeled_dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=2
)

print(f"Unlabeled data batches: {len(unlabeled_loader)}")

# Create and train autoencoder
autoencoder = Autoencoder(input_dim=784, latent_dim=128)

train_losses = train_autoencoder(autoencoder, unlabeled_loader, num_epochs=30)

plot_training_progress(train_losses)
visualize_reconstructions(autoencoder, unlabeled_dataset)

## 5. Latent Space Analysis

In [None]:
def extract_latent_representations(model, dataset, max_samples=5000):
    """Extract latent representations and labels for analysis"""

    model.eval()

    latent_vectors = []
    labels = []

    # Create data loader
    loader = DataLoader(dataset, batch_size=128, shuffle=False)

    with torch.no_grad():
        for images, batch_labels in loader:
            if len(latent_vectors) * 128 >= max_samples:
                break

            images = images.to(device)
            images_flat = images.view(images.size(0), -1)

            # Get latent representations
            latent = model.encode(images_flat)

            latent_vectors.append(latent.cpu())
            labels.extend(batch_labels.numpy())

    latent_vectors = torch.cat(latent_vectors, dim=0).numpy()
    labels = np.array(labels)

    return latent_vectors[:max_samples], labels[:max_samples]

def visualize_latent_space(latent_vectors, labels, method='tsne'):
    """Visualize latent space using dimensionality reduction"""

    if method == 'tsne':

        print("Applying t-SNE to latent representations...")
        tsne = TSNE(n_components=2, random_state=42, perplexity=30)
        latent_2d = tsne.fit_transform(latent_vectors)

    elif method == 'pca':
        from sklearn.decomposition import PCA

        print("Applying PCA to latent representations...")
        pca = PCA(n_components=2, random_state=42)
        latent_2d = pca.fit_transform(latent_vectors)

        print(f"PCA explained variance ratio: {pca.explained_variance_ratio_}")

    # Plot the 2D representation
    plt.figure(figsize=(12, 10))

    # Create color map for digits
    colors = plt.cm.tab10(np.linspace(0, 1, 10))

    for digit in range(10):
        mask = labels == digit
        plt.scatter(latent_2d[mask, 0], latent_2d[mask, 1],
                   c=[colors[digit]], label=f'Digit {digit}', alpha=0.6, s=20)

    plt.title(f'Latent Space Visualization ({method.upper()})')
    plt.xlabel('Component 1')
    plt.ylabel('Component 2')
    plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.show()

def analyze_latent_space_structure(latent_vectors, labels):
    """Analyze the structure of the learned latent space"""

    # Calculate inter-class and intra-class distances
    from scipy.spatial.distance import pdist, squareform

    inter_class_distances = []
    intra_class_distances = []

    for digit in range(10):
        # Get samples for this digit
        digit_mask = labels == digit
        digit_latents = latent_vectors[digit_mask]

        if len(digit_latents) > 1:
            # Intra-class distances (within same digit)
            intra_distances = pdist(digit_latents)
            intra_class_distances.extend(intra_distances)

        # Inter-class distances (between different digits)
        for other_digit in range(digit + 1, 10):
            other_mask = labels == other_digit
            other_latents = latent_vectors[other_mask]

            if len(other_latents) > 0:
                # Calculate distances between all pairs
                for d1 in digit_latents[:50]:  # Sample to avoid too many calculations
                    for d2 in other_latents[:50]:
                        distance = np.linalg.norm(d1 - d2)
                        inter_class_distances.append(distance)

    # Plot distance distributions
    plt.figure(figsize=(12, 5))

    plt.subplot(1, 2, 1)
    plt.hist(intra_class_distances, bins=50, alpha=0.7, color='blue', density=True)
    plt.title('Intra-class Distances')
    plt.xlabel('Distance')
    plt.ylabel('Density')
    plt.grid(True, alpha=0.3)

    plt.subplot(1, 2, 2)
    plt.hist(inter_class_distances, bins=50, alpha=0.7, color='red', density=True)
    plt.title('Inter-class Distances')
    plt.xlabel('Distance')
    plt.ylabel('Density')
    plt.grid(True, alpha=0.3)

    plt.tight_layout()
    plt.show()

    print(f"Average intra-class distance: {np.mean(intra_class_distances):.4f}")
    print(f"Average inter-class distance: {np.mean(inter_class_distances):.4f}")
    print(f"Separation ratio: {np.mean(inter_class_distances) / np.mean(intra_class_distances):.4f}")

print("Extracting latent representations...")
latent_vectors, labels = extract_latent_representations(autoencoder, test_dataset)
#
print(f"Extracted {len(latent_vectors)} latent representations")
print(f"Latent dimension: {latent_vectors.shape[1]}")
#
# # Visualize latent space
visualize_latent_space(latent_vectors, labels, method='tsne')
visualize_latent_space(latent_vectors, labels, method='pca')
#
# # Analyze latent space structure
analyze_latent_space_structure(latent_vectors, labels)

## 6. Classifier Fine-Tuning

In [None]:
class SupervisedClassifier(nn.Module):
    """Classifier that uses pretrained encoder + classification head"""

    def __init__(self, encoder, num_classes=10, freeze_encoder=False):
        super(SupervisedClassifier, self).__init__()

        # TODO: Use pretrained encoder and add classification head
        # YOUR CODE HERE:

        self.encoder = encoder

        # Freeze encoder weights if specified
        if freeze_encoder:
            for param in self.encoder.parameters():
                param.requires_grad = False
            print("Encoder weights frozen")
        else:
            print("Encoder weights will be fine-tuned")

        # Add classification head
        # Assuming encoder outputs latent_dim features
        latent_dim = 128  # Should match encoder output

        self.classifier = nn.Sequential(
          .....
        )

    def forward(self, x):
        # TODO: Implement forward pass
        # YOUR CODE HERE:

        # Extract features using pretrained encoder

        return logits

def train_classifier(model, train_loader, val_loader, num_epochs=50, learning_rate=1e-3):
    """Train classifier with limited labeled data"""

    # TODO: Setup training for classification
    # YOUR CODE HERE:

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=1e-4)
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=15, gamma=0.7)

    model.to(device)

    train_losses = []
    train_accuracies = []
    val_accuracies = []

    best_val_acc = 0

    print("Starting supervised fine-tuning...")

    for epoch in range(num_epochs):
        # Training phase
        model.train()
        epoch_loss = 0
        correct_train = 0
        total_train = 0

        for images, labels in train_loader:
            # TODO: Implement training step
            # YOUR CODE HERE:

            images, labels = images.to(device), labels.to(device)
            batch_size = images.size(0)

            # Flatten images
            images_flat = images.view(batch_size, -1)

            optimizer.zero_grad()

            # Forward pass
            logits =
            loss =

            # Backward pass
            loss.backward()
            optimizer.step()

            # Statistics
            epoch_loss += loss.item()
            _, predicted = torch.max(logits.data, 1)
            total_train += labels.size(0)
            correct_train += (predicted == labels).sum().item()

        # Validation phase
        model.eval()
        correct_val = 0
        total_val = 0

        with torch.no_grad():
            for images, labels in val_loader:
                images, labels = images.to(device), labels.to(device)
                images_flat = images.view(images.size(0), -1)

                logits = model(images_flat)
                _, predicted = torch.max(logits.data, 1)
                total_val += labels.size(0)
                correct_val += (predicted == labels).sum().item()

        # Calculate metrics
        train_acc = 100 * correct_train / total_train
        val_acc = 100 * correct_val / total_val
        avg_loss = epoch_loss / len(train_loader)

        train_losses.append(avg_loss)
        train_accuracies.append(train_acc)
        val_accuracies.append(val_acc)

        # Save best model
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            torch.save(model.state_dict(), 'best_ssl_classifier.pth')

        # Update learning rate
        scheduler.step()

        if (epoch + 1) % 5 == 0:
            print(f'Epoch [{epoch+1}/{num_epochs}]')
            print(f'Train Loss: {avg_loss:.4f}, Train Acc: {train_acc:.2f}%')
            print(f'Val Acc: {val_acc:.2f}%, Best Val Acc: {best_val_acc:.2f}%')
            print('-' * 60)

    # Load best model
    model.load_state_dict(torch.load('best_ssl_classifier.pth'))

    print(f"Fine-tuning completed! Best validation accuracy: {best_val_acc:.2f}%")
    return train_losses, train_accuracies, val_accuracies

def evaluate_classifier(model, test_loader):
    """Evaluate classifier on test set"""

    model.eval()
    correct = 0
    total = 0
    predictions = []
    true_labels = []

    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            images_flat = images.view(images.size(0), -1)

            logits = model(images_flat)
            _, predicted = torch.max(logits.data, 1)

            total += labels.size(0)
            correct += (predicted == labels).sum().item()

            predictions.extend(predicted.cpu().numpy())
            true_labels.extend(labels.cpu().numpy())

    accuracy = 100 * correct / total

    print(f"Test Accuracy: {accuracy:.2f}%")

    # Detailed classification report
    print("\nClassification Report:")
    print(classification_report(true_labels, predictions,
                              target_names=[f'Digit {i}' for i in range(10)]))

    return accuracy, predictions, true_labels

# Create data loaders for supervised training
labeled_loader = DataLoader(labeled_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False)

# Split labeled data into train/val
val_split = 0.2
val_size = int(len(labeled_dataset) * val_split)
train_size = len(labeled_dataset) - val_size

labeled_train, labeled_val = random_split(labeled_dataset, [train_size, val_size])

labeled_train_loader = DataLoader(labeled_train, batch_size=32, shuffle=True)
labeled_val_loader = DataLoader(labeled_val, batch_size=32, shuffle=False)

print(f"Labeled training samples: {len(labeled_train)}")
print(f"Labeled validation samples: {len(labeled_val)}")

ssl_classifier = SupervisedClassifier(autoencoder.encoder, num_classes=10, freeze_encoder=False)

ssl_train_losses, ssl_train_accs, ssl_val_accs = train_classifier(
 ssl_classifier, labeled_train_loader, labeled_val_loader, num_epochs=50
)

ssl_test_accuracy, ssl_predictions, ssl_true_labels = evaluate_classifier(ssl_classifier, test_loader)
