In [31]:
import torch 
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms

In [2]:
target_path = '/kaggle/input/dataset/ascii/ascii'
input_path = '/kaggle/input/dataset/images/images'

In [3]:
from PIL import Image

In [4]:
#img = Image.open(target_path+'/0.jpg')

In [5]:
from torch.utils.data import Dataset, DataLoader

In [6]:
import os

In [7]:
train_transform = transforms.Compose([
    transforms.Resize((256,256)),
    transforms.ToTensor(),
    transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5)),
    transforms.RandomRotation(10),
    transforms.CenterCrop(256)
])

In [8]:
class Data(Dataset):
    def __init__(self, img_path, target_path, transform):
        self.img_path = img_path
        self.transform = transform
        self.target_path = target_path
        self.data = []
        self.create_dataset()
        
    def create_dataset(self):
        for file in os.listdir(self.img_path):
            file_path = os.path.join(self.img_path, file)
            target_file = os.path.join(self.target_path, file)
            
            tar_img = Image.open(target_file).convert('RGB')
            tar_img = self.transform(tar_img)
            
            inp_img = Image.open(file_path).convert("RGB")
            inp_img = self.transform(inp_img)
            self.data.append((inp_img, tar_img))
            
    def __getitem__(self, idx):
        return self.data[idx]
    
    def __len__(self):
        return len(self.data)

In [9]:
train_dataset = Data(input_path, target_path, train_transform)

In [10]:
train_dataset[0][0].shape

torch.Size([3, 256, 256])

In [11]:
#torch.save(train_dataset, '/kaggle/working/dataset.pth')

In [12]:
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)

In [13]:
class Generator(nn.Module):
    def __init__(self, in_channels):
        super(Generator, self).__init__()
        
        ## down
        self.encoder1 = nn.Sequential(
            nn.Conv2d(in_channels, 64, kernel_size=4, stride=2, padding=1, bias=False),
            nn.LeakyReLU(0.2, inplace=True)
        )    
        self.encoder2 = nn.Sequential(
            nn.Conv2d(64, 128, 4,2,1,bias=False),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True)
        )
        self.encoder3 = nn.Sequential(
            nn.Conv2d(128, 256, 4,2,1,bias=False),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True)
        )
        self.encoder4 = nn.Sequential(
            nn.Conv2d(256, 512, 4,2,1,bias=False),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True)
        )
        self.encoder5 = nn.Sequential(
            nn.Conv2d(512, 512, 4,2,1,bias=False),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True)
        )
        self.encoder6 = nn.Sequential(
            nn.Conv2d(512, 512, 4,2,1,bias=False),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True)
        )
        self.encoder7 = nn.Sequential(
            nn.Conv2d(512, 512, 4,2,1,bias=False),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True)
        )
        self.encoder8 = nn.Sequential(
            nn.Conv2d(512, 512, 4,2,1,bias=False),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True)
        )

        ## up
        self.decoder1 = nn.Sequential(
            nn.ConvTranspose2d(512, 512, 4, 2, 1, bias=False),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),
            nn.Dropout(0.5)
        )
        self.decoder2 = nn.Sequential(
            nn.ConvTranspose2d(1024, 512, 4, 2, 1, bias=False),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),
            nn.Dropout(0.5)
        )
        self.decoder3 = nn.Sequential(
            nn.ConvTranspose2d(1024, 512, 4, 2, 1, bias=False),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),
            nn.Dropout(0.5)
        )
        self.decoder4 = nn.Sequential(
            nn.ConvTranspose2d(1024, 512, 4, 2, 1, bias=False),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),
        )
        self.decoder5= nn.Sequential(
            nn.ConvTranspose2d(1024, 256, 4, 2, 1, bias=False),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
        )
        self.decoder6 = nn.Sequential(
            nn.ConvTranspose2d(512, 128, 4, 2, 1, bias=False),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
        )
        self.decoder7 = nn.Sequential(
            nn.ConvTranspose2d(256, 64, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
        )

        self.final = nn.Sequential(
            nn.ConvTranspose2d(128, 3, 4, 2, 1),
            nn.Tanh()
        )
    def forward(self,x):
        #encoder pass
        e1 = self.encoder1(x)
        e2 = self.encoder2(e1)
        e3 = self.encoder3(e2)
        e4 = self.encoder4(e3)
        e5 = self.encoder5(e4)
        e6 = self.encoder6(e5)
        e7 = self.encoder7(e6)
        e8 = self.encoder8(e7)

        #decoder pass
        d1 = self.decoder1(e8)
        d2 = self.decoder2(torch.cat([d1, e7], dim=1))
        d3 = self.decoder3(torch.cat([d2, e6], dim=1))
        d4 = self.decoder4(torch.cat([d3, e5], dim=1))
        d5 = self.decoder5(torch.cat([d4, e4], dim=1))
        d6 = self.decoder6(torch.cat([d5, e3], dim=1))
        d7 = self.decoder7(torch.cat([d6, e2], dim=1))
        out = self.final(torch.cat([d7, e1], dim=1))
        return out

In [14]:
#PatchGAN
class Discriminator(nn.Module):
    def __init__(self, in_channels=3):
        super(Discriminator, self).__init__()
        self.conv1 = nn.Sequential(
            nn.Conv2d(in_channels*2, 64, 4, 2, 1),
            nn.LeakyReLU(0.2, inplace=True)
        )
        self.conv2 = nn.Sequential(
            nn.Conv2d(64, 128, 4, 2, 1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True)
        )
        self.conv3 = nn.Sequential(
            nn.Conv2d(128, 256, 4, 2, 1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True)
        )
        self.conv4 = nn.Sequential(
            nn.Conv2d(256, 512, 4, 2, 1),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True)
        )
        self.final = nn.Conv2d(512, 1, 4, 1, 1)
    def forward(self, x):
        c1 = self.conv1(x)
        c2 = self.conv2(c1)
        c3 = self.conv3(c2)
        c4 = self.conv4(c3)
        return self.final(c4)

