In [1]:
import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)
import os

In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision.utils import save_image
from torch.utils.data import Dataset, DataLoader

import albumentations as A
from albumentations.pytorch import ToTensorV2


from PIL import Image
from tqdm import tqdm

In [3]:
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

TRAIN_DIR = "../input/anime-sketch-colorization-pair/data/train"
TEST_DIR = "../input/anime-sketch-colorization-pair/data/val"
LR = 2e-4
BATCH_SIZE = 32
NUM_WORKERS = 2
IMG_SIZE = 256
IMG_CHANNELS = 3
L1_LAMBDA = 100
LAMBDA_GP = 10
EPOCHS = 100
LOAD_MODEL = False
SAVE_MODEL = True
DISC_CHK = "./disc.pth.tar"
GEN_CHK = "./gen.pth.tar"

both_transforms = A.Compose(
    [
        A.Resize(width=256,height=256)
    ],
    additional_targets={"image0":"image"}
)

transform_input = A.Compose(
    [
        A.HorizontalFlip(p=0.5),
        A.ColorJitter(p=0.2),
        A.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225],max_pixel_value=255.0),
        ToTensorV2()
    ]
)

transform_mask = A.Compose(
    [
        A.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225],max_pixel_value=255.0),
        ToTensorV2()
    ]
)

In [4]:
class MyDataset(Dataset):
    def __init__(self,root_dir):
        self.root_dir = root_dir
        self.list_files = os.listdir(self.root_dir)

    
    def __len__(self):
        return len(self.list_files)

    def __getitem__(self, index):
        img_file = self.list_files[index]
        img_path = os.path.join(self.root_dir,img_file)
        image = np.array(Image.open(img_path))
        input_image = image[:,511:,:]
        target_image = image[:,:511,:]


        augmentations = both_transforms(image=input_image,image0=target_image)
        input_image = augmentations["image"]
        target_image = augmentations["image0"]

        input_image = transform_input(image=input_image)["image"]
        target_image = transform_mask(image=target_image)["image"]

        return input_image,target_image

In [5]:
def save_examples(gen,val_loader,epoch,folder):
    x,y = next(iter(val_loader))
    x,y = x.to(DEVICE),y.to(DEVICE)
    gen.eval()

    with torch.no_grad():
        y_fake = gen(x)
        y_fake = y_fake*0.5 + 0.5
        save_image(y_fake, folder+f"/y_gen_{epoch}.png")
        save_image(x*0.5+0.5, folder+f"/input_{epoch}.png")
        if epoch == 1:
            save_image(y*0.5+0.5, folder+f"/label_{epoch}.png")
    gen.train()

def save_checkpoint(model,optimizer, filename="./my_checkpoint.pth.tar"):
    print("--> Saving checkpoint")
    checkpoint = {
        "state_dict": model.state_dict(),
        "optimizer": optimizer.state_dict()
    }
    torch.save(checkpoint,filename)

def load_checkpoint(checkpoint_file,model,optimizer,lr):
    print("--> Loading checkpoint")
    checkpoint = torch.load(checkpoint_file,map_location=DEVICE)
    model.load_state_dict(checkpoint["state_dict"])
    optimizer.load_state_dict(checkpoint["optimizer"])

    for param_group in optimizer.param_groups:
        param_group["lr"] = lr


In [6]:
class ConvBlock(nn.Module):
    def __init__(self,in_features,out_features, use_dropout=False, isEncoder=True):
        super(ConvBlock,self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_features, out_features, 4, 2, 1, bias=False, padding_mode='reflect')
            if isEncoder 
            else nn.ConvTranspose2d(in_features, out_features, 4, 2, 1, bias=False),

            nn.BatchNorm2d(out_features),
            nn.LeakyReLU(0.2) if isEncoder else nn.ReLU(),
        )
        self.use_dropout = use_dropout
        self.dropout = nn.Dropout(0.5)
        self.isEncoder = isEncoder

    def forward(self,x):
        x = self.conv(x)
        return self.dropout(x) if self.use_dropout else x

