In [1]:
import torch
import torchvision
import torch.nn as nn
from PIL import Image
import numpy as np
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
import torch.optim as optim
import os
import matplotlib.pyplot as plt
import albumentations as A
from albumentations.pytorch import ToTensorV2
from torchvision.utils import save_image



In [2]:
class CNNBlock(nn.Module):
    def __init__(self,in_channels, out_channels, stride):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels,4,stride,bias=False,padding_mode="reflect"),
            nn.BatchNorm2d(out_channels),
            nn.LeakyReLU(0.2),
        )
    def forward(self,x):
        return self.conv(x)

In [3]:
class Discriminator(nn.Module):
    def __init__(self,in_channels=3, features=[64,128,256,512]):
        super().__init__()
        self.initial = nn.Sequential(
            nn.Conv2d(
                in_channels*2,
                features[0],
                kernel_size=4,
                stride=2,
                padding=1,
                padding_mode="reflect",
            ),
            nn.LeakyReLU(0.2),
        )
        layers = []
        in_channels = features[0]
        for feature in features[1:]:
            layers.append(
                CNNBlock(in_channels,feature,stride=1 if feature == features[-1] else 2),
            )
            in_channels = feature
        layers.append(
            nn.Conv2d(
                in_channels,1, kernel_size=4,stride=1,padding=1,padding_mode="reflect"
            ),
        )
        self.model = nn.Sequential(*layers)
    def forward(self,x,y):
        x = torch.cat([x,y], dim =1)
        x = self.initial(x)
        x = self.model(x)
        return x

In [4]:
class Block(nn.Module):
    def __init__(self, in_channels, out_channels, down=True, act="relu", use_dropout=False):
        super(Block, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 4, 2, 1, bias=False, padding_mode="reflect")
            if down
            else nn.ConvTranspose2d(in_channels, out_channels, 4, 2, 1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU() if act == "relu" else nn.LeakyReLU(0.2),
        )

        self.use_dropout = use_dropout
        self.dropout = nn.Dropout(0.5)
        self.down = down

    def forward(self, x):
        x = self.conv(x)
        return self.dropout(x) if self.use_dropout else x

In [5]:
class Generator(nn.Module):
    def __init__(self, in_channels=3, features=64):
        super().__init__()
        self.initial_down = nn.Sequential(
            nn.Conv2d(in_channels, features, 4, 2, 1, padding_mode="reflect"),
            nn.LeakyReLU(0.2),
        )
        self.down1 = Block(features, features * 2, down=True, act="leaky", use_dropout=False)
        self.down2 = Block(
            features * 2, features * 4, down=True, act="leaky", use_dropout=False
        )
        self.down3 = Block(
            features * 4, features * 8, down=True, act="leaky", use_dropout=False
        )
        self.down4 = Block(
            features * 8, features * 8, down=True, act="leaky", use_dropout=False
        )
        self.down5 = Block(
            features * 8, features * 8, down=True, act="leaky", use_dropout=False
        )
        self.down6 = Block(
            features * 8, features * 8, down=True, act="leaky", use_dropout=False
        )
        self.bottleneck = nn.Sequential(
            nn.Conv2d(features * 8, features * 8, 4, 2, 1), nn.ReLU()
        )

        self.up1 = Block(features * 8, features * 8, down=False, act="relu", use_dropout=True)
        self.up2 = Block(
            features * 8 * 2, features * 8, down=False, act="relu", use_dropout=True
        )
        self.up3 = Block(
            features * 8 * 2, features * 8, down=False, act="relu", use_dropout=True
        )
        self.up4 = Block(
            features * 8 * 2, features * 8, down=False, act="relu", use_dropout=False
        )
        self.up5 = Block(
            features * 8 * 2, features * 4, down=False, act="relu", use_dropout=False
        )
        self.up6 = Block(
            features * 4 * 2, features * 2, down=False, act="relu", use_dropout=False
        )
        self.up7 = Block(features * 2 * 2, features, down=False, act="relu", use_dropout=False)
        self.final_up = nn.Sequential(
            nn.ConvTranspose2d(features * 2, in_channels, kernel_size=4, stride=2, padding=1),
            nn.Tanh(),
        )
    
    def forward(self,x):
        d1 = self.initial_down(x)
        d2 = self.down1(d1)
        d3 = self.down2(d2)
        d4 = self.down3(d3)
        d5 = self.down4(d4)
        d6 = self.down5(d5)
        d7 = self.down6(d6)
        bottleneck = self.bottleneck(d7)
        up1 = self.up1(bottleneck)
        up2 = self.up2(torch.cat([up1,d7],1))
        up3 = self.up3(torch.cat([up2,d6],1))
        up4 = self.up4(torch.cat([up3,d5],1))
        up5 = self.up5(torch.cat([up4,d4],1))
        up6 = self.up6(torch.cat([up5,d3],1))
        up7 = self.up7(torch.cat([up6,d2],1))
        return self.final_up(torch.cat([up7,d1],1))

