In [21]:
import os
from pathlib import Path
from tqdm import tqdm
import matplotlib.pyplot as plt

import torch
import torchvision
import torchinfo
from torch.utils.tensorboard import SummaryWriter

import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
import torchvision.datasets as datasets


#### Extracted Information From the Paper

![image](https://pytorch.org/tutorials/_images/dcgan_generator.png)

On paper, kernal size = 5 is used, but dimensions does not align for the discriminator later (require padding = 3/2 for stride = 2 to be used as well)

=> Thus, kernal size used will be 4 instead!

---

### 1. DCGAN Generator (from Figure)

- **Input**: 100-dimensional uniform distribution (Z)
- **Layers**: A series of four fractionally-strided convolutions  
  *(Note: These are sometimes mistakenly referred to as deconvolutions in recent papers)*  
  *[Configuration as per figure: stride = 2, kernal size = 4, use formulas to calcuate the padding required]* 
- **Output**: Converts the high-level representation into a 64 × 64 pixel image.


### 2. Architecture Guidelines for Stable Deep Convolutional GANs

- **Replace pooling layers**:
  - Use **strided convolutions** in the discriminator.
  - Use **fractional-strided convolutions** in the generator.  
  *(The discriminator mirrors the generator.)*
  
- **Batch normalization**:
  - Apply **batchnorm** in both the generator and discriminator.
  - DO NOT APPLY batchnorm to the generator output layer and the discriminator input layer.
  
- **Deeper architectures**:
  - Remove **fully connected hidden layers**.

- **Activation functions**:
  - Use **ReLU** activation in the generator for all layers, except the output, which uses **Tanh**.
  - Use **LeakyReLU** activation in the discriminator for all layers.  
    *(Leak slope set to 0.2 in all models.)*


### 3. Adversarial Training

- **Image preprocessing**:  
  - Resize image dimensions to 64 by 64
  - No preprocessing other than scaling to the range of the Tanh activation function \([-1, 1]\).  
    *(To match the Generator's Tanh output.)*
  
- **Batch size**: 128

- **Weight initialization**:  
  - Weights are initialized from a **zero-centered Normal Distribution** with a standard deviation of 0.02.

- **Optimizer**:  
  - **Adam Optimizer** with learning rate lr = 0.0002 and momentum term B_1 = 0.5

---


### 1. Create DCGAN's Generator and Discriminator Classes

#### 1.1 Generator Class

In [2]:
# Fractional Strided Convolutional Layers for the Generator
# - Configuration as per figure: stride = 2, kernal size = 4
# - Use formula to calcuate the padding required:
#       H[out] = (H[in] - 1) * stride - 2 * padding + kernal_size + output_padding
#   => output_padding - 2*padding = -2 for all cases 
#   => Let output_padding = 0, padding = 1
#   => This config will x2 to the image dim (img_height and img width) for each convultional layer

# Generator:
# 1. Input: 100-dimensional uniform distribution (Z)
# 2. Projection layer: to 1024*4*4 (to be reshaped before sending to f-s convolutional layers)
# 3. A series of four fractionally-strided convolutions  (stride = 2, kernal size = 4) 
#   3.1 f-s conv: Output Chanels = 512, img_dim from 4*4 to 8*8
#   3.2 f-s conv: Output Chanels = 256, img_dim from 8*8 to 16*16
#   3.3 f-s conv: Output Chanels = 128, img_dim from 4*4 to 32*32
#   3.4. [Output]: f-s conv: Output Chanels = 3, img_dim from 32*32 to 64*64
# *Batch Norm to be applied (except last layer of generator)
# *ReLU all layers (except Output layer)
# *TanH for output layer

In [3]:
class Generator(nn.Module):

    def __init__(self, noise_channels=100, img_channels=3):
        super().__init__()
        self.conv_layers = nn.Sequential(

            # 1st fractional strided convolution layer (upsample from 1*1 -> 4*4)
            # Projection layer, to convert the z of 100 inputs to 1024 * 4 * 4 (noise_channels = z_dim)
            # Each input (z) will be actually reshaped to 100 * 1 * 1 (100 channels)
            # (to ensure from 1x1 -> 4x4, with stride = 2 and kernal = 4, we need padding = 0 now (for a x4 increase))
            self._block(in_channels=noise_channels, out_channels=1024, kernel_size=4, stride=2, padding=0),

            # 2nd fractional strided convolution layer (upsample from 4*4 -> 8*8)
            self._block(in_channels=1024, out_channels=512, kernel_size=4, stride=2, padding=1),

            # 3rd fractional strided convolution layer (upsample from 8*8 -> 16*16)
            self._block(in_channels=512, out_channels=256, kernel_size=4, stride=2, padding=1),
            
            # 4th fractional strided convolution layer (upsample from 16*16 -> 32*32)
            self._block(in_channels=256, out_channels=128, kernel_size=4, stride=2, padding=1),

            # Output fractional strided convolution layer (upsample from 32*32 -> 64*64)
            nn.ConvTranspose2d(in_channels=128, out_channels=img_channels, kernel_size=4, stride=2, padding=1),
            nn.Tanh()
        )
    
    def _block(self, in_channels, out_channels, kernel_size, stride, padding, batch_norm=True):

        return nn.Sequential(
            nn.ConvTranspose2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
        ) if batch_norm else nn.Sequential(
            nn.ConvTranspose2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding),
            nn.ReLU(),
        )

    def forward(self, z):
        return self.conv_layers(z)

In [4]:
# Check Generator
batch_size = 32
z_dim = 100
z_dummy = torch.randn(size=(batch_size, z_dim))
z_dummy = z_dummy.view(-1, 100, 1, 1)
print(z_dummy.shape)

generator = Generator(img_channels=3, noise_channels=z_dim)
torchinfo.summary(model=generator, input_size=[z_dummy.shape]) # ensure output shape = [batch size * 3 * 64 * 64]

torch.Size([32, 100, 1, 1])


Layer (type:depth-idx)                   Output Shape              Param #
Generator                                [32, 3, 64, 64]           --
├─Sequential: 1-1                        [32, 3, 64, 64]           --
│    └─Sequential: 2-1                   [32, 1024, 4, 4]          --
│    │    └─ConvTranspose2d: 3-1         [32, 1024, 4, 4]          1,639,424
│    │    └─BatchNorm2d: 3-2             [32, 1024, 4, 4]          2,048
│    │    └─ReLU: 3-3                    [32, 1024, 4, 4]          --
│    └─Sequential: 2-2                   [32, 512, 8, 8]           --
│    │    └─ConvTranspose2d: 3-4         [32, 512, 8, 8]           8,389,120
│    │    └─BatchNorm2d: 3-5             [32, 512, 8, 8]           1,024
│    │    └─ReLU: 3-6                    [32, 512, 8, 8]           --
│    └─Sequential: 2-3                   [32, 256, 16, 16]         --
│    │    └─ConvTranspose2d: 3-7         [32, 256, 16, 16]         2,097,408
│    │    └─BatchNorm2d: 3-8             [32, 256, 16, 16]

#### 1.2 Discriminator Class

In [5]:
# Strided Convolutional Layer for the Discriminator
# - Configuration as per figure: stride = 2, kernal size = 4
# - Use formula to calcuate the padding required:
#       H[out] = [ (H[in] - kernal_size + 2 * padding) / stride ] + 1

#   => padding = 1 for all cases 
#   => Let padding = 1
#   => This config will x2 to the image dim (img_height and img width) for each convultional layer

# Discriminator: (just the mirror opposite of the configuration, with strided convolutional layers instead of fractional strided convolutional layers)
# 1. Input: (3 by 64 by 64) images
# 2. A series of four strided convolutions  (stride = 2, kernal size = 4) 
#   2.1 f-s conv: Output Chanels = 128, img_dim from 64*64 to 32*32
#   2.2 f-s conv: Output Chanels = 256, img_dim from 32*32 to 16*16 
#   2.3 f-s conv: Output Chanels = 512, img_dim from 16*16 to 8*8
#   2.4. [Output]: f-s conv: Output Chanels = 1024, img_dim from 8*8 to 4*4
# *Batch Norm to be applied (except first layer for the discriminator)
# *LeakyReLU all layers, slope set to 0.2

In [6]:
class Discriminator(nn.Module):

    def __init__(self, img_channels=3):
        super().__init__()

        self.conv_layers = nn.Sequential(
            
            # 1st fractional strided convolution layer (downsample from 64*64 -> 32*32)
            self._block(in_channels=img_channels, out_channels=128, kernel_size=4, stride=2, padding=1, batch_norm=False),

            # 2nd fractional strided convolution layer (downsample from 32*32 -> 16*16)
            self._block(in_channels=128, out_channels=256, kernel_size=4, stride=2, padding=1),
            
            # 3rd fractional strided convolution layer (downsample from 16*16 -> 8*8)
            self._block(in_channels=256, out_channels=512, kernel_size=4, stride=2, padding=1),

            # Output fractional strided convolution layer (downsample from 8*8 -> 4*4)
            self._block(in_channels=512, out_channels=1024, kernel_size=4, stride=2, padding=1),
            
            # Classifier
            # No fully connected layer for DCGAN, use another way (instead of nn.Flatten(), nn.Linear(in_features=1024*4*4, out_features=1))
            # Use another convolutional layer (to ensure from 4x4 to 1x1, with stride = 2 and kernal = 4, we need padding = 0 now (for a x4 reduction))
            nn.Conv2d(in_channels=1024, out_channels=1, kernel_size=4, stride=2, padding=0),
            nn.Sigmoid() # ensure prediction is within [0, 1]
        )

    def _block(self, in_channels, out_channels, kernel_size, stride, padding, batch_norm=True):

        return nn.Sequential(
            nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding),
            nn.BatchNorm2d(out_channels),
            nn.LeakyReLU(negative_slope=0.2)
        ) if batch_norm else nn.Sequential(
            nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding),
            nn.LeakyReLU(negative_slope=0.2)
        )
    
    def forward(self, x):
        return self.conv_layers(x)

