In [6]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Subset
from torchvision import datasets, transforms
from torchvision.utils import save_image, make_grid
import os
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm  # Added tqdm import

# -----------------------------
# Configuration and Setup
# -----------------------------

#run on mps

device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")

# Device configuration
#device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Hyperparameters
batch_size = 128
latent_dim = 100
num_epochs = 50
learning_rate = 0.0002
beta1 = 0.5
beta2 = 0.999
sample_dir = 'samples'
os.makedirs(sample_dir, exist_ok=True)

# Image preprocessing
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])  # Normalize to [-1, 1]
])

# -----------------------------
# Data Loading
# -----------------------------

# MNIST dataset
train_dataset = datasets.MNIST(root='data/',
                               train=True,
                               transform=transform,
                               download=True)

test_dataset = datasets.MNIST(root='data/',
                              train=False,
                              transform=transform,
                              download=True)

# Data loaders
train_loader = DataLoader(dataset=train_dataset,
                          batch_size=batch_size,
                          shuffle=True,
                          num_workers=2)

test_loader = DataLoader(dataset=test_dataset,
                         batch_size=batch_size,
                         shuffle=False,
                         num_workers=2)

# -----------------------------
# Model Definitions
# -----------------------------

# Generator Model
class Generator(nn.Module):
    def __init__(self, latent_dim):
        super(Generator, self).__init__()
        self.init_size = 7  # Initial size before upsampling
        self.l1 = nn.Sequential(
            nn.Linear(latent_dim, 256 * self.init_size * self.init_size)
        )

        self.conv_blocks = nn.Sequential(
            nn.BatchNorm2d(256),
            # Upsample to 14x14
            nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),  # 7->14
            nn.BatchNorm2d(128),
            nn.ReLU(True),
            # Upsample to 28x28
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),   # 14->28
            nn.BatchNorm2d(64),
            nn.ReLU(True),
            # Convolution to 32 channels
            nn.Conv2d(64, 32, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(True),
            # Final convolution to 1 channel
            nn.Conv2d(32, 1, kernel_size=3, stride=1, padding=1),
            nn.Tanh()  # Output range [-1, 1]
        )

    def forward(self, z):
        out = self.l1(z)
        out = out.view(z.size(0), 256, self.init_size, self.init_size)
        img = self.conv_blocks(out)
        return img

# Discriminator Model
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.conv_layers = nn.Sequential(
            # Input: 1 x 28 x 28
            nn.Conv2d(1, 64, kernel_size=4, stride=2, padding=1),  # 28->14
            nn.LeakyReLU(0.2, inplace=True),
            # 64 x 14 x 14
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),  # 14->7
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            # 128 x 7 x 7
            nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1),  # 7->4
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
            # 256 x 4 x 4
            nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1),  # 4->4
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),
            # 512 x 4 x 4
        )
        self.flatten = nn.Flatten()
        self.fc = nn.Sequential(
            nn.Linear(512 * 3 * 3, 1),  # Changed from 512 * 4 * 4 to 512 * 3 * 3
            nn.Sigmoid()
        )

    def forward(self, img):
        features = self.conv_layers(img)
        out = self.flatten(features)
        validity = self.fc(out)
        return validity

    def feature_extractor(self, img):
        with torch.no_grad():
            features = self.conv_layers(img)
            features = features.view(features.size(0), -1)
        return features

# Initialize models
generator = Generator(latent_dim).to(device)
discriminator = Discriminator().to(device)

# -----------------------------
# Loss and Optimizer
# -----------------------------

adversarial_loss = nn.BCELoss()

optimizer_G = optim.Adam(generator.parameters(), lr=learning_rate, betas=(beta1, beta2))
optimizer_D = optim.Adam(discriminator.parameters(), lr=learning_rate, betas=(beta1, beta2))

# -----------------------------
# Training the DCGAN
# -----------------------------

print("Starting DCGAN Training...")

