In [None]:
# BC why not


import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Subset
import numpy as np
import matplotlib.pyplot as plt

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Hyperparameters
batch_size = 128
image_size = 28
nz = 100  # Size of latent vector z
num_epochs = 20
lr = 0.0002
beta1 = 0.5  # Beta1 hyperparameter for Adam optimizers

# Data transformation and loading
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))  # Normalize to [-1, 1]
])

# Load the MNIST dataset
dataset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# Discriminator network
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.main = nn.Sequential(
            # Input: (batch_size, 1, 28, 28)
            nn.Conv2d(1, 64, kernel_size=4, stride=2, padding=1),  # Output: (batch_size, 64, 14, 14)
            nn.LeakyReLU(0.2, inplace=True),
            # Second conv layer
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),  # Output: (batch_size, 128, 7, 7)
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            # Third conv layer
            nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1),  # Output: (batch_size, 256, 4, 4)
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
            # Fourth conv layer
            nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1),  # Output: (batch_size, 512, 4, 4)
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),
            # Output layer
            nn.Conv2d(512, 1, kernel_size=4, stride=1, padding=0),  # Output: (batch_size, 1, 1, 1)
            nn.Sigmoid()
        )

    def forward(self, input):
        output = self.main(input)
        return output.view(-1)

# Generator network
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.main = nn.Sequential(
            # Input: latent vector z of size (batch_size, nz)
            nn.Linear(nz, 256 * 7 * 7),
            nn.BatchNorm1d(256 * 7 * 7),
            nn.ReLU(True),
            # Reshape to (batch_size, 256, 7, 7)
            nn.Unflatten(1, (256, 7, 7)),
            # First ConvTranspose layer
            nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),  # Output: (batch_size, 128, 14, 14)
            nn.BatchNorm2d(128),
            nn.ReLU(True),
            # Second ConvTranspose layer
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),  # Output: (batch_size, 64, 28, 28)
            nn.BatchNorm2d(64),
            nn.ReLU(True),
            # Third ConvTranspose layer
            nn.ConvTranspose2d(64, 32, kernel_size=3, stride=1, padding=1),  # Output: (batch_size, 32, 28, 28)
            nn.BatchNorm2d(32),
            nn.ReLU(True),
            # Output layer
            nn.ConvTranspose2d(32, 1, kernel_size=3, stride=1, padding=1),  # Output: (batch_size, 1, 28, 28)
            nn.Tanh()  # Output values in [-1, 1]
        )

    def forward(self, input):
        output = self.main(input)
        return output

# Initialize models
netD = Discriminator().to(device)
netG = Generator().to(device)

# Initialize weights
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1 or classname.find('Linear') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)

netD.apply(weights_init)
netG.apply(weights_init)

# Loss function and optimizers
criterion = nn.BCELoss()

fixed_noise = torch.randn(64, nz, device=device)  # For visualization

real_label = 1.
fake_label = 0.

optimizerD = optim.Adam(netD.parameters(), lr=lr, betas=(beta1, 0.999))
optimizerG = optim.Adam(netG.parameters(), lr=lr, betas=(beta1, 0.999))

# Training loop
for epoch in range(num_epochs):
    for i, data in enumerate(dataloader, 0):
        ############################
        # (1) Update D network
        ############################
        ## Train with all-real batch
        netD.zero_grad()
        real_cpu = data[0].to(device)
        b_size = real_cpu.size(0)
        label = torch.full((b_size,), real_label, dtype=torch.float, device=device)
        output = netD(real_cpu)
        errD_real = criterion(output, label)
        errD_real.backward()
        D_x = output.mean().item()

        ## Train with all-fake batch
        noise = torch.randn(b_size, nz, device=device)
        fake = netG(noise)
        label.fill_(fake_label)
        output = netD(fake.detach())
        errD_fake = criterion(output, label)
        errD_fake.backward()
        D_G_z1 = output.mean().item()
        errD = errD_real + errD_fake
        optimizerD.step()

        ############################
        # (2) Update G network
        ############################
        netG.zero_grad()
        label.fill_(real_label)  # Fake labels are real for generator cost
        output = netD(fake)
        errG = criterion(output, label)
        errG.backward()
        D_G_z2 = output.mean().item()
        optimizerG.step()

        # Print training stats
        if i % 100 == 0:
            print(f'[{epoch}/{num_epochs}][{i}/{len(dataloader)}] '
                  f'Loss_D: {errD.item():.4f} Loss_G: {errG.item():.4f} '
                  f'D(x): {D_x:.4f} D(G(z)): {D_G_z1:.4f}/{D_G_z2:.4f}')