In [7]:
# Check Discriminator
batch_size = 32
img_dim = 64
img_channels=3
x_dummy = torch.randn(size=(batch_size, 3, img_dim, img_dim))
print(x_dummy.shape)

discriminator = Discriminator(img_channels=3)
torchinfo.summary(model=discriminator, input_size=[x_dummy.shape]) # ensure output shape = [batch size * 1 * 1 * 1]

torch.Size([32, 3, 64, 64])


Layer (type:depth-idx)                   Output Shape              Param #
Discriminator                            [32, 1, 1, 1]             --
├─Sequential: 1-1                        [32, 1, 1, 1]             --
│    └─Sequential: 2-1                   [32, 128, 32, 32]         --
│    │    └─Conv2d: 3-1                  [32, 128, 32, 32]         6,272
│    │    └─LeakyReLU: 3-2               [32, 128, 32, 32]         --
│    └─Sequential: 2-2                   [32, 256, 16, 16]         --
│    │    └─Conv2d: 3-3                  [32, 256, 16, 16]         524,544
│    │    └─BatchNorm2d: 3-4             [32, 256, 16, 16]         512
│    │    └─LeakyReLU: 3-5               [32, 256, 16, 16]         --
│    └─Sequential: 2-3                   [32, 512, 8, 8]           --
│    │    └─Conv2d: 3-6                  [32, 512, 8, 8]           2,097,664
│    │    └─BatchNorm2d: 3-7             [32, 512, 8, 8]           1,024
│    │    └─LeakyReLU: 3-8               [32, 512, 8, 8]          

