In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from torchvision.utils import save_image
from tqdm import tqdm

In [2]:
# Device configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Hyperparameters
latent_dim = 100
img_size = 28
channels = 1
batch_size = 64
num_epochs_gan = 20
num_epochs_classifier = 5
learning_rate = 0.0002

# Image shape
img_shape = (channels, img_size, img_size)

# GAN Components: Generator
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(latent_dim, 128),
            nn.ReLU(),
            nn.Linear(128, 256),
            nn.BatchNorm1d(256),
            nn.ReLU(),
            nn.Linear(256, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(),
            nn.Linear(512, int(torch.prod(torch.tensor(img_shape)))),
            nn.Tanh()
        )

    def forward(self, z):
        img = self.model(z)
        img = img.view(img.size(0), *img_shape)
        return img

# GAN Components: Discriminator
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(int(torch.prod(torch.tensor(img_shape))), 512),
            nn.LeakyReLU(0.2),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 1),
            nn.Sigmoid()
        )

    def forward(self, img):
        img_flat = img.view(img.size(0), -1)
        validity = self.model(img_flat)
        return validity

# MNIST Classifier
class Classifier(nn.Module):
    def __init__(self):
        super(Classifier, self).__init__()
        self.model = nn.Sequential(
            nn.Flatten(),
            nn.Linear(28 * 28, 128),
            nn.ReLU(),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, 10),
        )

    def forward(self, x):
        return self.model(x)

# Data loader
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])
])

mnist = datasets.MNIST(root="./data", train=True, download=True, transform=transform)
data_loader = DataLoader(mnist, batch_size=batch_size, shuffle=True)

# Initialize GAN models
generator = Generator().to(device)
discriminator = Discriminator().to(device)

# Loss function
adversarial_loss = nn.BCELoss()

# Optimizers
optimizer_G = optim.Adam(generator.parameters(), lr=learning_rate)
optimizer_D = optim.Adam(discriminator.parameters(), lr=learning_rate)

# Training GAN
for epoch in tqdm(range(num_epochs_gan)):
    for i, (imgs, _) in enumerate(data_loader):
        # Prepare real and fake inputs
        real_imgs = imgs.to(device)
        batch_size = real_imgs.size(0)
        valid = torch.ones((batch_size, 1), device=device, dtype=torch.float32)
        fake = torch.zeros((batch_size, 1), device=device, dtype=torch.float32)

        # Train Generator
        optimizer_G.zero_grad()
        z = torch.randn(batch_size, latent_dim, device=device)
        gen_imgs = generator(z)
        g_loss = adversarial_loss(discriminator(gen_imgs), valid)
        g_loss.backward()
        optimizer_G.step()

        # Train Discriminator
        optimizer_D.zero_grad()
        real_loss = adversarial_loss(discriminator(real_imgs), valid)
        fake_loss = adversarial_loss(discriminator(gen_imgs.detach()), fake)
        d_loss = (real_loss + fake_loss) / 2
        d_loss.backward()
        optimizer_D.step()

    print(f"[Epoch {epoch+1}/{num_epochs_gan}] Loss D: {d_loss.item()}, Loss G: {g_loss.item()}")

# Save some generated images
z = torch.randn(64, latent_dim, device=device)
generated_imgs = generator(z)
save_image(generated_imgs.data, "generated_images.png", nrow=8, normalize=True)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9.91M/9.91M [00:00<00:00, 21.5MB/s]


Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28.9k/28.9k [00:00<00:00, 617kB/s]


Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1.65M/1.65M [00:00<00:00, 5.64MB/s]


Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4.54k/4.54k [00:00<00:00, 2.91MB/s]


Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw



  5%|▌         | 1/20 [00:43<13:50, 43.73s/it]

[Epoch 1/20] Loss D: 0.09871074557304382, Loss G: 7.507387638092041


 10%|█         | 2/20 [01:21<12:09, 40.51s/it]

