In [1]:
import torch
import torch.nn as nn 
import torch.nn.functional as F 
import torch.optim as optim 
torch.backends.cudnn.benchmark = True
import albumentations as A
from torch.utils.data import Dataset, DataLoader  
import os 
from PIL import Image 
from albumentations.pytorch import ToTensorV2
from torchvision.utils import save_image
import tqdm
import numpy as np 

# Hyperparameters 

In [32]:
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
TRAIN_DIR = "maps/maps/train"
VAL_DIR = "maps/maps/val"
LEARNING_RATE = 2e-4
BATCH_SIZE = 16
NUM_WORKERS = 0
IMAGE_SIZE = 256
CHANNELS_IMG = 3
L1_LAMBDA = 100
LAMBDA_GP = 10
NUM_EPOCHS = 500
LOAD_MODEL = False
SAVE_MODEL = False
CHECKPOINT_DISC = "disc.pth.tar"
CHECKPOINT_GEN = "gen.pth.tar"


both_transform = A.Compose(
    [A.Resize(width=256, height=256),], additional_targets={"image0": "image"},
)

transform_only_input = A.Compose(
    [
        A.HorizontalFlip(p=0.5),
        A.ColorJitter(p=0.2),
        A.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], max_pixel_value=255.0,),
        ToTensorV2(),
    ]
)

transform_only_mask = A.Compose(
    [
        A.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], max_pixel_value=255.0,),
        ToTensorV2(),
    ])

# Discriminator

In [33]:
class Discriminator(nn.Module):
    def __init__(self, in_channels = 3, features = [64, 128, 256, 512]):
        super(Discriminator, self).__init__()
        
        self.initial = nn.Sequential(
            nn.Conv2d(in_channels*2, 
            features[0],
            kernel_size = 4, 
            stride = 2, 
            padding = 1, 
            padding_mode = 'reflect')
        )
        layers = []
        
        in_channel = features[0]
        for out_channel in features[1:]:
            layer = self.block(in_channel, out_channel,stride = 1 if out_channel == features[-1] else 2 )
            in_channel = out_channel 
            layers.append(layer)
            
        layers.append(
            nn.Conv2d(
                in_channel, 1, kernel_size=4, stride=1, padding=1, padding_mode="reflect"
            ),
        )

            
        self.model = nn.Sequential(*layers)
            
            
        
    def block(self, in_channel, out_channel, stride = 1):
        block_layer = nn.Sequential(
            nn.Conv2d(in_channel, out_channel, 4, stride,1, bias = False, padding_mode = 'reflect'),
            nn.BatchNorm2d(out_channel),
            nn.LeakyReLU(0.2)
        )
        return block_layer 
    def forward(self, x,y):
        x = torch.cat([x, y] , dim = 1)
        x = self.initial(x)
        x = self.model(x)
        return x 
    
        

In [34]:
a = torch.ones([1,3,256,256])
b = torch.ones([1,3, 256, 256])

In [35]:
Discriminator()(a,b).shape

torch.Size([1, 1, 30, 30])

# Generator

In [36]:
class Block(nn.Module):
    def __init__(self, in_channels,out_channels, down = True, act = 'relu', use_dropout = False):
        super(Block, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 4, 2, 1, bias = False, padding_mode = 'reflect') 
            if down else nn.ConvTranspose2d(in_channels, out_channels, 4, 2, 1, bias = False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU() if act == 'relu' else nn.LeakyReLU(0.2),
        )
        
        self.dropout = nn.Dropout(0.5)
        self.use_dropout = use_dropout 
        self.down = down 
        
    def forward(self, x): 
        x = self.conv(x)
        return self.dropout(x) if self.use_dropout else x 

In [37]:
x = torch.ones(1,3, 60, 60)