# Generate 12 new samples
netG.eval()
with torch.no_grad():
    fixed_noise = torch.randn(12, nz, device=device)
    fake_images = netG(fixed_noise).detach().cpu()

# Plot the images
grid = torchvision.utils.make_grid(fake_images, nrow=4, normalize=True)
plt.figure(figsize=(8, 8))
plt.axis('off')
plt.title('Generated Images')
plt.imshow(np.transpose(grid, (1, 2, 0)))
plt.show()

# ------------------------------
# Part 1.2: GAN as a Pre-Training Framework
# ------------------------------

# Create a feature extractor from the discriminator
class FeatureExtractor(nn.Module):
    def __init__(self, discriminator):
        super(FeatureExtractor, self).__init__()
        # Copy layers up to before the last Conv2d layer
        self.features = nn.Sequential(*list(discriminator.main.children())[:-2])  # Exclude last Conv2d and Sigmoid

    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), -1)  # Flatten
        return x

# Instantiate the feature extractor
feature_extractor = FeatureExtractor(netD).to(device)

# Freeze feature extractor parameters
for param in feature_extractor.parameters():
    param.requires_grad = False

# Prepare 10% of the training data
train_indices = np.arange(len(dataset))
np.random.shuffle(train_indices)
subset_indices = train_indices[:len(dataset) // 10]
train_subset = Subset(dataset, subset_indices)
train_loader = DataLoader(train_subset, batch_size=batch_size, shuffle=True)

# Prepare test data
test_dataset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

# Determine the size of the feature vector
with torch.no_grad():
    dummy_input = torch.randn(1, 1, 28, 28, device=device)
    features = feature_extractor(dummy_input)
    feature_size = features.shape[1]

# Define a linear classifier
classifier = nn.Linear(feature_size, 10).to(device)

# Loss function and optimizer for classifier
criterion_cls = nn.CrossEntropyLoss()
optimizer_cls = optim.Adam(classifier.parameters(), lr=0.001)

# Train the classifier
num_epochs_cls = 10
for epoch in range(num_epochs_cls):
    classifier.train()
    running_loss = 0.0
    correct = 0
    total = 0
    for i, (inputs, labels) in enumerate(train_loader):
        inputs = inputs.to(device)
        labels = labels.to(device)

        with torch.no_grad():
            features = feature_extractor(inputs)
        outputs = classifier(features)
        loss = criterion_cls(outputs, labels)

        optimizer_cls.zero_grad()
        loss.backward()
        optimizer_cls.step()

        running_loss += loss.item()
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()

    train_loss = running_loss / len(train_loader)
    train_acc = 100. * correct / total

    # Evaluate on test set
    classifier.eval()
    test_loss = 0.0
    correct_test = 0
    total_test = 0
    with torch.no_grad():
        for inputs, labels in test_loader:
            inputs = inputs.to(device)
            labels = labels.to(device)
            features = feature_extractor(inputs)
            outputs = classifier(features)
            loss = criterion_cls(outputs, labels)

            test_loss += loss.item()
            _, predicted = outputs.max(1)
            total_test += labels.size(0)
            correct_test += predicted.eq(labels).sum().item()

    test_loss /= len(test_loader)
    test_acc = 100. * correct_test / total_test

    print(f'[{epoch + 1}/{num_epochs_cls}] Train Loss: {train_loss:.4f} Train Acc: {train_acc:.2f}% '
          f'Test Loss: {test_loss:.4f} Test Acc: {test_acc:.2f}%')


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import torchvision
import torchvision.transforms as transforms
from torchvision.utils import make_grid

import matplotlib.pyplot as plt
import numpy as np

from torch.utils.data import DataLoader, Subset
import random

# Set random seeds for reproducibility
seed = 42
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

# Hyperparameters
batch_size = 128
learning_rate = 1e-3
weight_decay = 0
num_epochs = 25
k = 5  # Number of Gibbs sampling steps
M_values = [16, 64, 256]  # Hidden layer sizes

# Data transformations
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))  # Normalize to [-1, 1]
])

