In [1]:
# importing necessary libraires
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import torchvision.utils as vutils

In [2]:
# Hyperparameters
latent_dim = 100
image_size = 28*28
batch_size = 128
lr = 0.0002
epochs = 50
sample_dir = 'samples'

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [4]:
class generator(nn.Module):
  def __init__(self, latent_dimension):
    super().__init__()

    self.generator = nn.Sequential(
        nn.Linear(latent_dimension, 128*7*7),
        nn.BatchNorm1d(128*7*7),
        nn.ReLU(inplace=True),
        nn.Unflatten(1, (128, 7, 7)),

        nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
        nn.ReLU(inplace=True),

        nn.ConvTranspose2d(64, 1, kernel_size=4, stride=2, padding=1),
        nn.Tanh()
    )

  def forward(self,z):
    x = self.generator(z)
    return x #.view(-1,1,28,28)



In [5]:
def lol():
  latent_dim = 100
  batch_size = 2

  z = torch.randn(batch_size, latent_dim)
  model = generator(100)
  output = model(z)

  return output.shape

print(lol())

torch.Size([2, 1, 28, 28])


In [6]:
class discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
            nn.Conv2d(1, 32, 3),       # → (batch,32,26,26)
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2),           # → (batch,32,13,13)

            nn.Conv2d(32, 64, 3),      # → (batch,64,11,11)
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2),           # → (batch,64,5,5)

            nn.Flatten(),              # → (batch,64*5*5)
            nn.Linear(64 * 5 * 5, 1),
            nn.Sigmoid()               # output in [0,1]
        )

    def forward(self, x):
        return self.model(x)

In [7]:
def lol():
  x = torch.randn(3,1,28,28)
  model = discriminator()
  out = model(x)
  return out

print(lol())

tensor([[0.4977],
        [0.4202],
        [0.3982]], grad_fn=<SigmoidBackward0>)


In [8]:
transform = transforms.Compose([
    transforms.Resize((28, 28)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5], std=[0.5]),
])

def label_as_zero(_):
    return torch.tensor(0, dtype=torch.long)

mnist = datasets.MNIST(
    root='./data',
    train=True,
    transform=transform,
    download=True,
    target_transform=label_as_zero
    )

In [9]:
dataloader = DataLoader(mnist, batch_size=batch_size, shuffle=True)

In [10]:
generator = generator(100).to(device)
discriminator = discriminator().to(device)

criterion = nn.BCELoss()
optimizer_G = torch.optim.Adam(generator.parameters(), lr=lr)
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=lr)


In [11]:
images, labels = next(iter(dataloader))

# 3. Inspect shapes
print(images.shape)  # torch.Size([64, 1, 28, 28])
print(labels.shape)  # torch.Size([64])

torch.Size([128, 1, 28, 28])
torch.Size([128])


In [13]:
sample_dir = 'samples'
# Create the samples directory if it doesn't exist
if not os.path.exists(sample_dir):
    os.makedirs(sample_dir)


for epoch in range(epochs):
    for i, (real_imgs, _) in enumerate(dataloader):
        real_imgs = real_imgs.to(device)
        batch_size = real_imgs.size(0)

        # === Train Discriminator ===
        z = torch.randn(batch_size, latent_dim).to(device)
        fake_imgs = generator(z)

        real_labels = torch.ones(batch_size, 1).to(device)
        fake_labels = torch.zeros(batch_size, 1).to(device)

        d_real = discriminator(real_imgs)
        d_fake = discriminator(fake_imgs.detach())

        loss_real = criterion(d_real, real_labels)
        loss_fake = criterion(d_fake, fake_labels)
        d_loss = loss_real + loss_fake

        optimizer_D.zero_grad()
        d_loss.backward()
        optimizer_D.step()

        # === Train Generator ===
        z = torch.randn(batch_size, latent_dim).to(device)
        fake_imgs = generator(z)
        g_loss = criterion(discriminator(fake_imgs), real_labels)

        optimizer_G.zero_grad()
        g_loss.backward()
        optimizer_G.step()

        if i % 200 == 0:
            print(f"Epoch [{epoch+1}/{epochs}] Batch {i}/{len(dataloader)} \
                  Loss D: {d_loss.item():.4f}, Loss G: {g_loss.item():.4f}")

    # Save samples every epoch
    vutils.save_image(fake_imgs[:64], os.path.join(sample_dir, f"epoch_{epoch+1}.png"), normalize=True, nrow=8)

Epoch [1/50] Batch 0/469                   Loss D: 0.0011, Loss G: 6.9295
Epoch [1/50] Batch 200/469                   Loss D: 0.0008, Loss G: 7.1587
Epoch [1/50] Batch 400/469                   Loss D: 0.0007, Loss G: 7.3310
Epoch [2/50] Batch 0/469                   Loss D: 0.0006, Loss G: 7.4018
Epoch [2/50] Batch 200/469                   Loss D: 0.0005, Loss G: 7.5875
Epoch [2/50] Batch 400/469                   Loss D: 0.0007, Loss G: 7.7710
Epoch [3/50] Batch 0/469                   Loss D: 0.0005, Loss G: 7.8144
Epoch [3/50] Batch 200/469                   Loss D: 0.0004, Loss G: 7.9753
Epoch [3/50] Batch 400/469                   Loss D: 0.0003, Loss G: 8.1224
Epoch [4/50] Batch 0/469                   Loss D: 0.0003, Loss G: 8.1787
Epoch [4/50] Batch 200/469                   Loss D: 0.0003, Loss G: 8.3218
Epoch [4/50] Batch 400/469                   Loss D: 0.0006, Loss G: 8.4875
Epoch [5/50] Batch 0/469                   Loss D: 0.0002, Loss G: 8.5182
Epoch [5/50] Batch 200

KeyboardInterrupt: 