In [38]:
class Generator(nn.Module):
    def __init__(self,in_channels=3, features = 64):
        super().__init__()
        self.initial_down = nn.Sequential(
            nn.Conv2d(in_channels, features, 4, 2, 1, padding_mode = 'reflect'),
            nn.LeakyReLU(0.2)
        ) # channel 64, 
        
        self.down1 = Block(features, features*2, down = True, act = 'leaky', use_dropout = False) # channel = 128
        self.down2 = Block(features*2, features*4, down = True, act = 'leaky', use_dropout = False) # channel = 256
        self.down3 = Block(features*4, features*8, down = True, act = 'leaky', use_dropout = False) # channel = 512
        self.down4 = Block(features*8, features*8, down = True, act = 'leaky', use_dropout = False) # channel = 512
        self.down5 = Block(features*8, features*8, down = True, act = 'leaky', use_dropout = False) # channel = 512
        self.down6 = Block(features*8, features*8, down = True, act = 'leaky', use_dropout = False) # channel = 512 
        self.bottleneck = nn.Sequential(
            nn.Conv2d(features*8, features*8, 4, 2,1), nn.ReLU()
        )                                                                    # channel 512 
        
        self.up1 = Block(features*8, features*8, down = False, act = 'relu', use_dropout = True) # channel = 512
        self.up2 = Block(features*8*2, features*8, down = False, act = 'relu', use_dropout = True) # channel = 512
        self.up3 = Block(features*8*2, features*8, down = False, act = 'relu', use_dropout = True) # channel = 512 
        self.up4 = Block(features*8*2, features*8, down = False, act = 'relu', use_dropout = True) # channel 512 
        self.up5 = Block(features*8*2, features*4, down = False, act = 'relu', use_dropout = True) # channel = 256 
        self.up6 = Block(features*4*2, features*2, down = False, act = 'relu', use_dropout = True) # channel = 128
        self.up7 = Block(features*2*2, features, down = False, act = 'relu', use_dropout = True) # channel = 64
        
        self.final_up = nn.Sequential(
            nn.ConvTranspose2d(features*2, in_channels, kernel_size = 4, stride = 2, padding = 1), 
            nn.Tanh(),
        )
    def forward(self,x):
        d1 = self.initial_down(x)
        d2 = self.down1(d1)
        d3 = self.down2(d2)
        d4 = self.down3(d3)
        d5 = self.down4(d4)
        d6 = self.down5(d5)
        # print(d6.shape)
        d7 = self.down6(d6)
        # print(d7.shape)
        bottleneck = self.bottleneck(d7)
        # print(bottleneck.shape)
        up1 = self.up1(bottleneck)
        #print(up1.shape)
        
        up2 = self.up2(torch.cat([up1, d7], dim = 1))
        #print(up2.shape)
        up3 = self.up3(torch.cat([up2, d6], dim = 1))
        up4 = self.up4(torch.cat([up3, d5], dim = 1))
        up5 = self.up5(torch.cat([up4, d4], dim = 1))
        up6 = self.up6(torch.cat([up5, d3], dim = 1))
        up7 = self.up7(torch.cat([up6, d2], dim = 1))
        
        
        return self.final_up(torch.cat([up7, d1], dim = 1)) 

In [39]:
z = torch.ones(1, 3, 256, 256)

In [40]:
Generator()(z).shape

torch.Size([1, 3, 256, 256])

# DATASET Generation 

In [41]:
class MapDataset(Dataset):
    def __init__(self,root_dir):
        super(MapDataset, self).__init__()
        self.root_dir = root_dir 
        self.list_files = os.listdir(self.root_dir)
    def __len__(self):
        return len(self.list_files)
    
    def __getitem__(self, idex):
        img_file = self.list_files[idex]
        img_path = os.path.join(self.root_dir, img_file)
        image = np.array(Image.open(img_path))
        input_image = image[:, :600,:]
        target_image = image[:, 600:, :]
        
        augmentations = both_transform(image=input_image, image0=target_image)
        input_image = augmentations["image"]
        target_image = augmentations["image0"]

        input_image = transform_only_input(image=input_image)["image"]
        target_image = transform_only_mask(image=target_image)["image"]

        return input_image, target_image

# Training Phase

# Config the model and optiminzers 

In [48]:
def save_some_examples(gen, val_loader, epoch, folder):
    x, y = next(iter(val_loader))
    x, y = x.to(DEVICE), y.to(DEVICE)
    gen.eval()
    if not os.path.exists(folder):
        os.mkdir(folder)
    with torch.no_grad():
        y_fake = gen(x)
        y_fake = y_fake * 0.5 + 0.5  # remove normalization#
        save_image(y_fake, folder + f"/y_gen_{epoch}.png")
        save_image(x * 0.5 + 0.5, folder + f"/input_{epoch}.png")
        if epoch == 1:
            save_image(y * 0.5 + 0.5, folder + f"/label_{epoch}.png")
    gen.train()


def save_checkpoint(model, optimizer, filename="my_checkpoint.pth.tar"):
    print("=> Saving checkpoint")
    checkpoint = {
        "state_dict": model.state_dict(),
        "optimizer": optimizer.state_dict(),
    }
    torch.save(checkpoint, filename)