class Generator(nn.Module):
    def __init__(self):
        super().__init__()
        self.e1 = nn.Sequential(
            nn.Conv2d(3, 64, 4, 2, 1,padding_mode='reflect'),
            nn.LeakyReLU(0.2),
        )
        self.e2 = ConvBlock(64, 128, isEncoder=True)
        self.e3 = ConvBlock(128, 256, isEncoder=True)
        self.e4 = ConvBlock(256, 512, isEncoder=True)
        self.e5 = ConvBlock(512, 512, isEncoder=True)
        self.e6 = ConvBlock(512, 512, isEncoder=True)
        self.e7 = ConvBlock(512, 512, isEncoder=True)
        
        self.bottleneck = nn.Sequential(
            nn.Conv2d(512, 512, 4,2,1),
            nn.ReLU(),
        )

        self.d1 = ConvBlock(512, 512, isEncoder=False, use_dropout=True)
        self.d2 = ConvBlock(1024, 512, isEncoder=False, use_dropout=True)
        self.d3 = ConvBlock(1024, 512, isEncoder=False, use_dropout=True)
        self.d4 = ConvBlock(1024, 512, isEncoder=False)
        self.d5 = ConvBlock(1024, 256, isEncoder=False)
        self.d6 = ConvBlock(512, 128, isEncoder=False)
        self.d7 = ConvBlock(256, 64, isEncoder=False)
        self.d8 = nn.Sequential(
            nn.ConvTranspose2d(128, 3, 4, 2, 1),
            nn.Tanh(),
        )

    def forward(self,x):
        down1 = self.e1(x)
        down2 = self.e2(down1)
        down3 = self.e3(down2)
        down4 = self.e4(down3)
        down5 = self.e5(down4)
        down6 = self.e6(down5)
        down7 = self.e7(down6)
        
        bottleneck = self.bottleneck(down7)
        
        up1 = self.d1(bottleneck)
        up2 = self.d2(torch.cat([up1, down7], 1))
        up3 = self.d3(torch.cat([up2, down6], 1))
        up4 = self.d4(torch.cat([up3, down5], 1))
        up5 = self.d5(torch.cat([up4, down4], 1))
        up6 = self.d6(torch.cat([up5, down3], 1))
        up7 = self.d7(torch.cat([up6, down2], 1))
        
        return self.d8(torch.cat([up7, down1], 1))


class Block(nn.Module):
    def __init__(self,in_features,out_features,stride):
        super(Block,self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_features, out_features, 4, stride, 1,bias=False,padding_mode="reflect"),
            nn.BatchNorm2d(out_features),
            nn.LeakyReLU(0.2)
        )

    def forward(self,x):
        return self.conv(x)

class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()

        self.initial = nn.Sequential(
            nn.Conv2d(6, 64, 4, 2, 1, padding_mode="reflect"),
            nn.LeakyReLU(0.2)
        )
        block1 = Block(64, 128, stride=2)
        block2 = Block(128, 256, stride=2)
        block3 = Block(256, 512, stride=1)
        block4 = nn.Conv2d(512, 1, 4,stride=1,padding=1,padding_mode="reflect")

        self.model = nn.Sequential(
            block1,
            block2,
            block3,
            block4
        )

    def forward(self,x,y):
        x = torch.cat([x,y],1)
        x = self.initial(x)
        x = self.model(x)
        return x

In [7]:
os.makedirs('./validation')

