In [None]:
%%bash

FILE="horse2zebra"

echo "Specified [$FILE]"
URL=http://efrosgans.eecs.berkeley.edu/cyclegan/datasets/$FILE.zip
mkdir datasets
ZIP_FILE="/content/datasets/$FILE.zip"
TARGET_DIR="/content/datasets/$FILE/"
wget -N $URL -O $ZIP_FILE
mkdir $TARGET_DIR
unzip $ZIP_FILE -d /content/datasets/
rm $ZIP_FILE

mkdir saved_images evaluation

In [None]:
#config
import torch
import os
import albumentations as A
from albumentations.pytorch import ToTensorV2

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
TRAIN_DIR = "/content/datasets/horse2zebra/train"
VAL_DIR = "/content/datasets/horse2zebra/test"
BATCH_SIZE = 8
LEARNING_RATE = 1e-5
LAMBDA_IDENTITY = 0.0
LAMBDA_CYCLE = 10
NUM_WORKERS = os.cpu_count()
NUM_EPOCHS = 10
LOAD_MODEL = False
SAVE_MODEL = True
CHECKPOINT_GEN_H = "/content/genh.pth"
CHECKPOINT_GEN_Z = "/content/genz.pth"
CHECKPOINT_CRITIC_H = "/content/critich.pth"
CHECKPOINT_CRITIC_Z = "/content/criticz.pth"
transforms = A.Compose(
    [
        A.Resize(width=256, height=256),
        A.HorizontalFlip(p=0.5),
        A.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], max_pixel_value=255),
        ToTensorV2(),
     ],
    additional_targets={"image0": "image"},
)

In [None]:
#dataset
import os
from PIL import Image
from torch.utils.data import Dataset
import numpy as np

class HorseZebraDataset(Dataset):
    def __init__(self, root_zebra, root_horse, transform=None):
        self.root_zebra = root_zebra
        self.root_horse = root_horse
        self.transform = transform

        self.zebra_images = os.listdir(root_zebra)
        self.horse_images = os.listdir(root_horse)
        self.length_dataset = max(len(self.root_horse),len(self.zebra_images))
        self.zebra_len = len(self.zebra_images)
        self.horse_len = len(self.horse_images)
    
    def __len__(self):
        return self.length_dataset
    
    def __getitem__(self, index):
        zebra_img = self.zebra_images[index % self.zebra_len]
        horse_img = self.horse_images[index % self.horse_len]

        zebra_path = os.path.join(self.root_zebra,zebra_img)
        horse_path = os.path.join(self.root_horse,horse_img)
        zebra_img = np.array(Image.open(zebra_path).convert("RGB"))
        horse_img = np.array(Image.open(horse_path).convert("RGB"))

        if self.transform:
            augmentations = self.transform(image = zebra_img, image0 = horse_img)
            zebra_img = augmentations['image']
            horse_img = augmentations['image0']
        return zebra_img,horse_img

In [None]:
#Generator
import torch
import torch.nn as nn

class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, down=True, use_act=True, **kwargs):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, padding_mode="reflect", **kwargs)
            if down
            else nn.ConvTranspose2d(in_channels, out_channels, **kwargs),
            nn.InstanceNorm2d(out_channels),
            nn.ReLU(inplace=True) if use_act else nn.Identity()
        )

    def forward(self, x):
        return self.conv(x)