In [15]:
# def loss(x, y, generator, discriminator, lam=100):
#     #generator out
#     gen_out= generator(x)
#     #discriminator out
#     dis_real = discriminator(x, y)
#     dis_fake = discriminator(x, gen_out.detach())

#     #adverarial loss
#     valid = torch.ones_like(dis_real)
#     fake = torch.zeros_like(dis_fake)
    
#     loss_real = bce(dis_real, valid)
#     loss_fake = bce(dis_fake, fake)
#     adversarial_loss = (loss_real + loss_fake) * 0.5

#     gen_dis_out = discriminator(x, gen_out)
#     gen_adv_loss = bce(gen_dis_out, valid)

#     #reconstruction loss
#     recon_loss = lse(gen_out, y)
    
#     total_loss = adversarial_loss + lam*recon_loss
#     return total_loss, adversarial_loss, recon_loss

In [16]:
adversarial_loss = nn.MSELoss()
reconstruction_loss = nn.L1Loss()

In [17]:
def discriminator_loss(discriminator, x, y, gen_out):
    real = discriminator(torch.cat([x,y], dim=1))
    fake = discriminator(torch.cat([x, gen_out.detach()], dim=1))

    real_loss = adversarial_loss(real, torch.ones_like(real))
    fake_loss = adversarial_loss(fake, torch.zeros_like(fake))

    loss = 0.5 * (real_loss + fake_loss)
    return loss

In [18]:
def generator_loss(discriminator, x, y, gen_out, lam=100):
    fake = discriminator(torch.cat([x, gen_out], dim=1))
    adv_loss = adversarial_loss(fake, torch.ones_like(fake))
    L1 = reconstruction_loss(gen_out, y) * 100

    loss = adv_loss + L1
    return loss 

In [19]:
generator = Generator(in_channels=3).to('cuda')
discriminator = Discriminator().to('cuda')

In [32]:
gen_optimizer = optim.AdamW(generator.parameters(), lr=2e-4, betas=(0.5, 0.999))
dis_optimizer = optim.AdamW(discriminator.parameters(), lr=2e-4, betas=(0.5, 0.999))

In [33]:
from tqdm import tqdm

In [34]:
import torch
import math

def compute_psnr(fake, real):
    mse = torch.mean((fake - real) ** 2)
    if mse == 0:
        return 100  # perfect reconstruction
    psnr = 10 * math.log10(1 / mse.item())
    return psnr

