In [3]:
import torch
import torch.nn as nn
from PIL import Image
import numpy as np
import os
from torch.utils.data import Dataset
import albumentations as A
from albumentations.pytorch import ToTensorV2
from torchvision.utils import save_image
from tqdm import tqdm
from torch.utils.data import DataLoader
import torch.optim as optim

In [4]:
DEVICE= "cuda:0" if torch.cuda.is_available() else "cpu"
LEARNING_RATE= 2e-4
BATCH_SIZE= 16
NUM_WORKERS= 2
IMAGE_SIZE= 256
CHANNELS_IMG= 3
L1_LAMBDA= 100
NUM_EPOCHS= 10
LOAD_MODEL= False
SAVE_MODEL= True
CHECKPOINT_DISC= "disc.pth.tar"
CHECKPOINT_GEN= "gen.path.tar"

In [5]:
class CNNBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=2):
        super().__init__()
        self.conv= nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 4, stride, bias="False", padding_mode="reflect"),
            nn.BatchNorm2d(out_channels),
            nn.LeakyReLU(0.2),
        )
    
    def forward(self, x):
        return self.conv(x)

In [6]:
class Discriminator(nn.Module):
    def __init__(self, in_channels=3, features=[64, 128, 256, 512]):
        super().__init__()
        self.initial= nn.Sequential(
                nn.Conv2d(in_channels*2, features[0], kernel_size= 4, stride= 2, padding= 1, padding_mode="reflect"),
                nn.LeakyReLU(0.2),
        )
        
        layers=[]
        in_channels= features[0]
        for feature in features[1:]:
            layers.append(
                    CNNBlock(in_channels, feature, stride= 1 if feature == features[-1] else 2),
            )
            in_channels= feature
        layers.append(
            nn.Conv2d(
                in_channels, 1, kernel_size= 4, stride= 1, padding= 1, padding_mode= "reflect"
            ),
        )
        self.model= nn.Sequential(*layers)
    
    def forward(self, x, y):
        x= torch.cat([x, y], dim= 1)
        x= self.initial(x)
        return self.model(x)

# Just to test Discriminator model

In [7]:
def test():
    x= torch.randn((1, 3, 256, 256))
    y= torch.randn((1, 3, 256, 256))
    model= Discriminator()
    preds= model(x, y)
    print(preds.shape)
test()

torch.Size([1, 1, 26, 26])


In [8]:
class Block(nn.Module):
    def __init__(self, in_channels, out_channels, down= True, act= "relu", use_dropout= False):
        super().__init__()
        self.conv= nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 4, 2, 1, bias= False, padding_mode= "reflect")
            if down
            else nn.ConvTranspose2d(in_channels, out_channels, 4, 2, 1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU() if act=="relu" else nn.LeakyReLU(0.2),
        )
        self.use_dropout= use_dropout
        self.dropout= nn.Dropout(0.5)
    
    def forward(self, x):
        x= self.conv(x)
        return self.dropout(x) if self.use_dropout else x

In [9]:
class Generator(nn.Module):
    def __init__(self, in_channels= 3, features= 64):
        super().__init__()
        self.initial_down= nn.Sequential(
            nn.Conv2d(in_channels, features, 4, 2, 1, padding_mode= "reflect"),
            nn.LeakyReLU(0.2),
        ) # returns 128 
        
        self.down1= Block(features, features*2, down= True, act= "leaky", use_dropout= False)#64
        self.down2= Block(features*2, features*4, down= True, act= "leaky", use_dropout= False)#32
        self.down3= Block(features*4, features*8, down= True, act= "leaky", use_dropout= False)#16
        self.down4= Block(features*8, features*8, down= True, act= "leaky", use_dropout= False)#8
        self.down5= Block(features*8, features*8, down= True, act= "leaky", use_dropout= False)#4
        self.down6= Block(features*8, features*8, down= True, act= "leaky", use_dropout= False)#2
        self.bottleneck= nn.Sequential(
            nn.Conv2d(features*8, features*8, 4, 2, 1, padding_mode="reflect"), 
            nn.ReLU(),#1x1
        )
        
        self.up1= Block(features*8, features*8, down= False, act= "relu", use_dropout= True)#64
        self.up2= Block(features*8*2, features*8, down= False, act= "relu", use_dropout= True)#64
        self.up3= Block(features*8*2, features*8, down= False, act= "relu", use_dropout= True)#64
        self.up4= Block(features*8*2, features*8, down= False, act= "relu", use_dropout= False)#64
        self.up5= Block(features*8*2, features*4, down= False, act= "relu", use_dropout= False)#64
        self.up6= Block(features*4*2, features*2, down= False, act= "relu", use_dropout= False)#64
        self.up7= Block(features*2*2, features, down= False, act= "relu", use_dropout= False)#64
        self.final_up= nn.Sequential(
            nn.ConvTranspose2d(features*2, in_channels, kernel_size=4, stride= 2, padding=1),
            nn.Tanh(),   
        )                    
    
    def forward(self, x):
        d1= self.initial_down(x)
        d2= self.down1(d1)
        d3= self.down2(d2)
        d4= self.down3(d3)
        d5= self.down4(d4)
        d6= self.down5(d5)
        d7= self.down6(d6)
        bottleneck= self.bottleneck(d7)
        up1= self.up1(bottleneck)
        up2= self.up2(torch.cat([up1, d7], 1))
        up3= self.up3(torch.cat([up2, d6], 1))
        up4= self.up4(torch.cat([up3, d5], 1))
        up5= self.up5(torch.cat([up4, d4], 1))
        up6= self.up6(torch.cat([up5, d3], 1))
        up7= self.up7(torch.cat([up6, d2], 1))
        return self.final_up(torch.cat([up7, d1], 1))