In [8]:
def train(disc,gen,loader,opt_disc,opt_gen,l1_loss,bce,g_scaler,d_scaler):
    loop = tqdm(loader,leave=True)

    for idx,(x,y) in enumerate(loop):
        x = x.to(DEVICE)
        y = y.to(DEVICE)

        with torch.cuda.amp.autocast():
            y_fake = gen(x)
            D_real = disc(x,y)
            D_real_loss = bce(D_real,torch.ones_like(D_real))
            D_fake = disc(x,y_fake.detach())
            D_fake_loss = bce(D_fake,torch.zeros_like(D_fake))
            D_loss = (D_real_loss + D_fake_loss) / 2
        
        disc.zero_grad()
        d_scaler.scale(D_loss).backward()
        d_scaler.step(opt_disc)
        d_scaler.update()

        with torch.cuda.amp.autocast():
            D_fake = disc(x, y_fake)
            G_fake_loss = bce(D_fake,torch.ones_like(D_fake))
            L1 = l1_loss(y_fake,y) * L1_LAMBDA
            G_loss = G_fake_loss + L1

        opt_gen.zero_grad()
        g_scaler.scale(G_loss).backward()
        g_scaler.step(opt_gen)
        g_scaler.update()

        if idx %10 == 0:
            loop.set_postfix(
                D_real = torch.sigmoid(D_real).mean().item(),
                D_fake = torch.sigmoid(D_fake).mean().item(),
            )

def main():
    disc = Discriminator().to(DEVICE)
    gen = Generator().to(DEVICE)
    opt_disc = optim.Adam(disc.parameters(),lr = LR,betas=(0.5,0.999))
    opt_gen = optim.Adam(gen.parameters(),lr = LR,betas=(0.5,0.999))
    BCE = nn.BCEWithLogitsLoss()
    L1_LOSS = nn.L1Loss()


    if LOAD_MODEL:
        load_checkpoint(GEN_CHK, gen, opt_gen, LR)
        load_checkpoint(DISC_CHK, disc, opt_disc, LR)

    train_dataset = MyDataset(TRAIN_DIR)
    train_loader = DataLoader(
        train_dataset,
        batch_size = BATCH_SIZE,
        shuffle = True,
        num_workers = NUM_WORKERS
    )

    g_scaler = torch.cuda.amp.GradScaler()
    d_scaler = torch.cuda.amp.GradScaler()

    val_dataset = MyDataset(TEST_DIR)
    val_loader = DataLoader(
        val_dataset,
        batch_size=1,
        shuffle=False
    )

    for epoch in range(EPOCHS):
        train(disc, gen, train_loader, opt_disc, opt_gen, L1_LOSS, BCE, g_scaler, d_scaler)

        if SAVE_MODEL and epoch % 5 == 0:
            save_checkpoint(gen, opt_gen,filename=GEN_CHK)
            save_checkpoint(disc, opt_disc,filename=DISC_CHK)
        if epoch%10 == 0:
            save_examples(gen, val_loader, epoch, folder="./validation")


if __name__=="__main__":
    main()

100%|██████████| 445/445 [04:37<00:00,  1.60it/s, D_fake=0.00406, D_real=0.997]


--> Saving checkpoint
--> Saving checkpoint


100%|██████████| 445/445 [04:27<00:00,  1.66it/s, D_fake=0.00486, D_real=0.997]
100%|██████████| 445/445 [04:25<00:00,  1.68it/s, D_fake=0.00505, D_real=0.993]
100%|██████████| 445/445 [04:27<00:00,  1.66it/s, D_fake=0.0018, D_real=0.989]
100%|██████████| 445/445 [04:25<00:00,  1.67it/s, D_fake=0.000823, D_real=0.997]
100%|██████████| 445/445 [04:24<00:00,  1.68it/s, D_fake=0.00082, D_real=0.999]


--> Saving checkpoint
--> Saving checkpoint


100%|██████████| 445/445 [04:23<00:00,  1.69it/s, D_fake=0.00684, D_real=0.998]
100%|██████████| 445/445 [04:25<00:00,  1.67it/s, D_fake=0.000936, D_real=0.999]
100%|██████████| 445/445 [04:26<00:00,  1.67it/s, D_fake=0.000286, D_real=1]
100%|██████████| 445/445 [04:26<00:00,  1.67it/s, D_fake=0.000862, D_real=0.999]
100%|██████████| 445/445 [04:26<00:00,  1.67it/s, D_fake=0.000313, D_real=1]


