In [13]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Subset
import numpy as np
import matplotlib.pyplot as plt
import os
from tqdm import tqdm
import torchvision.utils as vutils


In [14]:
device = torch.device('cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu')
print("using device: ", device)

using device:  mps


In [15]:
# Hyperparameters

batch_size = 128
image_size=28
nz = 100
num_epochs=50
learning_rate = 0.0002
beta1 = 0.5

# data transformation

transform=transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.5,), std=(0.5,))
])

#load the MNIST dataset

dataset = torchvision.datasets.MNIST(root='./data',
                                     train=True,
                                     transform=transform,
                                     download=True)

dataloader = DataLoader(dataset=dataset,batch_size=batch_size,shuffle=True)

In [16]:
# Let's now defien the network

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        
        self.main=nn.Sequential(
            # Input size is 1x28x28
            nn.Conv2d(in_channels=1, out_channels=64, kernel_size=4, stride=2, padding=1) , # 64x14x14
            nn.LeakyReLU(negative_slope=0.2, inplace=True),
            
            #seocond layer
            nn.Conv2d(in_channels=64, out_channels=128, kernel_size=4, stride=2, padding=1), # 128x7x7
            nn.BatchNorm2d(128),
            nn.LeakyReLU(negative_slope=0.2, inplace=True),
            
            #third layer
            nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, stride=2, padding=1), # 256x4x4
            nn.BatchNorm2d(256),
            nn.LeakyReLU(negative_slope=0.2, inplace=True),
            
            #fourth layer
            nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3, stride=1, padding=1), # 512x4x4
            nn.BatchNorm2d(512),
            nn.LeakyReLU(negative_slope=0.2, inplace=True),
            
            #output layer
            nn.Conv2d(in_channels=512, out_channels=1, kernel_size=4, stride=1, padding=0), # 1x1x1
            nn.Sigmoid()
        )
        
    def forward(self, input):
        output=self.main(input)
        return output.view(-1)
    
# Generator network

class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        
        self.main=nn.Sequential(
            #Input is nz - going into convolution
            nn.Linear(in_features=nz, out_features=256*7*7),
            nn.BatchNorm1d(256*7*7),
            nn.ReLU(inplace=True), 
            
            # reshape
            nn.Unflatten(dim=1, unflattened_size=(256,7,7)),
            
            # first convose transpose
            nn.ConvTranspose2d(in_channels=256, out_channels=128, kernel_size=4, stride=2, padding=1), # 128x14x14
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            
            # second convose transpose
            nn.ConvTranspose2d(in_channels=128, out_channels=64, kernel_size=4, stride=2, padding=1), # 64x28x28
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            
            # Third convse layer
            nn.ConvTranspose2d(in_channels=64, out_channels=32, kernel_size=3, stride=1, padding=1),  # Output: (batch_size, 32, 28, 28)
            nn.BatchNorm2d(32),
            nn.ReLU(True),
            
            # Output layer
            nn.ConvTranspose2d(32, 1, kernel_size=3, stride=1, padding=1),  # Output: (batch_size, 1, 28, 28)
            nn.Tanh()  # Output values in [-1, 1]
        )
        
    def forward(self, input):
        output = self.main(input)
        return output
    

# Initialize the networks

netD = Discriminator().to(device)
netG = Generator().to(device)


In [17]:
# Initialize weights
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1 or classname.find('Linear') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
        if m.bias is not None:
            nn.init.constant_(m.bias.data, 0)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)

netD.apply(weights_init)
netG.apply(weights_init)

# Note that BCE is the minmax loss function under the hood

criterion = nn.BCELoss()
# fixed_noise = torch.randn(64, nz, device=device)

real_label = 1
fake_label = 0

optimizerD = optim.Adam(netD.parameters(), lr=learning_rate, betas=(beta1, 0.999))
optimizerG = optim.Adam(netG.parameters(), lr=learning_rate, betas=(beta1, 0.999))

In [18]:
# Ensure the directory exists for saving images
os.makedirs('mid_run_samples', exist_ok=True)

# Fixed noise for generating samples
fixed_noise = torch.randn(12, nz, device=device)