class ResidualBlock(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.block = nn.Sequential(
            ConvBlock(channels, channels, kernel_size=3, padding=1),
            ConvBlock(channels, channels, use_act=False, kernel_size=3, padding=1),
        )

    def forward(self, x):
        return x + self.block(x)

class Generator(nn.Module):
    def __init__(self, img_channels, num_features = 64, num_residuals=9):
        super().__init__()
        self.initial = nn.Sequential(
            nn.Conv2d(img_channels, num_features, kernel_size=7, stride=1, padding=3, padding_mode="reflect"),
            nn.InstanceNorm2d(num_features),
            nn.ReLU(inplace=True),
        )
        self.down_blocks = nn.ModuleList(
            [
                ConvBlock(num_features, num_features*2, kernel_size=3, stride=2, padding=1),
                ConvBlock(num_features*2, num_features*4, kernel_size=3, stride=2, padding=1),
            ]
        )
        self.res_blocks = nn.Sequential(
            *[ResidualBlock(num_features*4) for _ in range(num_residuals)]
        )
        self.up_blocks = nn.ModuleList(
            [
                ConvBlock(num_features*4, num_features*2, down=False, kernel_size=3, stride=2, padding=1, output_padding=1),
                ConvBlock(num_features*2, num_features*1, down=False, kernel_size=3, stride=2, padding=1, output_padding=1),
            ]
        )

        self.last = nn.Conv2d(num_features*1, img_channels, kernel_size=7, stride=1, padding=3, padding_mode="reflect")

    def forward(self, x):
        x = self.initial(x)
        for layer in self.down_blocks:
            x = layer(x)
        x = self.res_blocks(x)
        for layer in self.up_blocks:
            x = layer(x)
        return torch.tanh(self.last(x))

def test():
    img_channels = 3
    img_size = 256
    x = torch.randn((2, img_channels, img_size, img_size))
    gen = Generator(img_channels, 9)
    print(gen(x).shape)

test()

torch.Size([2, 3, 256, 256])


In [None]:
#Discriminator
import torch
import torch.nn as nn

class CNNBlock(nn.Module):
    def __init__(self,in_channels,out_channels,stride = 2):
        super(CNNBlock,self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels,out_channels,4,stride,1,bias=True,padding_mode='reflect'),
            nn.InstanceNorm2d(out_channels),
            nn.LeakyReLU(0.2,inplace=True)
        )
    def forward(self,x):
        return self.conv(x)

class Discriminator(nn.Module):
    def __init__(self,in_channels = 3,features = [64,128,256,512]):
        super(Discriminator,self).__init__()
        self.initial = nn.Sequential(
            nn.Conv2d(in_channels,out_channels=features[0],kernel_size=4,stride=2,padding=1,padding_mode='reflect'),
            nn.LeakyReLU(0.2,inplace=True)
        )
        layers = []
        in_channels = features[0]
        for feature in features[1:]:
            layers.append(CNNBlock(in_channels,out_channels=feature,stride=1 if feature==features[-1] else 2))
            in_channels = feature
        layers.append(nn.Conv2d(in_channels,1,kernel_size=4,stride=1,padding=1,padding_mode='reflect'))
        self.model = nn.Sequential(*layers)
    
    def forward(self,x):
        x = self.initial(x)
        return torch.sigmoid(self.model(x))

def test():
    x = torch.randn((5, 3, 256, 256))
    model = Discriminator(in_channels=3)
    preds = model(x)
    print(preds.shape)



test()

torch.Size([5, 1, 30, 30])


In [None]:
#utils
import torch
#import config

def save_checkpoint(model, optimizer, filename="my_checkpoint.pth.tar"):
    print("=> Saving checkpoint")
    checkpoint = {
        "state_dict": model.state_dict(),
        "optimizer": optimizer.state_dict(),
    }
    torch.save(checkpoint, filename)

def load_checkpoint(checkpoint_file, model, optimizer, lr):
    print("=> Loading checkpoint")
    checkpoint = torch.load(checkpoint_file, map_location=DEVICE)
    model.load_state_dict(checkpoint["state_dict"])
    optimizer.load_state_dict(checkpoint["optimizer"])
    for param_group in optimizer.param_groups:
        param_group["lr"] = lr

In [None]:
#train
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision.utils import save_image
from torch.utils.data import DataLoader
from tqdm import tqdm

