In [32]:
import torch.nn as nn
import torch.nn.functional as F
import torch
import tqdm
from data import load_data
import matplotlib.pyplot as plt
from model import Generator, Critic
import torchvision


In [33]:
def evaluate(generator: Generator, critic: Critic, test_loader, num_data, batch_size, device):
    generator.eval()
    critic.eval()
    progress_bar = tqdm.tqdm(test_loader)
    generator_losses = []
    critic_losses = []
    for i, (images, _) in enumerate(progress_bar):
      with torch.no_grad():
            real_img_batch = images.to(device)
            fake_img_batch = generator(batch_size)
            
            # Descriminate
            real_critic_score = critic(real_img_batch).squeeze() # (5, 1)
            fake_critic_score = critic(fake_img_batch).squeeze() # Descrimating Generated image
            
            real_critic_mean = torch.mean(real_critic_score)
            fake_critic_mean = torch.mean(fake_critic_score)
            
            critic_loss =  real_critic_mean - fake_critic_mean
            
             # Descriminate
            fake_img_batch = generator(batch_size)
            real_critic_score = critic(real_img_batch).squeeze() # (5, 1)
            fake_critic_score = critic(fake_img_batch).squeeze() # Descrimating Generated image
            
            real_critic_mean = torch.mean(real_critic_score)
            fake_critic_mean = torch.mean(fake_critic_score)
            
            generator_loss = fake_critic_mean - real_critic_mean            
            generator_losses.append(generator_loss.item())
            critic_losses.append(critic_loss.item())
          
      if i == num_data:
          break
    
    eval_generator_mean_loss = torch.tensor(generator_losses).mean()
    eval_critic_mean_loss = torch.tensor(critic_losses).mean()
    
    # Show a grid of generated images
    fake_samples = generator(16).detach().cpu()
    grid = torchvision.utils.make_grid(fake_samples, nrow=4, normalize=True)
    plt.imshow(grid.permute(1, 2, 0))
    plt.axis("off")
    plt.show()
    
    generator.train()
    critic.train()
    
    return eval_critic_mean_loss, eval_generator_mean_loss
 
def train(generator, critic, num_epochs, eval_epoch, device, batch_size, sample_size, n_critic):
   
    
    generator.to(device)
    critic.to(device)
    D_optimizer = torch.optim.AdamW(critic.parameters(), lr=1e-5, weight_decay=0.001, betas=(0.5, 0.999))
    G_optimizer = torch.optim.AdamW(generator.parameters(), lr=5e-4, betas=(0.5, 0.999))
    train_loader, test_loader = load_data(train_batch_size=batch_size, test_batch_size=2, sample_size=sample_size)
    
    
    for epoch in range(num_epochs):
        progress = tqdm.tqdm(train_loader, dynamic_ncols=True)
        generator.train()
        critic.train()
        progress.set_description(f'Epoch: {epoch}')
        generator_losses = []
        critic_losses = []
        
        for images, _ in progress:
            real_images = images.to(device)
            
            for _ in range(n_critic):
                
                fake_images = generator(batch_size)
            
                # Descriminate
                D_real = critic(real_images).squeeze() # (5, 1)
                D_fake = critic(fake_images).squeeze() # Descrimating Generated image
                
                D_loss = - (D_real.mean() - D_fake.mean())

                D_optimizer.zero_grad()
                D_loss.backward()
                D_optimizer.step()
                
                for p in critic.parameters():
                    p.data.clamp_(-0.01, 0.01)

            
            
            G_optimizer.zero_grad()
            
             # Descriminate
            fake_images = generator(batch_size)
            
            D_fake = critic(fake_images).squeeze() # Descrimating Generated image
            G_loss = - torch.mean(D_fake)
            G_loss.backward()
            G_optimizer.step()
            
            
            generator_losses.append(G_loss.item())
            critic_losses.append(G_loss.item())
            
            progress.set_postfix({'generator_loss': f"{G_loss.item():.4f}", 'critic_loss': f"{D_loss.item():.4f}"})
            
      
        generator_mean_loss = torch.tensor(generator_losses).mean()
        critic_mean_loss = torch.tensor(critic_losses).mean()
       
        eval_critic_loss, eval_generator_loss = evaluate(generator, critic,test_loader, num_data=20, batch_size=2, device=device)
        progress.set_postfix({'generator_loss': f"{generator_mean_loss:.5f}", 'critic_loss': f"{critic_mean_loss:.5f}", 'eval_gen_loss': f"{eval_generator_loss:.5f}", 'eval_des_loss':f"{eval_critic_loss:.4f}"})
       
      
          
    

In [34]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

generator = Generator(latent_dim=512, device=device)
critic = Critic()


train(generator=generator, critic=critic, num_epochs=100, eval_epoch=20, device=device, batch_size=32, sample_size=1000, n_critic=5)

Epoch: 0:  22%|██▏       | 7/32 [00:07<00:28,  1.13s/it, generator_loss=0.0103, critic_loss=0.0000] 


KeyboardInterrupt: 