[Epoch 2/20] Loss D: 0.041166190057992935, Loss G: 7.423006057739258


 15%|█▌        | 3/20 [01:59<11:08, 39.33s/it]

[Epoch 3/20] Loss D: 0.07913760095834732, Loss G: 8.655781745910645


 20%|██        | 4/20 [02:37<10:17, 38.58s/it]

[Epoch 4/20] Loss D: 0.07188107818365097, Loss G: 3.3360321521759033


 25%|██▌       | 5/20 [03:14<09:32, 38.15s/it]

[Epoch 5/20] Loss D: 0.013869359157979488, Loss G: 8.017868041992188


 30%|███       | 6/20 [03:52<08:51, 37.99s/it]

[Epoch 6/20] Loss D: 0.10452797263860703, Loss G: 5.935298442840576


 35%|███▌      | 7/20 [04:29<08:12, 37.86s/it]

[Epoch 7/20] Loss D: 0.1394721120595932, Loss G: 2.8424363136291504


 40%|████      | 8/20 [05:06<07:29, 37.45s/it]

[Epoch 8/20] Loss D: 0.041789937764406204, Loss G: 5.946020603179932


 45%|████▌     | 9/20 [05:43<06:50, 37.29s/it]

[Epoch 9/20] Loss D: 0.050974104553461075, Loss G: 4.458655834197998


 50%|█████     | 10/20 [06:20<06:11, 37.18s/it]

[Epoch 10/20] Loss D: 0.051654741168022156, Loss G: 3.6465940475463867


 55%|█████▌    | 11/20 [06:57<05:33, 37.05s/it]

[Epoch 11/20] Loss D: 0.07393762469291687, Loss G: 5.578413009643555


 60%|██████    | 12/20 [07:34<04:57, 37.18s/it]

[Epoch 12/20] Loss D: 0.04107266291975975, Loss G: 3.8037264347076416


 65%|██████▌   | 13/20 [08:11<04:19, 37.11s/it]

[Epoch 13/20] Loss D: 0.09571036696434021, Loss G: 3.7022757530212402


 70%|███████   | 14/20 [08:47<03:40, 36.72s/it]

[Epoch 14/20] Loss D: 0.1050112247467041, Loss G: 3.7132811546325684


 75%|███████▌  | 15/20 [09:24<03:03, 36.71s/it]

[Epoch 15/20] Loss D: 0.31433606147766113, Loss G: 2.6017470359802246


 80%|████████  | 16/20 [10:01<02:27, 36.86s/it]

[Epoch 16/20] Loss D: 0.15417727828025818, Loss G: 2.9456515312194824


 85%|████████▌ | 17/20 [10:39<01:51, 37.12s/it]

[Epoch 17/20] Loss D: 0.14751377701759338, Loss G: 3.645857810974121


 90%|█████████ | 18/20 [11:16<01:14, 37.27s/it]

[Epoch 18/20] Loss D: 0.1348954439163208, Loss G: 4.629648685455322


 95%|█████████▌| 19/20 [11:54<00:37, 37.33s/it]

[Epoch 19/20] Loss D: 0.14929147064685822, Loss G: 3.368868350982666


100%|██████████| 20/20 [12:32<00:00, 37.62s/it]

[Epoch 20/20] Loss D: 0.20883461833000183, Loss G: 3.9148964881896973





In [3]:
# Train MNIST Classifier
classifier = Classifier().to(device)
criterion = nn.CrossEntropyLoss()
optimizer_C = optim.Adam(classifier.parameters(), lr=learning_rate)

# Prepare dataset combining real and synthetic images
synthetic_imgs = generator(torch.randn(len(mnist), latent_dim, device=device)).detach()
synthetic_labels = torch.randint(0, 10, (len(mnist),), device=device)  # Random labels for GAN outputs
combined_data = torch.cat([mnist.data.unsqueeze(1).float() / 255.0, synthetic_imgs])
combined_labels = torch.cat([mnist.targets, synthetic_labels])