In [35]:
for epoch in range(100):
    epoch_gen_loss=0
    epoch_disc_loss=0
    epoch_psnr=0

    for x, y in tqdm(train_loader, desc=f'Epoch: {epoch+1}'):
        x, y = x.to('cuda'), y.to('cuda')
        
        # discriminator
        dis_optimizer.zero_grad()
        gen_out = generator(x)
        disc_loss = discriminator_loss(discriminator, x, y, gen_out)
        disc_loss.backward()
        dis_optimizer.step()
        
        #generator
        gen_optimizer.zero_grad()
        gen_loss = generator_loss(discriminator, x, y, gen_out, lam=60)
        gen_loss.backward()
        gen_optimizer.step()
        
        #metrics
        psnr = compute_psnr(gen_out.detach(), y)
        epoch_gen_loss += gen_loss.item()
        epoch_disc_loss += disc_loss.item()
        epoch_psnr += psnr
        
    avg_gen_loss = epoch_gen_loss / len(train_loader)
    avg_disc_loss = epoch_disc_loss / len(train_loader)
    avg_psnr = epoch_psnr / len(train_loader)
     
    print(f"Epoch: {epoch+1}, Gen Loss: {avg_gen_loss:.2f}, Dis Loss: {avg_disc_loss:.2f}, PSNR: {avg_psnr:.2f} dB")

    #save checkpoint
    
    if (epoch+1)%20==0:
        print("SAVING CHECKPOINT!!!!!!!!!!!!!!")
        torch.save({
            'epoch': epoch,
            'generator_state_dict': generator.state_dict(),
            'gen_optimizer_state_dict': gen_optimizer.state_dict(),
            'discriminator_state_dict': discriminator.state_dict(),
            'disc_optimizer_state_dict': dis_optimizer.state_dict(),
            'dis_loss': avg_disc_loss,
            'gen_loss': avg_gen_loss,
            'psnr': avg_psnr,
        }, '/kaggle/working/checkpoint.pth')

Epoch: 1: 100%|██████████| 46/46 [00:36<00:00,  1.27it/s]


Epoch: 1, Gen Loss: 12.21, Dis Loss: 0.10, PSNR: 13.82 dB


Epoch: 2: 100%|██████████| 46/46 [00:39<00:00,  1.16it/s]


Epoch: 2, Gen Loss: 12.10, Dis Loss: 0.09, PSNR: 13.90 dB


Epoch: 3: 100%|██████████| 46/46 [00:38<00:00,  1.20it/s]


Epoch: 3, Gen Loss: 12.13, Dis Loss: 0.10, PSNR: 13.87 dB


Epoch: 4: 100%|██████████| 46/46 [00:38<00:00,  1.19it/s]


Epoch: 4, Gen Loss: 11.60, Dis Loss: 0.19, PSNR: 14.20 dB


Epoch: 5: 100%|██████████| 46/46 [00:38<00:00,  1.19it/s]


Epoch: 5, Gen Loss: 11.30, Dis Loss: 0.09, PSNR: 14.29 dB


Epoch: 6: 100%|██████████| 46/46 [00:38<00:00,  1.19it/s]


Epoch: 6, Gen Loss: 11.54, Dis Loss: 0.11, PSNR: 14.27 dB


Epoch: 7: 100%|██████████| 46/46 [00:38<00:00,  1.18it/s]


Epoch: 7, Gen Loss: 11.64, Dis Loss: 0.13, PSNR: 14.20 dB


Epoch: 8: 100%|██████████| 46/46 [00:38<00:00,  1.19it/s]


Epoch: 8, Gen Loss: 11.65, Dis Loss: 0.12, PSNR: 14.24 dB


Epoch: 9: 100%|██████████| 46/46 [00:38<00:00,  1.20it/s]


Epoch: 9, Gen Loss: 11.42, Dis Loss: 0.12, PSNR: 14.46 dB


Epoch: 10: 100%|██████████| 46/46 [00:38<00:00,  1.19it/s]


Epoch: 10, Gen Loss: 11.25, Dis Loss: 0.13, PSNR: 14.59 dB


Epoch: 11: 100%|██████████| 46/46 [00:38<00:00,  1.19it/s]


Epoch: 11, Gen Loss: 11.15, Dis Loss: 0.11, PSNR: 14.70 dB