def train_func(disc_H,disc_Z,gen_H,gen_Z,opt_disc,opt_gen,g_scaler,d_scaler,L1,mse,loader):
    loop = tqdm(loader,leave=True)
    for idx, (zebra,horse) in enumerate(loop):
        zebra = zebra.to(DEVICE)
        horse = horse.to(DEVICE)
        # train Discriminator H & Z
        with torch.cuda.amp.autocast():
            fake_horse = gen_H(zebra)
            D_H_real = disc_H(horse)
            D_H_fake = disc_H(fake_horse.detach())
            D_H_real_loss = mse(D_H_real,torch.ones_like(D_H_real))
            D_H_fake_loss = mse(D_H_fake,torch.zeros_like(D_H_fake))
            D_H_loss = D_H_real_loss + D_H_fake_loss

            fake_zebra = gen_Z(horse)
            D_Z_real = disc_Z(zebra)
            D_Z_fake = disc_Z(fake_zebra.detach())
            D_Z_real_loss = mse(D_Z_real,torch.ones_like(D_Z_real))
            D_Z_fake_loss = mse(D_Z_fake,torch.zeros_like(D_Z_fake))
            D_Z_loss = D_Z_real_loss + D_Z_fake_loss

            D_loss = (D_H_loss + D_Z_loss)/2
        
        opt_disc.zero_grad()
        d_scaler.scale(D_loss).backward()
        d_scaler.step(opt_disc)
        d_scaler.update()

        # Train generator H & Z
        with torch.cuda.amp.autocast():
            D_H_fake = disc_H(fake_horse)
            D_Z_fake = disc_Z(fake_zebra)
            loss_G_H = mse(D_H_fake,torch.ones_like(D_H_fake))
            loss_G_Z = mse(D_Z_fake, torch.ones_like(D_Z_fake))
            # cycle loss
            cycle_zebra = gen_Z(fake_horse)
            cycle_horse = gen_H(fake_zebra)
            cycle_zebra_loss = L1(zebra,cycle_zebra)
            cycle_horse_loss = L1(horse,cycle_horse)

            identity_zebra = gen_Z(zebra)
            identity_horse = gen_H(horse)
            identity_zebra_loss = L1(zebra, identity_zebra)
            identity_horse_loss = L1(horse, identity_horse)

            G_loss = (
                loss_G_Z
                + loss_G_H
                + cycle_zebra_loss * LAMBDA_CYCLE
                + cycle_horse_loss * LAMBDA_CYCLE
                + identity_horse_loss * LAMBDA_IDENTITY
                + identity_zebra_loss * LAMBDA_IDENTITY
            )
        
        opt_gen.zero_grad()
        g_scaler.scale(G_loss).backward()
        g_scaler.step(opt_gen)
        g_scaler.update()

        if (idx+1) % 80 == 0:
            save_image(horse*0.5+0.5, f"saved_images/horse_{idx}.png")
            save_image(fake_horse*0.5+0.5, f"saved_images/fake_horse_{idx}.png")
            save_image(zebra*0.5+0.5, f"saved_images/zebra_{idx}.png")
            save_image(fake_zebra*0.5+0.5, f"saved_images/fake_zebra_{idx}.png")