In [10]:
def test():
    x= torch.randn((1, 3, 256, 256))
    model= Generator(in_channels=3, features= 64)
    preds= model(x)
    print(preds.shape)
test()

torch.Size([1, 3, 256, 256])


In [11]:
both_transform= A.Compose(
    [A.Resize(width=256, height=256),], additional_targets={"image0":"image"},
)

transform_only_input= A.Compose(
    [
        A.ColorJitter(p=0.1),
        A.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], max_pixel_value= 255.0,),
        ToTensorV2(),
    ]
)

transform_only_mask= A.Compose(
    [
        A.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], max_pixel_value= 255.0,),
        ToTensorV2(),
    ]
)

In [12]:
class MapDataset(Dataset):
    def __init__(self, root_dir):
        self.root_dir= root_dir
        self.list_files= os.listdir(self.root_dir)
        print(self.list_files)
    
    def __len__(self):
        return len(self.list_files)
    
    def __getitem__(self, index):
        img_file= self.list_files[index]
        img_path= os.path.join(self.root_dir, img_file)
        image= np.array(Image.open(img_path))
        input_image= image[:, :600, :]
        target_image= image[:, 600:, :]
        
        augmentations= both_transform(image=input_image, image0=target_image)
        input_image, target_image= augmentations["image"], augmentations["image0"]
        
        input_image= transform_only_input(image= input_image)["image"]
        target_image= transform_only_input(image= target_image)["image"]
        
        return input_image, target_image

In [13]:
def train(disc, gen,loader, opt_disc, opt_gen, l1, bce, g_scaler, d_scaler):
    loop= tqdm(loader, leave= True)
    
    for idx, (x, y) in enumerate(loop):
        x, y= x.to(DEVICE), y.to(DEVICE)
        
        with torch.cuda.amp.autocast():
            y_fake= gen(x)
            D_real= disc(x, y)
            D_fake= disc(x, y_fake.detach())
            D_real_loss= bce(D_real, torch.ones_like(D_real))
            D_fake_loss= bce(D_fake, torch.zeros_like(D_fake))
            D_loss= (D_real_loss + D_fake_loss)/2
        
        disc.zero_grad()
        d_scaler.scale(D_loss).backward()
        d_scaler.step(opt_disc)
        d_scaler.update()
        
        with torch.cuda.amp.autocast():
            D_fake= disc(x, y_fake)
            G_fake_loss= bce(D_fake, torch.ones_like(D_fake))
            L1= l1(y_fake, y)*L1_LAMBDA
            G_loss= G_fake_loss + L1
        
        disc.zero_grad()
        g_scaler.scale(G_loss).backward()
        g_scaler.step(opt_disc)
        g_scaler.update()
        
        if idx %10 == 0:
            loop.set_postfix(
                D_real= torch.sigmoid(D_real).mean().item(),
                D_fake= torch.sigmoid(D_fake).mean().item(),
            )

In [14]:
def save_checkpoint(model, optimizer, filename="my_checkpoint.pth.tar"):
    print("=> Saving Checkpoint")
    checkpoint= {
        "state_dict": model.state_dict(),
        "optimizer": optimizer.state_dict(),
    }
    torch.save(checkpoint, filename)

In [15]:
def save_some_examples(gen, val_loader, epoch):
    x, y= next(iter(val_loader))
    x, y= x.to(DEVICE), y.to(DEVICE)
    gen.eval()
    with torch.no_grad():
        y_fake= gen(x)
        y_fake= y_fake * 0.5+ 0.5 
        save_image(y_fake,f"/y_gen_{epoch}.png")
        save_image(x*0.5 + 0.5,f"/input_{epoch}.png")
        if epoch % 1==0:
            save_image(y * 0.5+ 0.5, f"/label_{epoch}.png")
    gen.train()

