In [1]:
import torch
import torch.nn as nn

In [2]:
import torch
import albumentations as A
from albumentations.pytorch import ToTensorV2
from easydict import EasyDict as edict

config = edict()
config.DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
config.TRAIN_DIR = "data/train"
config.VAL_DIR = "data/val"
config.BATCH_SIZE = 1
config.LEARNING_RATE = 1e-5
config.LAMBDA_IDENTITY = 0.0
config.LAMBDA_CYCLE = 10
config.NUM_WORKERS = 0
config.NUM_EPOCHS = 10
config.LOAD_MODEL = False
config.SAVE_MODEL = True
config.CHECKPOINT_GEN_H = "genh.pth.tar"
config.CHECKPOINT_GEN_Z = "genz.pth.tar"
config.CHECKPOINT_CRITIC_H = "critich.pth.tar"
config.CHECKPOINT_CRITIC_Z = "criticz.pth.tar"

config.transforms = A.Compose(
    [
        A.Resize(width=256, height=256),
        A.HorizontalFlip(p=0.5),
        A.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], max_pixel_value=255),
        ToTensorV2(),
     ],
    additional_targets={"image0": "image"},
)

# Discriminator

In [3]:
class Discriminator(nn.Module):
    def __init__(self, in_channels = 3, features = [64, 128, 256, 512]):
        super(Discriminator, self).__init__()
        self.initial = nn.Sequential(
            nn.Conv2d(in_channels,features[0], 4, 2, 1, padding_mode = 'reflect'),
            nn.LeakyReLU(0.2, inplace = True),
        
        )
        
        layers = []
        
        in_channels = features[0]
        
        for out_channels in features[1:]:
            layer = self.block(in_channels, out_channels , 1 if out_channels==features[-1] else 2)
            layers.append(layer)
            in_channels = out_channels 
            
        layer = nn.Conv2d(in_channels, 1, 4, 1,1, padding_mode = 'reflect')
        layers.append(layer)
            
            
        self.model = nn.Sequential(*layers)
        
        
    def block(self, in_channels, out_channels, stride):
        conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size = 4, stride = stride, padding = 1, bias = True, padding_mode = 'reflect'),
            nn.InstanceNorm2d(out_channels), 
            nn.LeakyReLU(0.2, inplace = True),
        )
        return conv 
        
    def forward(self, x):
        x = self.initial(x)
        return torch.sigmoid(self.model(x))
    
    
        

In [4]:
x = torch.randn((5, 3, 256, 256))
model = Discriminator()
preds = model(x)
print(preds.shape)

torch.Size([5, 1, 30, 30])


In [5]:
model

Discriminator(
  (initial): Sequential(
    (0): Conv2d(3, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), padding_mode=reflect)
    (1): LeakyReLU(negative_slope=0.2, inplace=True)
  )
  (model): Sequential(
    (0): Sequential(
      (0): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), padding_mode=reflect)
      (1): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
      (2): LeakyReLU(negative_slope=0.2, inplace=True)
    )
    (1): Sequential(
      (0): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), padding_mode=reflect)
      (1): InstanceNorm2d(256, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
      (2): LeakyReLU(negative_slope=0.2, inplace=True)
    )
    (2): Sequential(
      (0): Conv2d(256, 512, kernel_size=(4, 4), stride=(1, 1), padding=(1, 1), padding_mode=reflect)
      (1): InstanceNorm2d(512, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
      

# Generator 

In [6]:
class ResidualBlock(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.block = nn.Sequential(
            self.convBlock(channels, channels, kernel_size=3, padding = 1),
            self.convBlock(channels, channels, kernel_size = 3, padding = 1),
        
        )
        
        
    def convBlock(self, in_channels, out_channels,down = True, use_act = True, **kwargs):
        block = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, **kwargs, padding_mode = 'reflect')
            if down 
            else nn.ConvTranspose2d(in_channels, out_channels, **kwargs),
            nn.InstanceNorm2d(out_channels),
            nn.ReLU(inplace = True) if use_act else nn.Identity()
            
        )
        
        return block 
    
    def forward(self, x):
        return x+ self.block(x)     

In [7]:
class Generator(nn.Module):
    def __init__(self,img_channels=3, num_features = 64, num_residuals = 9):
        super().__init__()
        
        self.initial = nn.Sequential(
            nn.Conv2d(img_channels, num_features, 7, 1, padding = 3, padding_mode = 'reflect'),
            nn.InstanceNorm2d(num_features),
            nn.ReLU(inplace = True),
        )
        
        self.down_blocks = nn.ModuleList(
            [
                self.convBlock(num_features, num_features*2, kernel_size = 3, stride = 2, padding =1),
                self.convBlock(num_features*2, num_features*4, kernel_size = 3, stride = 2, padding =1)
            ]
        )
        
        self.transformer_block = nn.Sequential(
            *[ResidualBlock(num_features*4) for _ in range(num_residuals)]
        )
        
        
        self.up_blocks = nn.ModuleList(
            [
                self.convBlock(num_features*4, num_features*2, down = False, kernel_size = 3, stride =2, padding = 1, output_padding = 1),
                self.convBlock(num_features*2, num_features, down = False, kernel_size = 3, stride = 2, padding = 1, output_padding = 1)
            ]
        )
        
        self.last = nn.Conv2d(num_features, img_channels, kernel_size = 7, stride = 1, padding = 3, padding_mode = 'reflect')
        
        
    def convBlock(self, in_channels, out_channels,down = True, use_act = True, **kwargs):
        block = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, **kwargs, padding_mode = 'reflect')
            if down 
            else nn.ConvTranspose2d(in_channels, out_channels, **kwargs),
            nn.InstanceNorm2d(out_channels),
            nn.ReLU(inplace = True) if use_act else nn.Identity()
            
        )
        
        return block 
    
    def forward(self, x):
        x = self.initial(x)
        for layer in self.down_blocks:
            x = layer(x)
            
        x = self.transformer_block(x)
        
        for layer in self.up_blocks:
            x = layer(x)
            
        return torch.tanh(self.last((x)))
    

