In [1]:
import numpy as np
import torch
import torch.nn as nn

In [2]:
class Block(nn.Module):
    def __init__(self, in_channels, out_channels, stride):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 4, stride, 1, bias=True, padding_mode='reflect'),
            nn.InstanceNorm2d(out_channels),
            nn.LeakyReLU(0.2)
        )
        
    def forward(self, x):
        return self.conv(x)

class Discriminator(nn.Module):
    def __init__(self, in_channels=3, features=[64, 128, 256, 512]):
        super().__init__()
        self.initial = nn.Sequential(
            nn.Conv2d(in_channels, features[0], 4, 2, 1, padding_mode='reflect'),
            nn.LeakyReLU(0.2)
        )
        
        layers = []
        in_channels = features[0]
        
        for feature in features[1:]:
            layers.append(Block(in_channels, feature, stride=1 if feature == features[-1] else 2))
            in_channels = feature
            
        layers.append(nn.Conv2d(in_channels, 1, kernel_size=4, stride=1, padding=1, padding_mode='reflect'))
        
        self.model = nn.Sequential(*layers)
    def forward(self, x):
        x = self.initial(x)
        return torch.sigmoid(self.model(x))
    
x = torch.randn((1, 3, 256, 256))

model = Discriminator(in_channels=3)
preds = model(x)
print(preds.shape)

torch.Size([1, 1, 30, 30])


In [3]:
class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, down=True, use_act=True, **kwargs):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, padding_mode='reflect', **kwargs)
            if down
            else nn.ConvTranspose2d(in_channels, out_channels, **kwargs),
            nn.InstanceNorm2d(out_channels),
            nn.ReLU(inplace=True) if use_act else nn.Identity()
        )
        
    def forward(self, x):
        return self.conv(x)
    
class ResidualBlock(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.block = nn.Sequential(
            ConvBlock(channels, channels, kernel_size=3, padding=1),
            ConvBlock(channels, channels, use_act=False, kernel_size=3, padding=1)
       )
    
    def forward(self, x):
        return x + self.block(x)
    
    
class Generator(nn.Module):
    def __init__(self, in_channels, num_features=64, num_residuals=9):
        super().__init__()
        
        self.initial = nn.Sequential(
            nn.Conv2d(in_channels, num_features, kernel_size=7, stride=1, padding=3, padding_mode='reflect'),
            nn.ReLU()
        )
        
        self.down_block = nn.ModuleList(
            [
                ConvBlock(num_features, num_features*2, kernel_size=3, stride=2, padding=1),
                ConvBlock(num_features*2, num_features*4, kernel_size=3, stride=2, padding=1),
            ]
        )
        
        self.residuals = nn.Sequential(
            *[ResidualBlock(num_features*4) for _ in range(num_residuals)]
        )
        
        self.up_block = nn.ModuleList(
            [
                ConvBlock(num_features*4, num_features*2, down=False, kernel_size=3, stride=2, padding=1, output_padding=1),
                ConvBlock(num_features*2, num_features, down=False, kernel_size=3, stride=2, padding=1, output_padding=1),
            ]
        )
        
        self.final = nn.Conv2d(num_features, in_channels, 7, 1, 3, padding_mode='reflect')
        
    def forward(self, x):
        x = self.initial(x)
        for layer in self.down_block:
            x = layer(x)
            
        x=self.residuals(x)

        for layer in self.up_block:
            x = layer(x)

        return torch.tanh(self.final(x))

x = torch.randn((1, 3, 256, 256))

model = Generator(in_channels=3)
preds = model(x)
print(preds.shape)

torch.Size([1, 3, 256, 256])


In [4]:
import torch
from torchvision.utils import save_image
import albumentations as A
from albumentations.pytorch import ToTensorV2

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
TRAIN_DIR = "/kaggle/input/summer2winter-yosemite"
VAL_DIR = "/kaggle/input/summer2winter-yosemite"
BATCH_SIZE = 4
LEARNING_RATE = 1e-5
LAMBDA_IDENTITY = 0.5
LAMBDA_CYCLE = 10
NUM_WORKERS = 2
NUM_EPOCHS = 200
LOAD_MODEL = False
SAVE_MODEL = True
CHECKPOINT_GEN_S = "gens.pth.tar"
CHECKPOINT_GEN_W = "genw.pth.tar"
CHECKPOINT_DISC_S = "critics.pth.tar"
CHECKPOINT_DISC_W = "criticw.pth.tar"

transformsA = A.Compose(
    [
        A.Resize(width=256, height=256),
        A.HorizontalFlip(p=0.5),
        A.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], max_pixel_value=255),
        ToTensorV2(),
    ],
    additional_targets={"image0": "image"},
)

import random, os
import copy

def save_checkpoint(model, optimizer, filename="my_checkpoint.pth.tar"):
    print("=> Saving checkpoint")
    checkpoint = {
        "state_dict": model.state_dict(),
        "optimizer": optimizer.state_dict(),
    }
    torch.save(checkpoint, filename)


