In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Subset
import numpy as np
import matplotlib.pyplot as plt
import os
from tqdm import tqdm
import torchvision.utils as vutils


In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu')
print("using device: ", device)

using device:  mps


In [3]:
# Hyperparameters

batch_size = 128
image_size=28
nz = 100
num_epochs=10
learning_rate = 0.0002
beta1 = 0.5

# data transformation

transform=transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.5,), std=(0.5,))
])

#load the MNIST dataset

dataset = torchvision.datasets.MNIST(root='./data',
                                     train=True,
                                     transform=transform,
                                     download=True)

dataloader = DataLoader(dataset=dataset,batch_size=batch_size,shuffle=True)

In [4]:
# Let's now defien the network

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        
        self.main=nn.Sequential(
            # Input size is 1x28x28
            nn.Conv2d(in_channels=1, out_channels=64, kernel_size=4, stride=2, padding=1) , # 64x14x14
            nn.LeakyReLU(negative_slope=0.2, inplace=True),
            
            #seocond layer
            nn.Conv2d(in_channels=64, out_channels=128, kernel_size=4, stride=2, padding=1), # 128x7x7
            nn.BatchNorm2d(128),
            nn.LeakyReLU(negative_slope=0.2, inplace=True),
            
            #third layer
            nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, stride=2, padding=1), # 256x4x4
            nn.BatchNorm2d(256),
            nn.LeakyReLU(negative_slope=0.2, inplace=True),
            
            #fourth layer
            nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3, stride=1, padding=1), # 512x4x4
            nn.BatchNorm2d(512),
            nn.LeakyReLU(negative_slope=0.2, inplace=True),
            
            #output layer
            nn.Conv2d(in_channels=512, out_channels=1, kernel_size=4, stride=1, padding=0), # 1x1x1
            nn.Sigmoid()
        )
        
    def forward(self, input):
        output=self.main(input)
        return output.view(-1)
    
# Generator network

class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        
        self.main=nn.Sequential(
            #Input is nz - going into convolution
            nn.Linear(in_features=nz, out_features=256*7*7),
            nn.BatchNorm1d(256*7*7),
            nn.ReLU(inplace=True), 
            
            # reshape
            nn.Unflatten(dim=1, unflattened_size=(256,7,7)),
            
            # first convose transpose
            nn.ConvTranspose2d(in_channels=256, out_channels=128, kernel_size=4, stride=2, padding=1), # 128x14x14
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            
            # second convose transpose
            nn.ConvTranspose2d(in_channels=128, out_channels=64, kernel_size=4, stride=2, padding=1), # 64x28x28
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            
            # Third convse layer
            nn.ConvTranspose2d(in_channels=64, out_channels=32, kernel_size=3, stride=1, padding=1),  # Output: (batch_size, 32, 28, 28)
            nn.BatchNorm2d(32),
            nn.ReLU(True),
            
            # Output layer
            nn.ConvTranspose2d(32, 1, kernel_size=3, stride=1, padding=1),  # Output: (batch_size, 1, 28, 28)
            nn.Tanh()  # Output values in [-1, 1]
        )
        
    def forward(self, input):
        output = self.main(input)
        return output
    

# Initialize the networks

netD = Discriminator().to(device)
netG = Generator().to(device)


In [5]:
# Initialize weights
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1 or classname.find('Linear') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
        if m.bias is not None:
            nn.init.constant_(m.bias.data, 0)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)

netD.apply(weights_init)
netG.apply(weights_init)

# Note that BCE is the minmax loss function under the hood

criterion = nn.BCELoss()
# fixed_noise = torch.randn(64, nz, device=device)

real_label = 1
fake_label = 0

optimizerD = optim.Adam(netD.parameters(), lr=learning_rate, betas=(beta1, 0.999))
optimizerG = optim.Adam(netG.parameters(), lr=learning_rate, betas=(beta1, 0.999))

In [6]:
# Ensure the directory exists for saving images
os.makedirs('mid_run_samples', exist_ok=True)

