In [2]:
import torch
from torch import nn,optim
from torch.utils.data import DataLoader
from discriminator_model import Discriminator
from generator_model import Generator
from torchvision import transforms
import torchvision
from tqdm import tqdm
from Dataset import CustomDataset
import itertools
import matplotlib.pyplot as plt
import numpy as np
import os


In [3]:
device="cuda"
lr=2e-4
batch_size=8
epochs=100
lambda_identity=0.5
lambda_cycle=10
img_size=256
img_channels=3


In [4]:
gen_G=Generator(img_channels,32,num_residuals=9).to(device)
gen_f=Generator(img_channels,32,num_residuals=9).to(device)
disc_X=Discriminator(in_channels=img_channels,features=[32,64,128,256]).to(device)
disc_Y=Discriminator(in_channels=img_channels,features=[32,64,128,256]).to(device)


In [5]:
optim_G=optim.Adam(
  itertools.chain(gen_G.parameters(),gen_f.parameters()),
  lr=lr,
  betas=(0.5,0.999)
)
optim_D=optim.Adam(
  itertools.chain(disc_X.parameters(),disc_Y.parameters()),
  lr=lr,
  betas=(0.5,0.999)

)


In [6]:
L1=nn.L1Loss()
MSE=nn.MSELoss()


In [7]:
transform=transforms.Compose([
  transforms.Resize((img_size,img_size)),
  transforms.ToTensor(),
  transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))
])


In [22]:
dataset=CustomDataset(root_A="day_to_night/day",root_B="day_to_night/night",transform=transform)
loader=DataLoader(dataset,batch_size=batch_size,shuffle=True,num_workers=6,pin_memory=True)


In [8]:
import matplotlib.pyplot as plt
import numpy as np

def display_images(epoch, original_X, original_Y, generated_Y, generated_X):
    fig, ax = plt.subplots(2, 2, figsize=(10, 10))


    original_X = (original_X * 0.5) + 0.5
    original_Y = (original_Y * 0.5) + 0.5
    generated_Y = (generated_Y * 0.5) + 0.5
    generated_X = (generated_X * 0.5) + 0.5

   
    original_X = original_X.cpu().detach().numpy()[0].astype(np.float32)
    generated_Y = generated_Y.cpu().detach().numpy()[0].astype(np.float32)
    original_Y = original_Y.cpu().detach().numpy()[0].astype(np.float32)
    generated_X = generated_X.cpu().detach().numpy()[0].astype(np.float32)

  
    if original_X.shape[0] == 1: 
        original_X = np.squeeze(original_X, axis=0)
        generated_Y = np.squeeze(generated_Y, axis=0)
        original_Y = np.squeeze(original_Y, axis=0)
        generated_X = np.squeeze(generated_X, axis=0)
    else:  
        original_X = np.transpose(original_X, (1, 2, 0))
        generated_Y = np.transpose(generated_Y, (1, 2, 0))
        original_Y = np.transpose(original_Y, (1, 2, 0))
        generated_X = np.transpose(generated_X, (1, 2, 0))

 
    ax[0, 0].imshow(original_X, cmap='gray' if len(original_X.shape) == 2 else None)
    ax[0, 0].set_title("Original X")
    ax[0, 0].axis("off")

    ax[0, 1].imshow(generated_Y, cmap='gray' if len(generated_Y.shape) == 2 else None)
    ax[0, 1].set_title("Generated Y (G(X))")
    ax[0, 1].axis("off")

    ax[1, 0].imshow(original_Y, cmap='gray' if len(original_Y.shape) == 2 else None)
    ax[1, 0].set_title("Original Y")
    ax[1, 0].axis("off")

    ax[1, 1].imshow(generated_X, cmap='gray' if len(generated_X.shape) == 2 else None)
    ax[1, 1].set_title("Generated X (F(Y))")
    ax[1, 1].axis("off")

    plt.suptitle(f"Epoch {epoch}")
    plt.show()


In [9]:
d_scaler=torch.cuda.amp.GradScaler()
g_scaler=torch.cuda.amp.GradScaler()