In [6]:
class Config():
    def __init__(self):
        self.DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
        self.TRAIN_DIR = "/kaggle/input/25-diff-mini/images2"
        self.VAL_DIR = "/kaggle/input/25-diff-mini/images2"
        self.LEARNING_RATE = 2e-4
        self.BATCH_SIZE = 16
        self.NUM_WORKERS =0
        self.IMAGE_SIZE = 256
        self.CHANNELS_IMG = 3
        self.L1_LAMBDA = 100
        self.LAMBDA_GP = 10
        self.NUM_EPOCHS = 36
        self.LOAD_MODEL = True
        self.SAVE_MODEL = True
        

        self.both_transform = A.Compose(
            [A.Resize(width=256, height=256),], additional_targets={"image0": "image"},
        )

        self.transform_only_input = A.Compose(
            [
                A.HorizontalFlip(p=0.5),
                A.ColorJitter(p=0.2),
                A.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], max_pixel_value=255.0,),
                ToTensorV2(),
            ]
        )

        self.transform_only_mask = A.Compose(
            [
                A.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], max_pixel_value=255.0,),
                ToTensorV2(),
            ]
        )
config = Config()

In [7]:
class ImageDataset(Dataset):
    def __init__(self, root_dir):
        self.root_dir = root_dir
        self.list_files = os.listdir(self.root_dir)

    def __len__(self):
        return len(self.list_files)

    def __getitem__(self, index):
        img_file = self.list_files[index]
        img_path = os.path.join(self.root_dir, img_file)
        image = np.array(Image.open(img_path))
        input_image = image[:, :200, :]
        target_image = image[:, 200:, :]

        augmentations = config.both_transform(image=input_image, image0=target_image)
        input_image = augmentations["image"]
        target_image = augmentations["image0"]

        input_image = config.transform_only_input(image=input_image)["image"]
        target_image = config.transform_only_mask(image=target_image)["image"]

        return input_image, target_image

In [8]:
dataset = ImageDataset("/kaggle/input/25-diff-mini/images2")
loader = DataLoader(dataset,batch_size=5)

In [9]:
def save_some_examples(gen, val_loader, epoch, folder):
    x, y = next(iter(val_loader))
    x, y = x.to(config.DEVICE), y.to(config.DEVICE)
    gen.eval()
    with torch.no_grad():
        y_fake = gen(x)
        y_fake = y_fake * 0.5 + 0.5  # remove normalization#
        save_image(y_fake, folder + f"/y_gen_{epoch}.png")
        save_image(x * 0.5 + 0.5, folder + f"/input_{epoch}.png")
        if epoch == 1:
            save_image(y * 0.5 + 0.5, folder + f"/label_{epoch}.png")
    gen.train()

In [10]:
def save_checkpoint(model, optimizer, filename="my_checkpoint.pth.tar"):
    print("=> Saving checkpoint")
    checkpoint = {
        "state_dict": model.state_dict(),
        "optimizer": optimizer.state_dict(),
    }
    torch.save(checkpoint, filename)

In [11]:
def load_checkpoint(checkpoint_file, model, optimizer, lr):
    print("=> Loading checkpoint")
    checkpoint = torch.load(checkpoint_file, map_location=config.DEVICE)
    model.load_state_dict(checkpoint["state_dict"])
    optimizer.load_state_dict(checkpoint["optimizer"])
    for param_group in optimizer.param_groups:
        param_group["lr"] = lr