### 2. Instantiate the Models

In [8]:
# Weights are initialized from Normal Distribution with mean = 0; standard deviation = 0.02.
def initialize_weights(model):
    for m in model.modules():
        if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d, nn.BatchNorm2d)):
            nn.init.normal_(m.weight.data, 0.0, 0.02)

In [9]:
def test():

    N, in_channels, H, W = 8, 3, 64, 64
    x = torch.randn((N, in_channels, H, W))
    disc = Discriminator(img_channels=3)
    assert disc(x).shape == (N, 1, 1, 1), "Discriminator test failed"

    z_dim = 100
    z = torch.randn((N, z_dim, 1, 1))
    gen = Generator(noise_channels=z_dim, img_channels=3)
    assert gen(z).shape == (N, in_channels, H, W), "Generator test failed"

    print("Success, tests passed!")

In [10]:
test()

Success, tests passed!


### 3. Training

In [11]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(device)

cuda


In [12]:
# HyperParameters

epochs = 10

# Following the DCGAN paper
BATCH_SIZE = 128
IMG_SIZE = 64
IMG_CHANNELS = 3 # can be changed wrt to images (althought DCGAN paper, input channels of images is to be 3)
Z_DIM = 100
LEARNING_RATE = 2e-4
B1 = 0.5

In [13]:
# Create the Transformations
# Ensure:
# 1. Image is resized to 64*64
# 2. Ensure that Input Images are normalised such that they are within [-1, 1] (to follow generator's tanh output of [-1, 1])

