In [11]:
import torch
from torch import nn
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from torch import optim
import matplotlib.pyplot as plt
from tqdm import tqdm
from torch.cuda.amp import GradScaler,autocast
from generator import Generator
from Discriminator import Discriminator
from Dataset import CustomDataset
from PIL import Image

In [33]:
device=torch.device("cuda")

epochs=100
batch_size=64
lr=2e-4
lambda_pixel=500

In [3]:
generator=Generator().to(device)
discriminator=Discriminator().to(device)

In [13]:

criterion=nn.MSELoss()
criterion_pixelwise=nn.L1Loss()
optim_G=optim.Adam(generator.parameters(),lr=lr)
optim_D=optim.Adam(discriminator.parameters(),lr=lr)

In [14]:
transform=transforms.Compose([
  transforms.Resize((256,256),Image.BICUBIC),
  transforms.ToTensor(),
  transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))
])
dataset=CustomDataset(root="maps/train",transform=transform)
loader=DataLoader(dataset,batch_size=batch_size,shuffle=False)

In [None]:
d_scaler=GradScaler()
g_scaler=GradScaler()
for epoch in range(epochs):
  for i,(x,y) in enumerate(tqdm(loader,desc=f"Epoch={epoch+1}/{epochs}")):
    x=x.to(device)
    y=y.to(device)

    with autocast():
      y_fake=generator(x)
      disc_fake=discriminator(x,y_fake.detach())
      d_loss_fake=criterion(disc_fake,torch.zeros_like(disc_fake))
      disc_real=discriminator(x,y)
      d_loss_real=criterion(disc_real,torch.ones_like(disc_real))
      d_loss=(d_loss_fake+d_loss_real)/2
    optim_D.zero_grad()
    d_scaler.scale(d_loss).backward()
    d_scaler.step(optim_D)
    d_scaler.update()
    with autocast():
      d_fake=discriminator(x,y_fake)
      g_fake_loss=criterion(d_fake,torch.ones_like(d_fake))
      l1=criterion_pixelwise(y_fake,y)*lambda_pixel
      g_loss=g_fake_loss+l1
    optim_G.zero_grad()
    g_scaler.scale(g_loss).backward()
    g_scaler.step(optim_G)
    g_scaler.update()
  print(f"Epoch [{epoch+1}/{epochs}] Loss D:{d_loss.item():.4f},Loss G:{g_loss.item():.4f}")
  if epoch%10==0:
    torch.save(generator.state_dict(),f"generator.pth")
    torch.save(discriminator.state_dict(),f"discriminator.pth")
  with torch.no_grad():
    y_fake=generator(x)
    x=transforms.ToPILImage()(x[2].cpu())
    y=transforms.ToPILImage()(y[2].cpu())
    y_fake=transforms.ToPILImage()(y_fake[2].cpu())
    plt.figure(figsize=(10,5))
    plt.subplot(1,3,1)
    plt.title("Input Image")
    plt.imshow(x)
    plt.subplot(1,3,2)
    plt.title("Ground Truth")
    plt.imshow(y)
    plt.subplot(1,3,3)
    plt.title("Generated Image")
    plt.imshow(y_fake)
    plt.show()


In [37]:
eval=CustomDataset(root="maps/val",transform=transform)

In [38]:
def display_images(input_image, generated_image, target_image=None):
    input_image = input_image.squeeze().permute(1, 2, 0).cpu().numpy()
    generated_image = generated_image.squeeze().permute(1, 2, 0).cpu().numpy()
    
    plt.figure(figsize=(15, 5))
    
 
    plt.subplot(1, 3, 1)
    plt.imshow(input_image)
    plt.title('Input Image')
    plt.axis('off')
    plt.subplot(1, 3, 2)
    plt.imshow(generated_image)
    plt.title('Generated Image')
    plt.axis('off')
    
    if target_image is not None:
        target_image = target_image.squeeze().permute(1, 2, 0).cpu().numpy()
        plt.subplot(1, 3, 3)
        plt.imshow(target_image)
        plt.title('Target Image')
        plt.axis('off')
    
    plt.show()

In [None]:
generator.load_state_dict(torch.load('generator.pth'))
discriminator.load_state_dict(torch.load('discriminator.pth'))

In [None]:
generator.eval() 

In [None]:
num_images_to_generate=5
count=0
for i,(x,y) in enumerate(eval):
    if count >= num_images_to_generate:
        break
    x=x.to(device)
    y=y.to(device)
    if x.dim() == 3:
        x = x.unsqueeze(0)
    if y.dim() == 3:  
        y = y.unsqueeze(0)
    with torch.no_grad():
        generated_image = generator(x)
    display_images(x, generated_image,y)
    count+=1