In [1]:
# digit_generator_training.ipynb or .py
# Train a conditional generator on MNIST (PyTorch)

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import numpy as np

# --- Model Definition ---
class DigitGenerator(nn.Module):
    def __init__(self):
        super(DigitGenerator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(110, 256),
            nn.ReLU(),
            nn.Linear(256, 512),
            nn.ReLU(),
            nn.Linear(512, 784),
            nn.Sigmoid()
        )

    def forward(self, noise, labels):
        x = torch.cat([noise, labels], dim=1)
        return self.model(x)

# --- One-hot encode labels ---
def one_hot(labels, num_classes=10):
    return torch.eye(num_classes)[labels]

# --- Training Config ---
batch_size = 128
z_dim = 100
epochs = 10
lr = 0.001

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# --- Data Loading ---
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Lambda(lambda x: x.view(-1))  # Flatten 28x28 to 784
])
train_loader = DataLoader(
    datasets.MNIST(root="./data", train=True, download=True, transform=transform),
    batch_size=batch_size,
    shuffle=True
)

# --- Initialize Model, Optimizer, Loss ---
model = DigitGenerator().to(device)
optimizer = optim.Adam(model.parameters(), lr=lr)
loss_fn = nn.MSELoss()

# --- Training Loop ---
model.train()
for epoch in range(epochs):
    for images, labels in train_loader:
        images = images.to(device)
        labels_onehot = one_hot(labels).to(device)
        z = torch.randn(images.size(0), z_dim).to(device)

        outputs = model(z, labels_onehot)
        loss = loss_fn(outputs, images)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    print(f"Epoch {epoch+1}/{epochs}, Loss: {loss.item():.4f}")

# --- Save Trained Model ---
torch.save(model.state_dict(), "digit_gen_model.pt")
print("Model saved as digit_gen_model.pt")


100%|██████████| 9.91M/9.91M [00:00<00:00, 18.2MB/s]
100%|██████████| 28.9k/28.9k [00:00<00:00, 500kB/s]
100%|██████████| 1.65M/1.65M [00:00<00:00, 3.99MB/s]
100%|██████████| 4.54k/4.54k [00:00<00:00, 8.66MB/s]


Epoch 1/10, Loss: 0.0524
Epoch 2/10, Loss: 0.0511
Epoch 3/10, Loss: 0.0526
Epoch 4/10, Loss: 0.0518
Epoch 5/10, Loss: 0.0541
Epoch 6/10, Loss: 0.0522
Epoch 7/10, Loss: 0.0543
Epoch 8/10, Loss: 0.0544
Epoch 9/10, Loss: 0.0529
Epoch 10/10, Loss: 0.0525
Model saved as digit_gen_model.pt
