# DCGAN Train

#### Libraries

In [16]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

from dcgan import Discriminator, Generator, initialize_weights

In [17]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

#### Hyperparameters

In [18]:
learning_rate = 3e-4
batch_size = 128
image_size = 64
img_channels = 1
z_dim = 100
epochs = 5
discriminator_features = 64
generator_features = 64

#### Transforms

In [25]:
transform = transforms.Compose(
    [
        transforms.Resize((image_size, image_size)),
        transforms.ToTensor(),
        transforms.Normalize([0.5 for _ in range(img_channels)], [0.5 for _ in range(img_channels)])
    ]
)

#### Dataset

In [26]:
dataset = datasets.FashionMNIST(root="dataset/", train=True, transform=transform, download=True)

In [27]:
loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

#### Initializing generator and discriminator

In [22]:
generator = Generator(z_dim, img_channels, generator_features).to(device)
discriminator = Discriminator(img_channels, discriminator_features).to(device)

initialize_weights(generator)
initialize_weights(discriminator)

#### Training loop

In [30]:
def train(generator: nn.Module, discriminator: nn.Module, loader: DataLoader, batch_size: int, learning_rate: float, epochs: int, device: torch.device):
    opt_generator = optim.Adam(generator.parameters(), lr=learning_rate, betas=(0.5, 0.999))
    opt_discriminator = optim.Adam(discriminator.parameters(), lr=learning_rate, betas=(0.5, 0.999))

    loss = nn.BCELoss()

    generator.train()
    discriminator.train()

    for epoch in range(epochs):
        for batch_idx, (real, _) in enumerate(loader):
            real = real.to(device)
            noise = torch.randn((batch_size, z_dim, 1, 1)).to(device)

            fake = generator(noise)

            # training discriminator
            discriminator_real = discriminator(real).reshape(-1)
            loss_discriminator_real = loss(discriminator_real, torch.ones_like(discriminator_real))

            discriminator_fake = discriminator(fake).reshape(-1)
            loss_discriminator_fake = loss(discriminator_fake, torch.zeros_like(discriminator_fake))

            total_loss_discriminator = (loss_discriminator_real + loss_discriminator_fake) / 2

            discriminator.zero_grad()
            total_loss_discriminator.backward(retain_graph=True)
            opt_discriminator.step()

            # training generator
            loss_generator = loss(discriminator_fake, torch.ones_like(discriminator_fake))
            generator.zero_grad()
            loss_generator.backward()
            opt_generator.step()

In [None]:
train(generator, discriminator, loader, batch_size, learning_rate, epochs, device)