# Load KMNIST dataset
train_set = torchvision.datasets.KMNIST(root='./data', train=True, download=True, transform=transform)
test_set = torchvision.datasets.KMNIST(root='./data', train=False, download=True, transform=transform)

train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)
test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)

# Define the Gaussian-Bernoulli RBM
class RBM(nn.Module):
    """Gaussian-Bernoulli Restricted Boltzmann Machine."""

    def __init__(self, visible_units: int, hidden_units: int, k: int):
        """
        Initializes the RBM model.

        Args:
            visible_units (int): Number of visible units (input dimension).
            hidden_units (int): Number of hidden units.
            k (int): Number of Gibbs sampling steps.
        """
        super(RBM, self).__init__()
        self.visible = visible_units
        self.hidden = hidden_units
        self.k = k

        # Initialize weights and biases
        self.W = nn.Parameter(torch.randn(hidden_units, visible_units) * 0.01)  # Weight matrix
        self.c = nn.Parameter(torch.zeros(visible_units))  # Visible biases
        self.b = nn.Parameter(torch.zeros(hidden_units))  # Hidden biases

    def sample_bernoulli(self, probs):
        """Sample binary values from probabilities."""
        return torch.bernoulli(probs)

    def sample_gaussian(self, mean, std=1.0):
        """Sample continuous values from Gaussian distribution."""
        return mean + torch.randn(mean.size()).to(device) * std

    def p_h_v(self, v):
        """
        Compute the probability of hidden units given visible units.

        Args:
            v (torch.Tensor): Visible units.

        Returns:
            torch.Tensor: Probability of hidden units.
        """
        return torch.sigmoid(F.linear(v, self.W, self.b))

    def p_v_h(self, h):
        """
        Compute the mean of visible units given hidden units.

        Args:
            h (torch.Tensor): Hidden units.

        Returns:
            torch.Tensor: Mean of visible units.
        """
        return F.linear(h, self.W, self.c)

    def free_energy(self, v):
        """
        Compute the free energy for a batch of visible units.

        Args:
            v (torch.Tensor): Visible units.

        Returns:
            torch.Tensor: Free energy.
        """
        vbias_term = torch.sum((v - self.c) ** 2, dim=1) / 2
        wx_b = F.linear(v, self.W, self.b)
        hidden_term = torch.sum(torch.log1p(torch.exp(wx_b)), dim=1)
        return vbias_term - hidden_term

    def forward(self, v):
        """
        Perform Contrastive Divergence (CD-k) to sample from the model.

        Args:
            v (torch.Tensor): Initial visible units.

        Returns:
            torch.Tensor: Reconstructed visible units after k Gibbs steps.
            torch.Tensor: Probabilities of hidden units after k Gibbs steps.
        """
        v_neg = v
        for _ in range(self.k):
            # Sample hidden units
            p_h = self.p_h_v(v_neg)
            h_sample = self.sample_bernoulli(p_h)

            # Sample visible units
            v_mean = self.p_v_h(h_sample)
            v_neg = self.sample_gaussian(v_mean)

        p_h_neg = self.p_h_v(v_neg)
        return v_neg, p_h_neg