Epoch: 12: 100%|██████████| 46/46 [00:38<00:00,  1.19it/s]


Epoch: 12, Gen Loss: 11.22, Dis Loss: 0.12, PSNR: 14.67 dB


Epoch: 13: 100%|██████████| 46/46 [00:38<00:00,  1.19it/s]


Epoch: 13, Gen Loss: 11.11, Dis Loss: 0.13, PSNR: 14.84 dB


Epoch: 14: 100%|██████████| 46/46 [00:38<00:00,  1.19it/s]


Epoch: 14, Gen Loss: 10.99, Dis Loss: 0.12, PSNR: 14.95 dB


Epoch: 15: 100%|██████████| 46/46 [00:38<00:00,  1.18it/s]


Epoch: 15, Gen Loss: 10.93, Dis Loss: 0.12, PSNR: 15.02 dB


Epoch: 16: 100%|██████████| 46/46 [00:38<00:00,  1.18it/s]


Epoch: 16, Gen Loss: 10.86, Dis Loss: 0.13, PSNR: 15.09 dB


Epoch: 17: 100%|██████████| 46/46 [00:38<00:00,  1.19it/s]


Epoch: 17, Gen Loss: 10.84, Dis Loss: 0.13, PSNR: 15.12 dB


Epoch: 18: 100%|██████████| 46/46 [00:38<00:00,  1.19it/s]


Epoch: 18, Gen Loss: 10.85, Dis Loss: 0.13, PSNR: 15.18 dB


Epoch: 19: 100%|██████████| 46/46 [00:38<00:00,  1.19it/s]


Epoch: 19, Gen Loss: 10.69, Dis Loss: 0.14, PSNR: 15.20 dB


Epoch: 20: 100%|██████████| 46/46 [00:38<00:00,  1.19it/s]


Epoch: 20, Gen Loss: 10.59, Dis Loss: 0.12, PSNR: 15.43 dB
SAVING CHECKPOINT!!!!!!!!!!!!!!


Epoch: 21: 100%|██████████| 46/46 [00:38<00:00,  1.19it/s]


Epoch: 21, Gen Loss: 10.63, Dis Loss: 0.13, PSNR: 15.50 dB


Epoch: 22: 100%|██████████| 46/46 [00:38<00:00,  1.19it/s]


Epoch: 22, Gen Loss: 10.61, Dis Loss: 0.12, PSNR: 15.49 dB


Epoch: 23: 100%|██████████| 46/46 [00:38<00:00,  1.19it/s]


Epoch: 23, Gen Loss: 10.57, Dis Loss: 0.14, PSNR: 15.57 dB


Epoch: 24: 100%|██████████| 46/46 [00:38<00:00,  1.19it/s]


Epoch: 24, Gen Loss: 10.51, Dis Loss: 0.12, PSNR: 15.65 dB


Epoch: 25: 100%|██████████| 46/46 [00:38<00:00,  1.19it/s]


Epoch: 25, Gen Loss: 10.37, Dis Loss: 0.12, PSNR: 15.75 dB


Epoch: 26: 100%|██████████| 46/46 [00:38<00:00,  1.19it/s]


Epoch: 26, Gen Loss: 10.06, Dis Loss: 0.44, PSNR: 16.01 dB


Epoch: 27: 100%|██████████| 46/46 [00:38<00:00,  1.20it/s]


Epoch: 27, Gen Loss: 8.75, Dis Loss: 0.12, PSNR: 16.86 dB


Epoch: 28: 100%|██████████| 46/46 [00:38<00:00,  1.20it/s]


Epoch: 28, Gen Loss: 9.17, Dis Loss: 0.11, PSNR: 16.60 dB


Epoch: 29: 100%|██████████| 46/46 [00:38<00:00,  1.19it/s]


Epoch: 29, Gen Loss: 9.39, Dis Loss: 0.11, PSNR: 16.48 dB


Epoch: 30: 100%|██████████| 46/46 [00:38<00:00,  1.19it/s]


Epoch: 30, Gen Loss: 9.58, Dis Loss: 0.15, PSNR: 16.33 dB


Epoch: 31: 100%|██████████| 46/46 [00:38<00:00,  1.19it/s]