In [12]:
torch.backends.cudnn.benchmark = True

In [13]:
def train_fn(disc, gen, loader, opt_disc, opt_gen, l1_loss, bce, g_scaler, d_scaler):
    loop = tqdm(loader, leave=True)

    for idx, (x, y) in enumerate(loop):
        x = x.to(config.DEVICE)
        y = y.to(config.DEVICE)

        # Train Discriminator
        with torch.cuda.amp.autocast():
            y_fake = gen(x)
            D_real = disc(x, y)
            D_real_loss = bce(D_real, torch.ones_like(D_real))
            D_fake = disc(x, y_fake.detach())
            D_fake_loss = bce(D_fake, torch.zeros_like(D_fake))
            D_loss = (D_real_loss + D_fake_loss) / 2

        disc.zero_grad()
        d_scaler.scale(D_loss).backward()
        d_scaler.step(opt_disc)
        d_scaler.update()

        # Train generator
        with torch.cuda.amp.autocast():
            D_fake = disc(x, y_fake)
            G_fake_loss = bce(D_fake, torch.ones_like(D_fake))
            L1 = l1_loss(y_fake, y) * config.L1_LAMBDA
            G_loss = G_fake_loss + L1

        opt_gen.zero_grad()
        g_scaler.scale(G_loss).backward()
        g_scaler.step(opt_gen)
        g_scaler.update()

        if idx % 10 == 0:
            loop.set_postfix(
                D_real=torch.sigmoid(D_real).mean().item(),
                D_fake=torch.sigmoid(D_fake).mean().item(),
            )


In [14]:
disc = Discriminator(in_channels=3).to(config.DEVICE)
gen = Generator(in_channels=3, features=64).to(config.DEVICE)
opt_disc = optim.Adam(disc.parameters(), lr=config.LEARNING_RATE, betas=(0.5, 0.999),)
opt_gen = optim.Adam(gen.parameters(), lr=config.LEARNING_RATE, betas=(0.5, 0.999))
BCE = nn.BCEWithLogitsLoss()
L1_LOSS = nn.L1Loss()

# if config.LOAD_MODEL:
#     load_checkpoint(
#         config.CHECKPOINT_GEN, gen, opt_gen, config.LEARNING_RATE,
#     )
#     load_checkpoint(
#         config.CHECKPOINT_DISC, disc, opt_disc, config.LEARNING_RATE,
#     )

train_dataset = ImageDataset(root_dir=config.TRAIN_DIR)
train_loader = DataLoader(
    train_dataset,
    batch_size=config.BATCH_SIZE,
    shuffle=True,
    num_workers=config.NUM_WORKERS,
)
g_scaler = torch.cuda.amp.GradScaler()
d_scaler = torch.cuda.amp.GradScaler()
val_dataset = ImageDataset(root_dir=config.VAL_DIR)
val_loader = DataLoader(val_dataset, batch_size=5, shuffle=False)

for epoch in range(config.NUM_EPOCHS):
    train_fn(
        disc, gen, train_loader, opt_disc, opt_gen, L1_LOSS, BCE, g_scaler, d_scaler,
    )

    if config.SAVE_MODEL and epoch % 5 == 0:
        save_checkpoint(gen, opt_gen, filename=f"gen_{epoch}.pth")
        if epoch%10 == 0:
            save_checkpoint(disc, opt_disc, filename=f"disc_{epoch}.pth")

100%|██████████| 2824/2824 [18:58<00:00,  2.48it/s, D_fake=0.0381, D_real=0.971]


=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 2824/2824 [17:01<00:00,  2.76it/s, D_fake=0.217, D_real=0.892]
100%|██████████| 2824/2824 [17:07<00:00,  2.75it/s, D_fake=0.496, D_real=0.496]
100%|██████████| 2824/2824 [17:38<00:00,  2.67it/s, D_fake=0.462, D_real=0.505]
100%|██████████| 2824/2824 [17:05<00:00,  2.75it/s, D_fake=0.403, D_real=0.704]
100%|██████████| 2824/2824 [16:47<00:00,  2.80it/s, D_fake=0.0806, D_real=0.665]


