In [1]:
# ======================================================
# 1. IMPORT LIBRARIES
# ======================================================
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms, utils
from torch.utils.data import DataLoader
import os
import numpy as np
from collections import Counter

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

# ======================================================
# 2. USER INPUT PARAMETERS
# ======================================================
dataset_choice = 'mnist'        # 'mnist' or 'fashion'
epochs = 30
batch_size = 128
noise_dim = 100
lr_G = 0.0002
lr_D = 0.0001
save_interval = 5

# ======================================================
# 3. DATASET LOADING
# ======================================================
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

if dataset_choice == 'mnist':
    dataset = datasets.MNIST('./data', train=True, download=True, transform=transform)
elif dataset_choice == 'fashion':
    dataset = datasets.FashionMNIST('./data', train=True, download=True, transform=transform)
else:
    raise ValueError("Invalid dataset choice")

dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
img_shape = (1, 28, 28)

# ======================================================
# 4. OUTPUT FOLDERS
# ======================================================
os.makedirs("generated_samples", exist_ok=True)
os.makedirs("final_generated_images", exist_ok=True)

# ======================================================
# 5. GENERATOR
# ======================================================
class Generator(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(noise_dim, 256),
            nn.ReLU(True),
            nn.Linear(256, 512),
            nn.ReLU(True),
            nn.Linear(512, 1024),
            nn.ReLU(True),
            nn.Linear(1024, int(np.prod(img_shape))),
            nn.Tanh()
        )

    def forward(self, z):
        img = self.model(z)
        return img.view(img.size(0), *img_shape)

# ======================================================
# 6. DISCRIMINATOR
# ======================================================
class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(int(np.prod(img_shape)), 512),
            nn.LeakyReLU(0.2),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 1),
            nn.Sigmoid()
        )

    def forward(self, img):
        img = img.view(img.size(0), -1)
        return self.model(img)

G = Generator().to(device)
D = Discriminator().to(device)

# ======================================================
# 7. LOSS & OPTIMIZERS
# ======================================================
criterion = nn.BCELoss()
optimizer_G = optim.Adam(G.parameters(), lr=lr_G, betas=(0.5, 0.999))
optimizer_D = optim.Adam(D.parameters(), lr=lr_D, betas=(0.5, 0.999))

# ======================================================
# 8. TRAINING LOOP
# ======================================================
for epoch in range(1, epochs + 1):
    D_loss_total, G_loss_total = 0.0, 0.0
    correct, total = 0, 0

    for real_imgs, _ in dataloader:
        real_imgs = real_imgs.to(device)
        batch = real_imgs.size(0)

        # Label smoothing
        real_labels = torch.full((batch, 1), 0.9).to(device)
        fake_labels = torch.zeros(batch, 1).to(device)

        # --------------------
        # Train Discriminator
        # --------------------
        optimizer_D.zero_grad()

        real_loss = criterion(D(real_imgs), real_labels)

        z = torch.randn(batch, noise_dim).to(device)
        fake_imgs = G(z)
        fake_loss = criterion(D(fake_imgs.detach()), fake_labels)

        D_loss = real_loss + fake_loss
        D_loss.backward()
        optimizer_D.step()

        # Accuracy
        preds_real = (D(real_imgs) > 0.5).float()
        preds_fake = (D(fake_imgs.detach()) < 0.5).float()
        correct += preds_real.sum().item() + preds_fake.sum().item()
        total += batch * 2

        # --------------------
        # Train Generator (TWICE, FIXED)
        # --------------------
        for _ in range(2):
            optimizer_G.zero_grad()

            z = torch.randn(batch, noise_dim).to(device)   # NEW noise
            fake_imgs = G(z)                               # NEW graph

            G_loss = criterion(D(fake_imgs), real_labels)
            G_loss.backward()
            optimizer_G.step()

        D_loss_total += D_loss.item()
        G_loss_total += G_loss.item()

    D_acc = (correct / total) * 100

    print(f"Epoch {epoch}/{epochs} | "
          f"D_loss: {D_loss_total/len(dataloader):.3f} | "
          f"D_acc: {D_acc:.2f}% | "
          f"G_loss: {G_loss_total/len(dataloader):.3f}")

    # Save generated samples
    if epoch % save_interval == 0:
        utils.save_image(fake_imgs[:25],
                         f"generated_samples/epoch_{epoch:02d}.png",
                         nrow=5,
                         normalize=True)

# ======================================================
# 9. GENERATE FINAL 100 IMAGES
# ======================================================
z = torch.randn(100, noise_dim).to(device)
final_images = G(z)

for i in range(100):
    utils.save_image(final_images[i],
                     f"final_generated_images/img_{i}.png",
                     normalize=True)