--> Saving checkpoint
--> Saving checkpoint


100%|██████████| 445/445 [04:24<00:00,  1.68it/s, D_fake=0.000333, D_real=1]
100%|██████████| 445/445 [04:26<00:00,  1.67it/s, D_fake=8.98e-5, D_real=1]
100%|██████████| 445/445 [04:25<00:00,  1.67it/s, D_fake=9.55e-5, D_real=1]
100%|██████████| 445/445 [04:28<00:00,  1.66it/s, D_fake=8.46e-5, D_real=1]
100%|██████████| 445/445 [04:28<00:00,  1.66it/s, D_fake=4.78e-5, D_real=1]


--> Saving checkpoint
--> Saving checkpoint


100%|██████████| 445/445 [04:25<00:00,  1.68it/s, D_fake=0.00138, D_real=0.997]
100%|██████████| 445/445 [04:26<00:00,  1.67it/s, D_fake=0.000404, D_real=0.999]
100%|██████████| 445/445 [04:26<00:00,  1.67it/s, D_fake=0.00212, D_real=0.998]
100%|██████████| 445/445 [04:24<00:00,  1.68it/s, D_fake=0.000461, D_real=1]
100%|██████████| 445/445 [04:23<00:00,  1.69it/s, D_fake=0.000261, D_real=1]


--> Saving checkpoint
--> Saving checkpoint


100%|██████████| 445/445 [04:27<00:00,  1.66it/s, D_fake=0.000255, D_real=1]
100%|██████████| 445/445 [04:26<00:00,  1.67it/s, D_fake=8.39e-5, D_real=1]
100%|██████████| 445/445 [04:25<00:00,  1.67it/s, D_fake=4.73e-5, D_real=1]
100%|██████████| 445/445 [04:26<00:00,  1.67it/s, D_fake=8.42e-5, D_real=1]
100%|██████████| 445/445 [04:25<00:00,  1.68it/s, D_fake=4.09e-5, D_real=1]


--> Saving checkpoint
--> Saving checkpoint


100%|██████████| 445/445 [04:25<00:00,  1.68it/s, D_fake=4.21e-5, D_real=1]
100%|██████████| 445/445 [04:26<00:00,  1.67it/s, D_fake=2.65e-5, D_real=1]
100%|██████████| 445/445 [04:27<00:00,  1.66it/s, D_fake=1.91e-5, D_real=1]
100%|██████████| 445/445 [04:27<00:00,  1.67it/s, D_fake=1.45e-5, D_real=1]
100%|██████████| 445/445 [04:27<00:00,  1.67it/s, D_fake=1.04e-5, D_real=1]


--> Saving checkpoint
--> Saving checkpoint


100%|██████████| 445/445 [04:25<00:00,  1.68it/s, D_fake=1.08e-5, D_real=1]
100%|██████████| 445/445 [04:27<00:00,  1.66it/s, D_fake=5.9e-6, D_real=1]
100%|██████████| 445/445 [04:27<00:00,  1.67it/s, D_fake=5.6e-6, D_real=1]
100%|██████████| 445/445 [04:29<00:00,  1.65it/s, D_fake=4.83e-6, D_real=1]
100%|██████████| 445/445 [04:29<00:00,  1.65it/s, D_fake=0.000935, D_real=0.997]


--> Saving checkpoint
--> Saving checkpoint


100%|██████████| 445/445 [04:28<00:00,  1.66it/s, D_fake=0.000348, D_real=1]
100%|██████████| 445/445 [04:31<00:00,  1.64it/s, D_fake=0.000377, D_real=1]
100%|██████████| 445/445 [04:30<00:00,  1.64it/s, D_fake=0.00018, D_real=1]
100%|██████████| 445/445 [04:33<00:00,  1.63it/s, D_fake=0.000157, D_real=1]
100%|██████████| 445/445 [04:32<00:00,  1.64it/s, D_fake=3.48e-5, D_real=1]


