<a href="https://colab.research.google.com/github/Navodit-Sahai/GANIME-GAN-based-anime-face-generator/blob/main/GANIME.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install opendatasets --upgrade --quiet

In [None]:
import opendatasets as od
dataset_url="https://www.kaggle.com/datasets/splcher/animefacedataset"
od.download(dataset_url)

In [None]:
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
import torchvision.transforms as T

In [None]:
image_size=64
batch_size=128
stats=(0.5,0.5,0.5),(0.5,0.5,0.5)

In [None]:
transform = T.Compose([
    T.Resize(image_size),
    T.CenterCrop(image_size),
    T.ToTensor(),
    T.Normalize(*stats)
])

train_ds = ImageFolder("/content/animefacedataset/", transform=transform)

train_dl = DataLoader(train_ds, batch_size, shuffle=True, num_workers=3, pin_memory=True)

Creating helper functions to denormalize the image tensors

In [None]:
import torch
from torchvision.utils import make_grid
import matplotlib.pyplot as plt
%matplotlib inline

In [None]:
def denorm(image_tensors):
  return image_tensors*stats[1][0]+stats[0][0]

In [None]:
def show_images(images, nmax=64):
    fig, ax = plt.subplots(figsize=(8, 8))
    ax.set_xticks([]); ax.set_yticks([])
    ax.imshow(make_grid(denorm(images.detach()[:nmax]), nrow=8).permute(1, 2, 0))


def show_batch(dl, nmax=64):
    for images, _ in dl:
        show_images(images, nmax)
        break


In [None]:
show_batch(train_dl)

In [None]:
import torch.nn as nn


In [None]:
discriminator=nn.Sequential(
    nn.Conv2d(3,64,kernel_size=4,stride=2,padding=1,bias=False),
    nn.BatchNorm2d(64),
    nn.LeakyReLU(0.2,inplace=True),

    nn.Conv2d(64,128,kernel_size=4,stride=2,padding=1,bias=False),
    nn.BatchNorm2d(128),
    nn.LeakyReLU(0.2,inplace=True),

    nn.Conv2d(128,256,kernel_size=4,stride=2,padding=1,bias=False),
    nn.BatchNorm2d(256),
    nn.LeakyReLU(0.2,inplace=True),

    nn.Conv2d(256,512,kernel_size=4,stride=2,padding=1,bias=False),
    nn.BatchNorm2d(512),
    nn.LeakyReLU(0.2,inplace=True),

    nn.Conv2d(512,1,kernel_size=4,stride=1,padding=0,bias=False),
    nn.Flatten(),

)

Generator

In [None]:
latent_size=128

In [None]:
generator=nn.Sequential(
    nn.ConvTranspose2d(128,512,kernel_size=4,stride=1,padding=0,bias=False),
    nn.BatchNorm2d(512),
    nn.ReLU(True),

    nn.ConvTranspose2d(512,256,kernel_size=4,stride=2,padding=1,bias=False),
    nn.BatchNorm2d(256),
    nn.ReLU(True),

    nn.ConvTranspose2d(256,128,kernel_size=4,stride=2,padding=1,bias=False),
    nn.BatchNorm2d(128),
    nn.ReLU(True),

    nn.ConvTranspose2d(128,64,kernel_size=4,stride=2,padding=1,bias=False),
    nn.BatchNorm2d(64),
    nn.ReLU(True),

    nn.ConvTranspose2d(64,3,kernel_size=4,stride=2,padding=1,bias=False),
    nn.Tanh()
)

In [None]:
xb=torch.randn(batch_size,latent_size,1,1)#random latent tensors
fake_images=generator(xb)
print(fake_images.shape)
show_images(fake_images)

