In [1]:
import torch
import streamlit as st
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import os
from torchvision.utils import make_grid


In [2]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])
])

mnist_train = datasets.MNIST(root='.', train=True, download=True, transform=transform)
train_loader = DataLoader(mnist_train, batch_size=128, shuffle=True)


In [3]:
class DigitGenerator(nn.Module):
    def __init__(self, latent_dim=100, label_dim=10):
        super().__init__()
        self.label_embedding = nn.Embedding(label_dim, label_dim)
        self.model = nn.Sequential(
            nn.Linear(latent_dim + label_dim, 128),
            nn.LeakyReLU(0.2),
            nn.Linear(128, 256),
            nn.BatchNorm1d(256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 512),
            nn.BatchNorm1d(512),
            nn.LeakyReLU(0.2),
            nn.Linear(512, 784),
            nn.Tanh()
        )

    def forward(self, noise, labels):
        label_input = self.label_embedding(labels)
        x = torch.cat((noise, label_input), dim=1)
        x = self.model(x)
        return x.view(-1, 1, 28, 28)


In [4]:
latent_dim = 100
generator = DigitGenerator(latent_dim=latent_dim)
criterion = nn.MSELoss()
optimizer = optim.Adam(generator.parameters(), lr=0.0002)

epochs = 20
for epoch in range(epochs):
    for _ in range(len(train_loader)):
        z = torch.randn(64, latent_dim)
        labels = torch.randint(0, 10, (64,))
        fake_images = generator(z, labels)
        loss = criterion(fake_images, torch.randn_like(fake_images))

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    print(f"Epoch {epoch+1}/{epochs}, Loss: {loss.item():.4f}")


Epoch 1/20, Loss: 1.0194
Epoch 2/20, Loss: 1.0031
Epoch 3/20, Loss: 1.0033
Epoch 4/20, Loss: 1.0048
Epoch 5/20, Loss: 0.9999
Epoch 6/20, Loss: 1.0101
Epoch 7/20, Loss: 1.0054
Epoch 8/20, Loss: 0.9972
Epoch 9/20, Loss: 1.0015
Epoch 10/20, Loss: 1.0059
Epoch 11/20, Loss: 1.0024
Epoch 12/20, Loss: 0.9895
Epoch 13/20, Loss: 1.0018
Epoch 14/20, Loss: 0.9996
Epoch 15/20, Loss: 0.9990
Epoch 16/20, Loss: 1.0090
Epoch 17/20, Loss: 1.0030
Epoch 18/20, Loss: 1.0009
Epoch 19/20, Loss: 1.0047
Epoch 20/20, Loss: 0.9932


In [5]:
os.makedirs('models', exist_ok=True)
torch.save(generator.state_dict(), 'models/digit_generator.pth')