def load_checkpoint(checkpoint_file, model, optimizer, lr):
    print("=> Loading checkpoint")
    checkpoint = torch.load(checkpoint_file, map_location=config.DEVICE)
    model.load_state_dict(checkpoint["state_dict"])
    optimizer.load_state_dict(checkpoint["optimizer"])

    # If we don't do this then it will just have learning rate of old checkpoint
    # and it will lead to many hours of debugging \:
    for param_group in optimizer.param_groups:
        param_group["lr"] = lr


def seed_everything(seed=42):
    os.environ["PYTHONHASHSEED"] = str(seed)
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False




In [5]:
from PIL import Image
import os
from torch.utils.data import Dataset

class YosemiteDataset(Dataset):
    def __init__(self, root_summer, root_winter, transform=None):
        self.root_summer = root_summer
        self.root_winter = root_winter
        self.transform = transform
        
        self.summer_images = os.listdir(self.root_summer)
        self.winter_images = os.listdir(self.root_winter)
        
        self.length_dataset = max(len(self.summer_images), len(self.winter_images))
        self.summer_len = len(self.summer_images)
        self.winter_len = len(self.winter_images)
        
    def __len__(self):
        return self.length_dataset
    
    def __getitem__(self, index):
        summer_img = self.summer_images[index % self.summer_len]
        winter_img = self.winter_images[index % self.winter_len]
        
        summer_path = os.path.join(self.root_summer, summer_img)
        winter_path = os.path.join(self.root_winter, winter_img)
        
        summer_img = np.array(Image.open(summer_path).convert('RGB'))
        winter_img = np.array(Image.open(winter_path).convert('RGB'))
    
        if self.transform:
            augs = self.transform(image=summer_img, image0=winter_img)
            summer_img = augs['image']
            winter_img = augs['image0']
        
        return summer_img, winter_img

In [6]:
from torch.utils.data import DataLoader
te_dataset = YosemiteDataset(root_summer='/kaggle/input/summer2winter-yosemite/trainA/', root_winter='/kaggle/input/summer2winter-yosemite/trainB/', transform=transformsA)
te_loader = DataLoader(te_dataset, batch_size=1, shuffle=True, num_workers=1)

batch= iter(te_loader)
image, image0 = next(batch)

In [7]:
!mkdir fake_images
!mkdir real_images

In [8]:
!pip install wandb
from kaggle_secrets import UserSecretsClient
import wandb
wandb.login(key=UserSecretsClient().get_secret("wandb_api"))
run = wandb.init(
    # Set the project where this run will be logged
    project="my-awesome-project",
    # Track hyperparameters and run metadata
    config={
        "learning_rate": LEARNING_RATE,
        "epochs": NUM_EPOCHS,
        "batch_size": BATCH_SIZE
    })



