## Import Libraries

In [1]:
import os

import torch
from torch import nn, optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

## Set Up Data Augmentation and DataLoader

In [2]:
# Set up data augmentation
transform = transforms.Compose(
    [
        transforms.Resize(64),
        transforms.CenterCrop(64),
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,)),  # Normalizing the images to [-1, 1]
    ]
)

# Load your dataset
dataset = datasets.ImageFolder(root="./training/images/", transform=transform)

# Create DataLoader
dataloader = DataLoader(
    dataset, batch_size=32, shuffle=True, num_workers=os.cpu_count()
)

## Define the Generator and Discriminator

In [3]:
# Generator model
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.main = nn.Sequential(
            nn.ConvTranspose2d(100, 512, 4, 1, 0, bias=False),
            nn.BatchNorm2d(512),
            nn.ReLU(True),
            nn.ConvTranspose2d(512, 256, 4, 2, 1, bias=False),
            nn.BatchNorm2d(256),
            nn.ReLU(True),
            nn.ConvTranspose2d(256, 128, 4, 2, 1, bias=False),
            nn.BatchNorm2d(128),
            nn.ReLU(True),
            nn.ConvTranspose2d(128, 64, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(True),
            nn.ConvTranspose2d(64, 3, 4, 2, 1, bias=False),
            nn.Tanh(),
        )

    def forward(self, input):
        return self.main(input)


# Discriminator model
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.main = nn.Sequential(
            nn.Conv2d(3, 64, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, 128, 4, 2, 1, bias=False),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(128, 256, 4, 2, 1, bias=False),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(256, 512, 4, 2, 1, bias=False),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(512, 1, 4, 1, 0, bias=False),
            nn.Sigmoid(),
        )

    def forward(self, input):
        return self.main(input)

## Initialize Models, Loss Function, and Optimizers

In [4]:
# Initialize models
device = torch.device("cpu")
G = Generator().to(device)
D = Discriminator().to(device)

# Loss function
criterion = nn.BCELoss()

# Optimizers
lr = 0.0002
beta1 = 0.5
G_optimizer = optim.Adam(G.parameters(), lr=lr, betas=(beta1, 0.999))
D_optimizer = optim.Adam(D.parameters(), lr=lr, betas=(beta1, 0.999))

## Training Loop

In [5]:
# Create instances of the models
G = Generator().to(device)
D = Discriminator().to(device)

# Loss function
criterion = nn.BCELoss()

# Optimizers
G_optimizer = optim.Adam(G.parameters(), lr=0.0002, betas=(0.5, 0.999))
D_optimizer = optim.Adam(D.parameters(), lr=0.0002, betas=(0.5, 0.999))

# Training loop
num_epochs = 5  # Number of epochs. Adjust as needed
for epoch in range(num_epochs):
    for i, (images, _) in enumerate(dataloader):
        # Training the Discriminator
        D.zero_grad()
        real_images = images.to(device)
        batch_size = real_images.size(0)
        labels = torch.full((batch_size,), 1, dtype=torch.float, device=device)

        output = D(real_images).view(-1)
        loss_real = criterion(output, labels)
        loss_real.backward()
        D_x = output.mean().item()

        # Generate fake images
        noise = torch.randn(batch_size, 100, 1, 1, device=device)
        fake_images = G(noise)
        labels.fill_(0)

        output = D(fake_images.detach()).view(-1)
        loss_fake = criterion(output, labels)
        loss_fake.backward()
        D_G_z1 = output.mean().item()

        loss_D = loss_real + loss_fake
        D_optimizer.step()

        # Training the Generator
        G.zero_grad()
        labels.fill_(1)  # The generator's goal is to fool the discriminator
        output = D(fake_images).view(-1)
        loss_G = criterion(output, labels)
        loss_G.backward()
        D_G_z2 = output.mean().item()
        G_optimizer.step()

        if i % 50 == 0:
            print(
                f"[{epoch}/{num_epochs}][{i}/{len(dataloader)}] "
                f"Loss_D: {loss_D.item():.4f} Loss_G: {loss_G.item():.4f} "
                f"D(x): {D_x:.4f} D(G(z)): {D_G_z1:.4f}/{D_G_z2:.4f}"
            )

    # Optionally, save the models periodically or after training
    # torch.save(G.state_dict(), f'generator_epoch_{epoch}.pth')
    # torch.save(D.state_dict(), f'discriminator_epoch_{epoch}.pth')

[0/5][0/47] Loss_D: 1.5173 Loss_G: 3.3408 D(x): 0.5431 D(G(z)): 0.5782/0.0365
[1/5][0/47] Loss_D: 0.0239 Loss_G: 16.2628 D(x): 0.9774 D(G(z)): 0.0000/0.0000
[2/5][0/47] Loss_D: 0.1311 Loss_G: 10.4292 D(x): 0.8965 D(G(z)): 0.0016/0.0001
[3/5][0/47] Loss_D: 1.0643 Loss_G: 5.3339 D(x): 0.8112 D(G(z)): 0.3830/0.0092
[4/5][0/47] Loss_D: 0.5318 Loss_G: 2.4147 D(x): 0.6941 D(G(z)): 0.0711/0.1066


### Understanding the stats of our DCGAN model training 

1. **Loss_D (Discriminator Loss)**: This value indicates how well the discriminator is distinguishing between real and fake images. A lower Loss_D suggests better performance. However, if it's too low, it might mean the discriminator is overfitting or the generator is underperforming.

2. **Loss_G (Generator Loss)**: This represents how well the generator is at creating images that the discriminator thinks are real. A lower Loss_G is generally better, indicating that the generator is improving.

3. **D(x)**: This is the average output for the discriminator on real images. Closer to 1 is better, as it means the discriminator is correctly identifying real images as real.

4. **D(G(z))**: This represents the discriminator's output on fake images. The first number (before the slash) is the value at the start of the iteration, and the second number (after the slash) is the value at the end. Ideally, you want the first number to be high (indicating the discriminator initially thinks the fake images are real) and the second number to be low (indicating that after the generator's update, its new images are more easily detected as fake by the discriminator).

## Save the Trained Models

In [None]:
torch.save(G.state_dict(), "generator_final.pth")
torch.save(D.state_dict(), "discriminator_final.pth")