# Fixed noise for generating samples
fixed_noise = torch.randn(12, nz, device=device)

# Training loop
for epoch in range(num_epochs):
    # Save generated images at the start of each epoch
    with torch.no_grad():
        fake_images = netG(fixed_noise).detach().cpu()
    vutils.save_image(
        fake_images,
        f'mid_run_samples/output_epoch_{epoch}.png',
        normalize=True,
        nrow=4
    )

    # Progress bar for batches
    for i, data in enumerate(tqdm(dataloader, desc=f"Epoch {epoch+1}/{num_epochs}")):
        ############################
        # (1) Update D network
        ############################
        netD.zero_grad()
        real_cpu = data[0].to(device)
        b_size = real_cpu.size(0)
        label = torch.full(
            (b_size,), real_label, dtype=torch.float, device=device
        )
        output = netD(real_cpu)
        errD_real = criterion(output, label)
        errD_real.backward()
        D_x = output.mean().item()

        # Train with all-fake batch
        noise = torch.randn(b_size, nz, device=device)
        fake = netG(noise)
        label.fill_(fake_label)
        output = netD(fake.detach())
        errD_fake = criterion(output, label)
        errD_fake.backward()
        D_G_z1 = output.mean().item()
        errD = errD_real + errD_fake
        optimizerD.step()

        ############################
        # (2) Update G network
        ############################
        for _ in range(3):  # Perform 3 generator iterations so generator and discriminator are more balanced
            netG.zero_grad()
            label.fill_(real_label)  # Fake labels are real for generator cost
            output = netD(fake)
            errG = criterion(output, label)
            errG.backward()
            D_G_z2 = output.mean().item()
            optimizerG.step()
            # Generate new fake data for the next generator iteration
            noise = torch.randn(b_size, nz, device=device)
            fake = netG(noise)

        # Print training stats
        if i % 200 == 0:
            print(
                f'[{epoch}/{num_epochs}][{i}/{len(dataloader)}] '
                f'Loss_D: {errD.item():.4f} Loss_G: {errG.item():.4f} '
                f'D(x): {D_x:.4f} D(G(z)): {D_G_z1:.4f}/{D_G_z2:.4f}'
            )

Epoch 1/10:   0%|          | 1/469 [00:00<05:19,  1.46it/s]

[0/10][0/469] Loss_D: 1.5260 Loss_G: 0.0690 D(x): 0.4970 D(G(z)): 0.4442/0.9360


Epoch 1/10:  43%|████▎     | 201/469 [00:45<01:00,  4.40it/s]

[0/10][200/469] Loss_D: 0.9269 Loss_G: 1.2368 D(x): 0.6276 D(G(z)): 0.3133/0.3353


Epoch 1/10:  86%|████████▌ | 401/469 [01:28<00:15,  4.37it/s]

[0/10][400/469] Loss_D: 0.9809 Loss_G: 1.0987 D(x): 0.6070 D(G(z)): 0.3243/0.3891


Epoch 1/10: 100%|██████████| 469/469 [01:43<00:00,  4.52it/s]
Epoch 2/10:   0%|          | 1/469 [00:00<01:44,  4.47it/s]

[1/10][0/469] Loss_D: 1.0770 Loss_G: 1.1401 D(x): 0.6780 D(G(z)): 0.4472/0.3522


Epoch 2/10:  43%|████▎     | 201/469 [00:44<00:59,  4.53it/s]

[1/10][200/469] Loss_D: 1.1752 Loss_G: 1.1238 D(x): 0.6520 D(G(z)): 0.4787/0.3659


Epoch 2/10:  86%|████████▌ | 401/469 [01:28<00:15,  4.38it/s]

[1/10][400/469] Loss_D: 1.0961 Loss_G: 0.8552 D(x): 0.5128 D(G(z)): 0.2846/0.4630


Epoch 2/10: 100%|██████████| 469/469 [01:43<00:00,  4.51it/s]
Epoch 3/10:   0%|          | 1/469 [00:00<01:46,  4.38it/s]