def train_rbm(model, train_loader, optimizer, epoch):
    """Train the RBM model for one epoch."""
    model.train()
    train_loss = 0
    for batch_idx, (data, _) in enumerate(train_loader):
        data = data.view(data.size(0), -1).to(device)
        data = data * 2 - 1  # Rescale from [0,1] to [-1,1]

        # Positive phase
        p_h_v = model.p_h_v(data)
        h_sample = model.sample_bernoulli(p_h_v)

        # Negative phase
        v_neg, p_h_neg = model(data)

        # Compute gradients
        positive_grad = torch.matmul(p_h_v.t(), data)
        negative_grad = torch.matmul(p_h_neg.t(), v_neg)

        # Update parameters
        loss = torch.mean(model.free_energy(data)) - torch.mean(model.free_energy(v_neg))
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        train_loss += loss.item()

        if (batch_idx + 1) % (len(train_loader) // 2) == 0:
            print(f'Epoch [{epoch}/{num_epochs}] Batch [{batch_idx + 1}/{len(train_loader)}] '
                  f'Loss: {loss.item():.4f}')

    avg_loss = train_loss / len(train_loader)
    print(f'====> Epoch: {epoch} Average loss: {avg_loss:.4f}')

def test_rbm(model, test_loader, epoch):
    """Evaluate the RBM model on the test set."""
    model.eval()
    test_loss = 0
    with torch.no_grad():
        for data, _ in test_loader:
            data = data.view(data.size(0), -1).to(device)
            data = data * 2 - 1  # Rescale from [0,1] to [-1,1]

            v_neg, p_h_neg = model(data)
            loss = torch.mean(model.free_energy(data)) - torch.mean(model.free_energy(v_neg))
            test_loss += loss.item()

    avg_test_loss = test_loss / len(test_loader)
    print(f'====> Test set loss: {avg_test_loss:.4f}')
    return avg_test_loss

def reconstruct(model, data):
    """Reconstruct visible units from the model."""
    with torch.no_grad():
        v_neg, _ = model(data)
    return v_neg

def plot_reconstructions(original, reconstructed, epoch, M):
    """Plot original and reconstructed images."""
    original = original.view(-1, 1, 28, 28).cpu()
    reconstructed = reconstructed.view(-1, 1, 28, 28).cpu()
    comparison = torch.cat([original[:8], reconstructed[:8]])
    grid = make_grid(comparison, nrow=8, padding=2, normalize=True)

    plt.figure(figsize=(16, 4))
    plt.title(f'Original and Reconstructed Images (Epoch {epoch}, M={M})')
    plt.imshow(np.transpose(grid, (1, 2, 0)))
    plt.axis('off')
    plt.show()

def compute_mse(model, data_loader):
    """Compute Mean Squared Error between original and reconstructed data."""
    model.eval()
    mse = 0
    count = 0
    with torch.no_grad():
        for data, _ in data_loader:
            data = data.view(data.size(0), -1).to(device)
            data = data * 2 - 1  # Rescale from [0,1] to [-1,1]
            v_neg, _ = model(data)
            mse += F.mse_loss(v_neg, data, reduction='sum').item()
            count += data.size(0)
    mse /= count
    return mse

# Training and Evaluation Loop for different M values
for M in M_values:
    print(f'\n======================== Training RBM with M = {M} ========================\n')
    rbm = RBM(visible_units=28*28, hidden_units=M, k=k).to(device)
    optimizer = optim.Adam(rbm.parameters(), lr=learning_rate, weight_decay=weight_decay)

    train_losses = []
    test_losses = []

    for epoch in range(1, num_epochs + 1):
        train_rbm(rbm, train_loader, optimizer, epoch)
        test_loss = test_rbm(rbm, test_loader, epoch)
        train_losses.append(test_loss)  # Using test loss as a placeholder
        test_losses.append(test_loss)

        # Reconstruct and visualize
        data, _ = next(iter(test_loader))
        data = data[:32]
        data = data.view(data.size(0), -1).to(device)
        data = data * 2 - 1  # Rescale to [-1,1]
        reconstructed = reconstruct(rbm, data)
        plot_reconstructions(data, reconstructed, epoch, M)

    # Compute MSE for train and test sets
    train_mse = compute_mse(rbm, train_loader)
    test_mse = compute_mse(rbm, test_loader)
    print(f'RBM with M={M}: Train MSE: {train_mse:.4f}, Test MSE: {test_mse:.4f}')

    # Plotting the loss curves
    epochs = range(1, num_epochs + 1)
    plt.figure(figsize=(10,5))
    plt.plot(epochs, train_losses, label='Train Loss')
    plt.plot(epochs, test_losses, label='Test Loss')
    plt.title(f'Loss Curves for RBM with M={M}')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    plt.show()

    # Reconstruct and visualize final epoch
    data, _ = next(iter(test_loader))
    data = data[:32]
    data_size = data.size()
    data = data.view(data.size(0), -1).to(device)
    data = data * 2 - 1  # Rescale to [-1,1]
    reconstructed = reconstruct(rbm, data)
    plot_reconstructions(data, reconstructed, num_epochs, M)

    print(f'Optimizer Learning rate: {optimizer.param_groups[0]["lr"]:.4f}\n')


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam
from torch.utils.data import DataLoader, random_split
import torchvision
import torchvision.transforms as transforms
from torchvision.utils import make_grid
import matplotlib.pyplot as plt
import numpy as np
from sklearn.manifold import TSNE
import os

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

# Hyperparameters
n_in = 784  # 28x28 images
n_hid = 400
z_dim = 20
learning_rate = 1e-3
batch_size = 128
num_epochs = 25
validation_split = 0.2  # 20% for validation

# Create directory to save models
os.makedirs('models', exist_ok=True)

# Data transformations
transform = transforms.Compose([
    transforms.ToTensor(),
])

# Load Fashion MNIST dataset
train_val_set = torchvision.datasets.FashionMNIST(
    root='./data',
    train=True,
    download=True,
    transform=transform
)

test_set = torchvision.datasets.FashionMNIST(
    root='./data',
    train=False,
    download=True,
    transform=transform
)

# Split training data into training and validation sets
num_train = int((1 - validation_split) * len(train_val_set))
num_val = len(train_val_set) - num_train
train_set, val_set = random_split(train_val_set, [num_train, num_val])

# Data loaders
train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)
val_loader = DataLoader(val_set, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)
test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)

