In [4]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import os
import numpy as np

In [5]:
# Custom Dataset
class PixelArtDataset(Dataset):
    def __init__(self, input_folder, target_folder, transform=None):
        self.input_files = sorted(
            [f for f in os.listdir(input_folder) if f.endswith('.png')])
        self.target_files = sorted(
            [f for f in os.listdir(target_folder) if f.endswith('.png')])
        self.input_folder = input_folder
        self.target_folder = target_folder
        self.transform = transform

    def __len__(self):
        return len(self.input_files)

    def __getitem__(self, idx):
        input_img = Image.open(os.path.join(
            self.input_folder, self.input_files[idx]))
        target_img = Image.open(os.path.join(
            self.target_folder, self.target_files[idx]))

        if self.transform:
            input_img = self.transform(input_img)
            target_img = self.transform(target_img)

        return input_img, target_img

In [6]:
# Generator
class Generator(nn.Module):
    def __init__(self, channels=64):
        super(Generator, self).__init__()

        # Encoder
        self.encoder = nn.Sequential(
            # 52x52 -> 26x26
            nn.Conv2d(3, channels, 4, stride=2, padding=1),
            nn.LeakyReLU(0.2),

            # 26x26 -> 13x13
            nn.Conv2d(channels, channels * 2, 4, stride=2, padding=1),
            nn.BatchNorm2d(channels * 2),
            nn.LeakyReLU(0.2),

            # 13x13 -> 7x7
            nn.Conv2d(channels * 2, channels * 4, 4, stride=2, padding=1),
            nn.BatchNorm2d(channels * 4),
            nn.LeakyReLU(0.2),
        )

        # Decoder
        self.decoder = nn.Sequential(
            # 7x7 -> 13x13
            nn.ConvTranspose2d(channels * 4, channels * \
                               2, 4, stride=2, padding=1),
            nn.BatchNorm2d(channels * 2),
            nn.ReLU(),

            # 13x13 -> 26x26
            nn.ConvTranspose2d(channels * 2, channels, 4, stride=2, padding=1),
            nn.BatchNorm2d(channels),
            nn.ReLU(),

            # 26x26 -> 52x52
            nn.ConvTranspose2d(channels, 3, 4, stride=2, padding=1),
            nn.Tanh()
        )

    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x

In [7]:
# Discriminator
class Discriminator(nn.Module):
    def __init__(self, channels=64):
        super(Discriminator, self).__init__()

        self.model = nn.Sequential(
            # 52x52 -> 26x26
            nn.Conv2d(6, channels, 4, stride=2, padding=1),
            nn.LeakyReLU(0.2),

            # 26x26 -> 13x13
            nn.Conv2d(channels, channels * 2, 4, stride=2, padding=1),
            nn.BatchNorm2d(channels * 2),
            nn.LeakyReLU(0.2),

            # 13x13 -> 7x7
            nn.Conv2d(channels * 2, channels * 4, 4, stride=2, padding=1),
            nn.BatchNorm2d(channels * 4),
            nn.LeakyReLU(0.2),

            # 7x7 -> 1x1
            nn.Conv2d(channels * 4, 1, 7, stride=1, padding=0),
            nn.Sigmoid()
        )

    def forward(self, x, condition):
        combined = torch.cat([x, condition], dim=1)
        return self.model(combined)

In [8]:
# Training function
def train_cgan(input_folder, target_folder, num_epochs=200, batch_size=4, lr=0.0002, beta1=0.5):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")
    
    # Data loading
    # Ensure the normalization parameters match the number of channels in the input images
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])  # Assuming 3-channel RGB images
    ])
    
    dataset = PixelArtDataset(input_folder, target_folder, transform=transform)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=True)  # Added drop_last=True
    
    generator = Generator().to(device)
    discriminator = Discriminator().to(device)
    
    criterion = nn.BCELoss()
    l1_loss = nn.L1Loss()
    
    g_optimizer = optim.Adam(generator.parameters(), lr=lr, betas=(beta1, 0.999))
    d_optimizer = optim.Adam(discriminator.parameters(), lr=lr, betas=(beta1, 0.999))
    
    # Training loop with fixed dimensions
    for epoch in range(num_epochs):
        for i, (input_imgs, target_imgs) in enumerate(dataloader):
            # Ensure consistent batch size
            current_batch_size = input_imgs.size(0)
            real_label = torch.ones(current_batch_size, 1, 1, 1).to(device)
            fake_label = torch.zeros(current_batch_size, 1, 1, 1).to(device)
            
            input_imgs = input_imgs.to(device)
            target_imgs = target_imgs.to(device)
            
            # Train Discriminator
            d_optimizer.zero_grad()
            
            d_real = discriminator(target_imgs, input_imgs)
            d_real_loss = criterion(d_real, real_label)
            
            fake_imgs = generator(input_imgs)
            d_fake = discriminator(fake_imgs.detach(), input_imgs)
            d_fake_loss = criterion(d_fake, fake_label)
            
            d_loss = (d_real_loss + d_fake_loss) * 0.5
            d_loss.backward()
            d_optimizer.step()
            
            # Train Generator
            g_optimizer.zero_grad()
            
            fake_imgs = generator(input_imgs)
            g_fake = discriminator(fake_imgs, input_imgs)
            g_gan_loss = criterion(g_fake, real_label)
            
            g_l1_loss = l1_loss(fake_imgs, target_imgs) * 100
            
            g_loss = g_gan_loss + g_l1_loss
            g_loss.backward()
            g_optimizer.step()
            
            if i % 10 == 0:
                print(f'Epoch [{epoch}/{num_epochs}], Step [{i}/{len(dataloader)}], '
                      f'D_loss: {d_loss.item():.4f}, G_loss: {g_loss.item():.4f}')
        
        if (epoch + 1) % 10 == 0:
            torch.save({
                'generator_state_dict': generator.state_dict(),
                'discriminator_state_dict': discriminator.state_dict(),
                'g_optimizer_state_dict': g_optimizer.state_dict(),
                'd_optimizer_state_dict': d_optimizer.state_dict(),
            }, f'checkpoint_epoch_{epoch+1}.pth')
    
    return generator, discriminator