Epoch: 31, Gen Loss: 9.62, Dis Loss: 0.15, PSNR: 16.35 dB


Epoch: 32: 100%|██████████| 46/46 [00:38<00:00,  1.19it/s]


Epoch: 32, Gen Loss: 9.64, Dis Loss: 0.14, PSNR: 16.31 dB


Epoch: 33: 100%|██████████| 46/46 [00:38<00:00,  1.19it/s]


Epoch: 33, Gen Loss: 9.74, Dis Loss: 0.16, PSNR: 16.23 dB


Epoch: 34: 100%|██████████| 46/46 [00:38<00:00,  1.19it/s]


Epoch: 34, Gen Loss: 9.79, Dis Loss: 0.15, PSNR: 16.29 dB


Epoch: 35: 100%|██████████| 46/46 [00:38<00:00,  1.18it/s]


Epoch: 35, Gen Loss: 9.69, Dis Loss: 0.15, PSNR: 16.44 dB


Epoch: 36: 100%|██████████| 46/46 [00:38<00:00,  1.19it/s]


Epoch: 36, Gen Loss: 9.67, Dis Loss: 0.14, PSNR: 16.51 dB


Epoch: 37: 100%|██████████| 46/46 [00:38<00:00,  1.19it/s]


Epoch: 37, Gen Loss: 9.80, Dis Loss: 0.14, PSNR: 16.42 dB


Epoch: 38: 100%|██████████| 46/46 [00:38<00:00,  1.19it/s]


Epoch: 38, Gen Loss: 9.62, Dis Loss: 0.15, PSNR: 16.60 dB


Epoch: 39: 100%|██████████| 46/46 [00:38<00:00,  1.19it/s]


Epoch: 39, Gen Loss: 9.61, Dis Loss: 0.14, PSNR: 16.62 dB


Epoch: 40: 100%|██████████| 46/46 [00:38<00:00,  1.19it/s]


Epoch: 40, Gen Loss: 9.75, Dis Loss: 0.16, PSNR: 16.56 dB
SAVING CHECKPOINT!!!!!!!!!!!!!!


Epoch: 41: 100%|██████████| 46/46 [00:38<00:00,  1.19it/s]


Epoch: 41, Gen Loss: 9.64, Dis Loss: 0.14, PSNR: 16.60 dB


Epoch: 42: 100%|██████████| 46/46 [00:38<00:00,  1.19it/s]


Epoch: 42, Gen Loss: 9.78, Dis Loss: 0.13, PSNR: 16.51 dB


Epoch: 43: 100%|██████████| 46/46 [00:38<00:00,  1.18it/s]


Epoch: 43, Gen Loss: 9.75, Dis Loss: 0.14, PSNR: 16.57 dB


Epoch: 44: 100%|██████████| 46/46 [00:38<00:00,  1.19it/s]


Epoch: 44, Gen Loss: 9.67, Dis Loss: 0.15, PSNR: 16.65 dB


Epoch: 45: 100%|██████████| 46/46 [00:38<00:00,  1.19it/s]


Epoch: 45, Gen Loss: 9.69, Dis Loss: 0.14, PSNR: 16.65 dB


Epoch: 46: 100%|██████████| 46/46 [00:38<00:00,  1.19it/s]


Epoch: 46, Gen Loss: 9.49, Dis Loss: 0.15, PSNR: 16.88 dB


Epoch: 47: 100%|██████████| 46/46 [00:38<00:00,  1.18it/s]


Epoch: 47, Gen Loss: 9.44, Dis Loss: 0.13, PSNR: 16.96 dB


Epoch: 48: 100%|██████████| 46/46 [00:38<00:00,  1.18it/s]


Epoch: 48, Gen Loss: 9.58, Dis Loss: 0.15, PSNR: 16.82 dB


Epoch: 49: 100%|██████████| 46/46 [00:38<00:00,  1.19it/s]


Epoch: 49, Gen Loss: 9.41, Dis Loss: 0.14, PSNR: 16.95 dB


Epoch: 50: 100%|██████████| 46/46 [00:38<00:00,  1.19it/s]


Epoch: 50, Gen Loss: 9.52, Dis Loss: 0.14, PSNR: 16.85 dB


Epoch: 51: 100%|██████████| 46/46 [00:38<00:00,  1.20it/s]