# ======================================================
# 10. SIMPLE CLASSIFIER
# ======================================================
class Classifier(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Flatten(),
            nn.Linear(784, 128),
            nn.ReLU(),
            nn.Linear(128, 10)
        )

    def forward(self, x):
        return self.net(x)

classifier = Classifier().to(device)
optimizer_C = optim.Adam(classifier.parameters(), lr=0.001)
loss_fn = nn.CrossEntropyLoss()

for epoch in range(3):
    for imgs, labels in dataloader:
        imgs, labels = imgs.to(device), labels.to(device)
        optimizer_C.zero_grad()
        loss = loss_fn(classifier(imgs), labels)
        loss.backward()
        optimizer_C.step()

# ======================================================
# 11. LABEL PREDICTION
# ======================================================
with torch.no_grad():
    preds = classifier(final_images).argmax(dim=1).cpu().numpy()

label_counts = Counter(preds)

print("\nLabel Distribution of Generated Images:")
for label, count in sorted(label_counts.items()):
    print(f"Label {label}: {count}")

Using device: cuda


100%|██████████| 9.91M/9.91M [00:01<00:00, 5.02MB/s]
100%|██████████| 28.9k/28.9k [00:00<00:00, 132kB/s]
100%|██████████| 1.65M/1.65M [00:01<00:00, 1.22MB/s]
100%|██████████| 4.54k/4.54k [00:00<00:00, 13.7MB/s]


Epoch 1/30 | D_loss: 1.414 | D_acc: 50.92% | G_loss: 0.679
Epoch 2/30 | D_loss: 1.390 | D_acc: 51.07% | G_loss: 0.731
Epoch 3/30 | D_loss: 1.388 | D_acc: 54.79% | G_loss: 0.743
Epoch 4/30 | D_loss: 1.386 | D_acc: 58.38% | G_loss: 0.756
Epoch 5/30 | D_loss: 1.381 | D_acc: 62.63% | G_loss: 0.769
Epoch 6/30 | D_loss: 1.361 | D_acc: 63.97% | G_loss: 0.787
Epoch 7/30 | D_loss: 1.360 | D_acc: 65.76% | G_loss: 0.804
Epoch 8/30 | D_loss: 1.361 | D_acc: 65.80% | G_loss: 0.803
Epoch 9/30 | D_loss: 1.353 | D_acc: 68.81% | G_loss: 0.822
Epoch 10/30 | D_loss: 1.336 | D_acc: 68.06% | G_loss: 0.838
Epoch 11/30 | D_loss: 1.348 | D_acc: 69.66% | G_loss: 0.835
Epoch 12/30 | D_loss: 1.313 | D_acc: 70.77% | G_loss: 0.878
Epoch 13/30 | D_loss: 1.295 | D_acc: 74.19% | G_loss: 0.923
Epoch 14/30 | D_loss: 1.314 | D_acc: 72.77% | G_loss: 0.897
Epoch 15/30 | D_loss: 1.250 | D_acc: 75.65% | G_loss: 0.973
Epoch 16/30 | D_loss: 1.275 | D_acc: 76.65% | G_loss: 0.965
Epoch 17/30 | D_loss: 1.212 | D_acc: 77.17% | G_l

In [2]:
!zip -r /content/final_generated_images.zip /content/final_generated_images


  adding: content/final_generated_images/ (stored 0%)
  adding: content/final_generated_images/img_49.png (stored 0%)
  adding: content/final_generated_images/img_59.png (stored 0%)
  adding: content/final_generated_images/img_46.png (stored 0%)
  adding: content/final_generated_images/img_90.png (stored 0%)
  adding: content/final_generated_images/img_53.png (stored 0%)
  adding: content/final_generated_images/img_35.png (stored 0%)
  adding: content/final_generated_images/img_0.png (stored 0%)
  adding: content/final_generated_images/img_27.png (stored 0%)
  adding: content/final_generated_images/img_73.png (stored 0%)
  adding: content/final_generated_images/img_64.png (stored 0%)
  adding: content/final_generated_images/img_80.png (stored 0%)
  adding: content/final_generated_images/img_22.png (stored 0%)
  adding: content/final_generated_images/img_61.png (stored 0%)
  adding: content/final_generated_images/img_15.png (stored 0%)
  adding: content/final_generated_images/img_70.png

In [3]:
!zip -r /content/generated_samples.zip /content/generated_samples

  adding: content/generated_samples/ (stored 0%)
  adding: content/generated_samples/epoch_15.png (stored 0%)
  adding: content/generated_samples/epoch_10.png (deflated 1%)
  adding: content/generated_samples/epoch_30.png (deflated 5%)
  adding: content/generated_samples/epoch_25.png (deflated 14%)
  adding: content/generated_samples/epoch_20.png (deflated 0%)
  adding: content/generated_samples/epoch_05.png (deflated 2%)
