In [1]:
import numpy as np
from PIL import Image
import os
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import albumentations as A
from albumentations.pytorch import ToTensorV2
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim 
from tqdm import tqdm

In [2]:
if torch.backends.mps.is_available():
    device = torch.device("mps")
    x = torch.ones(1, device=device)
    print(x)
    print(device)
else:
    print ("MPS device not found.")

tensor([1.], device='mps:0')
mps


In [3]:
class BlockD(nn.Module):
    def __init__(self, in_channels, out_channels, stride):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 4 , stride, 1, bias=True, padding_mode='reflect'),
            nn.InstanceNorm2d(out_channels),
            nn.LeakyReLU(0.2),
        )

    def forward(self, x):
        return self.conv(x) 

In [4]:
class Discriminator(nn.Module):
    def __init__(self, in_channels, features = [64, 128, 256, 512]):
        super().__init__()
        self.initial = nn.Sequential(
            nn.Conv2d(in_channels,features[0],kernel_size=4, stride=2, padding=1, padding_mode='reflect'),
            nn.LeakyReLU(0.2)
        )

        layers=[]
        in_channels=features[0]
        for feature in features[1:]:
            layers.append(BlockD(in_channels, feature, stride=1 if features[-1]==feature else 2))
            in_channels = feature
        layers.append(nn.Conv2d(in_channels, 1, kernel_size=4, stride=1, padding=1, padding_mode='reflect'))
        self.model = nn.Sequential(*layers)
    
    def forward(self,x):
        x=self.initial(x)
        return torch.sigmoid(self.model(x))

In [5]:
def test():
    x=torch.randn(5, 3, 256, 256)
    model = Discriminator(in_channels=3)
    preds = model(x)
    print(preds.shape)

# the output is a 70x70 patch-GAN
# each pixel in the output 30x30 sees a patch of 70x70 in the input
test()

torch.Size([5, 1, 30, 30])


In [6]:
class BlockG(nn.Module):
    def __init__(self, in_channels, out_channels, down = True, use_act = True, **kwargs):
        super().__init__()
        self.conv= nn.Sequential(
            nn.Conv2d(in_channels, out_channels, padding_mode='reflect', **kwargs)
            if down
            else nn.ConvTranspose2d(in_channels, out_channels, **kwargs),
            nn.InstanceNorm2d(out_channels),
            nn.ReLU(inplace=True) if use_act else nn.Identity()
        )
    
    def forward(self, x):
        return self.conv(x)