Epoch: 51, Gen Loss: 9.48, Dis Loss: 0.13, PSNR: 16.95 dB


Epoch: 52: 100%|██████████| 46/46 [00:38<00:00,  1.20it/s]


Epoch: 52, Gen Loss: 9.49, Dis Loss: 0.14, PSNR: 16.93 dB


Epoch: 53: 100%|██████████| 46/46 [00:38<00:00,  1.19it/s]


Epoch: 53, Gen Loss: 9.55, Dis Loss: 0.15, PSNR: 16.88 dB


Epoch: 54: 100%|██████████| 46/46 [00:38<00:00,  1.18it/s]


Epoch: 54, Gen Loss: 9.50, Dis Loss: 0.13, PSNR: 16.92 dB


Epoch: 55: 100%|██████████| 46/46 [00:38<00:00,  1.18it/s]


Epoch: 55, Gen Loss: 9.49, Dis Loss: 0.15, PSNR: 16.94 dB


Epoch: 56: 100%|██████████| 46/46 [00:38<00:00,  1.19it/s]


Epoch: 56, Gen Loss: 9.42, Dis Loss: 0.14, PSNR: 17.03 dB


Epoch: 57: 100%|██████████| 46/46 [00:38<00:00,  1.19it/s]


Epoch: 57, Gen Loss: 9.45, Dis Loss: 0.14, PSNR: 17.00 dB


Epoch: 58: 100%|██████████| 46/46 [00:38<00:00,  1.20it/s]


Epoch: 58, Gen Loss: 9.36, Dis Loss: 0.13, PSNR: 17.08 dB


Epoch: 59: 100%|██████████| 46/46 [00:38<00:00,  1.20it/s]


Epoch: 59, Gen Loss: 9.46, Dis Loss: 0.13, PSNR: 16.97 dB


Epoch: 60: 100%|██████████| 46/46 [00:38<00:00,  1.19it/s]


Epoch: 60, Gen Loss: 9.38, Dis Loss: 0.13, PSNR: 17.05 dB
SAVING CHECKPOINT!!!!!!!!!!!!!!


Epoch: 61: 100%|██████████| 46/46 [00:38<00:00,  1.19it/s]


Epoch: 61, Gen Loss: 9.35, Dis Loss: 0.13, PSNR: 17.15 dB


Epoch: 62: 100%|██████████| 46/46 [00:38<00:00,  1.18it/s]


Epoch: 62, Gen Loss: 9.41, Dis Loss: 0.14, PSNR: 17.08 dB


Epoch: 63: 100%|██████████| 46/46 [00:38<00:00,  1.19it/s]


Epoch: 63, Gen Loss: 9.45, Dis Loss: 0.12, PSNR: 17.04 dB


Epoch: 64: 100%|██████████| 46/46 [00:38<00:00,  1.19it/s]


Epoch: 64, Gen Loss: 9.43, Dis Loss: 0.13, PSNR: 17.05 dB


Epoch: 65: 100%|██████████| 46/46 [00:38<00:00,  1.20it/s]


Epoch: 65, Gen Loss: 9.31, Dis Loss: 0.13, PSNR: 17.22 dB


Epoch: 66: 100%|██████████| 46/46 [00:38<00:00,  1.19it/s]


Epoch: 66, Gen Loss: 9.33, Dis Loss: 0.13, PSNR: 17.20 dB


Epoch: 67: 100%|██████████| 46/46 [00:38<00:00,  1.19it/s]


Epoch: 67, Gen Loss: 9.26, Dis Loss: 0.13, PSNR: 17.26 dB


Epoch: 68: 100%|██████████| 46/46 [00:38<00:00,  1.19it/s]


Epoch: 68, Gen Loss: 9.33, Dis Loss: 0.12, PSNR: 17.18 dB


Epoch: 69: 100%|██████████| 46/46 [00:38<00:00,  1.19it/s]


Epoch: 69, Gen Loss: 9.62, Dis Loss: 0.11, PSNR: 16.92 dB


Epoch: 70: 100%|██████████| 46/46 [00:38<00:00,  1.19it/s]


Epoch: 70, Gen Loss: 9.28, Dis Loss: 0.13, PSNR: 17.26 dB