transforms = transforms.Compose(
    [
        transforms.Resize(size=(IMG_SIZE, IMG_SIZE)),
        transforms.ToTensor(),
        transforms.Normalize(
            [0.5 for _ in range(IMG_CHANNELS)],
            [0.5 for _ in range(IMG_CHANNELS)]
        )
    ]
)

In [15]:
# Load the data
# Create the dataset
data_dir = Path().cwd().parent.parent / "data"
image_folder_name = "celeb_A"

dataset = torchvision.datasets.ImageFolder(root=data_dir/image_folder_name, transform=transforms)

# Create the dataloader
NUM_WORKERS = os.cpu_count()

dataloader = torch.utils.data.DataLoader(
    dataset=dataset,
    batch_size=BATCH_SIZE,
    num_workers=NUM_WORKERS,
    shuffle=True,
    pin_memory=True
)

In [16]:
# x_sample = x[0]
# x_sample = x_sample.permute(1, 2, 0).numpy()
# plt.imshow(x_sample)
# x, y = next(iter(dataloader))

In [50]:
# Initialise the Models

generator = Generator(noise_channels=Z_DIM, img_channels=IMG_CHANNELS).to(device)
discriminator = Discriminator(img_channels=IMG_CHANNELS).to(device)

initialize_weights(generator)
initialize_weights(discriminator)

In [51]:
# Create loss function and optimizer
criterion = nn.BCELoss()
optimizer_D = optim.Adam(params=discriminator.parameters(), lr=LEARNING_RATE, betas=(B1, 0.999)) # b2 kept as default
optimizer_G = optim.Adam(params=generator.parameters(), lr=LEARNING_RATE, betas=(B1, 0.999))  # b2 kept as default

In [52]:
fixed_noise = torch.randn(BATCH_SIZE, Z_DIM, 1, 1).to(device)
writer_real = SummaryWriter(f"runs/MNIST/real")
writer_fake = SummaryWriter(f"runs/MNIST/fake")
step = 0

