CycleGAN is a type of Generative Adversarial Network (GAN) that enables image-to-image translation in an unsupervised manner. This means it can convert images from one domain to another without needing paired examples. Common applications include style transfer, season transfer, and photo enhancement.

**Why Add Self-Attention to CycleGAN?**


Self-attention mechanisms allow the network to weigh the importance of different regions in the input data, regardless of their position. In the context of image translation:

Detail Preservation: Attention can help preserve details by focusing on relevant features, which is especially beneficial in complex translations like changing facial expressions, altering seasons in landscapes, or converting paintings to photographs.
Global Context: It integrates global contextual information better than convolutions alone, which primarily capture local features.

In [None]:
import pandas as pd
import numpy as np
import os
import glob
import numpy as np
from PIL import Image
from torchvision import transforms

import cv2

import os
import glob
import torch
from PIL import Image
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms

import torch
from torch import nn
from torch.optim import Adam

In [None]:
IMG_HEIGHT = 256
IMG_WIDTH = 256
image_transforms = transforms.Compose([
    transforms.Resize(int(IMG_HEIGHT * 1.12), Image.BICUBIC),
    transforms.RandomCrop((IMG_HEIGHT, IMG_WIDTH)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])


In [None]:
class CustomDataset(Dataset):
    def __init__(self, root, mode='train', transform=None):
        self.root = root
        self.transform = transform
        self.mode = mode
        if self.mode == 'train':
            self.files_A = sorted(glob.glob(os.path.join(root, 'monet_jpg', '*.*'))[:250])
            self.files_B = sorted(glob.glob(os.path.join(root, 'photo_jpg', '*.*'))[:250])
        elif self.mode == 'test':
            self.files_A = sorted(glob.glob(os.path.join(root, 'monet_jpg', '*.*'))[250:])
            self.files_B = sorted(glob.glob(os.path.join(root, 'photo_jpg', '*.*'))[250:301])

    def __getitem__(self, index):
        image_A = Image.open(self.files_A[index % len(self.files_A)])
        image_B = Image.open(self.files_B[index % len(self.files_B)])

        if self.transform:
            image_A = self.transform(image_A)
            image_B = self.transform(image_B)

        return {'image_A': image_A, 'image_B': image_B}

    def __len__(self):
        return max(len(self.files_A), len(self.files_B))

In [None]:
data_dir = '/kaggle/input/gan-getting-started'

from torchvision import transforms

transforms_ = transforms.Compose([
    transforms.Lambda(lambda x: x.convert('RGB')), 
    transforms.Resize((256, 256)),  
    transforms.ToTensor(), 
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])


In [None]:
train_dataloader = DataLoader(
    CustomDataset(root=data_dir, mode='train', transform=transforms_),
    batch_size=1,  # Smaller batch size for demonstration
    shuffle=True,
    num_workers=2  # Adjust according to your system capabilities
)

# Create the validation/testing dataset loader
val_dataloader = DataLoader(
    CustomDataset(root=data_dir, mode='test', transform=transforms_),
    batch_size=5,  # Larger batch size for validation
    shuffle=True,
    num_workers=2
)

Adding self attention inside residual layers

In [None]:
from torch import nn
class SelfAttention(nn.Module):
    def __init__(self, in_dim):
        super(SelfAttention, self).__init__()
        self.query_conv = nn.Conv2d(in_dim, in_dim // 8, 1)
        self.key_conv = nn.Conv2d(in_dim, in_dim // 8, 1)
        self.value_conv = nn.Conv2d(in_dim, in_dim, 1)
        self.softmax = nn.Softmax(dim=-1)  # Apply softmax to the last dimension

    def forward(self, x):
        batch, channels, height, width = x.size()
        query = self.query_conv(x).view(batch, -1, height * width).permute(0, 2, 1)
        key = self.key_conv(x).view(batch, -1, height * width)
        value = self.value_conv(x).view(batch, -1, height * width)

        energy = torch.bmm(query, key)  # Batch matrix-matrix product
        attention = self.softmax(energy)
        out = torch.bmm(value, attention.permute(0, 2, 1))
        out = out.view(batch, channels, height, width)

        return out + x  # Add the input x directly to the output of the self-attention



In [None]:
class ResidualBlock(nn.Module):
    def __init__(self, in_features):
        super(ResidualBlock, self).__init__()
        
        self.conv1 = nn.Sequential(
            nn.ReflectionPad2d(1),
            nn.Conv2d(in_features, in_features, 3),
            nn.InstanceNorm2d(in_features),
            nn.ReLU(inplace=True)
        )
        
        self.attention1 = SelfAttention(in_features)
        
        self.conv2 = nn.Sequential(
            nn.ReflectionPad2d(1),
            nn.Conv2d(in_features, in_features, 3),
            nn.InstanceNorm2d(in_features),
            nn.ReLU(inplace=True)
        )
        
        self.attention2 = SelfAttention(in_features)
        
        self.conv3 = nn.Sequential(
            nn.ReflectionPad2d(1),
            nn.Conv2d(in_features, in_features, 3),
            nn.InstanceNorm2d(in_features)
        )

    def forward(self, x):
        out = self.conv1(x)
        out = self.attention1(out)
        out = self.conv2(out)
        out = self.attention2(out)
        out = self.conv3(out)
        return x + out  # Skip connection from input to output

In [None]:
class GeneratorResNet(nn.Module):
    def __init__(self, input_shape, n_residual_blocks=9):
        super(GeneratorResNet, self).__init__()
        channels, img_height, img_width = input_shape
        
        # Initial convolution block
        model = [
            nn.ReflectionPad2d(3),
            nn.Conv2d(channels, 64, 7),
            nn.InstanceNorm2d(64),
            nn.ReLU(inplace=True),
        ]

        # Downsampling
        in_features = 64
        out_features = in_features*2
        for _ in range(2):
            model += [
                nn.Conv2d(in_features, out_features, 3, stride=2, padding=1),
                nn.InstanceNorm2d(out_features),
                nn.ReLU(inplace=True),
            ]
            in_features = out_features
            out_features = in_features*2

        # Residual blocks
        for _ in range(n_residual_blocks):
            model.append(ResidualBlock(in_features))

        # Upsampling
        out_features = in_features // 2
        for _ in range(2):
            model += [
                nn.ConvTranspose2d(in_features, out_features, 3, stride=2, padding=1, output_padding=1),
                nn.InstanceNorm2d(out_features),
                nn.ReLU(inplace=True),
            ]
            in_features = out_features
            out_features = in_features // 2

        # Output layer
        model += [
            nn.ReflectionPad2d(3),
            nn.Conv2d(64, channels, 7),
            nn.Tanh()
        ]

        self.model = nn.Sequential(*model)

    def forward(self, x):
        return self.model(x)

In [None]:
class Discriminator(nn.Module):
    def __init__(self, input_shape):
        super(Discriminator, self).__init__()
        channels, _, _ = input_shape

        def discriminator_block(in_filters, out_filters, normalization=True):
            """Returns downsampling layers of each discriminator block"""
            layers = [nn.Conv2d(in_filters, out_filters, 4, stride=2, padding=1)]
            if normalization:
                layers.append(nn.InstanceNorm2d(out_filters))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers

        self.model = nn.Sequential(
            *discriminator_block(channels, 64, normalization=False),
            *discriminator_block(64, 128),
            *discriminator_block(128, 256),
            *discriminator_block(256, 512),
            nn.Conv2d(512, 1, 4, padding=1)
        )

    def forward(self, img):
        return self.model(img)

In [None]:

channels, img_height, img_width = 3, 256, 256
input_shape = (channels, img_height, img_width)
n_residual_blocks = 9

G_AB = GeneratorResNet(input_shape, n_residual_blocks)
G_BA = GeneratorResNet(input_shape, n_residual_blocks)
D_A = Discriminator(input_shape)
D_B = Discriminator(input_shape)

criterion_GAN = torch.nn.MSELoss()
criterion_cycle = torch.nn.L1Loss()
criterion_identity = torch.nn.L1Loss()

optimizer_G = Adam(list(G_AB.parameters()) + list(G_BA.parameters()), lr=0.0002, betas=(0.5, 0.999))
optimizer_D_A = Adam(D_A.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizer_D_B = Adam(D_B.parameters(), lr=0.0002, betas=(0.5, 0.999))

In [None]:
cuda = torch.cuda.is_available()

if cuda:
    G_AB = G_AB.cuda()
    G_BA = G_BA.cuda()
    D_A = D_A.cuda()
    D_B = D_B.cuda()
    
    criterion_GAN.cuda()
    criterion_cycle.cuda()
    criterion_identity.cuda()

In [None]:
def weights_init_normal(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        torch.nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm2d') != -1:
        torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
        torch.nn.init.constant_(m.bias.data, 0.0)

In [None]:
G_AB.apply(weights_init_normal)
G_BA.apply(weights_init_normal)
D_A.apply(weights_init_normal)
D_B.apply(weights_init_normal)

In [None]:
class LambdaLR:
    def __init__(self, n_epochs, offset, decay_start_epoch):
        self.n_epochs = n_epochs
        self.offset = offset
        self.decay_start_epoch = decay_start_epoch
    
    def get_lr_lambda(self):
        """Returns a lambda function for the learning rate scheduler."""
        return lambda epoch: 1.0 - max(0, epoch + self.offset - self.decay_start_epoch) / (self.n_epochs - self.decay_start_epoch)


In [None]:
n_epochs = 15
epoch = 0
decay_epoch = 5


# Assuming optimizer_G is already defined
lr_scheduler_G = torch.optim.lr_scheduler.LambdaLR(
    optimizer_G,
    lr_lambda=LambdaLR(n_epochs, epoch, decay_epoch).get_lr_lambda()
)

lr_scheduler_D_A = torch.optim.lr_scheduler.LambdaLR(
    optimizer_D_A,
    lr_lambda=LambdaLR(n_epochs, epoch, decay_epoch).get_lr_lambda()
)
lr_scheduler_D_B = torch.optim.lr_scheduler.LambdaLR(
    optimizer_D_B,
    lr_lambda=LambdaLR(n_epochs, epoch, decay_epoch).get_lr_lambda()
)

In [None]:
def train_cycle_gan(dataloader, G_AB, G_BA, D_A, D_B, optimizer_G, optimizer_D_A, optimizer_D_B, criterion_GAN, criterion_cycle, criterion_identity, n_epochs=100, device='cuda'):
    lambda_cycle = 10.0  # Weight for cycle-consistency loss
    lambda_identity = 0.5 * lambda_cycle  # Weight for identity loss, often half of lambda_cycle

    for epoch in range(n_epochs):
        for i, batch in enumerate(dataloader):
            real_A = batch['image_A'].to(device)
            real_B = batch['image_B'].to(device)

            valid = torch.ones(real_A.size(0), 1, requires_grad=False).to(device)
            fake = torch.zeros(real_A.size(0), 1, requires_grad=False).to(device)

            optimizer_G.zero_grad()

            loss_id_A = criterion_identity(G_BA(real_A), real_A)
            loss_id_B = criterion_identity(G_AB(real_B), real_B)
            loss_identity = (loss_id_A + loss_id_B) / 2

            fake_B = G_AB(real_A)
            loss_GAN_AB = criterion_GAN(D_B(fake_B), valid)
            fake_A = G_BA(real_B)
            loss_GAN_BA = criterion_GAN(D_A(fake_A), valid)
            loss_GAN = (loss_GAN_AB + loss_GAN_BA) / 2

            # Cycle loss
            recov_A = G_BA(fake_B)
            loss_cycle_A = criterion_cycle(recov_A, real_A)
            recov_B = G_AB(fake_A)
            loss_cycle_B = criterion_cycle(recov_B, real_B)
            loss_cycle = (loss_cycle_A + loss_cycle_B) / 2
            total_loss_G = loss_GAN + lambda_cycle * loss_cycle + lambda_identity * loss_identity
            total_loss_G.backward()
            optimizer_G.step()
            optimizer_D_A.zero_grad()

            loss_real_A = criterion_GAN(D_A(real_A), valid)
            loss_fake_A = criterion_GAN(D_A(fake_A.detach()), fake)
            total_loss_D_A = (loss_real_A + loss_fake_A) / 2
            total_loss_D_A.backward()
            optimizer_D_A.step()

            optimizer_D_B.zero_grad()

            loss_real_B = criterion_GAN(D_B(real_B), valid)
            loss_fake_B = criterion_GAN(D_B(fake_B.detach()), fake)
            total_loss_D_B = (loss_real_B + loss_fake_B) / 2
            total_loss_D_B.backward()
            optimizer_D_B.step()
            print(f'[Epoch {epoch+1}/{n_epochs}] [Batch {i+1}/{len(dataloader)}] [D loss: {total_loss_D_A.item() + total_loss_D_B.item():.6f}] [G loss: {total_loss_G.item():.6f} - (adv: {loss_GAN.item():.6f}, cycle: {lambda_cycle * loss_cycle.item():.6f}, identity: {lambda_identity * loss_identity.item():.6f})]')

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
train_cycle_gan(
    dataloader=train_dataloader,
    G_AB=G_AB,
    G_BA=G_BA,
    D_A=D_A,
    D_B=D_B,
    optimizer_G=optimizer_G,
    optimizer_D_A=optimizer_D_A,
    optimizer_D_B=optimizer_D_B,
    criterion_GAN=criterion_GAN,
    criterion_cycle=criterion_cycle,
    criterion_identity=criterion_identity,
    n_epochs=n_epochs,
    device=device
)

In [None]:
for i, batch in enumerate(val_dataloader):
    image_A = batch['image_A'].to(device)
    image_B = batch['image_B'].to(device)
    print(f'iter : {i}  image_A.size : {image_A.size()}')
    print(f'iter : {i}  image_B.size : {image_B.size()}')

    if i == 10:
        break
        
photo_dir = os.path.join(data_dir, 'photo_jpg')
files = [os.path.join(photo_dir, name) for name in os.listdir(photo_dir)]
len(files)

save_dir = '/kaggle/working/img/'
if not os.path.exists(save_dir):
    os.makedirs(save_dir)
    
for file in os.listdir(save_dir):
    os.remove(os.path.join(save_dir, file))
    
    
batch_size = 1

generate_transforms = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

to_image = transforms.ToPILImage()

G_BA.eval()
for i in range(0, len(files), batch_size):
    # read images
    imgs = []
    for j in range(i, min(len(files), i + batch_size)):
        try:
            img = Image.open(files[j])
            img = generate_transforms(img)
            imgs.append(img)
        except Exception as e:
              print(f"Error processing {files[j]}: {e}")
    
    
    imgs = torch.stack(imgs, 0).to(device) 

    # generate
    fake_imgs = G_BA(imgs).detach().cpu()


# save
    for j in range(fake_imgs.size(0)):
        img = fake_imgs[j].squeeze().permute(1, 2, 0)
        img_arr = img.numpy()
        img_arr = (img_arr - np.min(img_arr)) * 255 / (np.max(img_arr) - np.min(img_arr))
        img_arr = img_arr.astype(np.uint8)

        img = to_image(img_arr)
        _, name = os.path.split(files[i+j])
        img.save(os.path.join(save_dir, name))