Epoch: 71: 100%|██████████| 46/46 [00:38<00:00,  1.19it/s]


Epoch: 71, Gen Loss: 9.27, Dis Loss: 0.12, PSNR: 17.25 dB


Epoch: 72: 100%|██████████| 46/46 [00:38<00:00,  1.19it/s]


Epoch: 72, Gen Loss: 9.24, Dis Loss: 0.12, PSNR: 17.33 dB


Epoch: 73: 100%|██████████| 46/46 [00:38<00:00,  1.19it/s]


Epoch: 73, Gen Loss: 9.23, Dis Loss: 0.11, PSNR: 17.34 dB


Epoch: 74: 100%|██████████| 46/46 [00:38<00:00,  1.19it/s]


Epoch: 74, Gen Loss: 9.21, Dis Loss: 0.13, PSNR: 17.35 dB


Epoch: 75: 100%|██████████| 46/46 [00:38<00:00,  1.19it/s]


Epoch: 75, Gen Loss: 9.26, Dis Loss: 0.11, PSNR: 17.29 dB


Epoch: 76: 100%|██████████| 46/46 [00:38<00:00,  1.19it/s]


Epoch: 76, Gen Loss: 9.34, Dis Loss: 0.12, PSNR: 17.24 dB


Epoch: 77: 100%|██████████| 46/46 [00:38<00:00,  1.19it/s]


Epoch: 77, Gen Loss: 9.34, Dis Loss: 0.12, PSNR: 17.23 dB


Epoch: 78: 100%|██████████| 46/46 [00:38<00:00,  1.19it/s]


Epoch: 78, Gen Loss: 9.00, Dis Loss: 0.15, PSNR: 17.57 dB


Epoch: 79: 100%|██████████| 46/46 [00:38<00:00,  1.19it/s]


Epoch: 79, Gen Loss: 9.12, Dis Loss: 0.12, PSNR: 17.41 dB


Epoch: 80: 100%|██████████| 46/46 [00:38<00:00,  1.19it/s]


Epoch: 80, Gen Loss: 9.11, Dis Loss: 0.12, PSNR: 17.43 dB
SAVING CHECKPOINT!!!!!!!!!!!!!!


Epoch: 81: 100%|██████████| 46/46 [00:38<00:00,  1.19it/s]


Epoch: 81, Gen Loss: 9.11, Dis Loss: 0.11, PSNR: 17.44 dB


Epoch: 82: 100%|██████████| 46/46 [00:38<00:00,  1.19it/s]


Epoch: 82, Gen Loss: 8.97, Dis Loss: 0.12, PSNR: 17.60 dB


Epoch: 83: 100%|██████████| 46/46 [00:38<00:00,  1.19it/s]


Epoch: 83, Gen Loss: 9.13, Dis Loss: 0.11, PSNR: 17.42 dB


Epoch: 84: 100%|██████████| 46/46 [00:38<00:00,  1.19it/s]


Epoch: 84, Gen Loss: 9.13, Dis Loss: 0.11, PSNR: 17.42 dB


Epoch: 85: 100%|██████████| 46/46 [00:38<00:00,  1.19it/s]


Epoch: 85, Gen Loss: 9.17, Dis Loss: 0.12, PSNR: 17.40 dB


Epoch: 86: 100%|██████████| 46/46 [00:38<00:00,  1.19it/s]


Epoch: 86, Gen Loss: 9.05, Dis Loss: 0.12, PSNR: 17.51 dB


Epoch: 87: 100%|██████████| 46/46 [00:38<00:00,  1.19it/s]


Epoch: 87, Gen Loss: 9.02, Dis Loss: 0.12, PSNR: 17.60 dB


Epoch: 88: 100%|██████████| 46/46 [00:38<00:00,  1.19it/s]


Epoch: 88, Gen Loss: 9.05, Dis Loss: 0.11, PSNR: 17.53 dB


Epoch: 89: 100%|██████████| 46/46 [00:38<00:00,  1.19it/s]


Epoch: 89, Gen Loss: 9.10, Dis Loss: 0.11, PSNR: 17.47 dB


Epoch: 90: 100%|██████████| 46/46 [00:38<00:00,  1.19it/s]