In [53]:
for epoch in tqdm(range(epochs)):

    for batch_idx, (x, _) in enumerate(dataloader):
        
        x = x.to(device)
        noise = torch.randn(size = (x.shape[0], Z_DIM, 1, 1)).to(device)

        g_z = generator(noise) # G(z)
        d_x = discriminator(x).reshape(-1) # D(x), reshape from 1*1*1 to 1
        d_g_z = discriminator(g_z).reshape(-1) # D(G(z)), reshape from 1*1*1 to 1

        ### Train the Discriminator: Min -(log(D(x)) + log(1-D(G(Z)))) <---> Max log(D(x)) + log(1-D(G(Z)))

        loss_real_D = criterion(d_x, torch.ones_like(d_x)) # -log(D(X))
        loss_fake_D = criterion(d_g_z, torch.zeros_like(d_g_z)) # -log(1-D(G(z)))
        loss_D = loss_fake_D + loss_real_D #-(log(D(x)) + log(1-D(G(Z))))

        optimizer_D.zero_grad()

        loss_D.backward(retain_graph=True)

        optimizer_D.step()

        ### Train the Generator: Min -log(D(G(z)) <---> Max log(D(G(z))) <---> Min log(1-D(G(z)))
        d_g_z_next = discriminator(g_z).reshape(-1) # after training the disc, new D(G(z)), reshape from 1*1*1 to 1
        loss_G = criterion(d_g_z_next, torch.ones_like(d_g_z_next)) # -log(D(G(z)))

        optimizer_G.zero_grad()

        loss_G.backward()

        optimizer_G.step()

         # Print losses occasionally and print to tensorboard
        if batch_idx % 100 == 0:
            print(
                f"Epoch [{epoch}/{epochs}] Batch {batch_idx}/{len(dataloader)} \
                  Loss D: {loss_D:.4f}, loss G: {loss_G:.4f}"
            )

            with torch.no_grad():
                fake = generator(fixed_noise)
                # take out (up to) 32 examples
                img_grid_real = torchvision.utils.make_grid(x[:32], normalize=True)
                img_grid_fake = torchvision.utils.make_grid(fake[:32], normalize=True)

                writer_real.add_image("Real", img_grid_real, global_step=step)
                writer_fake.add_image("Fake", img_grid_fake, global_step=step)

            step += 1

  0%|          | 0/10 [00:00<?, ?it/s]

Epoch [0/10] Batch 0/469                   Loss D: 1.4009, loss G: 0.9226
Epoch [0/10] Batch 100/469                   Loss D: 0.0090, loss G: 5.4320
Epoch [0/10] Batch 200/469                   Loss D: 0.0425, loss G: 4.1582
Epoch [0/10] Batch 300/469                   Loss D: 0.9446, loss G: 2.1268
Epoch [0/10] Batch 400/469                   Loss D: 1.0685, loss G: 1.3641


 10%|█         | 1/10 [03:49<34:24, 229.35s/it]

Epoch [1/10] Batch 0/469                   Loss D: 0.9675, loss G: 2.0899
Epoch [1/10] Batch 100/469                   Loss D: 0.9081, loss G: 1.0731
Epoch [1/10] Batch 200/469                   Loss D: 0.8497, loss G: 1.0624
Epoch [1/10] Batch 300/469                   Loss D: 0.9271, loss G: 0.9187
Epoch [1/10] Batch 400/469                   Loss D: 0.6022, loss G: 3.5060


 20%|██        | 2/10 [07:38<30:32, 229.10s/it]

Epoch [2/10] Batch 0/469                   Loss D: 0.7099, loss G: 3.3171
Epoch [2/10] Batch 100/469                   Loss D: 1.1726, loss G: 3.9874
Epoch [2/10] Batch 200/469                   Loss D: 0.1685, loss G: 3.1225
Epoch [2/10] Batch 300/469                   Loss D: 0.3734, loss G: 3.8617
Epoch [2/10] Batch 400/469                   Loss D: 0.4614, loss G: 0.8481


 30%|███       | 3/10 [11:26<26:41, 228.84s/it]

Epoch [3/10] Batch 0/469                   Loss D: 0.4336, loss G: 2.6499
Epoch [3/10] Batch 100/469                   Loss D: 0.5677, loss G: 3.1196
Epoch [3/10] Batch 200/469                   Loss D: 0.5899, loss G: 5.7335
Epoch [3/10] Batch 300/469                   Loss D: 1.3081, loss G: 1.6189
Epoch [3/10] Batch 400/469                   Loss D: 0.1410, loss G: 3.2208


 40%|████      | 4/10 [15:15<22:53, 228.92s/it]

Epoch [4/10] Batch 0/469                   Loss D: 0.3334, loss G: 3.2590
Epoch [4/10] Batch 100/469                   Loss D: 0.8403, loss G: 2.9569
Epoch [4/10] Batch 200/469                   Loss D: 0.2083, loss G: 3.1357
Epoch [4/10] Batch 300/469                   Loss D: 0.7014, loss G: 1.7803
Epoch [4/10] Batch 400/469                   Loss D: 0.2447, loss G: 2.9042


 50%|█████     | 5/10 [19:03<19:02, 228.44s/it]