def main():
    disc_H = Discriminator(in_channels=3).to(DEVICE)
    disc_Z = Discriminator(in_channels=3).to(DEVICE)
    gen_H = Generator(img_channels=3,num_residuals=9).to(DEVICE)
    gen_Z = Generator(img_channels=3,num_residuals=9).to(DEVICE)

    opt_disc = optim.Adam(
        list(disc_H.parameters()) + list(disc_Z.parameters()),
        lr=LEARNING_RATE,
        betas= (0.5,0.999)
    )
    opt_gen = optim.Adam(
        list(gen_H.parameters()) + list(gen_Z.parameters()),
        lr=LEARNING_RATE,
        betas= (0.5,0.999)
    )
    L1 = nn.L1Loss()
    mse = nn.MSELoss()
    if LOAD_MODEL:
        load_checkpoint(CHECKPOINT_GEN_H,gen_H,opt_gen,LEARNING_RATE)
        load_checkpoint(CHECKPOINT_GEN_Z,gen_Z,opt_gen,LEARNING_RATE)
        load_checkpoint(CHECKPOINT_CRITIC_H,disc_H,opt_disc,LEARNING_RATE)
        load_checkpoint(CHECKPOINT_CRITIC_Z,disc_Z,opt_disc,LEARNING_RATE)
    
    train_dataset = HorseZebraDataset(root_zebra=TRAIN_DIR+"B",root_horse=TRAIN_DIR+"A",transform=transforms)
    train_loader = DataLoader(train_dataset,batch_size=BATCH_SIZE,pin_memory=True,shuffle=True,num_workers=NUM_WORKERS)
    val_dataset = HorseZebraDataset(root_zebra=VAL_DIR+"B",root_horse=VAL_DIR+"A",transform=transforms)
    val_loader = DataLoader(val_dataset,batch_size= 1 ,pin_memory=True,shuffle=False,num_workers=NUM_WORKERS)

    g_scaler = torch.cuda.amp.GradScaler()
    d_scaler = torch.cuda.amp.GradScaler()

    for epoch in range(NUM_EPOCHS):
        train_func(disc_H,disc_Z,gen_H,gen_Z,opt_disc,opt_gen,g_scaler,d_scaler,L1,mse,train_loader)
        if SAVE_MODEL:
            save_checkpoint(gen_H, opt_gen, filename=CHECKPOINT_GEN_H)
            save_checkpoint(gen_Z, opt_gen, filename=CHECKPOINT_GEN_Z)
            save_checkpoint(disc_H, opt_disc, filename=CHECKPOINT_CRITIC_H)
            save_checkpoint(disc_Z, opt_disc, filename=CHECKPOINT_CRITIC_Z)

In [None]:
main()

100%|██████████| 167/167 [05:34<00:00,  2.01s/it]


=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 167/167 [05:28<00:00,  1.96s/it]


=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 167/167 [05:28<00:00,  1.96s/it]


=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 167/167 [05:27<00:00,  1.96s/it]


=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 167/167 [05:27<00:00,  1.96s/it]


=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 167/167 [05:28<00:00,  1.97s/it]


=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 167/167 [05:28<00:00,  1.97s/it]


=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 167/167 [05:28<00:00,  1.96s/it]


=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 167/167 [05:28<00:00,  1.96s/it]


=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint


100%|██████████| 167/167 [05:27<00:00,  1.96s/it]


=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint
=> Saving checkpoint


In [None]:
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision.utils import save_image
from tqdm import tqdm

val_dataset = HorseZebraDataset(root_zebra=VAL_DIR+"B",root_horse=VAL_DIR+"A",transform=transforms)
val_loader = DataLoader(val_dataset,batch_size=1,shuffle=False)
gen_H = Generator(img_channels=3,num_residuals=9).to(DEVICE)
gen_Z = Generator(img_channels=3,num_residuals=9).to(DEVICE)
opt_gen = optim.Adam(
        list(gen_H.parameters()) + list(gen_Z.parameters()),
        lr=LEARNING_RATE,
        betas= (0.5,0.999)
    )
load_checkpoint(CHECKPOINT_GEN_H,gen_H,opt_gen,LEARNING_RATE)

def save_some_examples(gen_H, val_loader, folder):
    loop = tqdm(val_loader,leave=True)
    for idx, (zebra,horse) in enumerate(loop):
        zebra = zebra.to(DEVICE)
        horse = horse.to(DEVICE)
        gen_H.eval()
        with torch.no_grad():
            y_fake = gen_H(zebra)
            y_fake = y_fake * 0.5 + 0.5  # remove normalization#
            save_image(y_fake, folder + f"/fake_horse{idx}.png")
            save_image(zebra * 0.5 + 0.5, folder + f"/zebra{idx}.png")
        gen_H.train()

=> Loading checkpoint


In [None]:
save_some_examples(gen_H,val_loader,folder='evaluation')

 54%|█████▍    | 76/140 [00:04<00:04, 15.80it/s]