In [8]:
def test():
    img_channels = 3
    img_size = 256
    x = torch.randn((2, img_channels, img_size, img_size))
    gen = Generator(img_channels, 9)
    print(gen(x).shape)

In [9]:
test()

torch.Size([2, 3, 256, 256])


# Dataset

In [10]:
from PIL import Image 
import os 
from torch.utils.data import Dataset
from torch.utils.data import DataLoader 
import numpy as np 

In [11]:
class HorseZebraDataset(Dataset):
    def __init__(self, root_zebra, root_horse, transform = None):
        self.root_zebra = root_zebra 
        self.root_horse = root_horse
        self.transform = transform 
        
        self.zebra_images = os.listdir(root_zebra)
        self.horse_images = os.listdir(root_horse)
        self.length_dataset = max(len(self.zebra_images), len(self.horse_images)) # 1000, 1500
        self.zebra_len = len(self.zebra_images)
        self.horse_len = len(self.horse_images)
        
    def __len__(self):
        return self.length_dataset 
    
    def __getitem__(self, index):
        zebra_img = self.zebra_images[index % self.zebra_len]
        horse_img = self.horse_images[index % self.horse_len]
        
        zebra_path = os.path.join(self.root_zebra, zebra_img)
        horse_path = os.path.join(self.root_horse, horse_img)
        
        zebra_image = np.array(Image.open(zebra_path).convert("RGB"))
        horse_image = np.array(Image.open(horse_path).convert("RGB"))
        
        if self.transform: 
            augmentations = self.transform(image= zebra_image, image0 = horse_image)
            zebra_image = augmentations["image"]
            horse_image = augmentations["image0"]
            
        return zebra_image, horse_image 
    

# Utils 

In [12]:
import torch , random, os, numpy as np 
import torch.nn as nn
#import config
import copy 

In [13]:
def save_checkpoint(model, optimizer, filename = 'mycheckpoint.pth.tar'):
    print( "=> Saving Checkpoint")
    checkpoint = {
        "state_dict": model.state_dict(),
        "optimizer": optimizer.state_dict(),
    }
    
    torch.save(checkpoint, filename)
    
def load_checkpoint(checkpoint_file, model, optimizer, lr):
    print("=> Loading checkpoint")
    checkpoint = torch.load(checkpoint_file, map_location = config.DEVICE)
    model.load_state_dict(checkpoint['state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer'])
    
    
    for param_group in optimizer.param_groups: 
        param_group["lr"] = lr
        

def seed_everything(seed = 42):
    os.environ['PYTHONHASHSEED'] = str(seed)
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False 
    
    

In [14]:
import torch.optim as optim 
from tqdm import tqdm 
from torchvision.utils import save_image 

