# 55: Deep Convolutional GAN (DCGAN)

## ðŸŽ¯ Objective
In the previous notebook, we built a GAN using linear (fully connected) layers. While it worked for simple MNIST digits, linear layers ignore the spatial structure of images, often resulting in noisy or incoherent outputs for more complex data.

In this tutorial, we upgrade to a **Deep Convolutional GAN (DCGAN)**. We will use **Convolutional Layers** in the Discriminator to extract features and **Transpose Convolutional Layers** in the Generator to build images up from the latent space. We will apply this to the **Fashion-MNIST** dataset to generate realistic clothing items.

## ðŸ“š Key Concepts
* **DCGAN:** A GAN architecture that uses convolutional layers. It is more stable and produces higher quality images than MLP-based GANs.
* **Transpose Convolution (`nn.ConvTranspose2d`):** Often called "deconvolution" (though technically incorrect), this operation upsamples a small feature map (or 1x1 vector) into a larger image. It is the engine of the Generator.
* **Strided Convolutions:** Instead of Max Pooling (which loses information), the Discriminator uses convolutions with `stride=2` to downsample images.
* **Batch Normalization:** Applied after convolutional layers to stabilize the distribution of inputs, preventing mode collapse and helping gradients flow.
* **Adam Betas:** GANs are sensitive to momentum. We will adjust the $\beta$ parameters of the Adam optimizer to stabilize training.

## 1. Import Libraries

We import the standard stack. Note the addition of `sys` for printing progress bars.

In [None]:
# import libraries
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F

# for importing data
import torchvision
import torchvision.transforms as T
from torch.utils.data import DataLoader,Subset

import sys

import matplotlib.pyplot as plt
from IPython import display
display.set_matplotlib_formats('svg')

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

## 2. Data Preparation

We load Fashion-MNIST. 

### Important Transforms
1.  **Resize to 64x64:** Standard DCGAN architectures often work best with powers of 2 inputs (like 64). Since Fashion-MNIST is 28x28, we resize it to 64x64 to make the math for upsampling/downsampling cleaner (64 -> 32 -> 16 -> 8 -> 4).
2.  **Normalize:** We normalize to mean 0.5 and std 0.5 to get data in the range $[-1, 1]$, matching the `Tanh` output of our Generator.

In [None]:
# Import the data

# transformations
transform = T.Compose([ T.ToTensor(),
                        T.Resize(64),
                        T.Normalize(.5,.5),
                       ])

# import the data and simultaneously apply the transform
dataset = torchvision.datasets.FashionMNIST(root='./data', download=True, transform=transform)

### Filtering the Dataset
To make training slightly easier and the results more visually distinct, we will filter the dataset to keep only 3 categories: **Trousers, Sneakers, and Pullovers**.

In [None]:
# list the categories
print(dataset.classes)

# pick three categories (leave one line uncommented)
classes2keep = [ 'Trouser','Sneaker','Pullover' ]
# classes2keep = [ 'Trouser','Sneaker', 'Sandal'  ]



# find the corresponding data indices
images2use = torch.Tensor()
for i in range(len(classes2keep)):
  classidx = dataset.classes.index(classes2keep[i])
  images2use = torch.cat( (images2use,torch.where(dataset.targets==classidx)[0]), 0).type(torch.long)
  print(f'Added class {classes2keep[i]} (index {classidx})')

# now select just those images

# transform to dataloaders
batchsize   = 100
sampler     = torch.utils.data.sampler.SubsetRandomSampler(images2use)
data_loader = DataLoader(dataset,sampler=sampler,batch_size=batchsize,drop_last=True)

In [None]:
# view some images
# inspect a few random images

X,y = next(iter(data_loader))

fig,axs = plt.subplots(3,6,figsize=(10,6))

for (i,ax) in enumerate(axs.flatten()):

  # extract that image
  pic = torch.squeeze(X.data[i])
  pic = pic/2 + .5 # undo normalization

  # and its label
  label = dataset.classes[y[i]]

  # and show!
  ax.imshow(pic,cmap='gray')
  ax.text(14,0,label,ha='center',fontweight='bold',color='k',backgroundcolor='y')
  ax.axis('off')

plt.tight_layout()
plt.show()

## 3. The Discriminator (Convolutional)

The Discriminator is a binary classifier that takes an image (64x64) and outputs a probability (Real/Fake).