for epoch in range(1, num_epochs + 1):
    epoch_bar = tqdm(train_loader, desc=f"Epoch {epoch}/{num_epochs}", leave=False)
    for i, (imgs, _) in enumerate(epoch_bar):
        
        # Adversarial ground truths
        real = imgs.to(device)
        batch_size_curr = real.size(0)
        valid = torch.ones(batch_size_curr, 1, device=device)
        fake = torch.zeros(batch_size_curr, 1, device=device)

        # ---------------------
        #  Train Discriminator
        # ---------------------
        discriminator.zero_grad()

        # Real images
        output_real = discriminator(real)
        loss_real = adversarial_loss(output_real, valid)

        # Fake images
        z = torch.randn(batch_size_curr, latent_dim, device=device)
        gen_imgs = generator(z)
        output_fake = discriminator(gen_imgs.detach())
        loss_fake = adversarial_loss(output_fake, fake)

        # Total discriminator loss
        loss_D = loss_real + loss_fake
        loss_D.backward()
        optimizer_D.step()

        # -----------------
        #  Train Generator
        # -----------------
        generator.zero_grad()

        # Generate images
        gen_imgs = generator(z)
        output = discriminator(gen_imgs)
        loss_G = adversarial_loss(output, valid)  # Generator tries to make discriminator believe they are real

        loss_G.backward()
        optimizer_G.step()

        # Update progress bar
        epoch_bar.set_postfix({'Loss D': loss_D.item(), 'Loss G': loss_G.item()})

    # Save sampled images at the end of each epoch
    with torch.no_grad():
        z = torch.randn(12, latent_dim, device=device)
        gen_imgs = generator(z)
        gen_imgs = gen_imgs * 0.5 + 0.5  # Denormalize to [0,1]
        grid = make_grid(gen_imgs, nrow=4)
        save_image(grid, os.path.join(sample_dir, f"epoch_{epoch}.png"))

print("DCGAN Training Completed!")

# -----------------------------
# Generate and Save 12 Samples
# -----------------------------

def generate_samples(generator, latent_dim, num_samples=12, device='cpu', save_path='generated_samples.png'):
    generator.eval()
    with torch.no_grad():
        z = torch.randn(num_samples, latent_dim, device=device)
        gen_imgs = generator(z)
        gen_imgs = gen_imgs * 0.5 + 0.5  # Denormalize to [0,1]
        grid = make_grid(gen_imgs, nrow=4)
        save_image(grid, save_path)
    generator.train()

generate_samples(generator, latent_dim, num_samples=12, device=device, save_path=os.path.join(sample_dir, "final_samples.png"))
print(f"Generated samples are saved in the '{sample_dir}' directory.")

# -----------------------------
# GAN as a Pre-Training Framework
# -----------------------------

print("Starting GAN as a Pre-Training Framework...")

# 1. Modify the Discriminator to remove the final linear layer
# We'll use the 'feature_extractor' method defined in the Discriminator class

# 2. Extract features from 10% of the training set
num_train_samples = int(0.1 * len(train_dataset))
train_subset_indices = np.random.choice(len(train_dataset), num_train_samples, replace=False)
train_subset = Subset(train_dataset, train_subset_indices)
train_subset_loader = DataLoader(dataset=train_subset,
                                 batch_size=batch_size,
                                 shuffle=False,
                                 num_workers=2)

# 3. Extract features for the training subset
discriminator.eval()
train_features = []
train_labels = []

with torch.no_grad():
    for imgs, labels in tqdm(train_subset_loader, desc="Extracting Train Features", leave=False):
        imgs = imgs.to(device)
        features = discriminator.feature_extractor(imgs)
        train_features.append(features.cpu())
        train_labels.append(labels)

train_features = torch.cat(train_features, dim=0)
train_labels = torch.cat(train_labels, dim=0)

# 4. Extract features for the test set
test_features = []
test_labels = []

test_loader_no_shuffle = DataLoader(dataset=test_dataset,
                                    batch_size=batch_size,
                                    shuffle=False,
                                    num_workers=2)

with torch.no_grad():
    for imgs, labels in tqdm(test_loader_no_shuffle, desc="Extracting Test Features", leave=False):
        imgs = imgs.to(device)
        features = discriminator.feature_extractor(imgs)
        test_features.append(features.cpu())
        test_labels.append(labels)

test_features = torch.cat(test_features, dim=0)
test_labels = torch.cat(test_labels, dim=0)

discriminator.train()

# 5. Define a Linear Classifier
class LinearClassifier(nn.Module):
    def __init__(self, input_dim, num_classes=10):
        super(LinearClassifier, self).__init__()
        self.fc = nn.Linear(input_dim, num_classes)
    
    def forward(self, x):
        out = self.fc(x)
        return out

classifier = LinearClassifier(input_dim=512*3*3).to(device)  # Changed from 512*4*4 to 512*3*3

# 6. Training the Linear Classifier
classifier_optimizer = optim.Adam(classifier.parameters(), lr=0.001)
classifier_criterion = nn.CrossEntropyLoss()

# Create DataLoader for classifier training
classifier_train_loader = DataLoader(dataset=torch.utils.data.TensorDataset(train_features, train_labels),
                                    batch_size=batch_size,
                                    shuffle=True,
                                    num_workers=2)

