In [1]:
import torch
import torch.nn as nn

In [2]:
class ConvolutionalBlock(nn.Module):
    def __init__(
        self,
        in_channel=int,
        out_channel=int,
        kernel_size=int,
        stride=1,
        padding=0,
        is_downsample=True,
        is_activation=True,
        **kwargs
    ):
        super().__init__()
        if is_downsample:
            self.main = nn.Sequential(
                nn.Conv2d(in_channel, out_channel,kernel_size = kernel_size, padding_mode="reflect", **kwargs),
                nn.InstanceNorm2d(out_channel),
            )
            if is_activation:
                self.main.append(nn.ReLU(inplace=True))
            else:
                self.main.append(nn.Identity())
        else:
            self.main = nn.Sequential(
                 nn.ConvTranspose2d(in_channel, out_channel,kernel_size = kernel_size,**kwargs),
                 nn.InstanceNorm2d(out_channel)
            )
            if is_activation:
                self.main.append(nn.ReLU(inplace=True))
            else:
                self.main.append(nn.Identity())

    def forward(self,x):
        return self.main(x)

In [3]:
class ResidualBlock(nn.Module):
    def __init__(self, in_channel: int,out_channel=256 ):
        super().__init__()
        self.block = nn.Sequential(
            ConvolutionalBlock(in_channel, out_channel, is_activation=True, kernel_size=3, padding=1),
            ConvolutionalBlock(in_channel, out_channel, is_activation=False, kernel_size=3, padding=1),
        )

    def forward(self, x):
        return self.block(x)

In [4]:
class Generator(nn.Module):
    def __init__(
        self,
        in_channel=3,
    ):
        super().__init__()
        channel = [64, 128, 256, 128, 64, 3]
        self.layers_1 = nn.Sequential(
            nn.Conv2d(in_channel, channel[0],kernel_size=7, stride=1, padding=3, padding_mode="reflect"),
            nn.InstanceNorm2d(channel[0]),
            nn.ReLU(inplace=True)
        )
        self.layers_2 =nn.ModuleList(
            [ConvolutionalBlock(channel[0], channel[1],kernel_size=3, stride=2, padding=1, is_downsample=True, is_activation=True),
             ConvolutionalBlock(channel[1], channel[2],kernel_size=3, stride=2, padding=1, is_downsample=True, is_activation=True)]
        )
        self.layers_3 = nn.Sequential(
            *[ResidualBlock(channel[2]) for _ in range(9)]
        )
        self.layers_4 = nn.ModuleList(
            [ConvolutionalBlock(channel[2], channel[3], kernel_size=3, stride=2, padding=1, is_downsample=False, is_activation=True),
             ConvolutionalBlock(channel[3], channel[4], kernel_size=3, stride=2, padding=1, is_downsample=False, is_activation=True)]
        )
        self.layers_5 = nn.Sequential(
            nn.Conv2d(channel[4], channel[5], kernel_size=7, stride=1, padding=3, padding_mode="reflect")
        )
    def forward(self,x):
        x = self.layers_1(x)
        for layer in self.layers_2:
            x = layer(x)

        x = self.layers_3(x)

        for layer in self.layers_4:
            x = layer(x)
        return torch.tanh(self.layers_5(x))

In [5]:
class Discriminator(nn.Module):

    def __init__(self, in_channel=3):
        super().__init__()

        channels = [64, 128, 256, 512]

        def ConvInstanceNormLeakyReLUBlock(
                in_channel,
                out_channel,
                normalize=True,
                kernel_size=4,
                stride=2,
                padding=1,
                activation=None
        ):

            layers =nn.ModuleList(
                [nn.Conv2d(
                    in_channel,
                    out_channel,
                    kernel_size=kernel_size,
                    stride=stride,
                    padding=padding,
                    bias=False if normalize else True)]
            )

            if normalize:
                layers.append(nn.BatchNorm2d(out_channel))

            layers.append(nn.LeakyReLU(0.2, inplace=True) if activation is None else activation)

            return layers

        self.main = nn.Sequential(
            *ConvInstanceNormLeakyReLUBlock(in_channel, channels[0], normalize=False),
            *ConvInstanceNormLeakyReLUBlock(channels[0], channels[1]),
            *ConvInstanceNormLeakyReLUBlock(channels[1], channels[2]),
            *ConvInstanceNormLeakyReLUBlock(channels[2], channels[3], stride=1),
            *ConvInstanceNormLeakyReLUBlock(channels[3], 1, normalize=False, stride=1, activation=nn.Sigmoid())
        )

    def forward(self,x):

        return self.main(x)


In [6]:
import os
import glob

from PIL import Image
from torch.utils.data import Dataset, DataLoader

import torchvision
from torchvision import transforms


