## Step 1: Import Necessary Libraries
We first import the required libraries for building and training the models.

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from torch.nn.parallel import DataParallel
import matplotlib.pyplot as plt

## Step 2: Set Up Device
We use CUDA (GPU) if available, otherwise default to CPU.

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

## Step 3: Define Data Transformations and Load MNIST Dataset
The MNIST dataset is used for both classification and generation tasks. We apply normalization as a preprocessing step.

In [None]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

train_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
test_dataset = datasets.MNIST(root='./data', train=False, transform=transform, download=True)

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

## Step 4: Define Classifier Model
The classifier is a simple feedforward neural network that takes flattened MNIST images as input and predicts class labels.

In [None]:
class Classifier(nn.Module):
    def __init__(self):
        super(Classifier, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(28*28, 128),
            nn.ReLU(),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, 10)
        )

    def forward(self, x):
        x = x.view(x.size(0), -1)
        return self.fc(x)

## Step 5: Define DCGAN Generator
The generator creates fake MNIST images from random noise using transposed convolutional layers.

In [None]:
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.gen = nn.Sequential(
            nn.ConvTranspose2d(100, 128, kernel_size=7, stride=1, padding=0, bias=False),
            nn.BatchNorm2d(128),
            nn.ReLU(True),
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(True),
            nn.ConvTranspose2d(64, 1, kernel_size=4, stride=2, padding=1, bias=False),
            nn.Tanh()
        )

    def forward(self, x):
        x = x.view(-1, 100, 1, 1)
        return self.gen(x)

## Step 6: Define DCGAN Critic (Discriminator)
The critic evaluates both real and fake MNIST images using convolutional layers.

In [None]:
class Critic(nn.Module):
    def __init__(self):
        super(Critic, self).__init__()
        self.critic = nn.Sequential(
            nn.Conv2d(1, 64, kernel_size=4, stride=2, padding=1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(128, 1, kernel_size=7, stride=1, padding=0, bias=False)
        )

    def forward(self, x):
        return self.critic(x).view(-1)

## Step 7: Gradient Penalty for WGAN-GP
To improve GAN stability, we compute the gradient penalty for interpolated images.

In [None]:
def gradient_penalty(critic, real_images, fake_images):
    batch_size, c, h, w = real_images.shape
    epsilon = torch.rand(batch_size, 1, 1, 1).to(device)
    interpolated_images = epsilon * real_images + (1 - epsilon) * fake_images
    interpolated_images.requires_grad_(True)

    interpolated_outputs = critic(interpolated_images)
    grad_outputs = torch.ones_like(interpolated_outputs, device=device)

    gradients = torch.autograd.grad(
        outputs=interpolated_outputs,
        inputs=interpolated_images,
        grad_outputs=grad_outputs,
        create_graph=True,
        retain_graph=True
    )[0]
    gradients = gradients.view(gradients.size(0), -1)
    penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean()
    return penalty

## Step 8: Initialize Models and Optimizers

In [None]:
classifier = Classifier().to(device)
generator = Generator().to(device)
critic = Critic().to(device)

classifier = DataParallel(classifier)
generator = DataParallel(generator)
critic = DataParallel(critic)

optimizer_classifier = optim.Adam(classifier.parameters(), lr=0.001)
optimizer_generator = optim.Adam(generator.parameters(), lr=0.0002)
optimizer_critic = optim.Adam(critic.parameters(), lr=0.0002)