In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
import os
from PIL import Image
import numpy as np
import itertools
import random
import matplotlib.pyplot as plt

In [None]:
class AnimeDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.images = os.listdir(root_dir)
        
    def __len__(self):
        return len(self.images)
    
    def __getitem__(self, idx):
        img_path = os.path.join(self.root_dir, self.images[idx])
        image = Image.open(img_path).convert('RGB')
        
        if self.transform:
            image = self.transform(image)
            
        return image


In [None]:
class ResidualBlock(nn.Module):
    def __init__(self, in_channels):
        super(ResidualBlock, self).__init__()
        
        self.block = nn.Sequential(
            nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1),
            nn.InstanceNorm2d(in_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1),
            nn.InstanceNorm2d(in_channels)
        )
        
    def forward(self, x):
        return x + self.block(x)

In [None]:
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        
        # Initial convolution block
        self.conv1 = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=7, stride=1, padding=3),
            nn.InstanceNorm2d(64),
            nn.ReLU(inplace=True)
        )
        
        # Downsampling blocks
        self.down_sampling = nn.Sequential(
            nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
            nn.InstanceNorm2d(128),
            nn.ReLU(inplace=True),
            
            nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1),
            nn.InstanceNorm2d(256),
            nn.ReLU(inplace=True)
        )
        
        # Residual blocks
        self.residual_blocks = nn.Sequential(
            *[ResidualBlock(256) for _ in range(6)]
        )
        
        # Upsampling blocks
        self.up_sampling = nn.Sequential(
            nn.ConvTranspose2d(256, 128, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.InstanceNorm2d(128),
            nn.ReLU(inplace=True),
            
            nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.InstanceNorm2d(64),
            nn.ReLU(inplace=True)
        )
        
        # Output layer
        self.conv2 = nn.Sequential(
            nn.Conv2d(64, 3, kernel_size=7, stride=1, padding=3),
            nn.Tanh()
        )
        
    def forward(self, x):
        x = self.conv1(x)
        x = self.down_sampling(x)
        x = self.residual_blocks(x)
        x = self.up_sampling(x)
        x = self.conv2(x)
        return x


In [None]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        
        self.model = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2, inplace=True),
            
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
            nn.InstanceNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            
            nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1),
            nn.InstanceNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
            
            nn.Conv2d(256, 512, kernel_size=4, stride=1, padding=1),
            nn.InstanceNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),
            
            nn.Conv2d(512, 1, kernel_size=4, stride=1, padding=1)
        )
        
    def forward(self, x):
        return self.model(x)

In [None]:
def show_images(images, title=''):
    """Display a batch of images"""
    plt.figure(figsize=(10, 10))
    plt.title(title)
    plt.imshow(np.transpose(utils.make_grid(images[:16], nrow=4, padding=2, normalize=True).cpu(), (1, 2, 0)))
    plt.axis('off')
    plt.show()

def save_checkpoint(state, filename="checkpoint.pth.tar"):
    """Save checkpoint"""
    torch.save(state, filename)
    print(f"Checkpoint saved: {filename}")

In [None]:
class Config:
    def __init__(self):
        self.batch_size = 4  
        
        # Learning parameters
        self.lr = 0.0002
        self.beta1 = 0.5
        self.beta2 = 0.999
        
        self.num_epochs = 200  
        
        self.lambda_cycle = 8.0  
        self.lambda_identity = 3.0  
        
        self.image_size = 128  
        
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
def load_data(anime_dir, real_dir, config):
    transform = transforms.Compose([
        transforms.Resize(config.image_size),
        transforms.RandomCrop(config.image_size),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])
    
    # Load datasets
    anime_dataset = AnimeDataset(anime_dir, transform=transform)
    real_dataset = AnimeDataset(real_dir, transform=transform)
    
    # Use 40% of anime images for training
    anime_size = int(len(anime_dataset) * 0.4)
    indices = random.sample(range(len(anime_dataset)), anime_size)
    anime_dataset = torch.utils.data.Subset(anime_dataset, indices)
    
    anime_loader = DataLoader(anime_dataset, batch_size=config.batch_size, shuffle=True)
    real_loader = DataLoader(real_dataset, batch_size=config.batch_size, shuffle=True)
    
    return anime_loader, real_loader