**Key Features:**
* **Strided Convolutions:** `stride=2` halves the spatial dimensions at each layer (64 -> 32 -> 16 -> 8 -> 4).
* **Batch Normalization:** Used after convolutions to stabilize training.
* **Leaky ReLU:** Standard for GAN discriminators.

In [None]:
# Create classes for the discriminator and generator

# Architecture and meta-parameter choices were inspired by https://pytorch.org/tutorials/beginner/dcgan_faces_tutorial.html

In [None]:
class discriminatorNet(nn.Module):
  def __init__(self):
    super().__init__()

    # convolution layers
    self.conv1 = nn.Conv2d(  1, 64, 4, 2, 1, bias=False)
    self.conv2 = nn.Conv2d( 64,128, 4, 2, 1, bias=False)
    self.conv3 = nn.Conv2d(128,256, 4, 2, 1, bias=False)
    self.conv4 = nn.Conv2d(256,512, 4, 2, 1, bias=False)
    self.conv5 = nn.Conv2d(512,  1, 4, 1, 0, bias=False)

    # batchnorm
    self.bn2 = nn.BatchNorm2d(128)
    self.bn3 = nn.BatchNorm2d(256)
    self.bn4 = nn.BatchNorm2d(512)

  def forward(self,x):
    x = F.leaky_relu( self.conv1(x) ,.2)
    x = F.leaky_relu( self.conv2(x) ,.2)
    x = self.bn2(x)
    x = F.leaky_relu( self.conv3(x) ,.2)
    x = self.bn3(x)
    x = F.leaky_relu( self.conv4(x) ,.2)
    x = self.bn4(x)
    return torch.sigmoid( self.conv5(x) ).view(-1,1)


dnet = discriminatorNet()
y = dnet(torch.randn(10,1,64,64))
y.shape

## 4. The Generator (Transpose Convolutional)

The Generator takes a latent vector (noise) and expands it into an image.

**Key Features:**
* **Input:** Random noise vector of size 100.
* **Transpose Convolutions (`ConvTranspose2d`):** Used to upsample. 
    * Input: 100 x 1 x 1
    * Conv1: 512 x 4 x 4
    * Conv2: 256 x 8 x 8
    * Conv3: 128 x 16 x 16
    * Conv4: 64 x 32 x 32
    * Conv5: 1 x 64 x 64 (The final image)
* **Activation:** `Tanh` at the end to map to $[-1, 1]$.

In [None]:
class generatorNet(nn.Module):
  def __init__(self):
    super().__init__()

    # convolution layers
    self.conv1 = nn.ConvTranspose2d(100,512, 4, 1, 0, bias=False)
    self.conv2 = nn.ConvTranspose2d(512,256, 4, 2, 1, bias=False)
    self.conv3 = nn.ConvTranspose2d(256,128, 4, 2, 1, bias=False)
    self.conv4 = nn.ConvTranspose2d(128, 64, 4, 2, 1, bias=False)
    self.conv5 = nn.ConvTranspose2d(64,   1, 4, 2, 1, bias=False)

    # batchnorm
    self.bn1 = nn.BatchNorm2d(512)
    self.bn2 = nn.BatchNorm2d(256)
    self.bn3 = nn.BatchNorm2d(128)
    self.bn4 = nn.BatchNorm2d( 64)


  def forward(self,x):
    x = F.relu( self.bn1(self.conv1(x)) )
    x = F.relu( self.bn2(self.conv2(x)) )
    x = F.relu( self.bn3(self.conv3(x)) )
    x = F.relu( self.bn4(self.conv4(x)) )
    x = torch.tanh( self.conv5(x) )
    return x


gnet = generatorNet()
y = gnet(torch.randn(10,100,1,1))
print(y.shape)
plt.imshow(y[0,:,:,:].squeeze().detach().numpy());

## 5. Training Setup

We set up the loss function and optimizers. 

**Note on Adam Betas:** We use `betas=(.5, .999)`. Lowering the first beta (momentum) from default 0.9 to 0.5 is a common trick in GAN training to prevent oscillation and instability.

In [None]:
# Train the models!
lossfun = nn.BCELoss()

dnet = discriminatorNet().to(device)
gnet = generatorNet().to(device)

d_optimizer = torch.optim.Adam(dnet.parameters(), lr=.0002, betas=(.5,.999))
g_optimizer = torch.optim.Adam(gnet.parameters(), lr=.0002, betas=(.5,.999))

## 6. Training Loop

