In [1]:
import torch.nn as nn
import torch.nn.functional as F
import torch
import tqdm
from data import load_car_data
from data import load_mnist_data
from data import load_bedroom_data
import matplotlib.pyplot as plt
from model import Generator, Critic
import torchvision


  from .autonotebook import tqdm as notebook_tqdm


In [None]:
def evaluate(generator: Generator, critic: Critic, test_loader, num_data, batch_size, device, latent_dim):
    generator.eval()
    critic.eval()
    progress_bar = tqdm.tqdm(test_loader)
    generator_losses = []
    critic_losses = []
    for i, images in enumerate(progress_bar):
      with torch.no_grad():
            real_img_batch = images.to(device)
            latent_inp = torch.randn(size=(batch_size, latent_dim)).to(device) # (b, h)
            x = torch.unsqueeze(torch.unsqueeze(latent_inp, dim=-1), dim=-1) # (b, h, 1, 1)
            fake_img_batch = generator(x)
            
            # Descriminate
            D_real = critic(real_img_batch).squeeze() # (5, 1)
            D_fake = critic(fake_img_batch).squeeze() # Descrimating Generated image
            
            D_loss = - (D_real.mean() - D_fake.mean())
            G_loss = - (D_fake.mean())         
              
            generator_losses.append(G_loss.item())
            critic_losses.append(D_loss.item())
          
      if i == num_data:
          break
    
    eval_generator_mean_loss = torch.tensor(generator_losses).mean()
    eval_critic_mean_loss = torch.tensor(critic_losses).mean()
    
    # Show a grid of generated images
    latent_inp = torch.randn(size=(16, latent_dim)).to(device) # (b, h)
    x = torch.unsqueeze(torch.unsqueeze(latent_inp, dim=-1), dim=-1) # (b, h, 1, 1)
    fake_samples = generator(x).detach().cpu()
    grid = torchvision.utils.make_grid(fake_samples, nrow=4, normalize=True)
    plt.figure(figsize=(7, 7))
    plt.imshow(grid.permute(1, 2, 0))
    plt.axis("off")
    plt.show()
    
    generator.train()
    critic.train()
    
    return eval_critic_mean_loss, eval_generator_mean_loss

def gradient_penalty(critic, real_images, fake_images, device):
    batch_size = real_images.size(0)

    
    
    # Random weight for interpolation between real & fake
    epsilon = torch.rand(batch_size, 1, 1, 1, device=device)
    
    interpolated = epsilon * real_images + (1 - epsilon) * fake_images
    interpolated.requires_grad_(True)

    # Critic score for interpolated images
    interpolated_score = critic(interpolated)

    # Compute gradients of scores w.r.t. interpolated images
    gradients = torch.autograd.grad(
        outputs=interpolated_score,
        inputs=interpolated,
        grad_outputs=torch.ones_like(interpolated_score),
        create_graph=True,
        retain_graph=True,
        only_inputs=True
    )[0]

    # Flatten gradients: (batch_size, -1)
    gradients = gradients.view(batch_size, -1)

    # Compute penalty: (||grad||2 - 1)^2
    penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean()
    return penalty
 
def train(generator, critic, num_epochs, eval_epoch, device, batch_size, sample_size, n_critic, latent_dim, lambda_g):
    generator.to(device)
    critic.to(device)
    D_optimizer = torch.optim.AdamW(critic.parameters(), lr=1e-4, betas=(0.0, 0.9))
    G_optimizer = torch.optim.AdamW(generator.parameters(), lr=1e-4,betas=(0.0, 0.9))
    train_loader, test_loader = load_bedroom_data(train_batch_size=batch_size, test_batch_size=2, sample_size=sample_size)
    
    
    for epoch in range(num_epochs):
        progress = tqdm.tqdm(train_loader, dynamic_ncols=True)
        generator.train()
        critic.train()
        progress.set_description(f'Epoch: {epoch}')
        generator_losses = []
        critic_losses = []
        
        # print(next(iter(progress)))
        
        for step, images in enumerate(progress):
            real_images = images.to(device)
            
            for _ in range(n_critic):
                
                latent_inp = torch.randn(size=(batch_size, latent_dim)).to(device) # (b, h)
                x = torch.unsqueeze(torch.unsqueeze(latent_inp, dim=-1), dim=-1) # (b, h, 1, 1)
                fake_images = generator(x).detach()
            
                # Descriminate
                D_real = critic(real_images).squeeze() # (5, 1)
                D_fake = critic(fake_images).squeeze() # Descrimating Generated image
                
                gp = gradient_penalty(critic, real_images, fake_images, device)
                
                D_loss = - (D_real.mean() - D_fake.mean()) + lambda_g * gp

                D_optimizer.zero_grad()
                D_loss.backward()
                D_optimizer.step()
            
            G_optimizer.zero_grad()
            latent_inp = torch.randn(size=(batch_size, latent_dim)).to(device) # (b, h)
            x = torch.unsqueeze(torch.unsqueeze(latent_inp, dim=-1), dim=-1) # (b, h, 1, 1)
            fake_images = generator(x)
            D_fake = critic(fake_images).squeeze() # Descrimating Generated image
            G_loss = - torch.mean(D_fake)
            G_loss.backward()
            G_optimizer.step()
            
            generator_losses.append(G_loss.item())
            critic_losses.append(D_loss.item())
            
            progress.set_postfix({'generator_loss': f"{G_loss.item():.4f}", 'critic_loss': f"{D_loss.item():.4f}"})
        
            if (step + 1) % 1000 == 0:
                eval_critic_loss, eval_generator_loss = evaluate(generator, critic, test_loader, num_data=20, batch_size=2, device=device, latent_dim=latent_dim)

        
        torch.save(generator, f"generator_{epoch}.pt")
        torch.save(critic, f"critic_{epoch}.pt")
            
        
       
      

In [10]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
latent_dim = 100

generator= torch.load("generator.pt", weights_only=False)
critic = torch.load("critic.pt", weights_only=False)

# generator = Generator(latent_dim=latent_dim, device=device)
# critic = Critic()

In [8]:
# torch.save(generator, "generator.pt")
# torch.save(critic, "critic.pt")

In [None]:
train(generator=generator, critic=critic, num_epochs=100, eval_epoch=20, device=device, batch_size=32, sample_size=250000, n_critic=5, latent_dim=latent_dim, lambda_g=10)

Epoch: 0:  10%|▉         | 779/7812 [06:20<57:58,  2.02it/s, generator_loss=-50.7978, critic_loss=-4.7890]    

In [27]:
# Show a grid of generated images
latent_inp = torch.randn(size=(1, latent_dim)).to(device) # (b, h)
x = torch.unsqueeze(torch.unsqueeze(latent_inp, dim=-1), dim=-1)
generator.eval()
with torch.no_grad():
    fake_samples = generator(x).detach().cpu()
grid = torchvision.utils.make_grid(fake_samples, nrow=4, normalize=True)
plt.figure(figsize=(5, 5))
plt.imshow(grid.permute(1, 2, 0))
plt.axis("off")
plt.show()
    

AcceleratorError: CUDA error: an illegal instruction was encountered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.