In [None]:
def train_style_transfer(anime_dir, real_dir):
    config = Config()
    
    # Load data
    anime_loader, real_loader = load_data(anime_dir, real_dir, config)
    
    # Initialize networks
    G_real2anime = Generator().to(config.device)
    G_anime2real = Generator().to(config.device)
    D_anime = Discriminator().to(config.device)
    D_real = Discriminator().to(config.device)
    
    # Loss functions
    criterion_GAN = nn.MSELoss()
    criterion_cycle = nn.L1Loss()
    criterion_identity = nn.L1Loss()
    
    # Optimizers
    optimizer_G = optim.Adam(
        itertools.chain(G_real2anime.parameters(), G_anime2real.parameters()),
        lr=config.lr, betas=(config.beta1, config.beta2)
    )
    optimizer_D_A = optim.Adam(D_anime.parameters(), lr=config.lr, betas=(config.beta1, config.beta2))
    optimizer_D_B = optim.Adam(D_real.parameters(), lr=config.lr, betas=(config.beta1, config.beta2))
    
    max_batches_per_epoch = 100  # i have reduced batch size due to  memory issues

    # Training loop
    for epoch in range(config.num_epochs):
        for i, (real_imgs, anime_imgs) in enumerate(zip(real_loader, anime_loader)):
            if i >= max_batches_per_epoch:
                break  # End epoch early if max_batches_per_epoch is reached

            real_imgs = real_imgs.to(config.device)
            anime_imgs = anime_imgs.to(config.device)
            
            # Generate fake images
            fake_anime = G_real2anime(real_imgs)
            fake_real = G_anime2real(anime_imgs)
            
            # Train Generators
            optimizer_G.zero_grad()
            
            # Identity loss
            loss_id_anime = criterion_identity(G_real2anime(anime_imgs), anime_imgs)
            loss_id_real = criterion_identity(G_anime2real(real_imgs), real_imgs)
            
            # GAN loss
            loss_GAN_real2anime = criterion_GAN(D_anime(fake_anime), torch.ones_like(D_anime(fake_anime)))
            loss_GAN_anime2real = criterion_GAN(D_real(fake_real), torch.ones_like(D_real(fake_real)))
            
            # Cycle loss
            recovered_real = G_anime2real(fake_anime)
            recovered_anime = G_real2anime(fake_real)
            loss_cycle_real = criterion_cycle(recovered_real, real_imgs)
            loss_cycle_anime = criterion_cycle(recovered_anime, anime_imgs)
            
            # Total generator loss
            loss_G = (loss_GAN_real2anime + loss_GAN_anime2real +
                      (loss_cycle_real + loss_cycle_anime) * config.lambda_cycle +
                      (loss_id_anime + loss_id_real) * config.lambda_identity)
            
            loss_G.backward()
            optimizer_G.step()
            
            # Train Discriminator A (Anime)
            optimizer_D_A.zero_grad()
            loss_real_A = criterion_GAN(D_anime(anime_imgs), torch.ones_like(D_anime(anime_imgs)))
            loss_fake_A = criterion_GAN(D_anime(fake_anime.detach()), torch.zeros_like(D_anime(fake_anime)))
            loss_D_A = (loss_real_A + loss_fake_A) * 0.5
            loss_D_A.backward()
            optimizer_D_A.step()
            
            # Train Discriminator B (Real)
            optimizer_D_B.zero_grad()
            loss_real_B = criterion_GAN(D_real(real_imgs), torch.ones_like(D_real(real_imgs)))
            loss_fake_B = criterion_GAN(D_real(fake_real.detach()), torch.zeros_like(D_real(fake_real)))
            loss_D_B = (loss_real_B + loss_fake_B) * 0.5
            loss_D_B.backward()
            optimizer_D_B.step()
            
            if i % 10 == 0:  # Display update every 10 batches instead of every 100
                print(f"Epoch [{epoch}/{config.num_epochs}] Batch [{i}/{max_batches_per_epoch}] "
                      f"Loss_D: {(loss_D_A + loss_D_B).item():.4f} "
                      f"Loss_G: {loss_G.item():.4f} "
                      f"Loss_cycle: {(loss_cycle_real + loss_cycle_anime).item():.4f}")
                
                # Display current results
                show_images(torch.cat([real_imgs, fake_anime, recovered_real], 0), 
                          f'Epoch {epoch} - Real → Anime → Recovered Real')
        
        # Save checkpoint
        if (epoch + 1) % 10 == 0:
            checkpoint = {
                'epoch': epoch,
                'G_real2anime': G_real2anime.state_dict(),
                'G_anime2real': G_anime2real.state_dict(),
                'D_anime': D_anime.state_dict(),
                'D_real': D_real.state_dict(),
                'optimizer_G': optimizer_G.state_dict(),
                'optimizer_D_A': optimizer_D_A.state_dict(),
                'optimizer_D_B': optimizer_D_B.state_dict()
            }
            save_checkpoint(checkpoint, f'checkpoint_epoch_{epoch+1}.pth.tar')


In [None]:
def convert_to_anime(image_path, generator_path):
    config = Config()
    
    # Load and preprocess the image
    transform = transforms.Compose([
        transforms.Resize(config.image_size),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])
    
    image = Image.open(image_path).convert('RGB')
    image = transform(image).unsqueeze(0).to(config.device)
    
    # Load the generator
    generator = Generator().to(config.device)
    checkpoint = torch.load(generator_path)
    generator.load_state_dict(checkpoint['G_real2anime'])
    generator.eval()
    
    # Convert image
    with torch.no_grad():
        anime_image = generator(image)
        
    # Post-process and save
    anime_image = (anime_image * 0.5 + 0.5).clamp(0, 1)
    anime_image = anime_image.squeeze().cpu().numpy()
    anime_image = np.transpose(anime_image, (1, 2, 0))
    anime_image = (anime_image * 255).astype(np.uint8)
    
    return Image.fromarray(anime_image)

In [None]:
anime_dir = '/kaggle/input/selfie2anime/trainB'
real_dir = '/kaggle/input/selfie2anime/trainA'

# Train the model
train_style_transfer(anime_dir, real_dir)



In [None]:

converted_image = convert_to_anime('/kaggle/input/selfie2anime/testA/female_10328.jpg', '/kaggle/working/checkpoint_epoch_30.pth.tar')
converted_image.save('anime_style_output.jpg')