# Training loop
for epoch in range(num_epochs):
    # Save generated images at the start of each epoch
    with torch.no_grad():
        fake_images = netG(fixed_noise).detach().cpu()
    vutils.save_image(
        fake_images,
        f'mid_run_samples/output_epoch_{epoch}.png',
        normalize=True,
        nrow=4
    )

    # Progress bar for batches
    for i, data in enumerate(tqdm(dataloader, desc=f"Epoch {epoch+1}/{num_epochs}")):
        ############################
        # (1) Update D network
        ############################
        netD.zero_grad()
        real_cpu = data[0].to(device)
        b_size = real_cpu.size(0)
        label = torch.full(
            (b_size,), real_label, dtype=torch.float, device=device
        )
        output = netD(real_cpu)
        errD_real = criterion(output, label)
        errD_real.backward()
        D_x = output.mean().item()

        # Train with all-fake batch
        noise = torch.randn(b_size, nz, device=device)
        fake = netG(noise)
        label.fill_(fake_label)
        output = netD(fake.detach())
        errD_fake = criterion(output, label)
        errD_fake.backward()
        D_G_z1 = output.mean().item()
        errD = errD_real + errD_fake
        optimizerD.step()

        ############################
        # (2) Update G network
        ############################
        for _ in range(3):  # Perform 3 generator iterations so generator and discriminator are more balanced
            netG.zero_grad()
            label.fill_(real_label)  # Fake labels are real for generator cost
            output = netD(fake)
            errG = criterion(output, label)
            errG.backward()
            D_G_z2 = output.mean().item()
            optimizerG.step()
            # Generate new fake data for the next generator iteration
            noise = torch.randn(b_size, nz, device=device)
            fake = netG(noise)

        # Print training stats
        if i % 200 == 0:
            print(
                f'[{epoch}/{num_epochs}][{i}/{len(dataloader)}] '
                f'Loss_D: {errD.item():.4f} Loss_G: {errG.item():.4f} '
                f'D(x): {D_x:.4f} D(G(z)): {D_G_z1:.4f}/{D_G_z2:.4f}'
            )

Epoch 1/50:   0%|          | 1/469 [00:00<01:50,  4.25it/s]

[0/50][0/469] Loss_D: 1.5422 Loss_G: 0.4968 D(x): 0.6273 D(G(z)): 0.5722/0.6514


Epoch 1/50:  43%|████▎     | 202/469 [00:44<00:56,  4.72it/s]

[0/50][200/469] Loss_D: 1.0746 Loss_G: 0.9979 D(x): 0.4692 D(G(z)): 0.1665/0.4081


Epoch 1/50:  86%|████████▌ | 401/469 [01:27<00:15,  4.44it/s]

[0/50][400/469] Loss_D: 1.0368 Loss_G: 1.2949 D(x): 0.6768 D(G(z)): 0.4165/0.3191


Epoch 1/50: 100%|██████████| 469/469 [01:42<00:00,  4.60it/s]
Epoch 2/50:   0%|          | 1/469 [00:00<01:42,  4.55it/s]

[1/50][0/469] Loss_D: 1.0289 Loss_G: 0.8186 D(x): 0.5280 D(G(z)): 0.2602/0.4815


Epoch 2/50:  43%|████▎     | 202/469 [00:44<00:57,  4.66it/s]

[1/50][200/469] Loss_D: 1.0524 Loss_G: 0.8378 D(x): 0.5436 D(G(z)): 0.3099/0.4686


Epoch 2/50:  86%|████████▌ | 402/469 [01:28<00:14,  4.70it/s]

[1/50][400/469] Loss_D: 1.1202 Loss_G: 1.0657 D(x): 0.6960 D(G(z)): 0.4791/0.4007


Epoch 2/50: 100%|██████████| 469/469 [01:42<00:00,  4.56it/s]
Epoch 3/50:   0%|          | 1/469 [00:00<01:44,  4.47it/s]

[2/50][0/469] Loss_D: 1.2017 Loss_G: 1.1515 D(x): 0.6205 D(G(z)): 0.4657/0.3533


Epoch 3/50:  43%|████▎     | 202/469 [00:43<00:57,  4.63it/s]

[2/50][200/469] Loss_D: 1.2365 Loss_G: 1.3005 D(x): 0.6034 D(G(z)): 0.4677/0.3273


Epoch 3/50:  86%|████████▌ | 402/469 [01:27<00:14,  4.68it/s]

[2/50][400/469] Loss_D: 1.1983 Loss_G: 0.5624 D(x): 0.4667 D(G(z)): 0.2994/0.5916


Epoch 3/50: 100%|██████████| 469/469 [01:41<00:00,  4.61it/s]
Epoch 4/50:   0%|          | 1/469 [00:00<01:44,  4.49it/s]

[3/50][0/469] Loss_D: 1.1684 Loss_G: 1.0299 D(x): 0.6076 D(G(z)): 0.4528/0.3884


