In [3]:
!pip install torch torchvision matplotlib numpy scikit-learn




In [4]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torchvision.utils import save_image, make_grid
import matplotlib.pyplot as plt
import numpy as np
import os


In [5]:
dataset_choice = input("Enter dataset (mnist / fashion): ").lower()
epochs = int(input("Enter number of epochs: "))
batch_size = int(input("Enter batch size: "))
noise_dim = int(input("Enter noise dimension: "))
learning_rate = float(input("Enter learning rate: "))
save_interval = int(input("Enter save interval: "))


Enter dataset (mnist / fashion): mnist
Enter number of epochs: 50
Enter batch size: 64
Enter noise dimension: 100
Enter learning rate: 0.0002
Enter save interval: 5


In [6]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])   # [-1, 1]
])

if dataset_choice == "mnist":
    dataset = datasets.MNIST(root="./data", train=True, download=True, transform=transform)
elif dataset_choice == "fashion":
    dataset = datasets.FashionMNIST(root="./data", train=True, download=True, transform=transform)
else:
    raise ValueError("Invalid dataset choice")

dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)


100%|██████████| 9.91M/9.91M [00:00<00:00, 19.1MB/s]
100%|██████████| 28.9k/28.9k [00:00<00:00, 492kB/s]
100%|██████████| 1.65M/1.65M [00:00<00:00, 4.54MB/s]
100%|██████████| 4.54k/4.54k [00:00<00:00, 10.9MB/s]


In [7]:
class Generator(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(noise_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 512),
            nn.ReLU(),
            nn.Linear(512, 784),
            nn.Tanh()
        )

    def forward(self, z):
        img = self.model(z)
        return img.view(img.size(0), 1, 28, 28)


In [8]:
class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(784, 512),
            nn.LeakyReLU(0.2),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 1),
            nn.Sigmoid()
        )

    def forward(self, img):
        img_flat = img.view(img.size(0), -1)
        return self.model(img_flat)


In [9]:
device = "cuda"
G = Generator().to(device)
D = Discriminator().to(device)

criterion = nn.BCELoss()
optimizer_G = optim.Adam(G.parameters(), lr=learning_rate)
optimizer_D = optim.Adam(D.parameters(), lr=learning_rate)


In [10]:
os.makedirs("generated_samples", exist_ok=True)
os.makedirs("final_generated_images", exist_ok=True)


In [13]:
for epoch in range(1, epochs + 1):
    for real_imgs, _ in dataloader:
        real_imgs = real_imgs.to(device)
        batch = real_imgs.size(0)

        valid = torch.ones(batch, 1).to(device)
        fake = torch.zeros(batch, 1).to(device)

        # ---- Train Discriminator ----
        z = torch.randn(batch, noise_dim).to(device)
        fake_imgs = G(z)

        real_loss = criterion(D(real_imgs), valid)
        fake_loss = criterion(D(fake_imgs.detach()), fake)
        d_loss = real_loss + fake_loss

        optimizer_D.zero_grad()
        d_loss.backward()
        optimizer_D.step()

        # ---- Train Generator ----
        g_loss = criterion(D(fake_imgs), valid)

        optimizer_G.zero_grad()
        g_loss.backward()
        optimizer_G.step()

    print(f"Epoch {epoch}/{epochs} | D_loss: {d_loss.item():.2f} | G_loss: {g_loss.item():.2f}")

    if epoch % save_interval == 0:
        save_image(fake_imgs[:25], f"generated_samples/epoch_{epoch:02d}.png", nrow=5, normalize=True)


Epoch 1/50 | D_loss: 0.66 | G_loss: 2.12
Epoch 2/50 | D_loss: 0.86 | G_loss: 1.51
Epoch 3/50 | D_loss: 0.40 | G_loss: 4.00
Epoch 4/50 | D_loss: 0.78 | G_loss: 2.17
Epoch 5/50 | D_loss: 0.41 | G_loss: 3.38
Epoch 6/50 | D_loss: 0.90 | G_loss: 1.68
Epoch 7/50 | D_loss: 0.50 | G_loss: 2.19
Epoch 8/50 | D_loss: 0.50 | G_loss: 2.26
Epoch 9/50 | D_loss: 0.67 | G_loss: 2.48
Epoch 10/50 | D_loss: 0.60 | G_loss: 2.27
Epoch 11/50 | D_loss: 0.74 | G_loss: 2.60
Epoch 12/50 | D_loss: 0.76 | G_loss: 2.44
Epoch 13/50 | D_loss: 0.67 | G_loss: 2.30
Epoch 14/50 | D_loss: 0.95 | G_loss: 1.48
Epoch 15/50 | D_loss: 0.43 | G_loss: 2.63
Epoch 16/50 | D_loss: 0.62 | G_loss: 1.90
Epoch 17/50 | D_loss: 0.76 | G_loss: 2.01
Epoch 18/50 | D_loss: 0.58 | G_loss: 2.67
Epoch 19/50 | D_loss: 0.66 | G_loss: 2.39
Epoch 20/50 | D_loss: 0.79 | G_loss: 2.27
Epoch 21/50 | D_loss: 0.62 | G_loss: 3.32
Epoch 22/50 | D_loss: 0.63 | G_loss: 1.95
Epoch 23/50 | D_loss: 0.58 | G_loss: 1.81
Epoch 24/50 | D_loss: 0.80 | G_loss: 2.08
E

In [14]:
z = torch.randn(100, noise_dim).to(device)
final_imgs = G(z)

for i in range(100):
    save_image(final_imgs[i], f"final_generated_images/img_{i+1}.png", normalize=True)


In [16]:
from torchvision.models import resnet18
from collections import Counter

classifier = resnet18(pretrained=True)
classifier.fc = nn.Linear(512, 10)
classifier.to(device) # Move classifier to the same device as the input
classifier.eval()

labels = []
for img in final_imgs:
    img = img.repeat(3, 1, 1).unsqueeze(0)
    pred = classifier(img).argmax().item()
    labels.append(pred)

print("Label Distribution:", Counter(labels))



Label Distribution: Counter({2: 82, 8: 9, 3: 5, 5: 3, 9: 1})


In [17]:
!ls


data  final_generated_images  generated_samples  sample_data


In [18]:
!zip -r generated_samples.zip generated_samples
!zip -r final_generated_images.zip final_generated_images


  adding: generated_samples/ (stored 0%)
  adding: generated_samples/epoch_40.png (deflated 2%)
  adding: generated_samples/epoch_05.png (deflated 3%)
  adding: generated_samples/epoch_15.png (deflated 2%)
  adding: generated_samples/epoch_35.png (deflated 2%)
  adding: generated_samples/epoch_30.png (deflated 3%)
  adding: generated_samples/epoch_45.png (deflated 3%)
  adding: generated_samples/epoch_25.png (deflated 2%)
  adding: generated_samples/epoch_20.png (deflated 3%)
  adding: generated_samples/epoch_10.png (deflated 3%)
  adding: generated_samples/epoch_50.png (deflated 3%)
  adding: final_generated_images/ (stored 0%)
  adding: final_generated_images/img_81.png (stored 0%)
  adding: final_generated_images/img_11.png (stored 0%)
  adding: final_generated_images/img_58.png (stored 0%)
  adding: final_generated_images/img_74.png (stored 0%)
  adding: final_generated_images/img_26.png (stored 0%)
  adding: final_generated_images/img_86.png (stored 0%)
  adding: final_generated_i