In [7]:
class ResidualBlock(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.Block= nn.Sequential(
            BlockG(channels, channels, kernel_size=3, padding=1),
            BlockG(channels, channels, use_act=False, kernel_size=3, padding=1),
        )
    
    def forward(self, x):
        return x + self.Block(x)

In [8]:
class Generator(nn.Module):
    def __init__(self, img_channels, num_features=64,residuals=9):
        super().__init__()
        self.initial = nn.Sequential(
            nn.Conv2d(img_channels, num_features, kernel_size=7, stride=1, padding=3, padding_mode="reflect"),
            nn.ReLU(inplace=True),
        )
        self.down_blocks= nn.ModuleList(
            [
                BlockG(num_features, num_features*2, kernel_size=3, stride=2, padding=1),
                BlockG(num_features*2, num_features*4, kernel_size=3, stride=2, padding=1),
            ]
        )

        self.residual_blocks = nn.Sequential(
            *[ResidualBlock(num_features*4) for _ in range(residuals)]
        )

        self.up_blocks = nn.ModuleList(
            [
                BlockG(num_features*4, num_features*2, down=False, kernel_size=3, stride=2, padding=1, output_padding=1),
                BlockG(num_features*2, num_features, down=False, kernel_size=3, stride=2, padding=1, output_padding=1),
            ]
        )

        self.last = nn.Conv2d(num_features, img_channels, kernel_size=7, stride=1, padding=3, padding_mode='reflect')

    def forward(self, x):
        x = self.initial(x)
        for layer in self.down_blocks:
            x=layer(x)
        for layer in self.residual_blocks:
            x=layer(x)
        for layer in self.up_blocks:
            x=layer(x)
        return torch.tanh(self.last(x))

In [9]:
def test():
    img_channels=3
    img_size=256
    x=torch.randn((2, img_channels, img_size, img_size))
    gen = Generator(img_channels, 9)
    # print(gen)
    print(gen(x).shape)

test()

torch.Size([2, 3, 256, 256])


In [10]:
TRAIN_DIR = "datasets/horse2zebra/train"
TEST_DIR = "datasets/horse2zebra/test"
lr= 2e-4
BATCH_SIZE = 1
NUM_WORKERS = 0
EPOCHS = 1
LAMBDA_IDENTITY = 0
LAMBDA_CYCLE = 10

In [11]:
class HorseZebra(Dataset):
    def __init__(self, horse_dir, zebra_dir, transform = None):
        self.horse_dir = horse_dir
        self.zebra_dir = zebra_dir
        self.horse = os.listdir(horse_dir)
        self.zebra = os.listdir(zebra_dir)
        self.transform = transform
        self.length = max(len(self.horse),len(self.zebra))

    def __len__(self): 
           return self.length
    
    def __getitem__(self, index):
        zebra_img = self.zebra[index % len(self.zebra)]
        horse_img = self.horse[index % len(self.horse)]
        
        zebra_img = np.array(Image.open(f"{self.zebra_dir}/{zebra_img}").convert("RGB"))
        horse_img = np.array(Image.open(f"{self.horse_dir}/{horse_img}").convert("RGB"))
        
        if self.transform:
            augmentations = self.transform(image = zebra_img, image0 = horse_img)
            zebra_img = augmentations["image"]
            horse_img = augmentations["image0"]

        return zebra_img,horse_img

In [12]:
disc_h = Discriminator(in_channels=3).to(device)
disc_z = Discriminator(in_channels=3).to(device)
gen_h = Generator(img_channels=3, residuals=9).to(device)
gen_z = Generator(img_channels=3, residuals = 9).to(device)

In [13]:
opt_disc = optim.Adam(
    list(disc_h.parameters()) + list(disc_z.parameters()),
    lr = lr,
    betas = (0.5,0.999)
)
opt_gen = optim.Adam(
    list(gen_h.parameters()) + list(gen_z.parameters()),
    lr = lr,
    betas = (0.5,0.999)
)

In [14]:
L1 = nn.L1Loss()
mse = nn.MSELoss()
image_list = []
losses = [[],[]]

In [15]:
transforms = A.Compose(
    [
        A.Resize(width=256, height = 256),
        A.HorizontalFlip(p=0.5),
        A.ColorJitter(p=0.1),
        A.Normalize(mean=[0.5,0.5,0.5],std=[0.5,0.5,0.5], max_pixel_value=255),
        ToTensorV2()
    ],
    additional_targets={"image0":"image"},
)

In [16]:
dataset = HorseZebra(horse_dir=f"{TRAIN_DIR}/horse_train", zebra_dir=f"{TRAIN_DIR}/zebra_train", transform=transforms)
loader = DataLoader(
    dataset,
    batch_size=BATCH_SIZE,
    shuffle = True,
    num_workers=NUM_WORKERS,
    pin_memory=True
)

In [17]:
def train_disc(zebra, horse, disc_h, disc_z, gen_h, gen_z, mse):
    fake_horse = gen_h(zebra)
    d_h_real = disc_h(horse)
    d_h_fake = disc_h(fake_horse.detach())
    d_h_real_loss = mse(d_h_real, torch.ones_like(d_h_real))
    d_h_fake_loss = mse(d_h_fake, torch.zeros_like(d_h_fake))
    d_h_loss = d_h_fake_loss + d_h_real_loss

    fake_zebra = gen_z(horse)
    d_z_real = disc_z(zebra)
    d_z_fake = disc_z(fake_zebra.detach())
    d_z_real_loss = mse(d_z_real, torch.ones_like(d_z_real))
    d_z_fake_loss = mse(d_z_fake, torch.zeros_like(d_z_fake))
    d_z_loss = d_z_fake_loss + d_z_real_loss

    d_loss = (d_z_loss + d_h_loss)/2

    return d_loss

In [18]:
def train_gen(zebra, horse, disc_h, disc_z, gen_h, gen_z, L1, mse):
    # Adversarial loss
    fake_horse = gen_h(zebra)
    d_h_fake = disc_h(fake_horse)
    g_h_loss = mse(d_h_fake, torch.ones_like(d_h_fake))

    fake_zebra = gen_z(horse)
    d_z_fake = disc_z(fake_zebra)
    g_z_loss = mse(d_z_fake, torch.ones_like(d_z_fake))

    adv_loss = g_h_loss + g_z_loss

    # Cycle Loss
    cycle_zebra = gen_z(fake_horse)
    cycle_horse = gen_h(fake_zebra)
    c_h_loss = L1(horse, cycle_horse)
    c_z_loss = L1(zebra, cycle_zebra)

    cycle_loss = c_h_loss + c_z_loss
    
    # Identity Loss
    identity_zebra = gen_z(zebra)
    identity_horse = gen_h(horse)
    i_z_loss = L1(zebra, identity_zebra)
    i_h_loss = L1(horse, identity_horse)

    identity_loss = i_h_loss + i_z_loss

    g_loss = adv_loss + cycle_loss * LAMBDA_CYCLE + identity_loss * LAMBDA_IDENTITY

    return g_loss

In [19]:
def train(disc_h, disc_z, gen_h, gen_z, loader, opt_disc, opt_gen, L1, mse):
    loop = tqdm(loader, leave=True)
    for idx, (zebra, horse) in enumerate(loop):
        zebra = zebra.to(device)
        horse = horse.to(device)

        d_loss = train_disc(zebra, horse, disc_h, disc_z, gen_h, gen_z, mse)

        opt_disc.zero_grad()
        d_loss.backward()
        opt_disc.step()

        g_loss = train_gen(zebra, horse, disc_h, disc_z, gen_h, gen_z, L1, mse)

        opt_gen.zero_grad()
        g_loss.backward()
        opt_gen.step()

    fake_horse = gen_h(zebra)
    fake_zebra = gen_z(horse)
    image_list.append([horse*0.5+ 0.5, fake_horse*0.5 + 0.5, zebra*0.5 + 0.5, fake_zebra*0.5 + 0.5])
    losses[0].append(d_loss.item())
    losses[1].append(g_loss.item())

In [20]:
for epoch in range(EPOCHS):
    train(disc_h, disc_z, gen_h, gen_z, loader, opt_disc, opt_gen, L1, mse)
    print("Epoch [{}/{}],  D_loss : {:.4f},   G_loss : {:.4f}".format(epoch+1, EPOCHS, losses[0][-1], losses[1][-1]))

 21%|██        | 278/1334 [20:51<1:19:15,  4.50s/it]  


KeyboardInterrupt: 

In [None]:
plt.plot(losses[0], '-')
plt.plot(losses[1], '-')
plt.xlabel('epoch')
plt.ylabel('loss')
plt.legend(['Discriminator', 'Generator'])
plt.title('Losses');

In [None]:
def plot_images(image_list, nrows = EPOCHS, ncols = 4):
    plt.subplots(nrows, ncols, figsize=(ncols*3 ,nrows*3))
    for (i, imgs) in enumerate(image_list):
            for j in range(1,ncols+1):
                plt.subplot(nrows, ncols, ncols*i+j)
                plt.imshow(imgs[j-1].cpu().detach().squeeze().numpy().transpose(1, 2, 0))
                plt.axis("off")
    plt.tight_layout()
    plt.show()

In [None]:
plot_images(image_list)