In [None]:
# google colab link: https://colab.research.google.com/drive/1SG9QdlxYQ3CIsR9rPTN4aRj_xRSxHVZt?usp=sharing
# streamlit website link: https://metiresponse3-vn6vhxwhfbpoqmj8vf8fzz.streamlit.app/
import torch # type: ignore
import torch.nn as nn # type: ignore
import torch.optim as optim # type: ignore
from torchvision import datasets, transforms # type: ignore
from torch.utils.data import DataLoader # type: ignore
import torch.nn.functional as F # type: ignore

# Device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Hyperparams
batch_size = 128
latent_dim = 100
num_epochs = 30
learning_rate = 0.0002

# MNIST dataset loader
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])  # Normalize to [-1,1]
])

train_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

# Generator network
class Generator(nn.Module):
    def __init__(self):
        super().__init__()
        self.label_emb = nn.Embedding(10, 10)  # For conditioning on digit label

        self.model = nn.Sequential(
            nn.Linear(latent_dim + 10, 256),
            nn.ReLU(True),
            nn.Linear(256, 512),
            nn.ReLU(True),
            nn.Linear(512, 1024),
            nn.ReLU(True),
            nn.Linear(1024, 28*28),
            nn.Tanh()  # Output range [-1,1]
        )
    def forward(self, noise, labels):
        c = self.label_emb(labels)
        x = torch.cat([noise, c], 1)
        img = self.model(x)
        img = img.view(-1, 1, 28, 28)
        return img

# Discriminator network
class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.label_emb = nn.Embedding(10, 10)

        self.model = nn.Sequential(
            nn.Linear(28*28 + 10, 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(256, 1),
            nn.Sigmoid()
        )
    def forward(self, img, labels):
        c = self.label_emb(labels)
        x = img.view(img.size(0), -1)
        x = torch.cat([x, c], 1)
        validity = self.model(x)
        return validity

# Initialize models and optimizers
generator = Generator().to(device)
discriminator = Discriminator().to(device)

optimizer_G = optim.Adam(generator.parameters(), lr=learning_rate, betas=(0.5, 0.999))
optimizer_D = optim.Adam(discriminator.parameters(), lr=learning_rate, betas=(0.5, 0.999))

adversarial_loss = nn.BCELoss()

# Training loop
for epoch in range(num_epochs):
    for batch_idx, (imgs, labels) in enumerate(train_loader):

        batch_size_curr = imgs.size(0)
        real_imgs = imgs.to(device)
        labels = labels.to(device)

        # Real and fake labels
        valid = torch.ones(batch_size_curr, 1, device=device)
        fake = torch.zeros(batch_size_curr, 1, device=device)

        # Train Generator
        optimizer_G.zero_grad()

        z = torch.randn(batch_size_curr, latent_dim, device=device)
        gen_labels = torch.randint(0, 10, (batch_size_curr,), device=device)

        gen_imgs = generator(z, gen_labels)
        g_loss = adversarial_loss(discriminator(gen_imgs, gen_labels), valid)

        g_loss.backward()
        optimizer_G.step()

        # Train Discriminator
        optimizer_D.zero_grad()

        real_loss = adversarial_loss(discriminator(real_imgs, labels), valid)
        fake_loss = adversarial_loss(discriminator(gen_imgs.detach(), gen_labels), fake)
        d_loss = (real_loss + fake_loss) / 2

        d_loss.backward()
        optimizer_D.step()

        if batch_idx % 400 == 0:
            print(f"Epoch [{epoch}/{num_epochs}] Batch {batch_idx}/{len(train_loader)} \
                  Loss D: {d_loss.item():.4f}, loss G: {g_loss.item():.4f}")

# Save Generator model weights
torch.save(generator.state_dict(), "generator.pth")


100%|██████████| 9.91M/9.91M [00:00<00:00, 20.3MB/s]
100%|██████████| 28.9k/28.9k [00:00<00:00, 614kB/s]
100%|██████████| 1.65M/1.65M [00:00<00:00, 5.60MB/s]
100%|██████████| 4.54k/4.54k [00:00<00:00, 7.86MB/s]


Epoch [0/30] Batch 0/469                   Loss D: 0.7108, loss G: 0.6997
Epoch [0/30] Batch 400/469                   Loss D: 0.5014, loss G: 0.6275
Epoch [1/30] Batch 0/469                   Loss D: 0.3431, loss G: 1.3717
Epoch [1/30] Batch 400/469                   Loss D: 0.2785, loss G: 1.9023
Epoch [2/30] Batch 0/469                   Loss D: 0.1412, loss G: 1.9447
Epoch [2/30] Batch 400/469                   Loss D: 0.2188, loss G: 1.9723
Epoch [3/30] Batch 0/469                   Loss D: 0.1435, loss G: 1.7376
Epoch [3/30] Batch 400/469                   Loss D: 0.1873, loss G: 2.2029
Epoch [4/30] Batch 0/469                   Loss D: 0.2992, loss G: 1.4917
Epoch [4/30] Batch 400/469                   Loss D: 0.2373, loss G: 2.5755
Epoch [5/30] Batch 0/469                   Loss D: 0.2341, loss G: 1.3738
Epoch [5/30] Batch 400/469                   Loss D: 0.1212, loss G: 2.5346
Epoch [6/30] Batch 0/469                   Loss D: 0.2406, loss G: 2.0199
Epoch [6/30] Batch 400/469