In [15]:
def train_fn(disc_H, disc_Z, gen_Z, gen_H, loader, opt_disc, opt_gen, l1, mse, d_scaler, g_scaler):
    H_reals = 0
    H_fakes = 0 
    
    loop = tqdm(loader, leave = True)
    
    for idx, (zebra, horse) in enumerate(loop):
        zebra = zebra.to(config.DEVICE)
        horse = horse.to(config.DEVICE)
        
        # train discrimininator H and Z 
        
        with torch.cuda.amp.autocast():
            fake_horse = gen_H(zebra) 
            
            D_H_real = disc_H(horse)
            D_H_fake = disc_H(fake_horse.detach())
            
            H_reals += D_H_real.mean().item()
            H_fakes += D_H_fake.mean().item()
            
            D_H_real_loss = mse(D_H_real, torch.ones_like(D_H_real))
            D_H_fake_loss = mse(D_H_fake, torch.zeros_like(D_H_fake))
            
            D_H_loss = D_H_real_loss + D_H_fake_loss

            
            
            # for zebra generator 
            fake_zebra = gen_Z(horse) 
            D_Z_real = disc_Z(zebra)
            D_Z_fake = disc_Z(fake_zebra.detach())
            
            D_Z_real_loss = mse(D_Z_real, torch.ones_like(D_Z_real))
            D_Z_fake_loss = mse(D_Z_fake, torch.ones_like(D_Z_fake))
            
            D_Z_loss = D_Z_real_loss + D_Z_fake_loss 
            D_loss = (D_H_loss + D_Z_loss)/2 
            
        opt_disc.zero_grad()
        d_scaler.scale(D_loss).backward()
        d_scaler.step(opt_disc)
        d_scaler.update()
        
        
        # train Generators H and Z 
        with torch.cuda.amp.autocast():
            # adversarial loss for both generators
            D_H_fake = disc_H(fake_horse)
            D_Z_fake = disc_Z(fake_zebra)
            loss_G_H = mse(D_H_fake, torch.ones_like(D_H_fake))
            loss_G_Z = mse(D_Z_fake, torch.ones_like(D_Z_fake))

            # cycle loss
            cycle_zebra = gen_Z(fake_horse)
            cycle_horse = gen_H(fake_zebra)
            cycle_zebra_loss = l1(zebra, cycle_zebra)
            cycle_horse_loss = l1(horse, cycle_horse)

            # identity loss (remove these for efficiency if you set lambda_identity=0)
            identity_zebra = gen_Z(zebra)
            identity_horse = gen_H(horse)
            identity_zebra_loss = l1(zebra, identity_zebra)
            identity_horse_loss = l1(horse, identity_horse)

            # add all togethor
            G_loss = (
                loss_G_Z
                + loss_G_H
                + cycle_zebra_loss * config.LAMBDA_CYCLE
                + cycle_horse_loss * config.LAMBDA_CYCLE
                + identity_horse_loss * config.LAMBDA_IDENTITY
                + identity_zebra_loss * config.LAMBDA_IDENTITY
            )

        opt_gen.zero_grad()
        g_scaler.scale(G_loss).backward()
        g_scaler.step(opt_gen)
        g_scaler.update()

        if idx % 200 == 0:
            save_image(fake_horse*0.5+0.5, str(f"saved_images/horse_{idx}.png"))
            save_image(fake_zebra*0.5+0.5, str(f"saved_images/zebra_{idx}.png"))

        loop.set_postfix(H_real=H_reals/(idx+1), H_fake=H_fakes/(idx+1))
            

In [16]:
def run():
    disc_H = Discriminator().to(config.DEVICE)
    disc_Z = Discriminator().to(config.DEVICE)
    gen_H = Generator().to(config.DEVICE)
    gen_Z = Generator().to(config.DEVICE)
    
    opt_disc = optim.Adam(
    
        list(disc_H.parameters()) + list(disc_Z.parameters()),
        lr = config.LEARNING_RATE,
        betas = (0.5, 0.999)
    )
    opt_gen = optim.Adam(
        list(gen_Z.parameters()) + list(gen_H.parameters()),
        lr=config.LEARNING_RATE,
        betas=(0.5, 0.999),
    )
    
    L1 = nn.L1Loss()
    mse = nn.MSELoss()
    
    if config.LOAD_MODEL:
        load_checkpoint(
            config.CHECKPOINT_GEN_H, gen_H, opt_gen, config.LEARNING_RATE,
        )
        load_checkpoint(
            config.CHECKPOINT_GEN_Z, gen_Z, opt_gen, config.LEARNING_RATE,
        )
        load_checkpoint(
            config.CHECKPOINT_CRITIC_H, disc_H, opt_disc, config.LEARNING_RATE,
        )
        load_checkpoint(
            config.CHECKPOINT_CRITIC_Z, disc_Z, opt_disc, config.LEARNING_RATE,
        )
        
    root_horse = "data/horse2zebra/horse2zebra/trainA"
    root_zebra = "data/horse2zebra/horse2zebra/trainB"
    dataset = HorseZebraDataset(root_zebra,root_horse, config.transforms)
    val_dataset = HorseZebraDataset("data/horse2zebra/horse2zebra/testB", "data/horse2zebra/horse2zebra/testA", config.transforms)
    loader = DataLoader(dataset, batch_size = config.BATCH_SIZE, 
                       shuffle = True, 
                       num_workers = config.NUM_WORKERS,
                       pin_memory = True)
    val_loader = DataLoader(
        val_dataset, batch_size = 1, shuffle = False, pin_memory = True
        )
    
    g_scaler = torch.cuda.amp.GradScaler()
    d_scaler = torch.cuda.amp.GradScaler()
    
    for epoch in range(config.NUM_EPOCHS):
        train_fn(disc_H, disc_Z, gen_Z, gen_H, loader, opt_disc, opt_gen, L1, mse, d_scaler, g_scaler)
        if config.SAVE_MODEL:
            save_checkpoint(gen_H, opt_gen, filename=config.CHECKPOINT_GEN_H)
            save_checkpoint(gen_Z, opt_gen, filename=config.CHECKPOINT_GEN_Z)
            save_checkpoint(disc_H, opt_disc, filename=config.CHECKPOINT_CRITIC_H)
            save_checkpoint(disc_Z, opt_disc, filename=config.CHECKPOINT_CRITIC_Z)


In [17]:
run()

  0%|                                                   | 1/1334 [00:19<7:15:42, 19.61s/it, H_fake=0.515, H_real=0.518]


KeyboardInterrupt: 