In [None]:
# imports
import torch 
import torch.nn as nn

# Overview

From AnimeGAN paper: "The loss function... used in AnimeGAN can be simply expressed as follows:  $L(G, D) = \omega_{adv}L_{adv}(G,D) + \omega_{con}L_{con}(G,D) + \omega_{gra}L_{gra}(G, D) + \omega_{col}L_{col}(G, D)$" 

### Adversarial Loss $L_{adv}(G,D)$

"In order to enable AnimeGAN to generate the higher quality images and
make the training of the entire network more stable, the least squares loss function in LSGAN is employed as the adversarial loss $L_{adv}(G, D)$"

In other words...

In [None]:
adversarial_loss = nn.MSELoss()

# use: adversarial_loss(disciminator(generator(input)), 1)

In [None]:
labels = torch.ones((4,1,1,1))
noise_outputs = torch.randn(4,1,64,64)
loss = adversarial_loss(noise_outputs, labels)

  return F.mse_loss(input, target, reduction=self.reduction)


### Content Loss $L_{con}(G,D)$

"$L_{con}(G,D)$ is the content loss which helps to make the generated image retain the content of the input photo"

"For the content loss $L_{con}(G,D)$... the pretrained VGG19 is used as the perceptual network to extract high-level semantic features of the images. $L_{con}(G,D)$... can be expressed as:

$L_{con}(G,D) = E_{p_{i}\sim S_{data}(x)}[||VGG_l(p_i) - VGG_l(G(p_i))||_1]$

where VGG_l(x) refers to the feature map of the lth layer in VGG... In our method, the lth layer is conv4-4 in VGG"



In [None]:
import torchvision.models as models 

noise_inputs = torch.randn((4,3,256,256))

vgg19 = models.vgg19(pretrained=True).eval() # eval means no backprop
# out = vgg19(noise_inputs)
# print(out.shape) # [4, 1000]
# print(vgg19.features) # list of all weights in vgg19

# We want to take up to the 25th feature according to:
# https://www.researchgate.net/figure/llustration-of-the-network-architecture-of-VGG-19-model-conv-means-convolution-FC-means_fig2_325137356
VGG = nn.Sequential(*list(vgg19.features._modules.values())[:26]).eval()
out = VGG(noise_inputs) # [4 x 512 x 32 x 32] feature map 
# print(out.shape)
# print(VGG)

# https://pytorch.org/docs/stable/generated/torch.nn.L1Loss.html
content_loss = nn.L1Loss()

# use: takes in real photo p and generated photo g(p)
#      content_loss(VGG(p), VGG(g(p)))

# I'm a little concerned about how the backprop step works, especially since
# we pass in our inputs through this network, which doesn't need backprop stuff, 
# but I suppose we'll manage.

Downloading: "https://download.pytorch.org/models/vgg19-dcbb9e9d.pth" to /root/.cache/torch/hub/checkpoints/vgg19-dcbb9e9d.pth


HBox(children=(FloatProgress(value=0.0, max=574673361.0), HTML(value='')))




In [None]:
class ContentLoss(nn.Module):
  def __init__(self, VGG):
    super(ContentLoss, self).__init__()
    self.VGG = VGG
    self.L1Loss = nn.L1Loss()
  
  def forward(self, generated, photo):
    generated = self.VGG(generated)
    photo = self.VGG(photo)
    return self.L1Loss(generated, photo)

    

### Grayscale Style loss $L_{gra}(G,D)$

"$L_{gra}(G,D)$ is the grayscale style loss which makes the generated images have the clear anime style on the texture and lines"

In [None]:
from dataloader import *

class GrayscaleStyleLoss(nn.Module):
  """ grayscale style loss makes the generated images have
      the clear anime style on the texture and lines"
  """

  def __init__(self, VGG):
    super(GrayscaleStyleLoss, self).__init__()
    self.VGG = VGG
    self.L1Loss = nn.L1Loss()

  @staticmethod
  def gram_matrix(A):
    """ @param A: image [N x C x H x W]
        gram = A_unrolled @ A_unrolled.T
        @returns: gram matrix of A, of shape [N, C, C]
    """
    N,C,H,W = A.shape
    A_unrolled = A.reshape((N,C,H*W))
    A_unrolled_transpose = torch.transpose(A_unrolled, 1, 2)
    gram = torch.bmm(A_unrolled, A_unrolled_transpose)
    return gram

  def forward(self, generated, anime_gray):
    """ @param generated: images generated from generator, G(photo),
                          of shape [N x C x H x W]
        @param anime_gray: grayscale anime images, of shape
                           [N x C x H x W]
    """
    gram_generated = GrayscaleStyleLoss.gram_matrix(self.VGG(generated))
    gram_anime_gray = GrayscaleStyleLoss.gram_matrix(self.VGG(anime_gray))
    return self.L1Loss(gram_generated, gram_anime_gray) / generated.numel()