Epoch: 90, Gen Loss: 9.11, Dis Loss: 0.12, PSNR: 17.50 dB


Epoch: 91: 100%|██████████| 46/46 [00:38<00:00,  1.19it/s]


Epoch: 91, Gen Loss: 9.11, Dis Loss: 0.11, PSNR: 17.46 dB


Epoch: 92: 100%|██████████| 46/46 [00:38<00:00,  1.19it/s]


Epoch: 92, Gen Loss: 9.34, Dis Loss: 0.10, PSNR: 17.23 dB


Epoch: 93: 100%|██████████| 46/46 [00:38<00:00,  1.20it/s]


Epoch: 93, Gen Loss: 9.59, Dis Loss: 1.09, PSNR: 17.76 dB


Epoch: 94: 100%|██████████| 46/46 [00:38<00:00,  1.20it/s]


Epoch: 94, Gen Loss: 7.39, Dis Loss: 0.22, PSNR: 18.91 dB


Epoch: 95: 100%|██████████| 46/46 [00:38<00:00,  1.20it/s]


Epoch: 95, Gen Loss: 7.33, Dis Loss: 0.19, PSNR: 19.05 dB


Epoch: 96: 100%|██████████| 46/46 [00:38<00:00,  1.19it/s]


Epoch: 96, Gen Loss: 7.38, Dis Loss: 0.16, PSNR: 19.04 dB


Epoch: 97: 100%|██████████| 46/46 [00:38<00:00,  1.19it/s]


Epoch: 97, Gen Loss: 7.59, Dis Loss: 0.14, PSNR: 18.84 dB


Epoch: 98: 100%|██████████| 46/46 [00:38<00:00,  1.19it/s]


Epoch: 98, Gen Loss: 7.65, Dis Loss: 0.14, PSNR: 18.82 dB


Epoch: 99: 100%|██████████| 46/46 [00:38<00:00,  1.19it/s]


Epoch: 99, Gen Loss: 7.78, Dis Loss: 0.14, PSNR: 18.65 dB


Epoch: 100: 100%|██████████| 46/46 [00:38<00:00,  1.19it/s]


Epoch: 100, Gen Loss: 7.86, Dis Loss: 0.15, PSNR: 18.54 dB
SAVING CHECKPOINT!!!!!!!!!!!!!!


In [24]:
import torchvision

In [25]:
test_transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5],
                         std=[0.5, 0.5, 0.5])
])

In [26]:
inp_img = Image.open(input_path+'/0.jpg').convert("RGB")
inp_tensor = test_transform(inp_img).unsqueeze(0).to("cuda")

In [27]:
generator.eval()

Generator(
  (encoder1): Sequential(
    (0): Conv2d(3, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (1): LeakyReLU(negative_slope=0.2, inplace=True)
  )
  (encoder2): Sequential(
    (0): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): LeakyReLU(negative_slope=0.2, inplace=True)
  )
  (encoder3): Sequential(
    (0): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): LeakyReLU(negative_slope=0.2, inplace=True)
  )
  (encoder4): Sequential(
    (0): Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): LeakyReLU(negative_slope=0.2, inplace=True)
  )
  (encoder5): Sequential(
    (0): 

In [28]:
with torch.no_grad():
    fake_out = generator(inp_tensor)

In [29]:
def denormalize(tensor):
    mean = torch.tensor([0.485, 0.456, 0.406]).to(tensor.device)
    std = torch.tensor([0.229, 0.224, 0.225]).to(tensor.device)
    tensor = tensor * std[None, :, None, None] + mean[None, :, None, None]
    return torch.clamp(tensor, 0, 1)

# Save or visualize
output_img = denormalize(fake_out)
torchvision.utils.save_image(output_img, "/kaggle/working/test_result.png")

In [30]:
target_img = Image.open(target_path+'/0.jpg').convert("RGB")
target_tensor = test_transform(target_img).unsqueeze(0).to('cuda')

# optional: calculate pixel-level difference
mse = torch.nn.functional.mse_loss(fake_out, target_tensor)
print("Reconstruction MSE:", mse.item())

Reconstruction MSE: 0.030887028202414513


In [57]:
import math
psnr = 10 * math.log10(1 / mse)
print(f"PSNR: {psnr:.2f} dB")

PSNR: 18.23 dB