def load_checkpoint(checkpoint_file, model, optimizer, lr):
    print("=> Loading checkpoint")
    checkpoint = torch.load(checkpoint_file, map_location=DEVICE)
    model.load_state_dict(checkpoint["state_dict"])
    optimizer.load_state_dict(checkpoint["optimizer"])

    # If we don't do this then it will just have learning rate of old checkpoint
    # and it will lead to many hours of debugging \:
    for param_group in optimizer.param_groups:
        param_group["lr"] = lr

In [49]:
device = "cuda" if torch.cuda.is_available() else "cpu"
disc = Discriminator().to(device)
gen = Generator().to(device)
opt_disc = optim.Adam(disc.parameters(), lr = 0.0002, betas = (0.5, 0.999))
opt_gen = optim.Adam(gen.parameters(), lr = 0.0002, betas = (0.5, 0.999))

BCE = nn.BCEWithLogitsLoss()
L1_LOSS = nn.L1Loss()

In [50]:
train_dataset = MapDataset(root_dir = TRAIN_DIR)
train_loader = DataLoader(
    train_dataset, 
    batch_size = BATCH_SIZE,
    shuffle = True,
    num_workers = NUM_WORKERS
)


val_dataset = MapDataset(root_dir=VAL_DIR)
val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False)

g_scaler = torch.cuda.amp.GradScaler()
d_scaler = torch.cuda.amp.GradScaler()

In [51]:
def train_fn(disc, gen, loader, opt_disc, opt_gen, l1_loss, bce, g_scaler, d_scaler):
    pbar = tqdm.tqdm(loader, leave = True)
    for idx,(x,y) in enumerate(pbar):
        x = x.to(device) # input image type 
        y = y.to(device) # target image type 

        # train discriminator 
        with torch.cuda.amp.autocast():
            y_fake = gen(x) # fake target generation
            D_real = disc(x,y) # disc pred with actual image 
            D_real_loss = bce(D_real, torch.ones_like(D_real))

            D_fake = disc(x, y_fake.detach())
            D_fake_loss = bce(D_fake, torch.zeros_like(D_fake))

            D_loss = (D_fake_loss + D_real_loss)/2 

        opt_disc.zero_grad()
        # D_loss.backward()
        # opt_disc.step()
        d_scaler.scale(D_loss).backward()
        d_scaler.step(opt_disc)
        d_scaler.update()

        # train generator 
        with torch.cuda.amp.autocast():
            D_fake = disc(x, y_fake)
            G_fake_loss = bce(D_fake, torch.ones_like(D_fake))
            L1 = l1_loss(y_fake,y)*L1_LAMBDA
            G_loss = G_fake_loss + L1

        opt_gen.zero_grad()
        g_scaler.scale(G_loss).backward()
        g_scaler.step(opt_gen)
        g_scaler.update()

        if idx % 10 == 0:
            pbar.set_postfix(
                D_real = torch.sigmoid(D_real).mean().item(),
                D_fake = torch.sigmoid(D_fake).mean().item(),
            )
        

In [52]:
SAVE_MODE = False 

for epoch in range(NUM_EPOCHS):
    train_fn(disc, gen, train_loader, opt_disc, opt_gen, L1_LOSS, BCE, g_scaler, d_scaler)
    if SAVE_MODE and epoch %5 ==0: 
        save_checkpoint(gen, opt_gen, filename = CHECKPOINT_GEN)
        save_checkpoint(disc, opt_disc, filename = CHECKPOINT_DISC)
            
    save_some_examples(gen, val_loader, epoch, folder = 'evaluation')    

100%|██████████████████████████████████████████████████████| 69/69 [00:21<00:00,  3.19it/s, D_fake=0.223, D_real=0.688]
100%|██████████████████████████████████████████████████████| 69/69 [00:21<00:00,  3.25it/s, D_fake=0.107, D_real=0.908]
100%|██████████████████████████████████████████████████████| 69/69 [00:21<00:00,  3.18it/s, D_fake=0.111, D_real=0.902]
100%|██████████████████████████████████████████████████████| 69/69 [00:21<00:00,  3.22it/s, D_fake=0.189, D_real=0.586]
100%|███████████████████████████████████████████████████████| 69/69 [00:21<00:00,  3.18it/s, D_fake=0.208, D_real=0.93]
100%|█████████████████████████████████████████████████████| 69/69 [00:21<00:00,  3.22it/s, D_fake=0.0742, D_real=0.937]
100%|█████████████████████████████████████████████████████| 69/69 [00:20<00:00,  3.31it/s, D_fake=0.0746, D_real=0.816]
100%|█████████████████████████████████████████████████████| 69/69 [00:21<00:00,  3.28it/s, D_fake=0.0505, D_real=0.981]
100%|███████████████████████████████████

