In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from models import Generator, Discriminator, intialize_weights


In [14]:
# HYPERPARAMETERS
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# DETERMINES WHETHER TO USE GPU (CUDA) OR CPU FOR COMPUTATIONS
# ADVANTAGES: AUTOMATICALLY USES GPU IF AVAILABLE FOR FASTER PROCESSING

LEARNING_RATE = 2e-4
# SETS THE LEARNING RATE FOR THE OPTIMIZER
# ADVANTAGES: SMALL LEARNING RATE HELPS IN STABLE TRAINING OF GANs

BATCH_SIZE = 128
# NUMBER OF IMAGES PROCESSED IN EACH TRAINING ITERATION
# ADVANTAGES: LARGER BATCH SIZE CAN LEAD TO MORE STABLE GRADIENTS

IMAGE_SIZE = 64
# DEFINES THE DIMENSIONS OF THE IMAGES (64x64 PIXELS)
# ADVANTAGES: STANDARD SIZE FOR DCGAN, BALANCES DETAIL AND COMPUTATION

CHANNELS_IMG = 1
# NUMBER OF COLOR CHANNELS (1 FOR GRAYSCALE, 3 FOR RGB)
# ADVANTAGES: ALLOWS FLEXIBILITY IN INPUT IMAGE TYPE

Z_DIM = 100
# DIMENSION OF THE LATENT SPACE VECTOR
# ADVANTAGES: PROVIDES ENOUGH VARIABILITY FOR GENERATOR INPUT

NUM_EPOCHS = 5
# NUMBER OF TIMES TO ITERATE THROUGH THE ENTIRE DATASET
# ADVANTAGES: ALLOWS FOR MULTIPLE PASSES OVER DATA FOR BETTER LEARNING

FEATURES_DISC = 64
FEATURES_GEN = 64
# STARTING NUMBER OF FEATURES FOR DISCRIMINATOR AND GENERATOR
# ADVANTAGES: ALLOWS FOR EASY SCALING OF NETWORK COMPLEXITY

NOISE_DIM = 100
# DIMENSION OF NOISE VECTOR (SAME AS Z_DIM)
# ADVANTAGES: CONSISTENCY IN GENERATOR INPUT

# TRANSFORMS
from torchvision import transforms
# IMPORT THE TRANSFORMS MODULE FROM TORCHVISION

transform = transforms.Compose(
    [
        transforms.Resize(IMAGE_SIZE),
        # RESIZES INPUT IMAGES TO THE SPECIFIED IMAGE_SIZE
        transforms.ToTensor(),
        # CONVERTS IMAGES TO PYTORCH TENSORS
        transforms.Normalize(
            [0.5 for _ in range(CHANNELS_IMG)], [0.5 for _ in range(CHANNELS_IMG)]
        )
        # NORMALIZES PIXEL VALUES TO RANGE [-1, 1]
        # ADVANTAGES: HELPS IN FASTER CONVERGENCE DURING TRAINING
    ]
)
# COMPOSE COMBINES MULTIPLE TRANSFORMS INTO A SINGLE TRANSFORM
# ADVANTAGES: ALLOWS FOR EASY APPLICATION OF MULTIPLE PREPROCESSING STEPS

In [15]:
# DATASET
# transform = transforms.Compose([
#     transforms.Resize(image_dim),
#     transforms.ToTensor(),
#     transforms.Normalize((0.5,), (0.5,))
# ])

dataset = datasets.MNIST(root="dataset/", train=True, transform=transforms, download=False)

# dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
gen = Generator(Z_DIM, CHANNELS_IMG, FEATURES_GEN).to(device)
disc = Discriminator(CHANNELS_IMG, FEATURES_DISC).to(device)

intialize_weights(gen)
intialize_weights(disc)



In [16]:
opt_gen = optim.Adam(gen.parameters(), lr=LEARNING_RATE, betas=(0.5, 0.999))
opt_disc = optim.Adam(disc.parameters(), lr=LEARNING_RATE, betas=(0.5, 0.999))

criterion = nn.BCELoss()

fixed_noise = torch.randn(32, Z_DIM, 1, 1).to(device)
writer_real = SummaryWriter(f"logs/real")
writer_fake = SummaryWriter(f"logs/fake")
step = 0

In [17]:
gen.train()
disc.train()

for epoch in range(NUM_EPOCHS):
    for batch_idx, (real, _) in enumerate(loader):
        real = real.to(device)
        noise = torch.randn(BATCH_SIZE, Z_DIM, 1, 1).to(device)
        fake = gen(noise)
        
        # TRAIN DISCRIMINATOR: MAX log(D(x)) + log(1 - D(G(z)))

        disc_real = disc(real).reshape(-1)
        loss_disc_real = criterion(disc_real, torch.ones_like(disc_real))
        disc_fake = disc(fake.detach()).reshape(-1)
        loss_disc_fake = criterion(disc_fake, torch.zeros_like(disc_fake))
        loss_disc = (loss_disc_real + loss_disc_fake) / 2
        disc.zero_grad()

        loss_disc.backward()
        opt_disc.step()
        output = disc(fake).reshape(-1)
        loss_gen = criterion(output, torch.ones_like(output))
        gen.zero_grad()
        loss_gen.backward()

        opt_gen.step()  
        

TypeError: 'module' object is not callable