combined_dataset = torch.utils.data.TensorDataset(combined_data, combined_labels)
combined_loader = DataLoader(combined_dataset, batch_size=batch_size, shuffle=True)

# Training the classifier
for epoch in tqdm(range(20)):
    for imgs, labels in combined_loader:
        imgs, labels = imgs.to(device), labels.to(device)
        optimizer_C.zero_grad()
        outputs = classifier(imgs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer_C.step()

    print(f"[Epoch {epoch+1}/{num_epochs_classifier}] Loss: {loss.item()}")

# Evaluate classifier on real MNIST test set
test_mnist = datasets.MNIST(root="./data", train=False, transform=transform)
test_loader = DataLoader(test_mnist, batch_size=batch_size, shuffle=False)

correct = 0
total = 0
with torch.no_grad():
    for imgs, labels in test_loader:
        imgs, labels = imgs.to(device), labels.to(device)
        outputs = classifier(imgs)
        _, predicted = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print(f"Classifier accuracy on real MNIST test set: {100 * correct / total:.2f}%")

  5%|▌         | 1/20 [00:09<03:03,  9.64s/it]

[Epoch 1/5] Loss: 1.3152340650558472


 10%|█         | 2/20 [00:19<02:58,  9.92s/it]

[Epoch 2/5] Loss: 1.0518131256103516


 15%|█▌        | 3/20 [00:30<02:52, 10.12s/it]

[Epoch 3/5] Loss: 1.3984044790267944


 20%|██        | 4/20 [00:40<02:45, 10.33s/it]

[Epoch 4/5] Loss: 1.5138798952102661


 25%|██▌       | 5/20 [00:50<02:32, 10.20s/it]

[Epoch 5/5] Loss: 1.024999737739563


 30%|███       | 6/20 [01:01<02:25, 10.41s/it]

[Epoch 6/5] Loss: 0.8751224279403687


 35%|███▌      | 7/20 [01:12<02:17, 10.56s/it]

[Epoch 7/5] Loss: 1.2639003992080688


 40%|████      | 8/20 [01:23<02:08, 10.68s/it]

[Epoch 8/5] Loss: 1.1872507333755493


 45%|████▌     | 9/20 [01:34<01:58, 10.79s/it]

[Epoch 9/5] Loss: 0.9603437185287476


 50%|█████     | 10/20 [01:45<01:48, 10.82s/it]

[Epoch 10/5] Loss: 0.9733055233955383


 55%|█████▌    | 11/20 [01:56<01:37, 10.83s/it]

[Epoch 11/5] Loss: 0.9468473792076111


 60%|██████    | 12/20 [02:06<01:25, 10.72s/it]

[Epoch 12/5] Loss: 1.166015386581421


 65%|██████▌   | 13/20 [02:17<01:15, 10.81s/it]

[Epoch 13/5] Loss: 1.0195095539093018


 70%|███████   | 14/20 [02:28<01:05, 10.85s/it]

[Epoch 14/5] Loss: 1.0780526399612427


 75%|███████▌  | 15/20 [02:39<00:54, 10.97s/it]

[Epoch 15/5] Loss: 0.9465162754058838


 80%|████████  | 16/20 [02:51<00:44, 11.07s/it]

[Epoch 16/5] Loss: 0.9808140397071838


 85%|████████▌ | 17/20 [03:02<00:33, 11.12s/it]

[Epoch 17/5] Loss: 1.1669838428497314


 90%|█████████ | 18/20 [03:13<00:22, 11.16s/it]

[Epoch 18/5] Loss: 1.297971248626709


 95%|█████████▌| 19/20 [03:25<00:11, 11.25s/it]

[Epoch 19/5] Loss: 1.0142955780029297


100%|██████████| 20/20 [03:36<00:00, 10.83s/it]

[Epoch 20/5] Loss: 1.364072561264038





Classifier accuracy on real MNIST test set: 16.87%