100%|█████████████████████████████████████████████████████| 69/69 [00:20<00:00,  3.45it/s, D_fake=0.0204, D_real=0.924]
100%|██████████████████████████████████████████████████████| 69/69 [00:20<00:00,  3.45it/s, D_fake=0.0181, D_real=0.99]
100%|████████████████████████████████████████████████████| 69/69 [00:20<00:00,  3.44it/s, D_fake=0.00642, D_real=0.975]
100%|████████████████████████████████████████████████████| 69/69 [00:20<00:00,  3.43it/s, D_fake=0.00546, D_real=0.969]
100%|██████████████████████████████████████████████████████| 69/69 [00:19<00:00,  3.46it/s, D_fake=0.0164, D_real=0.99]
100%|█████████████████████████████████████████████████████| 69/69 [00:19<00:00,  3.46it/s, D_fake=0.0166, D_real=0.995]
100%|█████████████████████████████████████████████████████| 69/69 [00:20<00:00,  3.45it/s, D_fake=0.0655, D_real=0.878]
100%|█████████████████████████████████████████████████████| 69/69 [00:20<00:00,  3.44it/s, D_fake=0.0284, D_real=0.821]
100%|███████████████████████████████████

100%|████████████████████████████████████████████████████| 69/69 [00:20<00:00,  3.44it/s, D_fake=0.00724, D_real=0.977]
100%|█████████████████████████████████████████████████████| 69/69 [00:20<00:00,  3.45it/s, D_fake=0.0231, D_real=0.995]
100%|█████████████████████████████████████████████████████| 69/69 [00:20<00:00,  3.44it/s, D_fake=0.0132, D_real=0.973]
100%|█████████████████████████████████████████████████████| 69/69 [00:19<00:00,  3.45it/s, D_fake=0.00571, D_real=0.97]
100%|████████████████████████████████████████████████████| 69/69 [00:19<00:00,  3.45it/s, D_fake=0.00734, D_real=0.995]
100%|████████████████████████████████████████████████████| 69/69 [00:20<00:00,  3.42it/s, D_fake=0.00881, D_real=0.984]
100%|█████████████████████████████████████████████████████| 69/69 [00:20<00:00,  3.44it/s, D_fake=0.0071, D_real=0.992]
100%|█████████████████████████████████████████████████████| 69/69 [00:20<00:00,  3.45it/s, D_fake=0.0106, D_real=0.971]
100%|███████████████████████████████████

100%|████████████████████████████████████████████████████| 69/69 [00:20<00:00,  3.45it/s, D_fake=0.00377, D_real=0.997]
100%|█████████████████████████████████████████████████████| 69/69 [00:20<00:00,  3.44it/s, D_fake=0.0119, D_real=0.998]
100%|█████████████████████████████████████████████████████| 69/69 [00:20<00:00,  3.44it/s, D_fake=0.0103, D_real=0.997]
100%|████████████████████████████████████████████████████| 69/69 [00:20<00:00,  3.43it/s, D_fake=0.00839, D_real=0.998]
100%|█████████████████████████████████████████████████████| 69/69 [00:19<00:00,  3.45it/s, D_fake=0.0011, D_real=0.992]
100%|██████████████████████████████████████████████████████| 69/69 [00:20<00:00,  3.45it/s, D_fake=0.014, D_real=0.995]
100%|████████████████████████████████████████████████████| 69/69 [00:19<00:00,  3.45it/s, D_fake=0.00769, D_real=0.999]
100%|██████████████████████████████████████████████████████| 69/69 [00:20<00:00,  3.45it/s, D_fake=0.002, D_real=0.988]
100%|███████████████████████████████████

