In [1]:
import kagglehub

# Download latest version
path = kagglehub.dataset_download("diraizel/anime-images-dataset")

print("Path to dataset files:", path)

Downloading from https://www.kaggle.com/api/v1/datasets/download/diraizel/anime-images-dataset?dataset_version_number=1...


100%|███████████████████████████████████████████████████████████████████████████████| 868M/868M [06:37<00:00, 2.29MB/s]

Extracting files...





Path to dataset files: C:\Users\ARYAN PALIMKAR\.cache\kagglehub\datasets\diraizel\anime-images-dataset\versions\1


In [None]:
import torch
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(), 
    transforms.Normalize(mean=[0.5], std=[0.5]) 
])

root_dir = "C:/Users/ARYAN PALIMKAR/Desktop/Aryan/Jupyter/project/U-Net Upscaler/data" 
dataset = datasets.ImageFolder(root=root_dir, transform=transform)
dataset.samples = [(path, label) for path, label in dataset.samples if path.endswith(('.png', '.jpg'))]
dataset.imgs = dataset.samples 

dataloader = DataLoader(dataset, batch_size=16, shuffle=True, num_workers=4, pin_memory=True)

data_iter = iter(dataloader)
images, labels = next(data_iter)

In [None]:
import matplotlib.pyplot as plt
import torchvision.utils as vutils

def show_images(images):
    plt.figure(figsize=(32, 18))
    plt.imshow(vutils.make_grid(images[:16], normalize=True, nrow=4).permute(1, 2, 0))
    plt.axis("off")
    plt.show()

In [4]:
def show_batch(dl, nmax=64):
  for images, _ in dl:
    show_images(images)
    break

In [6]:
def get_default_device():
  if torch.cuda.is_available():
    return torch.device('cuda')
  else:
    return torch.device('cpu')

def to_device(data, device):
  if isinstance(data, (list,tuple)):
    return [to_device(x, device) for x in data]
  return data.to(device, non_blocking = True)

class DeviceDataLoader():
  def __init__(self, dl, device):
    self.dl = dl
    self.device = device

  def __iter__(self):
    for b in self.dl:
      yield to_device(b, self.device)

  def __len__(self):
    return len(self.dl)

In [7]:
device = get_default_device()
device

device(type='cuda')

In [8]:
dataloader = DeviceDataLoader(dataloader, device)

# Generator using U-Net

In [9]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models

class UnetGenerator(nn.Module):
    def __init__(self, in_channels=1, out_channels=3):
        super(UnetGenerator, self).__init__()
    
        # Encoder
        self.enc1 = self.conv_block(in_channels, 64)
        self.enc2 = self.conv_block(64, 128)
        self.enc3 = self.conv_block(128, 256)

        # Bottleneck
        self.bottleneck = self.conv_block(256, 512)

        # Decoder
        self.upconv3 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
        self.dec3 = self.conv_block(512, 256)
        self.upconv2 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
        self.dec2 = self.conv_block(256, 128)
        self.upconv1 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
        self.dec1 = self.conv_block(128, 64)

        # Output layer
        self.out_conv = nn.Conv2d(64, out_channels, kernel_size=1)
        self.tanh = nn.Tanh()

    # Conv block
    def conv_block(self, in_channels, out_channels):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.BatchNorm2d(out_channels)
        )

    def forward(self, x):
        e1 = self.enc1(x)
        e2 = self.enc2(F.max_pool2d(e1, 2))
        e3 = self.enc3(F.max_pool2d(e2, 2))
        b = self.bottleneck(F.max_pool2d(e3, 2))
        d3 = self.upconv3(b)
        d3 = torch.cat([d3, e3], dim=1)
        d3 = self.dec3(d3)
        d2 = self.upconv2(d3)
        d2 = torch.cat([d2, e2], dim=1)
        d2 = self.dec2(d2)
        d1 = self.upconv1(d2)
        d1 = torch.cat([d1, e1], dim=1)
        d1 = self.dec1(d1)
        out = self.out_conv(d1)
        return self.tanh(out)

In [10]:
generator = UnetGenerator()

# Discriminator

In [11]:
class Discriminator(nn.Module):
    def __init__(self, in_channels=3):
        super(Discriminator, self).__init__()

        def discriminator_block(in_filters, out_filters, stride=2, normalize=True):
            layers = [nn.Conv2d(in_filters, out_filters, kernel_size=4, stride=stride, padding=1)]
            if normalize:
                layers.append(nn.InstanceNorm2d(out_filters))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers

        self.model = nn.Sequential(
            *discriminator_block(in_channels, 64, normalize=False),
            *discriminator_block(64, 128),
            *discriminator_block(128, 256),
            *discriminator_block(256, 512, stride=1),
            nn.Conv2d(512, 1, kernel_size=4, stride=1, padding=1)
        )

    def forward(self, x):
        return self.model(x)

In [12]:
discriminator = Discriminator()

In [13]:
class PerceptualLoss(nn.Module):
    def __init__(self, feature_layer=36):
        super(PerceptualLoss, self).__init__()
        vgg = models.vgg19(pretrained=True).features
        self.vgg = nn.Sequential(*list(vgg.children())[:feature_layer]).eval()
        for param in self.vgg.parameters():
            param.requires_grad = False

    def forward(self, pred, target):
        pred_features = self.vgg(pred)
        target_features = self.vgg(target)
        return F.mse_loss(pred_features, target_features)