# Training loop for classifier
num_classifier_epochs = 20
for epoch in range(1, num_classifier_epochs + 1):
    classifier.train()
    epoch_loss = 0.0
    correct = 0
    total = 0
    classifier_bar = tqdm(classifier_train_loader, desc=f"Classifier Epoch {epoch}/{num_classifier_epochs}", leave=False)
    for features, labels in classifier_bar:
        features = features.to(device)
        labels = labels.to(device)

        classifier_optimizer.zero_grad()
        outputs = classifier(features)
        loss = classifier_criterion(outputs, labels)
        loss.backward()
        classifier_optimizer.step()

        epoch_loss += loss.item() * features.size(0)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
    
    epoch_loss /= total
    accuracy = 100 * correct / total
    print(f"Classifier Epoch [{epoch}/{num_classifier_epochs}] Loss: {epoch_loss:.4f}, Accuracy: {accuracy:.2f}%")

# 7. Evaluate on Test Set
classifier.eval()
with torch.no_grad():
    test_features = test_features.to(device)
    test_labels = test_labels.to(device)
    outputs = classifier(test_features)
    _, predicted = torch.max(outputs.data, 1)
    total = test_labels.size(0)
    correct = (predicted == test_labels).sum().item()
    test_accuracy = 100 * correct / total
    print(f"Test Accuracy: {test_accuracy:.2f}%")

# 8. Evaluate on Training Subset
with torch.no_grad():
    train_features = train_features.to(device)
    train_labels = train_labels.to(device)
    outputs = classifier(train_features)
    _, predicted = torch.max(outputs.data, 1)
    total = train_labels.size(0)
    correct = (predicted == train_labels).sum().item()
    train_accuracy = 100 * correct / total
    print(f"Training Subset Accuracy: {train_accuracy:.2f}%")

print("GAN Pre-Training Framework Completed!")

# -----------------------------
# Optional: Visualize Generated Samples
# -----------------------------

def show_generated_images(save_path):
    grid = plt.imread(save_path)
    plt.figure(figsize=(8,8))
    plt.imshow(grid)
    plt.axis('off')
    plt.show()

# Uncomment the line below to visualize the final generated samples
# show_generated_images(os.path.join(sample_dir, "final_samples.png"))

Using device: mps
Starting DCGAN Training...


                                                                                            

DCGAN Training Completed!
Generated samples are saved in the 'samples' directory.
Starting GAN as a Pre-Training Framework...


                                                                          

Classifier Epoch [1/20] Loss: 0.4111, Accuracy: 89.25%


                                                                      

Classifier Epoch [2/20] Loss: 0.0803, Accuracy: 97.88%


                                                                     

Classifier Epoch [3/20] Loss: 0.0489, Accuracy: 98.97%


                                                                      

Classifier Epoch [4/20] Loss: 0.0353, Accuracy: 99.43%


                                                                     

Classifier Epoch [5/20] Loss: 0.0247, Accuracy: 99.73%


                                                                      

Classifier Epoch [6/20] Loss: 0.0192, Accuracy: 99.82%


                                                                     

Classifier Epoch [7/20] Loss: 0.0151, Accuracy: 99.90%


                                                                      

Classifier Epoch [8/20] Loss: 0.0118, Accuracy: 99.98%


                                                                      

Classifier Epoch [9/20] Loss: 0.0096, Accuracy: 100.00%


                                                                       

Classifier Epoch [10/20] Loss: 0.0077, Accuracy: 100.00%


                                                                      

Classifier Epoch [11/20] Loss: 0.0068, Accuracy: 100.00%


                                                                      

Classifier Epoch [12/20] Loss: 0.0059, Accuracy: 100.00%


                                                                      

Classifier Epoch [13/20] Loss: 0.0051, Accuracy: 100.00%


                                                                      

Classifier Epoch [14/20] Loss: 0.0046, Accuracy: 100.00%


                                                                       

Classifier Epoch [15/20] Loss: 0.0042, Accuracy: 100.00%


                                                                      

Classifier Epoch [16/20] Loss: 0.0037, Accuracy: 100.00%


                                                                      

Classifier Epoch [17/20] Loss: 0.0033, Accuracy: 100.00%


                                                                      

Classifier Epoch [18/20] Loss: 0.0031, Accuracy: 100.00%


                                                                       

Classifier Epoch [19/20] Loss: 0.0028, Accuracy: 100.00%


                                                                      

Classifier Epoch [20/20] Loss: 0.0026, Accuracy: 100.00%
Test Accuracy: 97.91%
Training Subset Accuracy: 100.00%
GAN Pre-Training Framework Completed!