# Function to convert labels to one-hot encoding
def one_hot(labels, num_classes=10):
    return F.one_hot(labels, num_classes=num_classes).float()

# ---------------------------
# 3.1 Vanilla VAE Implementation
# ---------------------------
class VAE(nn.Module):
    def __init__(self, n_in, n_hid, z_dim):
        super(VAE, self).__init__()

        # Encoder layers
        self.fc1 = nn.Linear(n_in, n_hid)
        self.fc21 = nn.Linear(n_hid, z_dim)  # For mean
        self.fc22 = nn.Linear(n_hid, z_dim)  # For log variance

        # Decoder layers
        self.fc3 = nn.Linear(z_dim, n_hid)
        self.fc4 = nn.Linear(n_hid, n_in)

    def encode(self, x):
        """Encoder forward pass."""
        h = F.relu(self.fc1(x))
        mu = self.fc21(h)
        logvar = self.fc22(h)
        return mu, logvar

    def reparameterize(self, mu, logvar):
        """Implements: z = mu + epsilon * std"""
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z):
        """Decoder forward pass."""
        h = F.relu(self.fc3(z))
        recon_x = torch.sigmoid(self.fc4(h))
        return recon_x

    def forward(self, x):
        """Forward pass through the network."""
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        recon_x = self.decode(z)
        return recon_x, mu, logvar

# Initialize VAE model
vae = VAE(n_in=n_in, n_hid=n_hid, z_dim=z_dim).to(device)
optimizer_vae = Adam(vae.parameters(), lr=learning_rate)

# Loss function: ELBO (Reconstruction loss + KL Divergence)
def loss_function(recon_x, x, mu, logvar):
    BCE = F.binary_cross_entropy(recon_x, x, reduction='sum')  # Reconstruction loss
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())  # KL divergence
    return BCE + KLD

# Training function for VAE
def train_vae(model, optimizer, train_loader, val_loader, epoch):
    model.train()
    train_loss = 0
    for batch_idx, (data, _) in enumerate(train_loader):
        data = data.view(-1, n_in).to(device)
        optimizer.zero_grad()
        recon_batch, mu, logvar = model(data)
        loss = loss_function(recon_batch, data, mu, logvar)
        loss.backward()
        train_loss += loss.item()
        optimizer.step()

    avg_train_loss = train_loss / len(train_loader.dataset)
    print(f'====> Epoch: {epoch} Average Train Loss: {avg_train_loss:.4f}')

    # Validation
    model.eval()
    val_loss = 0
    with torch.no_grad():
        for data, _ in val_loader:
            data = data.view(-1, n_in).to(device)
            recon_batch, mu, logvar = model(data)
            loss = loss_function(recon_batch, data, mu, logvar)
            val_loss += loss.item()

    avg_val_loss = val_loss / len(val_loader.dataset)
    print(f'====> Epoch: {epoch} Average Validation Loss: {avg_val_loss:.4f}')

