In [16]:
# MNIST Generator Training in PyTorch
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import os

In [17]:
# Generator network (Simple ConvTranspose2D)
class Generator(nn.Module):
    def __init__(self, latent_dim=100, num_classes=10):
        super(Generator, self).__init__()
        self.label_embed = nn.Embedding(num_classes, latent_dim)
        self.net = nn.Sequential(
            nn.Linear(latent_dim, 128),
            nn.ReLU(True),
            nn.Linear(128, 784),
            nn.Tanh()
        )


In [18]:
 def forward(self, z, labels):
        # Element-wise multiply noise and label embedding
        x = z * self.label_embed(labels)
        out = self.net(x)
        return out.view(-1, 1, 28, 28)

In [19]:
# Hyperparameters
latent_dim = 100
epochs = 10
batch_size = 64
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [20]:
# Prepare data
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize([0.5], [0.5])])
train_loader = DataLoader(datasets.MNIST('.', train=True, download=True, transform=transform), batch_size=batch_size, shuffle=True)

# Model, loss, optimizer
model = Generator(latent_dim).to(device)
loss_fn = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.0002)

100%|██████████| 9.91M/9.91M [00:00<00:00, 70.1MB/s]
100%|██████████| 28.9k/28.9k [00:00<00:00, 31.5MB/s]
100%|██████████| 1.65M/1.65M [00:00<00:00, 87.3MB/s]
100%|██████████| 4.54k/4.54k [00:00<00:00, 5.48MB/s]


In [24]:
class Generator(nn.Module):
    def __init__(self, latent_dim=100, num_classes=10):
        super(Generator, self).__init__()
        self.label_embed = nn.Embedding(num_classes, latent_dim)
        self.net = nn.Sequential(
            nn.Linear(latent_dim, 128),
            nn.ReLU(True),
            nn.Linear(128, 784),
            nn.Tanh()
        )

    def forward(self, z, labels):
        # Embed the label and element-wise multiply it with the noise
        embedded_labels = self.label_embed(labels)
        x = z * embedded_labels
        out = self.net(x)
        return out.view(-1, 1, 28, 28)



In [25]:
# Save the model
os.makedirs("saved_models", exist_ok=True)
torch.save(model.state_dict(), "saved_models/generator_mnist.pth")