In [17]:
def main():
    disc= Discriminator(in_channels=3).to(DEVICE)
    gen= Generator(in_channels=3).to(DEVICE)
    opt_disc= optim.Adam(disc.parameters(), lr= LEARNING_RATE, betas=(0.5, 0.999))
    opt_gen= optim.Adam(gen.parameters(), lr= LEARNING_RATE, betas=(0.5, 0.999))
    BCE=nn.BCEWithLogitsLoss()
    L1_LOSS= nn.L1Loss()
    
    train_dataset= MapDataset(root_dir="/kaggle/input/pix2pix-dataset/maps/maps/train")
    train_loader= DataLoader(train_dataset, batch_size= BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS)
    g_scaler= torch.cuda.amp.GradScaler()
    d_scaler= torch.cuda.amp.GradScaler()
    val_dataset= MapDataset(root_dir="/kaggle/input/pix2pix-dataset/maps/maps/val")
    val_loader= DataLoader(val_dataset, batch_size=1, shuffle= True)
    
    for epoch in range(NUM_EPOCHS):
        train(disc, gen, train_loader, opt_disc, opt_gen, L1_LOSS, BCE, g_scaler, d_scaler)
      
        if SAVE_MODEL and epoch%1==0:
            save_checkpoint(gen, opt_gen, filename=CHECKPOINT_GEN)
            save_checkpoint(disc, opt_disc, filename=CHECKPOINT_DISC)
            
            save_some_examples(gen, val_loader, epoch)
        
main()

['623.jpg', '764.jpg', '1075.jpg', '771.jpg', '208.jpg', '820.jpg', '473.jpg', '1031.jpg', '333.jpg', '1024.jpg', '537.jpg', '45.jpg', '369.jpg', '56.jpg', '654.jpg', '89.jpg', '20.jpg', '275.jpg', '785.jpg', '212.jpg', '239.jpg', '792.jpg', '1009.jpg', '58.jpg', '150.jpg', '6.jpg', '109.jpg', '149.jpg', '187.jpg', '521.jpg', '436.jpg', '76.jpg', '539.jpg', '355.jpg', '516.jpg', '71.jpg', '708.jpg', '474.jpg', '501.jpg', '915.jpg', '815.jpg', '760.jpg', '342.jpg', '817.jpg', '429.jpg', '1055.jpg', '646.jpg', '682.jpg', '544.jpg', '377.jpg', '1026.jpg', '272.jpg', '795.jpg', '270.jpg', '182.jpg', '215.jpg', '489.jpg', '576.jpg', '185.jpg', '613.jpg', '930.jpg', '243.jpg', '1010.jpg', '153.jpg', '703.jpg', '189.jpg', '143.jpg', '1025.jpg', '476.jpg', '717.jpg', '327.jpg', '253.jpg', '343.jpg', '115.jpg', '131.jpg', '1058.jpg', '446.jpg', '626.jpg', '425.jpg', '5.jpg', '824.jpg', '366.jpg', '850.jpg', '885.jpg', '151.jpg', '426.jpg', '732.jpg', '503.jpg', '8.jpg', '641.jpg', '892.jpg', '9

100%|██████████| 69/69 [00:12<00:00,  5.50it/s, D_fake=0.719, D_real=0.831]


=> Saving Checkpoint
=> Saving Checkpoint


100%|██████████| 69/69 [00:10<00:00,  6.49it/s, D_fake=0.688, D_real=0.893]


=> Saving Checkpoint
=> Saving Checkpoint


100%|██████████| 69/69 [00:10<00:00,  6.36it/s, D_fake=0.679, D_real=0.951]


=> Saving Checkpoint
=> Saving Checkpoint


100%|██████████| 69/69 [00:10<00:00,  6.30it/s, D_fake=0.694, D_real=0.957]


=> Saving Checkpoint
=> Saving Checkpoint


100%|██████████| 69/69 [00:11<00:00,  6.20it/s, D_fake=0.668, D_real=0.973]


=> Saving Checkpoint
=> Saving Checkpoint


100%|██████████| 69/69 [00:11<00:00,  6.14it/s, D_fake=0.691, D_real=0.98] 


=> Saving Checkpoint
=> Saving Checkpoint


100%|██████████| 69/69 [00:11<00:00,  6.06it/s, D_fake=0.658, D_real=0.984]


=> Saving Checkpoint
=> Saving Checkpoint


100%|██████████| 69/69 [00:11<00:00,  6.03it/s, D_fake=0.698, D_real=0.984]


=> Saving Checkpoint
=> Saving Checkpoint


100%|██████████| 69/69 [00:11<00:00,  6.14it/s, D_fake=0.652, D_real=0.987]


=> Saving Checkpoint
=> Saving Checkpoint


100%|██████████| 69/69 [00:11<00:00,  6.19it/s, D_fake=0.663, D_real=0.987]


=> Saving Checkpoint
=> Saving Checkpoint