class MyDataset(Dataset):
    def __init__(self, transform, mode='train'):
        self.files_A = glob.glob("/kaggle/input/gan-getting-started/monet_jpg/*jpg")
        self.files_B = glob.glob("/kaggle/input/gan-getting-started/photo_jpg/*jpg")
        self.transform = transform
        self.len_A = len(self.files_A)
        self.len_B = len(self.files_B)

    def __len__(self):
        return max(self.len_A, self.len_B)

    def __getitem__(self, idx):
        i = idx % self.len_A
        j = idx % self.len_B

        img_A = Image.open(self.files_A[i])
        img_B = Image.open(self.files_B[j])

        return {"A": self.transform(img_A), "B": self.transform(img_B)}

In [7]:
import numpy as np
from torch.utils import data
from torch.utils.data import DataLoader
import torchvision
from torchvision import transforms

from tqdm import tqdm
import torch.optim as optim

In [8]:
LEARNING_RATE =2e-5
BATCH_SIZE = 1
NUM_EPOCHS = 30
NUM_WORKERS = 2
LAMBDA_CYCLE = 10
DEVICE = torch.device("cuda" if torch.cuda else "cpu")

In [9]:
def train(disc_monet,
          disc_photo,
          monet_generator,
          photo_generator,
          loader,
          opt_disc,
          opt_gen,
          L1,mse,
          d_scaler
          ):
    
    torch.cuda.empty_cache
    monet_reals = 0
    monet_fakes = 0
    loop = tqdm(loader, leave=True)

    for images in loop:
        monet = images['A'].to(DEVICE)
        photo = images['B'].to(DEVICE)

        with torch.cuda.amp.autocast():
            fake_monet = monet_generator(photo)
            D_M_R = disc_monet(monet)
            D_M_F = disc_monet(fake_monet.detach())
            D_M_R_loss = mse(D_M_R, torch.ones_like(D_M_R))
            D_M_F_loss = mse(D_M_F, torch.zeros_like(D_M_F))
            D_M_loss = D_M_R_loss + D_M_F_loss

            fake_photo = photo_generator(monet)
            D_P_R = disc_photo(photo)
            D_P_F = disc_photo(fake_photo.detach())
            D_P_R_loss = mse(D_P_R, torch.ones_like(D_P_R))
            D_P_F_loss = mse(D_P_F, torch.zeros_like(D_P_F))
            D_P_loss = D_P_R_loss + D_P_F_loss

            D_loss = (D_P_loss + D_M_loss) / 2

        opt_disc.zero_grad()
        d_scaler.scale(D_loss).backward()
        d_scaler.step(opt_disc)
        d_scaler.update()

        with torch.cuda.amp.autocast():
            D_M_F = disc_monet(fake_monet)
            D_P_F = disc_photo(fake_photo)
            loss_G_M = mse(D_M_F, torch.ones_like(D_M_F))
            loss_G_P = mse(D_P_F, torch.ones_like(D_P_F))

            cycle_photo = photo_generator(fake_monet)
            cycle_monet = monet_generator(fake_photo)
            cycle_photo_loss = L1(photo,cycle_photo)
            cycle_monet_loss = L1(monet,cycle_monet)

            G_loss = (loss_G_M+loss_G_P+
                      cycle_photo_loss * LAMBDA_CYCLE+
                      cycle_monet_loss * LAMBDA_CYCLE)

        opt_gen.zero_grad()
        g_scaler.scale(G_loss).backward()
        g_scaler.step(opt_gen)
        g_scaler.update()


In [10]:
def main():
    disc_monet = Discriminator().to(DEVICE)
    disc_photo = Discriminator().to(DEVICE)
    monet_generator = Generator().to(DEVICE)
    photo_generator = Generator().to(DEVICE)

    opt_disc = optim.Adam(
        list(disc_monet.parameters()) + list(disc_photo.parameters()),
        lr=LEARNING_RATE,
        betas=(0.5, 0.999),
    )

    opt_gen = optim.Adam(
        list(monet_generator.parameters()) + list(photo_generator.parameters()),
        lr=LEARNING_RATE,
        betas=(0.5, 0.999),
    )

    L1 = nn.L1Loss()
    mse = nn.MSELoss()
    
    train_transform = [
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.5, 0.5, 0.5), 
    std=(0.5, 0.5, 0.5))
    ]
    
    train_transform = transforms.Compose(train_transform)
    train_data = MyDataset(train_transform)

    loader = data.DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True,
                                   num_workers=NUM_WORKERS)
    
    g_scaler = torch.cuda.amp.GradScaler()
    d_scaler = torch.cuda.amp.GradScaler()

    for epoch in range(NUM_EPOCHS):
        train(disc_monet,
          disc_photo,
          monet_generator,
          photo_generator,
          loader,
          opt_disc,
          opt_gen,
          L1,mse,
          d_scaler)

In [11]:
main()

ValueError: num_samples should be a positive integer value, but got num_samples=0