=> Saving checkpoint


100%|██████████| 2824/2824 [16:56<00:00,  2.78it/s, D_fake=0.508, D_real=0.547]
100%|██████████| 2824/2824 [16:48<00:00,  2.80it/s, D_fake=0.254, D_real=0.66]
100%|██████████| 2824/2824 [16:45<00:00,  2.81it/s, D_fake=0.337, D_real=0.701]
100%|██████████| 2824/2824 [17:00<00:00,  2.77it/s, D_fake=0.357, D_real=0.698]
100%|██████████| 2824/2824 [17:09<00:00,  2.74it/s, D_fake=0.426, D_real=0.587]


=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 2824/2824 [16:56<00:00,  2.78it/s, D_fake=0.438, D_real=0.386]
100%|██████████| 2824/2824 [16:57<00:00,  2.77it/s, D_fake=0.469, D_real=0.512]
100%|██████████| 2824/2824 [17:24<00:00,  2.70it/s, D_fake=0.29, D_real=0.652]
100%|██████████| 2824/2824 [16:46<00:00,  2.81it/s, D_fake=0.256, D_real=0.603]
100%|██████████| 2824/2824 [17:16<00:00,  2.72it/s, D_fake=0.508, D_real=0.485]


=> Saving checkpoint


100%|██████████| 2824/2824 [17:18<00:00,  2.72it/s, D_fake=0.263, D_real=0.689]
100%|██████████| 2824/2824 [17:04<00:00,  2.76it/s, D_fake=0.251, D_real=0.639]
100%|██████████| 2824/2824 [16:57<00:00,  2.78it/s, D_fake=0.473, D_real=0.552]
100%|██████████| 2824/2824 [17:07<00:00,  2.75it/s, D_fake=0.5, D_real=0.405]
100%|██████████| 2824/2824 [17:24<00:00,  2.70it/s, D_fake=0.231, D_real=0.791]


=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 2824/2824 [17:23<00:00,  2.71it/s, D_fake=0.195, D_real=0.662]
100%|██████████| 2824/2824 [17:29<00:00,  2.69it/s, D_fake=0.153, D_real=0.809]
100%|██████████| 2824/2824 [16:43<00:00,  2.81it/s, D_fake=0.294, D_real=0.606]
100%|██████████| 2824/2824 [16:32<00:00,  2.85it/s, D_fake=0.142, D_real=0.707]
100%|██████████| 2824/2824 [17:01<00:00,  2.76it/s, D_fake=0.478, D_real=0.505]


=> Saving checkpoint


100%|██████████| 2824/2824 [16:52<00:00,  2.79it/s, D_fake=0.106, D_real=0.736]
100%|██████████| 2824/2824 [16:36<00:00,  2.83it/s, D_fake=0.474, D_real=0.512]
100%|██████████| 2824/2824 [16:52<00:00,  2.79it/s, D_fake=0.173, D_real=0.73]
100%|██████████| 2824/2824 [16:45<00:00,  2.81it/s, D_fake=0.491, D_real=0.549]
100%|██████████| 2824/2824 [16:16<00:00,  2.89it/s, D_fake=0.324, D_real=0.676]


=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 2824/2824 [16:23<00:00,  2.87it/s, D_fake=0.0975, D_real=0.87]
100%|██████████| 2824/2824 [16:39<00:00,  2.83it/s, D_fake=0.495, D_real=0.51]
100%|██████████| 2824/2824 [16:45<00:00,  2.81it/s, D_fake=0.202, D_real=0.729]
100%|██████████| 2824/2824 [18:12<00:00,  2.58it/s, D_fake=0.102, D_real=0.916]
100%|██████████| 2824/2824 [18:32<00:00,  2.54it/s, D_fake=0.466, D_real=0.52]


=> Saving checkpoint


In [15]:
save_checkpoint(gen, opt_gen, filename=f"gen_final.pth")
save_checkpoint(disc, opt_disc, filename=f"disc_final.pth")

=> Saving checkpoint
=> Saving checkpoint
