In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torch.utils.data import random_split

# Hyperparameters
batch_size = 256
latent_dim = 16
epochs = 200
lr = 1e-3

# Dataset
transform = transforms.Compose([transforms.ToTensor()])
train_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)

# Split the dataset into training and validation sets (e.g., 80% train, 20% validation)
train_size = int(0.8 * len(train_dataset))
val_size = len(train_dataset) - train_size
train_dataset, val_dataset = random_split(train_dataset, [train_size, val_size])

# DataLoader for validation set
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

# DataLoader for train set remains the same
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

class EarlyStopping:
    def __init__(self, patience=5, delta=0):
        self.patience = patience
        self.counter = 0
        self.best_loss = None
        self.early_stop = False
        self.delta = delta

    def __call__(self, val_loss):
        if self.best_loss is None:
            self.best_loss = val_loss
        elif val_loss > self.best_loss + self.delta:
            self.counter += 1
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_loss = val_loss
            self.counter = 0

# CVAE Model
class CVAE(nn.Module):
    def __init__(self):
        super(CVAE, self).__init__()
        self.fc1 = nn.Linear(28 * 28 + 10, 256)
        self.fc2_mean = nn.Linear(256, latent_dim)
        self.fc2_logvar = nn.Linear(256, latent_dim)
        self.fc3 = nn.Linear(latent_dim + 10, 256)
        self.fc4 = nn.Linear(256, 28 * 28)

    def encoder(self, x, y):
        inputs = torch.cat([x, y], dim=1)
        h = torch.relu(self.fc1(inputs))
        z_mean = self.fc2_mean(h)
        z_logvar = self.fc2_logvar(h)
        return z_mean, z_logvar

    def reparameterize(self, z_mean, z_logvar):
        std = torch.exp(0.5 * z_logvar)
        eps = torch.randn_like(std)
        return z_mean + eps * std

    def decoder(self, z, y):
        inputs = torch.cat([z, y], dim=1)
        h = torch.relu(self.fc3(inputs))
        return torch.sigmoid(self.fc4(h))

    def forward(self, x, y):
        z_mean, z_logvar = self.encoder(x, y)
        z = self.reparameterize(z_mean, z_logvar)
        x_reconstructed = self.decoder(z, y)
        return x_reconstructed, z_mean, z_logvar

# Loss Function
def loss_function(x_reconstructed, x, z_mean, z_logvar):
    reconstruction_loss = nn.functional.binary_cross_entropy(x_reconstructed, x, reduction='sum')
    kl_divergence = -0.5 * torch.sum(1 + z_logvar - z_mean.pow(2) - z_logvar.exp())
    return reconstruction_loss + kl_divergence

# Training Loop
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = CVAE().to(device)
optimizer = optim.Adam(model.parameters(), lr=lr)

# Initialize early stopping
# early_stopping = EarlyStopping(patience=5)
early_stopping = EarlyStopping(patience=10, delta=0.001)

for epoch in range(epochs):
    model.train()
    train_loss = 0
    for batch_idx, (data, target) in enumerate(train_loader):
        data = data.view(-1, 28 * 28).to(device)
        target = torch.nn.functional.one_hot(target, num_classes=10).float().to(device)
        optimizer.zero_grad()
        x_reconstructed, z_mean, z_logvar = model(data, target)
        loss = loss_function(x_reconstructed, data, z_mean, z_logvar)
        loss.backward()
        train_loss += loss.item()
        optimizer.step()

    # Validation loss (assuming you have a validation loader)
    model.eval()
    val_loss = 0
    with torch.no_grad():
        for data, target in val_loader:
            data = data.view(-1, 28 * 28).to(device)
            target = torch.nn.functional.one_hot(target, num_classes=10).float().to(device)
            x_reconstructed, z_mean, z_logvar = model(data, target)
            loss = loss_function(x_reconstructed, data, z_mean, z_logvar)
            val_loss += loss.item()

    # Check early stopping condition
    early_stopping(val_loss / len(val_loader))
    if early_stopping.early_stop:
        print("Early stopping")
        break

    print(f"Epoch {epoch + 1}, Train Loss: {train_loss / len(train_loader.dataset):.4f}, Val Loss: {val_loss / len(val_loader.dataset):.4f}")

In [None]:
import matplotlib.pyplot as plt

# Set the model to evaluation mode
model.eval()
def generate_images(label, num_images=5):
    with torch.no_grad():
        # Create one-hot label
        label_tensor = torch.nn.functional.one_hot(torch.tensor([label] * num_images), num_classes=10).float().to(device)
        
        # Sample random latent vectors from normal distribution
        z = torch.randn(num_images, latent_dim).to(device)
        
        # Generate images using the decoder
        generated_images = model.decoder(z, label_tensor).cpu()  # Call with z and label_tensor as separate arguments
        generated_images = generated_images.view(num_images, 28, 28)  # Reshape back to 28x28

        # Plot the generated images
        fig, axes = plt.subplots(1, num_images, figsize=(10, 2))
        for i in range(num_images):
            axes[i].imshow(generated_images[i], cmap="gray")
            axes[i].axis("off")
        plt.show()

# Example: Generate 5 images conditioned on label '3'
generate_images(label=1)