[2/10][0/469] Loss_D: 1.2664 Loss_G: 0.6407 D(x): 0.4241 D(G(z)): 0.2511/0.5526


Epoch 3/10:  43%|████▎     | 202/469 [00:44<00:57,  4.67it/s]

[2/10][200/469] Loss_D: 1.3116 Loss_G: 1.0966 D(x): 0.6384 D(G(z)): 0.5093/0.3748


Epoch 3/10:  86%|████████▌ | 401/469 [01:28<00:15,  4.38it/s]

[2/10][400/469] Loss_D: 1.3682 Loss_G: 0.8026 D(x): 0.4850 D(G(z)): 0.4198/0.4797


Epoch 3/10: 100%|██████████| 469/469 [01:43<00:00,  4.55it/s]
Epoch 4/10:   0%|          | 1/469 [00:00<01:43,  4.51it/s]

[3/10][0/469] Loss_D: 1.2803 Loss_G: 0.6616 D(x): 0.4177 D(G(z)): 0.2782/0.5464


Epoch 4/10:  43%|████▎     | 201/469 [00:44<01:00,  4.40it/s]

[3/10][200/469] Loss_D: 1.1791 Loss_G: 0.6774 D(x): 0.4739 D(G(z)): 0.3005/0.5362


Epoch 4/10:  86%|████████▌ | 401/469 [01:28<00:15,  4.40it/s]

[3/10][400/469] Loss_D: 1.3372 Loss_G: 0.6874 D(x): 0.4771 D(G(z)): 0.3798/0.5229


Epoch 4/10: 100%|██████████| 469/469 [01:43<00:00,  4.52it/s]
Epoch 5/10:   0%|          | 1/469 [00:00<01:44,  4.48it/s]

[4/10][0/469] Loss_D: 1.2439 Loss_G: 0.5059 D(x): 0.4405 D(G(z)): 0.2950/0.6311


Epoch 5/10:  43%|████▎     | 201/469 [00:44<01:01,  4.37it/s]

[4/10][200/469] Loss_D: 1.2479 Loss_G: 0.8128 D(x): 0.5253 D(G(z)): 0.4125/0.4756


Epoch 5/10:  86%|████████▌ | 401/469 [01:28<00:15,  4.41it/s]

[4/10][400/469] Loss_D: 1.2224 Loss_G: 0.9000 D(x): 0.5643 D(G(z)): 0.4452/0.4460


Epoch 5/10: 100%|██████████| 469/469 [01:43<00:00,  4.54it/s]
Epoch 6/10:   0%|          | 1/469 [00:00<01:43,  4.51it/s]

[5/10][0/469] Loss_D: 1.0695 Loss_G: 0.9076 D(x): 0.6750 D(G(z)): 0.4436/0.4328


Epoch 6/10:  43%|████▎     | 201/469 [00:44<01:02,  4.32it/s]

[5/10][200/469] Loss_D: 1.3218 Loss_G: 0.8426 D(x): 0.5257 D(G(z)): 0.4333/0.4573


Epoch 6/10:  86%|████████▌ | 402/469 [01:28<00:14,  4.69it/s]

[5/10][400/469] Loss_D: 1.2673 Loss_G: 1.0910 D(x): 0.6936 D(G(z)): 0.5596/0.3627


Epoch 6/10: 100%|██████████| 469/469 [01:42<00:00,  4.55it/s]
Epoch 7/10:   0%|          | 1/469 [00:00<01:42,  4.55it/s]

[6/10][0/469] Loss_D: 1.4921 Loss_G: 1.4873 D(x): 0.7663 D(G(z)): 0.6733/0.2507


Epoch 7/10:  43%|████▎     | 202/469 [00:44<00:57,  4.60it/s]

[6/10][200/469] Loss_D: 1.1786 Loss_G: 0.9075 D(x): 0.5619 D(G(z)): 0.4178/0.4275