In [None]:

ANIME_PATH = '/content/drive/MyDrive/dataset/Shinkai/style/'
# SMOOTH_PATH = '/content/drive/MyDrive/dataset/Shinkai/smooth/'
PHOTOS_PATH = '/content/drive/MyDrive/dataset/train_photo/'

# photo_dataloader = getPhotoDataloader(PHOTOS_PATH)
anime_dataloader = getAnimeDataloader(ANIME_PATH, grayscale=True)
aniter = iter(anime_dataloader)
phiter = iter(photo_dataloader)


def unstandardizeImage(images):
    _mean=[0.485, 0.456, 0.406]; _std=[0.229, 0.224, 0.225]
    output = torch.zeros(images.shape).cuda()
    for i in range(3):
      output[:,i,:,:] += images[:,i,:,:] * _std[i]
      output[:,i,:,:] += images[:,i,:,:] + _mean[i]
    output *= 255.0
    return output

for i in range(10):
  photo_batch1 = next(phiter).cuda()
  photo_batch2 = next(phiter).cuda()

  anime_batch1 = next(aniter).cuda()
  anime_batch2 = next(aniter).cuda()

  gsl = GrayscaleStyleLoss(VGG.cuda())
  diff = gsl(anime_batch1, photo_batch1)
  print(diff.item())
  same = gsl(anime_batch1, anime_batch2)
  print(same.item())
  same = gsl(photo_batch1, photo_batch2)
  print(same.item())
  print()



# out = gsl(noise_inputs, noise_inputs)
# print(out) should be zero because they're the same inputs

### Color Reconstruction Loss $L_{col}(G,D)$

"$L_{col}(G,D)$ is used as the color reconstruction loss to make the generated images have the color of the original photos.

In [None]:
class ColorReconLoss(nn.Module):
  
  def __init__(self):
    super(ColorReconLoss, self).__init__()
    self.L1Loss = nn.L1Loss()
    self.HuberLoss = nn.SmoothL1Loss()
  
  @staticmethod
  def rgb_to_ycbcr(input):
    """ @param input: [N x 3 x H x W]
        returns: YUV formatted version of the images 
        code is repurposed from here: 
        https://discuss.pytorch.org/t/how-to-change-a-batch-rgb-images-to-ycbcr-images-during-training/3799/2
        formula is from: https://en.wikipedia.org/wiki/YCbCr
    """
    output = torch.zeros(input.shape)
    output[:, 0, :, :] = input[:, 0, :, :] * 65.481 + input[:, 1, :, :] * 128.553 + input[:, 2, :, :] * 24.966 + 16.
    output[:, 1, :, :] = input[:, 0, :, :] * -37.797 + input[:, 1, :, :] * 74.203 + input[:, 2, :, :] * 112. + 128.
    output[:, 2, :, :] = input[:, 0, :, :] * 112.0 + input[:, 1, :, :] * 93.786 + input[:, 2, :, :] * 18.214 + 128.
    return output
  
  def forward(self, generated, real_photos):
    """ @param generated: batch of generated anime images of RGB format,
                          of shape [N x 3 x H x W]
        @param real_photos: batch of real-life photos used to generate generated 
                            images, of shape [N x 3 x H x W]
    """
    generated_yuv = ColorReconLoss.rgb_to_ycbcr(generated)
    real_photos_yuv = ColorReconLoss.rgb_to_ycbcr(real_photos)
    y_loss = self.L1Loss(generated_yuv[:,0,:,:], real_photos_yuv[:,0,:,:])
    u_loss = self.HuberLoss(generated_yuv[:,1,:,:], real_photos_yuv[:,1,:,:])
    v_loss = self.HuberLoss(generated_yuv[:,2,:,:], real_photos_yuv[:,2,:,:])
    return y_loss + u_loss + v_loss

crl = ColorReconLoss()
out = crl(noise_inputs, noise_inputs)
print(out)

tensor(0.)


In [None]:
# I think this basically is a Binary Cross Entropy Loss with Real Anime 
# images as 1, smoothed anime images as 0, and generated images as 0. 
# Basically, in our dataset, we just have to include all three types 
# (smooth=0, original=1, real=0) of images in the dataset.

## Edge Promoting Adversarial Loss 

$L_{adv}(G,D) = E[log D(c_i)] + E[log (1 - D(e_j))] + + 0.1 * E[log (1 - D(G(p_k))]$