In [None]:
def train_discriminator(real_images, opt_d, device):
  for param in discriminator.parameters():
      param.requires_grad = True
  for param in generator.parameters():
      param.requires_grad = False

  opt_d.zero_grad()

  # Real images
  real_pred = discriminator(real_images)
  real_targets = torch.ones(real_images.size(0), 1).to(device) * 0.9  # Label smoothing
  real_loss = F.binary_cross_entropy_with_logits(real_pred, real_targets)
  real_score = torch.sigmoid(real_pred).mean().item()  # Apply sigmoid for score

  # Fake images
  fake_latent = torch.randn(batch_size, latent_size, 1, 1).to(device)
  fake_images = generator(fake_latent).detach()
  fake_pred = discriminator(fake_images)
  fake_targets = torch.zeros(fake_images.size(0), 1).to(device)
  fake_loss = F.binary_cross_entropy_with_logits(fake_pred, fake_targets)  # Changed
  fake_score = torch.sigmoid(fake_pred).mean().item()  # Apply sigmoid for score

  loss = real_loss + fake_loss
  loss.backward()
  opt_d.step()
  return loss.item(), real_score, fake_score

In [None]:
def train_generator(opt_g, device):
  for param in discriminator.parameters():
        param.requires_grad = False
  for param in generator.parameters():
      param.requires_grad = True

  opt_g.zero_grad()
  fake_latent = torch.randn(batch_size, latent_size, 1, 1).to(device)
  fake_images = generator(fake_latent)
  target = torch.ones(batch_size, 1).to(device)
  pred = discriminator(fake_images)
  loss = F.binary_cross_entropy_with_logits(pred, target)
  loss.backward()
  opt_g.step()
  return loss.item()

In [None]:
import os
from torchvision.utils import save_image
sample_dir='generated'
os.makedirs(sample_dir,exist_ok=True)

In [None]:
def save_samples(index, latent_tensors, show=True):
    fake_images = generator(latent_tensors)
    fake_fname = 'generated-images-{0:0=4d}.png'.format(index)
    save_image(denorm(fake_images), os.path.join(sample_dir, fake_fname), nrow=8)
    print('Saving', fake_fname)

    if show:
        fig, ax = plt.subplots(figsize=(8, 8))
        ax.set_xticks([]); ax.set_yticks([])
        ax.imshow(make_grid(fake_images.cpu().detach(), nrow=8).permute(1, 2, 0))


In [None]:
fixed_latent=torch.randn(64,latent_size,1,1)

In [None]:
from tqdm.notebook import tqdm
import torch.nn.functional as F

In [None]:
def fit(epochs, lr, start_index=1):

    losses_g = []
    losses_d = []
    real_scores = []
    fake_scores = []

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    generator.to(device)
    discriminator.to(device)

    # Move fixed_latent to the device once
    global fixed_latent
    fixed_latent = fixed_latent.to(device)

    opt_d = torch.optim.Adam(discriminator.parameters(), lr=lr, betas=(0.5, 0.999))
    opt_g = torch.optim.Adam(generator.parameters(), lr=lr, betas=(0.5, 0.999))

    for epoch in range(epochs):

        for real_images, _ in tqdm(train_dl):

            # Move data to GPU
            real_images = real_images.to(device)

            # Train Discriminator
            loss_d, real_score, fake_score = train_discriminator(real_images, opt_d, device)

            # Train Generator
            loss_g = train_generator(opt_g, device)

        # Store scalar values
        losses_g.append(loss_g)
        losses_d.append(loss_d)
        real_scores.append(real_score)
        fake_scores.append(fake_score)

        print(f"Epoch [{epoch+1}/{epochs}] | "
              f"Loss_G: {loss_g:.4f} | "
              f"Loss_D: {loss_d:.4f} | "
              f"Real Score: {real_score:.4f} | "
              f"Fake Score: {fake_score:.4f}")

        # Save samples once per epoch
        save_samples(epoch + start_index, fixed_latent, show=False)

    return losses_g, losses_d, real_scores, fake_scores

In [None]:
lr=0.0002
epochs=15

In [None]:
history=fit(epochs,lr)

In [None]:
torch.save(generator.state_dict(), 'generator.pth')
torch.save(discriminator.state_dict(), 'discriminator.pth')
print("Models saved as generator.pth and discriminator.pth")