In [None]:
adversarial_loss = nn.BCEWithLogitsLoss()
l1_loss = nn.L1Loss()
perceptual_loss = PerceptualLoss().to(device)

# Training

In [15]:
g_optimizer = torch.optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
d_optimizer = torch.optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))

In [16]:
grayscale_transform = transforms.Grayscale(num_output_channels=1)
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])

In [17]:
generator.to(device)
discriminator.to(device)

Discriminator(
  (model): Sequential(
    (0): Conv2d(3, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (1): LeakyReLU(negative_slope=0.2, inplace=True)
    (2): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (3): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
    (4): LeakyReLU(negative_slope=0.2, inplace=True)
    (5): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (6): InstanceNorm2d(256, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
    (7): LeakyReLU(negative_slope=0.2, inplace=True)
    (8): Conv2d(256, 512, kernel_size=(4, 4), stride=(1, 1), padding=(1, 1))
    (9): InstanceNorm2d(512, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
    (10): LeakyReLU(negative_slope=0.2, inplace=True)
    (11): Conv2d(512, 1, kernel_size=(4, 4), stride=(1, 1), padding=(1, 1))
  )
)

In [None]:
from tqdm import tqdm
import torch
import numpy as np
import torch.nn as nn
import torchvision.transforms.functional as TF
from torch.cuda.amp import autocast, GradScaler

scaler = GradScaler()

def train(epochs):
    for epoch in range(epochs):
        epoch_g_loss, epoch_d_loss = 0.0, 0.0
        num_batches = len(dataloader)

        for images, _ in tqdm(dataloader, desc=f"Epoch {epoch+1}/{epochs}"):
            images = images.to(device, non_blocking=True)
            grey = TF.rgb_to_grayscale(images, num_output_channels=1)

            # --- Train Discriminator ---
            d_optimizer.zero_grad()
            with autocast():
                real_output = discriminator(images)
                fake_images = generator(grey)
                fake_output = discriminator(fake_images.detach())

                real_label = torch.ones_like(real_output).to(device)
                fake_label = torch.zeros_like(fake_output).to(device)

                d_real_loss = adversarial_loss(real_output, real_label)
                d_fake_loss = adversarial_loss(fake_output, fake_label)

                d_loss = (d_real_loss + d_fake_loss) / 2

            scaler.scale(d_loss).backward()
            scaler.step(d_optimizer)

            # --- Train Generator ---
            g_optimizer.zero_grad()
            with autocast():
                fake_output = discriminator(fake_images)
                g_adv_loss = adversarial_loss(fake_output, real_label)
                g_l1_loss = l1_loss(fake_images, images)

                # Reduce perceptual loss resolution
                small_fake = nn.functional.interpolate(fake_images, size=(128, 128), mode='bilinear')
                small_real = nn.functional.interpolate(images, size=(128, 128), mode='bilinear')
                g_perceptual_loss = perceptual_loss(normalize(small_fake), normalize(small_real))

                g_loss = g_adv_loss + 100 * g_l1_loss + 0.1 * g_perceptual_loss

            scaler.scale(g_loss).backward()
            scaler.step(g_optimizer)
            scaler.update()

            epoch_g_loss += g_loss.item()
            epoch_d_loss += d_loss.item()

        epoch_g_loss /= num_batches
        epoch_d_loss /= num_batches
        print(f"Epoch [{epoch+1}/{epochs}], G Loss: {epoch_g_loss:.4f}, D Loss: {epoch_d_loss:.4f}")

        
        if epoch % 20 == 0:
            with torch.no_grad():
                plt.subplot(1, 3, 1)
                plt.title("Grayscale Input")
                plt.imshow(grey[0, 0].cpu().numpy(), cmap='gray')
        
                # Ground Truth RGB
                plt.subplot(1, 3, 2)
                plt.title("Ground Truth")
                img_rgb = (images[0].permute(1, 2, 0).cpu().numpy() + 1) / 2
                plt.imshow(np.clip(img_rgb, 0, 1))
        
                # Colorized Output
                plt.subplot(1, 3, 3)
                plt.title("Colorized Output")
                fake_img = (fake_images[0].detach().permute(1, 2, 0).cpu().numpy() + 1) / 2
                plt.imshow(np.clip(fake_img, 0, 1))
        
                plt.show()


  scaler = GradScaler()  # For AMP


In [None]:
train(epochs=100)

In [None]:
import numpy as np
import torchvision.transforms.functional as TF
import matplotlib.pyplot as plt
import torchvision.utils as vutils

def test_model(images, generator):
    images = images.to(device)
    grey = TF.rgb_to_grayscale(images, num_output_channels=1)
    fake_images = generator(grey)

    for i in range(len(images)):
        plt.figure(figsize=(15, 5))

        # Grayscale Input
        plt.subplot(1, 3, 1)
        plt.title("Grayscale Input")
        plt.imshow(grey[i, 0].cpu().numpy(), cmap='gray')

        # Ground Truth RGB
        plt.subplot(1, 3, 2)
        plt.title("Ground Truth RGB")
        img_rgb = (images[i].permute(1, 2, 0).cpu().numpy() + 1) / 2
        plt.imshow(np.clip(img_rgb, 0, 1))

        # Colorized Output
        plt.subplot(1, 3, 3)
        plt.title("Colorized Output")
        fake_img = (fake_images[i].detach().permute(1, 2, 0).cpu().numpy() + 1) / 2
        plt.imshow(np.clip(fake_img, 0, 1))

        plt.show()