--> Saving checkpoint
--> Saving checkpoint


100%|██████████| 445/445 [04:32<00:00,  1.63it/s, D_fake=0.000111, D_real=1]
100%|██████████| 445/445 [04:33<00:00,  1.63it/s, D_fake=0.000896, D_real=1]
100%|██████████| 445/445 [04:33<00:00,  1.63it/s, D_fake=0.000201, D_real=1]
100%|██████████| 445/445 [04:30<00:00,  1.65it/s, D_fake=4.65e-5, D_real=1]
100%|██████████| 445/445 [04:33<00:00,  1.63it/s, D_fake=4.63e-5, D_real=1]


--> Saving checkpoint
--> Saving checkpoint


100%|██████████| 445/445 [04:33<00:00,  1.63it/s, D_fake=4.64e-5, D_real=1]
100%|██████████| 445/445 [04:33<00:00,  1.63it/s, D_fake=3.61e-5, D_real=1]
100%|██████████| 445/445 [04:35<00:00,  1.61it/s, D_fake=4.56e-5, D_real=1]
100%|██████████| 445/445 [04:36<00:00,  1.61it/s, D_fake=2.22e-5, D_real=1]
100%|██████████| 445/445 [04:36<00:00,  1.61it/s, D_fake=2.7e-5, D_real=1]


--> Saving checkpoint
--> Saving checkpoint


100%|██████████| 445/445 [04:36<00:00,  1.61it/s, D_fake=9.24e-6, D_real=1]
100%|██████████| 445/445 [04:32<00:00,  1.63it/s, D_fake=1.34e-5, D_real=1]
100%|██████████| 445/445 [04:35<00:00,  1.61it/s, D_fake=1.19e-5, D_real=1]
100%|██████████| 445/445 [04:35<00:00,  1.61it/s, D_fake=0.000757, D_real=1]
100%|██████████| 445/445 [04:41<00:00,  1.58it/s, D_fake=0.000245, D_real=1]


--> Saving checkpoint
--> Saving checkpoint


100%|██████████| 445/445 [04:47<00:00,  1.55it/s, D_fake=0.000178, D_real=1]
100%|██████████| 445/445 [04:39<00:00,  1.59it/s, D_fake=0.000175, D_real=1]
100%|██████████| 445/445 [04:36<00:00,  1.61it/s, D_fake=4.14e-5, D_real=1]
100%|██████████| 445/445 [04:35<00:00,  1.62it/s, D_fake=5.75e-5, D_real=1]
100%|██████████| 445/445 [04:32<00:00,  1.63it/s, D_fake=6.58e-5, D_real=1]


--> Saving checkpoint
--> Saving checkpoint


100%|██████████| 445/445 [04:32<00:00,  1.63it/s, D_fake=4.76e-5, D_real=1]
100%|██████████| 445/445 [04:31<00:00,  1.64it/s, D_fake=1.74e-5, D_real=1]
100%|██████████| 445/445 [04:33<00:00,  1.63it/s, D_fake=1.11e-5, D_real=1]
100%|██████████| 445/445 [04:33<00:00,  1.63it/s, D_fake=0.000278, D_real=1]
100%|██████████| 445/445 [04:29<00:00,  1.65it/s, D_fake=9.41e-5, D_real=1]


--> Saving checkpoint
--> Saving checkpoint


100%|██████████| 445/445 [04:29<00:00,  1.65it/s, D_fake=0.000159, D_real=1]
100%|██████████| 445/445 [04:26<00:00,  1.67it/s, D_fake=0.000137, D_real=1]
100%|██████████| 445/445 [04:26<00:00,  1.67it/s, D_fake=4.57e-5, D_real=1]
100%|██████████| 445/445 [04:25<00:00,  1.67it/s, D_fake=3.4e-5, D_real=1]
100%|██████████| 445/445 [04:30<00:00,  1.64it/s, D_fake=2.32e-5, D_real=1]