In [None]:
for epoch in range(epochs):
    
    loop=tqdm(loader,desc=f"Epoch {epoch+1}/{epochs}",leave=True)
    for idx,(x,y) in enumerate(loop):
      
      x=x.to(device)
      y=y.to(device)
     
      with torch.cuda.amp.autocast():
        # discriminator x to y
        fake_y=gen_G(x)
        D_Y_real=disc_Y(y)
        D_Y_fake=disc_Y(fake_y.detach())
        D_Y_real_loss=MSE(D_Y_real,torch.ones_like(D_Y_real))
        D_Y_fake_loss=MSE(D_Y_fake,torch.zeros_like(D_Y_fake))
        D_Y_loss=(D_Y_real_loss+D_Y_fake_loss)/2
        # discrminator y to x
        fake_x=gen_f(y)
        D_X_real=disc_X(x)
        D_X_fake=disc_X(fake_x.detach())
        D_X_real_loss=MSE(D_X_real,torch.ones_like(D_X_real))
        D_X_fake_loss=MSE(D_X_fake,torch.zeros_like(D_X_fake))
        D_X_loss=(D_X_real_loss + D_X_fake_loss)/2

        D_loss=D_Y_loss+D_X_loss
      optim_D.zero_grad()
      d_scaler.scale(D_loss).backward()
      d_scaler.step(optim_D)
      d_scaler.update()
      #Train Generator G and F
      with torch.cuda.amp.autocast():
        #Train Generator G and F
        D_Y_fake=disc_Y(fake_y)
        D_X_fake=disc_X(fake_x)
        G_X_loss=MSE(D_X_fake,torch.ones_like(D_X_fake))
        G_Y_loss=MSE(D_Y_fake,torch.ones_like(D_Y_fake))
        
        # Cycle Consistency loss 

        cycle_X=gen_f(fake_y)
        cycle_Y=gen_G(fake_x)
        cycle_loss_X=L1(x,cycle_X)
        cycle_loss_Y=L1(y,cycle_Y)
        cycle_loss=lambda_cycle*(cycle_loss_X+cycle_loss_Y)
      
        #identity loss
        identity_x=gen_f(x)
        identity_y=gen_G(y)
        identity_loss_X=L1(x,identity_x)
        identity_loss_Y=L1(y,identity_y)

        G_loss=(
        G_X_loss + G_Y_loss + cycle_loss + identity_loss_X * lambda_identity +identity_loss_Y * lambda_identity
         )
      optim_G.zero_grad()
      g_scaler.scale(G_loss).backward()
      g_scaler.step(optim_G)
      g_scaler.update()
      if idx % 10==0:
        loop.set_postfix(
        D_loss=D_loss.item(),
        G_loss=G_loss.item()
        )
    with torch.no_grad():
      fake_y=gen_G(x)
      feak_x=gen_f(y)
      display_images(epoch,x,y,fake_y,fake_x)
    if (epoch + 1) % 2 == 0:
            torch.save(gen_G.state_dict(),f"generator_G(9).pth")
            torch.save(gen_f.state_dict(),f"generator_F(9).pth")
            torch.save(disc_X.state_dict(),f"discriminator_X(9).pth")
            torch.save(disc_Y.state_dict(),f"discriminator_Y(9).pth")
            print(f"Models saved at epoch {epoch + 1}.")
       



In [None]:
gen_G.load_state_dict(torch.load("generator_G(9).pth"))
gen_f.load_state_dict(torch.load("generator_F(9).pth"))
disc_X.load_state_dict(torch.load("discriminator_X(9).pth"))
disc_Y.load_state_dict(torch.load("discriminator_Y(9).pth"))

In [13]:
torch.save(gen_G.state_dict(),f"generator_G(9).pth")
torch.save(gen_f.state_dict(),f"generator_F(9).pth")
torch.save(disc_X.state_dict(),f"discriminator_X(9).pth")
torch.save(disc_Y.state_dict(),f"discriminator_Y(9).pth")

In [10]:
val_dataset=CustomDataset(root_A="night_to_day/valA",root_B="night_to_day/valB",transform=transform)
val_loader=DataLoader(val_dataset,batch_size=batch_size,shuffle=True,num_workers=6,pin_memory=True)


In [11]:

def evaluate_model(gen_G, gen_F, loader, device, num_images=10):
    gen_G.eval()
    gen_F.eval()

    with torch.no_grad():
        for idx, (x, y) in enumerate(loader):
            if idx >= 1:
                break

            x = x.to(device)
            y = y.to(device)

            fake_y = gen_G(x)
            fake_x = gen_F(y)

            fig, axes = plt.subplots(2, num_images, figsize=(15, 6))
            for i in range(num_images):
                axes[0, i].imshow(torchvision.utils.make_grid(x[i].cpu(), normalize=True).permute(1, 2, 0))
                axes[0, i].axis('off')
                axes[1, i].imshow(torchvision.utils.make_grid(fake_y[i].cpu(), normalize=True).permute(1, 2, 0))
                axes[1, i].axis('off')

            plt.suptitle('Real Images (Top Row) (Day) vs. Generated Images (Bottom Row) (Night)')
            plt.show()

            fig, axes = plt.subplots(2, num_images, figsize=(15, 6))
            for i in range(num_images):
                axes[0, i].imshow(torchvision.utils.make_grid(y[i].cpu(), normalize=True).permute(1, 2, 0))
                axes[0, i].axis('off')
                axes[1, i].imshow(torchvision.utils.make_grid(fake_x[i].cpu(), normalize=True).permute(1, 2, 0))
                axes[1, i].axis('off')

            plt.suptitle('Real Images (Top Row) (Night) vs. Generated Images (Bottom Row) (Day)')
            plt.show()

            break

In [None]:
evaluate_model(gen_G, gen_f, val_loader, device, num_images=5)