100%|████████████████████████████████████████████████████| 69/69 [00:19<00:00,  3.45it/s, D_fake=0.00845, D_real=0.769]
100%|████████████████████████████████████████████████████| 69/69 [00:20<00:00,  3.44it/s, D_fake=0.00571, D_real=0.999]
100%|█████████████████████████████████████████████████████| 69/69 [00:20<00:00,  3.45it/s, D_fake=0.0104, D_real=0.999]
100%|████████████████████████████████████████████████████| 69/69 [00:20<00:00,  3.45it/s, D_fake=0.00452, D_real=0.991]
100%|█████████████████████████████████████████████████████| 69/69 [00:20<00:00,  3.44it/s, D_fake=0.0248, D_real=0.977]
100%|█████████████████████████████████████████████████████| 69/69 [00:20<00:00,  3.45it/s, D_fake=0.0199, D_real=0.999]
100%|█████████████████████████████████████████████████████| 69/69 [00:19<00:00,  3.45it/s, D_fake=0.0084, D_real=0.995]
100%|████████████████████████████████████████████████████| 69/69 [00:19<00:00,  3.45it/s, D_fake=0.00169, D_real=0.998]
100%|███████████████████████████████████

100%|████████████████████████████████████████████████████████| 69/69 [00:20<00:00,  3.44it/s, D_fake=0.00378, D_real=1]
100%|████████████████████████████████████████████████████████| 69/69 [00:19<00:00,  3.46it/s, D_fake=0.00278, D_real=1]
100%|████████████████████████████████████████████████████████| 69/69 [00:20<00:00,  3.44it/s, D_fake=0.00148, D_real=1]
100%|████████████████████████████████████████████████████| 69/69 [00:20<00:00,  3.45it/s, D_fake=0.00831, D_real=0.999]
100%|████████████████████████████████████████████████████████| 69/69 [00:19<00:00,  3.46it/s, D_fake=0.00443, D_real=1]
100%|████████████████████████████████████████████████████| 69/69 [00:20<00:00,  3.45it/s, D_fake=0.00591, D_real=0.992]
100%|█████████████████████████████████████████████████████| 69/69 [00:20<00:00,  3.45it/s, D_fake=0.0101, D_real=0.988]
100%|████████████████████████████████████████████████████| 69/69 [00:20<00:00,  3.45it/s, D_fake=0.00578, D_real=0.994]
100%|███████████████████████████████████

100%|████████████████████████████████████████████████████████| 69/69 [00:20<00:00,  3.44it/s, D_fake=0.00305, D_real=1]
100%|███████████████████████████████████████████████████| 69/69 [00:19<00:00,  3.46it/s, D_fake=0.000616, D_real=0.998]
100%|████████████████████████████████████████████████████| 69/69 [00:20<00:00,  3.44it/s, D_fake=0.00168, D_real=0.999]
100%|███████████████████████████████████████████████████| 69/69 [00:20<00:00,  3.43it/s, D_fake=0.000237, D_real=0.999]
100%|████████████████████████████████████████████████████| 69/69 [00:20<00:00,  3.44it/s, D_fake=0.00192, D_real=0.997]
100%|████████████████████████████████████████████████████| 69/69 [00:20<00:00,  3.44it/s, D_fake=0.00434, D_real=0.969]
100%|████████████████████████████████████████████████████| 69/69 [00:20<00:00,  3.44it/s, D_fake=0.00797, D_real=0.999]
100%|████████████████████████████████████████████████████| 69/69 [00:20<00:00,  3.45it/s, D_fake=0.00484, D_real=0.999]
100%|███████████████████████████████████

100%|████████████████████████████████████████████████████| 69/69 [00:20<00:00,  3.44it/s, D_fake=0.00336, D_real=0.997]
100%|███████████████████████████████████████████████████| 69/69 [00:20<00:00,  3.45it/s, D_fake=0.000868, D_real=0.994]
100%|██████████████████████████████████████████████████████| 69/69 [00:20<00:00,  3.45it/s, D_fake=0.178, D_real=0.827]
100%|█████████████████████████████████████████████████████| 69/69 [00:19<00:00,  3.46it/s, D_fake=0.0029, D_real=0.999]
100%|████████████████████████████████████████████████████████| 69/69 [00:20<00:00,  3.44it/s, D_fake=0.00635, D_real=1]
100%|████████████████████████████████████████████████████| 69/69 [00:20<00:00,  3.44it/s, D_fake=0.00246, D_real=0.995]
100%|████████████████████████████████████████████████████████| 69/69 [00:20<00:00,  3.44it/s, D_fake=0.00155, D_real=1]
100%|████████████████████████████████████████████████████████| 69/69 [00:20<00:00,  3.44it/s, D_fake=0.00192, D_real=1]
100%|███████████████████████████████████