--> Saving checkpoint
--> Saving checkpoint


100%|██████████| 445/445 [04:35<00:00,  1.61it/s, D_fake=2.09e-5, D_real=1]
100%|██████████| 445/445 [04:35<00:00,  1.61it/s, D_fake=9.18e-6, D_real=1]
100%|██████████| 445/445 [04:29<00:00,  1.65it/s, D_fake=9.18e-6, D_real=1]
100%|██████████| 445/445 [04:29<00:00,  1.65it/s, D_fake=1.16e-5, D_real=1]
100%|██████████| 445/445 [04:27<00:00,  1.66it/s, D_fake=1.58e-5, D_real=1]


--> Saving checkpoint
--> Saving checkpoint


100%|██████████| 445/445 [04:27<00:00,  1.67it/s, D_fake=6.97e-6, D_real=1]
100%|██████████| 445/445 [04:26<00:00,  1.67it/s, D_fake=6.79e-6, D_real=1]
100%|██████████| 445/445 [04:26<00:00,  1.67it/s, D_fake=3.93e-6, D_real=1]
100%|██████████| 445/445 [04:26<00:00,  1.67it/s, D_fake=2.44e-6, D_real=1]
100%|██████████| 445/445 [04:27<00:00,  1.67it/s, D_fake=5.13e-6, D_real=1]


--> Saving checkpoint
--> Saving checkpoint


100%|██████████| 445/445 [04:28<00:00,  1.66it/s, D_fake=2.74e-6, D_real=1]
100%|██████████| 445/445 [04:35<00:00,  1.61it/s, D_fake=0.00179, D_real=0.998]
100%|██████████| 445/445 [04:35<00:00,  1.61it/s, D_fake=0.000664, D_real=1]
100%|██████████| 445/445 [04:28<00:00,  1.66it/s, D_fake=0.000199, D_real=1]
100%|██████████| 445/445 [04:26<00:00,  1.67it/s, D_fake=0.000105, D_real=1]


--> Saving checkpoint
--> Saving checkpoint


100%|██████████| 445/445 [04:26<00:00,  1.67it/s, D_fake=8.75e-5, D_real=1]
100%|██████████| 445/445 [04:25<00:00,  1.67it/s, D_fake=0.000121, D_real=1]
100%|██████████| 445/445 [04:24<00:00,  1.68it/s, D_fake=5.04e-5, D_real=1]
100%|██████████| 445/445 [04:25<00:00,  1.68it/s, D_fake=2.1e-5, D_real=1]
100%|██████████| 445/445 [04:26<00:00,  1.67it/s, D_fake=3.41e-5, D_real=1]


--> Saving checkpoint
--> Saving checkpoint


100%|██████████| 445/445 [04:27<00:00,  1.66it/s, D_fake=3.65e-5, D_real=1]
100%|██████████| 445/445 [04:27<00:00,  1.66it/s, D_fake=3.42e-5, D_real=1]
100%|██████████| 445/445 [04:26<00:00,  1.67it/s, D_fake=1.16e-5, D_real=1]
100%|██████████| 445/445 [04:37<00:00,  1.60it/s, D_fake=0.00212, D_real=1]
100%|██████████| 445/445 [04:40<00:00,  1.58it/s, D_fake=0.00017, D_real=1]


--> Saving checkpoint
--> Saving checkpoint


100%|██████████| 445/445 [04:29<00:00,  1.65it/s, D_fake=7.48e-5, D_real=1]
100%|██████████| 445/445 [04:28<00:00,  1.66it/s, D_fake=0.000113, D_real=1]
100%|██████████| 445/445 [04:29<00:00,  1.65it/s, D_fake=6.1e-5, D_real=1]
100%|██████████| 445/445 [04:28<00:00,  1.66it/s, D_fake=0.0014, D_real=0.998]