We train for a calculated number of epochs. The loop follows the standard GAN procedure:
1.  **Discriminator Step:** Train on real images (Label=1) and fake images (Label=0).
2.  **Generator Step:** Generate fake images and try to fool the discriminator (Label=1).

Notice that the input noise to the generator is now shape `(batchsize, 100, 1, 1)` to match the expected input of `ConvTranspose2d`.

In [None]:
# number of epochs (expressed in number of batches)
num_epochs = int(2500/len(data_loader))

losses  = []
disDecs = []

for epochi in range(num_epochs):

  for data,_ in data_loader:

    # send data to GPU
    data = data.to(device)

    # create labels for real and fake images
    real_labels = torch.ones(batchsize,1).to(device)
    fake_labels = torch.zeros(batchsize,1).to(device)



    ### ---------------- Train the discriminator ---------------- ###

    # forward pass and loss for REAL pictures
    pred_real   = dnet(data)                     # output of discriminator
    d_loss_real = lossfun(pred_real,real_labels) # all labels are 1

    # forward pass and loss for FAKE pictures
    fake_data   = torch.randn(batchsize,100,1,1).to(device) # random numbers to seed the generator
    fake_images = gnet(fake_data)                           # output of generator
    pred_fake   = dnet(fake_images)                         # pass through discriminator
    d_loss_fake = lossfun(pred_fake,fake_labels)            # all labels are 0

    # collect loss (using combined losses)
    d_loss = d_loss_real + d_loss_fake

    # backprop
    d_optimizer.zero_grad()
    d_loss.backward()
    d_optimizer.step()



    ### ---------------- Train the generator ---------------- ###

    # create fake images and compute loss
    fake_images = gnet( torch.randn(batchsize,100,1,1).to(device) )
    pred_fake   = dnet(fake_images)

    # compute loss
    g_loss = lossfun(pred_fake,real_labels)

    # backprop
    g_optimizer.zero_grad()
    g_loss.backward()
    g_optimizer.step()


    # collect losses and discriminator decisions
    losses.append([d_loss.item(),g_loss.item()])

    d1 = torch.mean((pred_real>.5).float()).detach()
    d2 = torch.mean((pred_fake>.5).float()).detach()
    disDecs.append([d1,d2])


  # print out a status message
  msg = f'Finished epoch {epochi+1}/{num_epochs}'
  sys.stdout.write('\r' + msg)


# convert performance from list to numpy array
losses  = np.array(losses)
disDecs = np.array(disDecs)

## 7. Analysis and Visualization

We smooth the loss curves to see the trends better. Ideally, we want the discriminator's accuracy (probability of real) to hover around 0.5, meaning it can't tell real from fake.

In [None]:
# create a 1D smoothing filter
def smooth(x,k=15):
  return np.convolve(x,np.ones(k)/k,mode='same')

In [None]:
fig,ax = plt.subplots(1,3,figsize=(18,5))

ax[0].plot(smooth(losses[:,0]))
ax[0].plot(smooth(losses[:,1]))
ax[0].set_xlabel('Batches')
ax[0].set_ylabel('Loss')
ax[0].set_title('Model loss')
ax[0].legend(['Discrimator','Generator'])
# ax[0].set_xlim([500,2300])
# ax[0].set_ylim([-.5,6])

ax[1].plot(losses[::5,0],losses[::5,1],'k.',alpha=.1)
ax[1].set_xlabel('Discriminator loss')
ax[1].set_ylabel('Generator loss')

ax[2].plot(smooth(disDecs[:,0]))
ax[2].plot(smooth(disDecs[:,1]))
ax[2].set_xlabel('Epochs')
ax[2].set_ylabel('Probablity ("real")')
ax[2].set_title('Discriminator output')
ax[2].legend(['Real','Fake'])

plt.show()

## 8. Generating Fashion

Let's see what our generator has learned! We run the generator in `eval()` mode and feed it a fresh batch of noise.

In [None]:
# Let's see some fake fashion!

# generate the images from the generator network
gnet.eval()
fake_data = gnet( torch.randn(batchsize,100,1,1).to(device) ).cpu()

# and visualize...
fig,axs = plt.subplots(3,6,figsize=(12,6))
for i,ax in enumerate(axs.flatten()):
  ax.imshow(fake_data[i,:,].detach().squeeze(),cmap='gray')
  ax.axis('off')

plt.suptitle(classes2keep,y=.95,fontweight='bold')
plt.show()