In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from model import Generator,Critic,initialize_weights 
from utils import gradient_penalty


### HYPERPARAMETERS

In [None]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
LEARNING_RATE = 1e-4
BATCH_SIZE = 64
IMAGE_SIZE = 64
IMG_CHANNELS = 1
Z_DIM = 100
NUM_EPOCHS = 5
FEATURES_DISC = 64
FEATURES_GENE = 64
CRITIC_ITERATION = 5
# WEIGTH_CLIP = 0.01 "Use it when implementing WGAN without Gradient Penalty"
LAMBDA_GP = 10

transforms = transforms.Compose(
    [
        transforms.Resize(IMAGE_SIZE),
        transforms.ToTensor(),
        transforms.Normalize(
            [0.5 for _ in range(IMG_CHANNELS)], [0.5 for _ in range(IMG_CHANNELS)]
        )
    ]
)




In [None]:
dataset = datasets.ImageFolder(root = "celeb_dataset",transforms=transforms)
loader = DataLoader(dataset, batch_size= BATCH_SIZE, shuffle=True)

gene = Generator(Z_DIM,IMG_CHANNELS,FEATURES_GENE)
critic = Critic(IMG_CHANNELS,FEATURES_DISC)
initialize_weights(gene)
initialize_weights(critic)

In [None]:
#========When Implementing WGAN without Gradient Penalty========
# opt_gene = optim.RMSprop(gene.parameters(), lr=LEARNING_RATE)
# opt_critic = optim.RMSprop(critic.parameters(), lr=LEARNING_RATE)

#========When Implementing WGAN with Gradient Penalty========
opt_gene = optim.Adam(gene.parmaeters(), lr=LEARNING_RATE, betas=(0.0, 0.9))
opt_critic = optim.Adam(critic.parmaeters(), lr=LEARNING_RATE, betas=(0.0, 0.9))


fixed_noise = torch.randn(32,Z_DIM,1,1).to(DEVICE)
writer_real = SummaryWriter(f"logs/real")
step = 0
writer_fake = SummaryWriter(f"logs/fake")

gene.train()
critic.train()

In [None]:
for epoch in range(NUM_EPOCHS):
    for batch_idx, (real , _) in enumerate(loader):
        real = real.to(DEVICE)
        
        #We want to train the Critic more
        for _ in range(CRITIC_ITERATION):
            noise = torch.randn(BATCH_SIZE,Z_DIM,1,1).to(DEVICE)
            fake = gene(noise)
            critic_real = critic(real).reshape(-1)
            critic_fake = critic(fake).reshape(-1)
            
            #=========When Implementing WGAN without Gradient Penalty========
            #loss_critic = -(torch.mean(critic_real) - torch.mean(critic_fake))
            
            #=========When Implementing WGAN with Gradient Penalty========
            gp = gradient_penalty(critic,real,fake,LAMBDA_GP, device=DEVICE)
            loss_critic = (
                -(torch.mean(critic_real) - torch.mean(critic_fake)) + LAMBDA_GP * gp
            )
            
            
            critic.zero_grad()
            loss_critic.backward(retain_graph=True)
            opt_critic.step()
            
            #=========When Implementing WGAN without Gradient Penalty=========
            #for p in critic.parameters():
            #     p.data.clamp_(-WEIGTH_CLIP, WEIGTH_CLIP) 
        
        #Train Generator: minimize -E[critic(gen_fake)]
        output = critic(fake).reshape(-1)
        loss_gene = -torch.mean(output)
        gene.zero_grad()
        loss_gene.backward()
        opt_gene.step()
        
        #Print the results at certain batch_idx for visualization purpose
        if batch_idx % 100 == 0:
            print(f"Epoch [{epoch}/{NUM_EPOCHS}] Batch {batch_idx}/{len(loader)} \
                    Loss Critic: {(loss_critic*100):.4f}% , Loss Gene: {(loss_gene*100):.4f}%")
            
            with torch.no_grad():
                fake = gene(fixed_noise)
                img_grid_real = torchvision.utils.make_grid(real[:32],normalize=True)
                img_grid_fake = torchvision.utils.make_grid(fake[:32], normalize=True)
                
                writer_real.add_image("Real",img_grid_real,global_step=step)
                writer_fake.add_image("Fake", img_grid_fake, global_step=step)
            step += 1   
        
        
            

SyntaxError: incomplete input (817474334.py, line 8)