Epoch 4/50:  13%|█▎        | 61/469 [00:13<01:30,  4.52it/s]


KeyboardInterrupt: 

In [19]:
# ------------------------------
# Part 1.2: GAN as a Pre-Training Framework
# ------------------------------

# Create a feature extractor from the discriminator
class FeatureExtractor(nn.Module):
    def __init__(self, discriminator):
        super(FeatureExtractor, self).__init__()
        # Copy layers up to before the last Conv2d layer
        self.features = nn.Sequential(*list(discriminator.main.children())[:-2])  # Exclude last Conv2d and Sigmoid

    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), -1)  # Flatten
        return x

# Instantiate the feature extractor
feature_extractor = FeatureExtractor(netD).to(device)

# Freeze feature extractor parameters
for param in feature_extractor.parameters():
    param.requires_grad = False

# Prepare 10% of the training data
train_indices = np.arange(len(dataset))
np.random.shuffle(train_indices)
subset_indices = train_indices[:len(dataset) // 10]
train_subset = Subset(dataset, subset_indices)
train_loader = DataLoader(train_subset, batch_size=batch_size, shuffle=True)

# Prepare test data
test_dataset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

# Determine the size of the feature vector
with torch.no_grad():
    dummy_input = torch.randn(1, 1, 28, 28, device=device)
    features = feature_extractor(dummy_input)
    feature_size = features.shape[1]

# Define a linear classifier
classifier = nn.Linear(feature_size, 10).to(device)

# Loss function and optimizer for classifier
criterion_cls = nn.CrossEntropyLoss()
optimizer_cls = optim.Adam(classifier.parameters(), lr=0.001)

# Train the classifier
num_epochs_cls = 10
for epoch in range(num_epochs_cls):
    classifier.train()
    running_loss = 0.0
    correct = 0
    total = 0
    for i, (inputs, labels) in enumerate(train_loader):
        inputs = inputs.to(device)
        labels = labels.to(device)

        with torch.no_grad():
            features = feature_extractor(inputs)
        outputs = classifier(features)
        loss = criterion_cls(outputs, labels)

        optimizer_cls.zero_grad()
        loss.backward()
        optimizer_cls.step()

        running_loss += loss.item()
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()

    train_loss = running_loss / len(train_loader)
    train_acc = 100. * correct / total

    # Evaluate on test set
    classifier.eval()
    test_loss = 0.0
    correct_test = 0
    total_test = 0
    with torch.no_grad():
        for inputs, labels in test_loader:
            inputs = inputs.to(device)
            labels = labels.to(device)
            features = feature_extractor(inputs)
            outputs = classifier(features)
            loss = criterion_cls(outputs, labels)

            test_loss += loss.item()
            _, predicted = outputs.max(1)
            total_test += labels.size(0)
            correct_test += predicted.eq(labels).sum().item()

    test_loss /= len(test_loader)
    test_acc = 100. * correct_test / total_test

    print(f'[{epoch + 1}/{num_epochs_cls}] Train Loss: {train_loss:.4f} Train Acc: {train_acc:.2f}% '
          f'Test Loss: {test_loss:.4f} Test Acc: {test_acc:.2f}%')

[1/10] Train Loss: 0.2956 Train Acc: 90.77% Test Loss: 0.1189 Test Acc: 96.31%
[2/10] Train Loss: 0.0918 Train Acc: 97.07% Test Loss: 0.1149 Test Acc: 96.32%
[3/10] Train Loss: 0.0550 Train Acc: 98.42% Test Loss: 0.0867 Test Acc: 97.30%
[4/10] Train Loss: 0.0307 Train Acc: 99.25% Test Loss: 0.0807 Test Acc: 97.38%
[5/10] Train Loss: 0.0169 Train Acc: 99.70% Test Loss: 0.0799 Test Acc: 97.53%
[6/10] Train Loss: 0.0124 Train Acc: 99.92% Test Loss: 0.0745 Test Acc: 97.63%
[7/10] Train Loss: 0.0107 Train Acc: 99.88% Test Loss: 0.0759 Test Acc: 97.55%
[8/10] Train Loss: 0.0078 Train Acc: 99.95% Test Loss: 0.0717 Test Acc: 97.75%
[9/10] Train Loss: 0.0057 Train Acc: 99.98% Test Loss: 0.0737 Test Acc: 97.67%
[10/10] Train Loss: 0.0045 Train Acc: 100.00% Test Loss: 0.0725 Test Acc: 97.76%