Epoch [5/10] Batch 0/469                   Loss D: 1.1111, loss G: 1.1240
Epoch [5/10] Batch 100/469                   Loss D: 0.1598, loss G: 4.0268
Epoch [5/10] Batch 200/469                   Loss D: 2.8133, loss G: 10.2775
Epoch [5/10] Batch 300/469                   Loss D: 0.1051, loss G: 3.9182
Epoch [5/10] Batch 400/469                   Loss D: 0.1281, loss G: 3.3564


 60%|██████    | 6/10 [22:52<15:14, 228.63s/it]

Epoch [6/10] Batch 0/469                   Loss D: 0.1476, loss G: 3.1856
Epoch [6/10] Batch 100/469                   Loss D: 0.1039, loss G: 4.3778
Epoch [6/10] Batch 200/469                   Loss D: 0.1845, loss G: 2.8277
Epoch [6/10] Batch 300/469                   Loss D: 0.2607, loss G: 3.8172
Epoch [6/10] Batch 400/469                   Loss D: 0.1992, loss G: 3.4054


 70%|███████   | 7/10 [26:41<11:26, 228.67s/it]

Epoch [7/10] Batch 0/469                   Loss D: 0.0659, loss G: 4.5583
Epoch [7/10] Batch 100/469                   Loss D: 0.5290, loss G: 4.0268
Epoch [7/10] Batch 200/469                   Loss D: 0.8195, loss G: 1.6489
Epoch [7/10] Batch 300/469                   Loss D: 0.0881, loss G: 3.1335
Epoch [7/10] Batch 400/469                   Loss D: 0.2595, loss G: 4.1144


 80%|████████  | 8/10 [30:29<07:37, 228.70s/it]

Epoch [8/10] Batch 0/469                   Loss D: 1.3147, loss G: 1.7915
Epoch [8/10] Batch 100/469                   Loss D: 0.1230, loss G: 3.2608
Epoch [8/10] Batch 200/469                   Loss D: 0.0569, loss G: 4.2130
Epoch [8/10] Batch 300/469                   Loss D: 0.3214, loss G: 3.2605
Epoch [8/10] Batch 400/469                   Loss D: 0.1070, loss G: 3.8579


 90%|█████████ | 9/10 [34:19<03:48, 228.87s/it]

Epoch [9/10] Batch 0/469                   Loss D: 0.0652, loss G: 3.5935
Epoch [9/10] Batch 100/469                   Loss D: 0.2918, loss G: 2.6131
Epoch [9/10] Batch 200/469                   Loss D: 0.5380, loss G: 1.7338
Epoch [9/10] Batch 300/469                   Loss D: 2.5044, loss G: 0.6034
Epoch [9/10] Batch 400/469                   Loss D: 0.1918, loss G: 3.3484


100%|██████████| 10/10 [38:08<00:00, 228.84s/it]


Archives