Epoch 7/10:  86%|████████▌ | 401/469 [01:28<00:15,  4.35it/s]

[6/10][400/469] Loss_D: 1.2208 Loss_G: 1.0456 D(x): 0.6263 D(G(z)): 0.4947/0.3821


Epoch 7/10: 100%|██████████| 469/469 [01:43<00:00,  4.53it/s]
Epoch 8/10:   0%|          | 1/469 [00:00<01:42,  4.55it/s]

[7/10][0/469] Loss_D: 1.2661 Loss_G: 0.8924 D(x): 0.5738 D(G(z)): 0.4644/0.4396


Epoch 8/10:  43%|████▎     | 201/469 [00:44<01:00,  4.45it/s]

[7/10][200/469] Loss_D: 1.0331 Loss_G: 0.9138 D(x): 0.6473 D(G(z)): 0.4119/0.4377


Epoch 8/10:  86%|████████▌ | 402/469 [01:28<00:14,  4.68it/s]

[7/10][400/469] Loss_D: 1.1737 Loss_G: 0.8117 D(x): 0.5444 D(G(z)): 0.3932/0.4751


Epoch 8/10: 100%|██████████| 469/469 [01:43<00:00,  4.55it/s]
Epoch 9/10:   0%|          | 1/469 [00:00<01:44,  4.48it/s]

[8/10][0/469] Loss_D: 1.4433 Loss_G: 0.5797 D(x): 0.3652 D(G(z)): 0.2733/0.5824


Epoch 9/10:  43%|████▎     | 201/469 [00:44<00:59,  4.48it/s]

[8/10][200/469] Loss_D: 1.4778 Loss_G: 0.5023 D(x): 0.3515 D(G(z)): 0.2879/0.6282


Epoch 9/10:  86%|████████▌ | 401/469 [01:28<00:15,  4.37it/s]

[8/10][400/469] Loss_D: 1.0848 Loss_G: 0.8698 D(x): 0.6425 D(G(z)): 0.4444/0.4449


Epoch 9/10: 100%|██████████| 469/469 [01:43<00:00,  4.55it/s]
Epoch 10/10:   0%|          | 1/469 [00:00<01:44,  4.49it/s]

[9/10][0/469] Loss_D: 1.1577 Loss_G: 0.7779 D(x): 0.5069 D(G(z)): 0.3360/0.4852


Epoch 10/10:  43%|████▎     | 201/469 [00:44<01:01,  4.39it/s]

[9/10][200/469] Loss_D: 1.2193 Loss_G: 1.3009 D(x): 0.6103 D(G(z)): 0.4684/0.3078


Epoch 10/10:  86%|████████▌ | 401/469 [01:28<00:15,  4.39it/s]

[9/10][400/469] Loss_D: 1.1723 Loss_G: 1.2154 D(x): 0.6885 D(G(z)): 0.5119/0.3318


Epoch 10/10: 100%|██████████| 469/469 [01:43<00:00,  4.53it/s]


In [16]:
# ------------------------------
# Part 1.2: GAN as a Pre-Training Framework
# ------------------------------

# Create a feature extractor from the discriminator
class FeatureExtractor(nn.Module):
    def __init__(self, discriminator):
        super(FeatureExtractor, self).__init__()
        # Copy layers up to before the last Conv2d layer
        self.features = nn.Sequential(*list(discriminator.main.children())[:-2])  # Exclude last Conv2d and Sigmoid

    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), -1)  # Flatten
        return x

# Instantiate the feature extractor
feature_extractor = FeatureExtractor(netD).to(device)

# Freeze feature extractor parameters
for param in feature_extractor.parameters():
    param.requires_grad = False