# Testing function for VAE
def test_vae(model, test_loader):
    model.eval()
    test_loss = 0
    with torch.no_grad():
        for data, _ in test_loader:
            data = data.view(-1, n_in).to(device)
            recon_batch, mu, logvar = model(data)
            loss = loss_function(recon_batch, data, mu, logvar)
            test_loss += loss.item()

    avg_test_loss = test_loss / len(test_loader.dataset)
    print(f'====> Test set loss: {avg_test_loss:.4f}')
    return avg_test_loss

# Training loop for VAE
print("Training Vanilla VAE...")
for epoch in range(1, num_epochs + 1):
    train_vae(vae, optimizer_vae, train_loader, val_loader, epoch)
    test_loss = test_vae(vae, test_loader)

# Save the trained VAE model
torch.save(vae.state_dict(), 'models/vae_fashion_mnist.pth')
print("VAE model saved to 'models/vae_fashion_mnist.pth'")

# ---------------------------
# 3.2 Conditional VAE (C-VAE) Implementation
# ---------------------------
class CVAE(nn.Module):
    def __init__(self, n_in, n_hid, z_dim, n_classes=10):
        super(CVAE, self).__init__()
        self.n_classes = n_classes

        # Encoder layers
        self.fc1 = nn.Linear(n_in + n_classes, n_hid)
        self.fc21 = nn.Linear(n_hid, z_dim)  # For mean
        self.fc22 = nn.Linear(n_hid, z_dim)  # For log variance

        # Decoder layers
        self.fc3 = nn.Linear(z_dim + n_classes, n_hid)
        self.fc4 = nn.Linear(n_hid, n_in)

    def encode(self, x, c):
        """Encoder forward pass with conditioning."""
        x = torch.cat([x, c], dim=1)  # Concatenate input with one-hot class vector
        h = F.relu(self.fc1(x))
        mu = self.fc21(h)
        logvar = self.fc22(h)
        return mu, logvar

    def reparameterize(self, mu, logvar):
        """Implements: z = mu + epsilon * std"""
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z, c):
        """Decoder forward pass with conditioning."""
        z = torch.cat([z, c], dim=1)  # Concatenate latent vector with one-hot class vector
        h = F.relu(self.fc3(z))
        recon_x = torch.sigmoid(self.fc4(h))
        return recon_x

    def forward(self, x, c):
        """Forward pass through the network with conditioning."""
        mu, logvar = self.encode(x, c)
        z = self.reparameterize(mu, logvar)
        recon_x = self.decode(z, c)
        return recon_x, mu, logvar

# Initialize C-VAE model
cvae = CVAE(n_in=n_in, n_hid=n_hid, z_dim=z_dim).to(device)
optimizer_cvae = Adam(cvae.parameters(), lr=learning_rate)

# Training function for C-VAE
def train_cvae(model, optimizer, train_loader, val_loader, epoch):
    model.train()
    train_loss = 0
    for batch_idx, (data, labels) in enumerate(train_loader):
        data = data.view(-1, n_in).to(device)
        labels = labels.to(device)
        c = one_hot(labels, num_classes=10).to(device)

        optimizer.zero_grad()
        recon_batch, mu, logvar = model(data, c)
        loss = loss_function(recon_batch, data, mu, logvar)
        loss.backward()
        train_loss += loss.item()
        optimizer.step()

    avg_train_loss = train_loss / len(train_loader.dataset)
    print(f'====> Epoch: {epoch} C-VAE Average Train Loss: {avg_train_loss:.4f}')

    # Validation
    model.eval()
    val_loss = 0
    with torch.no_grad():
        for data, labels in val_loader:
            data = data.view(-1, n_in).to(device)
            labels = labels.to(device)
            c = one_hot(labels, num_classes=10).to(device)
            recon_batch, mu, logvar = model(data, c)
            loss = loss_function(recon_batch, data, mu, logvar)
            val_loss += loss.item()

    avg_val_loss = val_loss / len(val_loader.dataset)
    print(f'====> Epoch: {epoch} C-VAE Average Validation Loss: {avg_val_loss:.4f}')