In [54]:
class Generator(nn.Module):

    def __init__(self, noise_channels=100, img_channels=3):
        super().__init__()
        
        # Projection layer, to convert the z of 100 inputs to 1024 * 4 * 4 (noise_channels = z_dim)
        # Each input (z) will be actually reshaped to 100 * 1 * 1 (100 channels)
        # (to ensure from 1x1 to 4x4, with stride = 2 and kernal = 4, we need padding = 0 now (for a x4 increase))
        self.projection = nn.Sequential(
            nn.ConvTranspose2d(in_channels=noise_channels, out_channels=1024, kernel_size=4, stride=2, padding=0),
            nn.BatchNorm2d(1024),
            nn.ReLU()
        )

        self.conv_layers = nn.Sequential(
            
            # 1st fractional strided convolution layer (upsample from 4*4 -> 8*8)
            nn.ConvTranspose2d(in_channels=1024, out_channels=512, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(),

            # 2nd fractional strided convolution layer (upsample from 8*8 -> 16*16)
            nn.ConvTranspose2d(in_channels=512, out_channels=256, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            
            # 3rd fractional strided convolution layer (upsample from 16*16 -> 32*32)
            nn.ConvTranspose2d(in_channels=256, out_channels=128, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),

            # Output fractional strided convolution layer (upsample from 32*32 -> 64*64)
            nn.ConvTranspose2d(in_channels=128, out_channels=img_channels, kernel_size=4, stride=2, padding=1),
            nn.Tanh()
        )
    
    def forward(self, z):
        z_projected = self.projection(z) # project each z from z_dim*1*1 into 1024*4*4
        return self.conv_layers(z_projected)
    

#### DEGRADED (ENSURE NO FULLY CONNECTED LAYER (EVEN FOR PROJECTION))
# class Generator(nn.Module):

#     def __init__(self, z_dim=100):
#         super().__init__()

#         self.projection = nn.Sequential(
#             nn.Linear(in_features=z_dim, out_features=1024*4*4),
#             nn.BatchNorm1d(num_features=1024*4*4),
#             nn.ReLU()
#         )

#         self.conv_layers = nn.Sequential(
            
#             # 1st fractional strided convolution layer (upsample from 4*4 -> 8*8)
#             nn.ConvTranspose2d(in_channels=1024, out_channels=512, kernel_size=4, stride=2, padding=1),
#             nn.BatchNorm2d(512),
#             nn.ReLU(),

#             # 2nd fractional strided convolution layer (upsample from 8*8 -> 16*16)
#             nn.ConvTranspose2d(in_channels=512, out_channels=256, kernel_size=4, stride=2, padding=1),
#             nn.BatchNorm2d(256),
#             nn.ReLU(),
            
#             # 3rd fractional strided convolution layer (upsample from 16*16 -> 32*32)
#             nn.ConvTranspose2d(in_channels=256, out_channels=128, kernel_size=4, stride=2, padding=1),
#             nn.BatchNorm2d(128),
#             nn.ReLU(),

#             # Output fractional strided convolution layer (upsample from 32*32 -> 64*64)
#             nn.ConvTranspose2d(in_channels=128, out_channels=3, kernel_size=4, stride=2, padding=1),
#             nn.Tanh()
#         )
    
#     def forward(self, z):
#         z_projected = self.projection(z) # project each z into a linear vector 1 by 1024*4*4 (output shape = batch_size by 1024*4*4)
#         z_projected_reshaped = z_projected.view(-1, 1024, 4, 4) # to reshape the projection after feeding into the convolutional layers (retain batch size)
#         return self.conv_layers(z_projected_reshaped)
    

# batch_size = 32
# z_dim = 100
# fixed_z = torch.randn(batch_size, z_dim)

# generator = Generator(z_dim=z_dim)

# output = generator(fixed_z)

# print(output.shape) # ensure [batch size by 3 by 64 by 64]

In [55]:
class Discriminator(nn.Module):

    def __init__(self, img_channels=3):
        super().__init__()

        self.conv_layers = nn.Sequential(
            
            # 1st fractional strided convolution layer (downsample from 64*64 -> 32*32)
            nn.Conv2d(in_channels=img_channels, out_channels=128, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(negative_slope=0.2),

            # 2nd fractional strided convolution layer (upsample from 32*32 -> 16*16)
            nn.Conv2d(in_channels=128, out_channels=256, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(negative_slope=0.2),
            
            # 3rd fractional strided convolution layer (upsample from 16*16 -> 8*8)
            nn.Conv2d(in_channels=256, out_channels=512, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(negative_slope=0.2),

            # Output fractional strided convolution layer (upsample from 8*8 -> 4*4)
            nn.Conv2d(in_channels=512, out_channels=1024, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(1024),
            nn.LeakyReLU(negative_slope=0.2)
        )

        # No fully connected layer for DCGAN, use another way (instead of nn.Flatten(), nn.Linear(in_features=1024*4*4, out_features=1))
        # Use another convolutional layer (to ensure from 4x4 to 1x1, with stride = 2 and kernal = 4, we need padding = 0 now (for a x4 reduction))
        self.classifier = nn.Sequential(
            nn.Conv2d(in_channels=1024, out_channels=1, kernel_size=4, stride=2, padding=0),
            nn.Sigmoid() # ensure prediction is within [0, 1]
        )
    
    def forward(self, x):
        return self.classifier(self.conv_layers(x))