[34m[1mwandb[0m: W&B API key is configured. Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33mhenry-laur[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [None]:
from tqdm import tqdm
import torch.optim as optim
from torchvision import transforms
transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize(1024),
    #transforms.ToTensor()
])

def resize(x):
    return [transform(x_) for x_ in x]

def train_fn(disc_S, disc_W, gen_S, gen_W, yosemite_loader, opt_disc, opt_gen, L1, mse, g_scaler, d_scaler, epoch):
    loop = tqdm(yosemite_loader, leave=True)
    
    for idx, (summer, winter) in enumerate(loop):
        summer = summer.to(DEVICE)
        winter = winter.to(DEVICE)
        
        # Train Disc S, W
        with torch.cuda.amp.autocast():
            fake_summer = gen_S(winter)
            D_S_real = disc_S(summer)
            D_S_fake = disc_S(fake_summer.detach())
            D_S_real_loss = mse(D_S_real, torch.ones_like(D_S_real))
            D_S_fake_loss = mse(D_S_fake, torch.zeros_like(D_S_fake))
            D_S_loss = D_S_real_loss + D_S_fake_loss
            
            
            fake_winter = gen_W(summer)
            D_W_real = disc_W(winter)
            D_W_fake = disc_W(fake_winter.detach())
            D_W_real_loss = mse(D_W_real, torch.ones_like(D_W_real))
            D_W_fake_loss = mse(D_W_fake, torch.zeros_like(D_W_fake))
            D_W_loss = D_W_real_loss + D_W_fake_loss
    
            # put it together
            D_loss = (D_W_loss + D_S_loss) / 2
            
        opt_disc.zero_grad()
        d_scaler.scale(D_loss).backward()
        d_scaler.step(opt_disc)
        d_scaler.update()
        
        # Train Gens S, W
        with torch.cuda.amp.autocast():
            
            #Adversarial loss
            D_S_fake = disc_S(fake_summer)
            D_W_fake = disc_W(fake_winter)
            
            loss_G_S = mse(D_S_fake, torch.ones_like(D_S_fake))
            loss_G_W = mse(D_W_fake, torch.ones_like(D_W_fake))
            
            #Cycle loss
            cycle_summer = gen_S(fake_winter)
            cycle_winter = gen_W(fake_summer)
            cycle_summer_loss = L1(summer, cycle_summer)
            cycle_winter_loss = L1(winter, cycle_winter)
            
            #Identity loss
            #identity_winter = gen_W(winter)
            #identity_summer = gen_S(summer)
            #identity_winter_loss = L1(winter, identity_winter)
            #identity_summer_loss = L1(summer, identity_summer)
            
            # losses
            G_loss = (
                loss_G_S + 
                loss_G_W + 
                cycle_summer_loss * LAMBDA_CYCLE + 
                cycle_winter_loss * LAMBDA_CYCLE 
                #identity_summer_loss * LAMBDA_IDENTITY + 
                #identity_winter_loss * LAMBDA_IDENTITY
            )
                    
        opt_gen.zero_grad()
        g_scaler.scale(G_loss).backward()
        g_scaler.step(opt_gen)
        g_scaler.update()

        if(idx == len(loop) - 1):      
            img = [*resize(summer*0.5 + 0.5), *resize(winter*0.5 + 0.5), *resize(fake_winter*0.5 + 0.5), *resize(fake_summer*0.5 + 0.5)]
            wandb.log({"examples": [wandb.Image(image) for image in img]})
                    
def main():
    disc_S = Discriminator(in_channels=3).to(DEVICE)
    disc_W = Discriminator(in_channels=3).to(DEVICE)
    gen_S = Generator(in_channels=3).to(DEVICE)
    gen_W = Generator(in_channels=3).to(DEVICE)
    
    opt_disc = optim.Adam(
        list(disc_S.parameters()) + list(disc_W.parameters()),
        lr=LEARNING_RATE,
        betas=(0.5, 0.999)
    )
    opt_gen = optim.Adam(
        list(gen_S.parameters()) + list(gen_W.parameters()),
        lr=LEARNING_RATE,
        betas=(0.5, 0.999)
    )
    
    L1 = nn.L1Loss()
    mse = nn.MSELoss()
    
    if LOAD_MODEL:
        load_checkpoint(CHECKPOINT_GEN_W, gen_w, opt_gen, LEANING_RATE)
        load_checkpoint(CHECKPOINT_GEN_S, gen_S, opt_gen, LEANING_RATE)
        load_checkpoint(CHECKPOINT_DISC_S, disc_S, opt_disc, LEANING_RATE)
        load_checkpoint(CHECKPOINT_DISC_W, disc_W, opt_disc, LEANING_RATE)
    
    yosemite_dataset = YosemiteDataset(root_summer='/kaggle/input/summer2winter-yosemite/trainA/', root_winter='/kaggle/input/summer2winter-yosemite/trainB/', transform=transformsA)
    yosemite_loader = DataLoader(yosemite_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS, pin_memory=True)

    g_scaler = torch.cuda.amp.GradScaler()
    d_scaler = torch.cuda.amp.GradScaler()
    
    for epoch in range(NUM_EPOCHS):
        train_fn(disc_S, disc_W, gen_S, gen_W, yosemite_loader, opt_disc, opt_gen, L1, mse, g_scaler, d_scaler, epoch)
        
        if SAVE_MODEL:
            save_checkpoint(gen_W, opt_gen, filename=CHECKPOINT_GEN_W)
            save_checkpoint(gen_S, opt_gen, filename=CHECKPOINT_GEN_S)
            save_checkpoint(disc_S, opt_disc, filename=CHECKPOINT_DISC_W)
            save_checkpoint(disc_W, opt_disc, filename=CHECKPOINT_DISC_S)
            wandb.save('/kaggle/working/*pth*')
    
main()

100%|██████████| 308/308 [04:19<00:00,  1.19it/s]


=> Saving checkpoint
=> Saving checkpoint




=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 308/308 [04:19<00:00,  1.19it/s]


=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 308/308 [04:19<00:00,  1.19it/s]


=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 308/308 [04:19<00:00,  1.19it/s]


=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 308/308 [04:19<00:00,  1.19it/s]


=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 308/308 [04:19<00:00,  1.19it/s]


=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 308/308 [04:19<00:00,  1.19it/s]


=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 308/308 [04:18<00:00,  1.19it/s]


=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 308/308 [04:18<00:00,  1.19it/s]


=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 308/308 [04:19<00:00,  1.19it/s]


=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 308/308 [04:19<00:00,  1.19it/s]


=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 308/308 [04:18<00:00,  1.19it/s]


=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 308/308 [04:19<00:00,  1.19it/s]


=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 308/308 [04:18<00:00,  1.19it/s]


=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 308/308 [04:19<00:00,  1.19it/s]


=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint


 32%|███▏      | 100/308 [01:21<02:48,  1.24it/s]