# Testing function for C-VAE
def test_cvae(model, test_loader):
    model.eval()
    test_loss = 0
    with torch.no_grad():
        for data, labels in test_loader:
            data = data.view(-1, n_in).to(device)
            labels = labels.to(device)
            c = one_hot(labels, num_classes=10).to(device)
            recon_batch, mu, logvar = model(data, c)
            loss = loss_function(recon_batch, data, mu, logvar)
            test_loss += loss.item()

    avg_test_loss = test_loss / len(test_loader.dataset)
    print(f'====> C-VAE Test set loss: {avg_test_loss:.4f}')
    return avg_test_loss

# Training loop for C-VAE
print("\nTraining Conditional VAE (C-VAE)...")
for epoch in range(1, num_epochs + 1):
    train_cvae(cvae, optimizer_cvae, train_loader, val_loader, epoch)
    test_loss = test_cvae(cvae, test_loader)

# Save the trained C-VAE model
torch.save(cvae.state_dict(), 'models/cvae_fashion_mnist.pth')
print("C-VAE model saved to 'models/cvae_fashion_mnist.pth'")

# ---------------------------
# 3.3 Manifold Comparison using t-SNE
# ---------------------------

def extract_mu(model, data_loader, conditional=False):
    """Extracts the mean vectors from the encoder."""
    model.eval()
    mus = []
    labels_list = []
    with torch.no_grad():
        for data, labels in data_loader:
            data = data.view(-1, n_in).to(device)
            labels = labels.to(device)
            if conditional:
                c = one_hot(labels, num_classes=10).to(device)
                mu, _ = model.encode(data, c)
            else:
                mu, _ = model.encode(data)
            mus.append(mu.cpu().numpy())
            labels_list.append(labels.cpu().numpy())
    mus = np.concatenate(mus, axis=0)
    labels = np.concatenate(labels_list, axis=0)
    return mus, labels

# Extract mu vectors from test set for VAE
print("\nExtracting latent representations for VAE...")
mu_vae, labels_vae = extract_mu(vae, test_loader, conditional=False)

# Extract mu vectors from test set for C-VAE
print("Extracting latent representations for C-VAE...")
mu_cvae, labels_cvae = extract_mu(cvae, test_loader, conditional=True)

# Apply t-SNE to reduce dimensions to 2D
print("Applying t-SNE on VAE latent representations...")
tsne_vae = TSNE(n_components=2, random_state=42)
mu_vae_2d = tsne_vae.fit_transform(mu_vae)

print("Applying t-SNE on C-VAE latent representations...")
tsne_cvae = TSNE(n_components=2, random_state=42)
mu_cvae_2d = tsne_cvae.fit_transform(mu_cvae)

# Function to plot t-SNE results
def plot_tsne(mu_2d, labels, title):
    plt.figure(figsize=(10, 8))
    scatter = plt.scatter(mu_2d[:, 0], mu_2d[:, 1], c=labels, cmap='tab10', alpha=0.6, s=10)
    plt.colorbar(scatter, ticks=range(10))
    plt.title(title)
    plt.xlabel('t-SNE Dimension 1')
    plt.ylabel('t-SNE Dimension 2')
    plt.grid(True)
    plt.show()

# Plot t-SNE for VAE
plot_tsne(mu_vae_2d, labels_vae, 't-SNE of VAE Latent Space')

# Plot t-SNE for C-VAE
plot_tsne(mu_cvae_2d, labels_cvae, 't-SNE of C-VAE Latent Space')

# ---------------------------
# 3.3 Comparison and Hypothesis
# ---------------------------

print("\nComparison and Hypothesis:")
print("""
Upon visualizing the latent spaces using t-SNE, the VAE's manifold may show overlapping clusters with less distinct boundaries between different classes. This is because the VAE is unsupervised and doesn't explicitly use class labels during training, leading to latent representations that capture general data features without class-specific separation.

In contrast, the C-VAE's manifold is expected to display more distinct and well-separated clusters corresponding to different classes. Since the C-VAE conditions on class labels during both encoding and decoding, it learns latent representations that are more discriminative with respect to the classes. This conditioning encourages the model to organize the latent space in a way that aligns with the class structure, resulting in clearer separations between different categories.

**Hypothesis:** The inclusion of class labels in the C-VAE provides additional supervised information that guides the latent space to form class-specific clusters, enhancing the model's ability to differentiate between classes. This leads to more interpretable and organized latent representations compared to the standard VAE, which lacks this supervised signal.
""")