# Prepare 10% of the training data
train_indices = np.arange(len(dataset))
np.random.shuffle(train_indices)
subset_indices = train_indices[:len(dataset) // 10]
train_subset = Subset(dataset, subset_indices)
train_loader = DataLoader(train_subset, batch_size=batch_size, shuffle=True)

# Prepare test data
test_dataset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

# Determine the size of the feature vector
with torch.no_grad():
    dummy_input = torch.randn(1, 1, 28, 28, device=device)
    features = feature_extractor(dummy_input)
    feature_size = features.shape[1]

# Define a linear classifier
classifier = nn.Linear(feature_size, 10).to(device)

# Loss function and optimizer for classifier
criterion_cls = nn.CrossEntropyLoss()
optimizer_cls = optim.Adam(classifier.parameters(), lr=0.001)

# Train the classifier
num_epochs_cls = 10
for epoch in range(num_epochs_cls):
    classifier.train()
    running_loss = 0.0
    correct = 0
    total = 0
    for i, (inputs, labels) in enumerate(train_loader):
        inputs = inputs.to(device)
        labels = labels.to(device)

        with torch.no_grad():
            features = feature_extractor(inputs)
        outputs = classifier(features)
        loss = criterion_cls(outputs, labels)

        optimizer_cls.zero_grad()
        loss.backward()
        optimizer_cls.step()

        running_loss += loss.item()
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()

    train_loss = running_loss / len(train_loader)
    train_acc = 100. * correct / total

    # Evaluate on test set
    classifier.eval()
    test_loss = 0.0
    correct_test = 0
    total_test = 0
    with torch.no_grad():
        for inputs, labels in test_loader:
            inputs = inputs.to(device)
            labels = labels.to(device)
            features = feature_extractor(inputs)
            outputs = classifier(features)
            loss = criterion_cls(outputs, labels)

            test_loss += loss.item()
            _, predicted = outputs.max(1)
            total_test += labels.size(0)
            correct_test += predicted.eq(labels).sum().item()

    test_loss /= len(test_loader)
    test_acc = 100. * correct_test / total_test

    print(f'[{epoch + 1}/{num_epochs_cls}] Train Loss: {train_loss:.4f} Train Acc: {train_acc:.2f}% '
          f'Test Loss: {test_loss:.4f} Test Acc: {test_acc:.2f}%')


[1/10] Train Loss: 0.2568 Train Acc: 92.22% Test Loss: 0.0939 Test Acc: 97.18%
[2/10] Train Loss: 0.0708 Train Acc: 97.72% Test Loss: 0.0896 Test Acc: 97.08%
[3/10] Train Loss: 0.0417 Train Acc: 98.73% Test Loss: 0.0665 Test Acc: 98.04%
[4/10] Train Loss: 0.0171 Train Acc: 99.67% Test Loss: 0.0651 Test Acc: 97.95%
[5/10] Train Loss: 0.0106 Train Acc: 99.92% Test Loss: 0.0632 Test Acc: 98.04%
[6/10] Train Loss: 0.0069 Train Acc: 99.97% Test Loss: 0.0623 Test Acc: 98.13%
[7/10] Train Loss: 0.0051 Train Acc: 100.00% Test Loss: 0.0613 Test Acc: 98.05%
[8/10] Train Loss: 0.0039 Train Acc: 100.00% Test Loss: 0.0607 Test Acc: 98.17%
[9/10] Train Loss: 0.0033 Train Acc: 100.00% Test Loss: 0.0607 Test Acc: 98.11%
[10/10] Train Loss: 0.0028 Train Acc: 100.00% Test Loss: 0.0598 Test Acc: 98.15%


### Discussion of Results

The trained classifier achieved a **training accuracy of 100%** and a **test accuracy of approximately 98.15%**. These high accuracy rates indicate that the feature extractor derived from the GAN effectively captures significant discriminative features from the MNIST dataset. The minimal gap between training and testing accuracies suggests that the model generalizes well to unseen data, demonstrating robust performance without evident overfitting. 

However, the slight decrease in test accuracy from **98.17%** to **98.15%** across epochs whiel having 100% on training accuracy may hint at the beginning of overfitting, although the change is minimal and within a negligible range. It would be intersting to consider training a simpler model - perhaps this feature extractor has too many parameters!

# Note to look at samples look at mid_run_samples directory