In [1]:
import zipfile
from google.colab import drive
drive.mount('/content/drive/')

MessageError: ignored

In [None]:
!unzip 

In [None]:
#!pip install fire
#!pip install tensorboardX
#!pip install prdc

import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from copy import deepcopy

from torch.utils.data import DataLoader
from torch.utils.data import Dataset
from torchvision import transforms
from torchvision.datasets import MNIST, CIFAR10, CelebA, SVHN
from torchvision.datasets import ImageFolder
from torch.utils import data

import warnings
import math
import torch
import numpy as np
from scipy import linalg
import torch.nn.functional as F
from torchvision.models import inception_v3, vgg16
from prdc import compute_prdc

import os
import numpy as np
import torch

import cv2
from torch.optim import Adam
import torch.nn.functional as F
from torch.autograd import grad
from tensorboardX import SummaryWriter
from torchvision.utils import save_image
from tqdm import tqdm

import fire

In [None]:
%load_ext tensorboard

In [None]:
class NextDataLoader(data.DataLoader):
    def __next__(self):
        try:
            return next(self.iterator)
        except:
            self.iterator = self.__iter__()
            return next(self.iterator)

In [None]:
def cifardata(batch_size, num_workers):
    transform = transforms.Compose([
                                    # transforms.Resize(28),
                                      transforms.ToTensor(),
                                      transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
                                  ])
    dataset = CIFAR10('data/cifar_data', transform=transform, download=True)
    dataloader = NextDataLoader(dataset, batch_size=batch_size, num_workers=num_workers)
    return dataloader


def svhndata(batch_size, num_workers):
    transform = transforms.Compose([
                                      transforms.ToTensor(),
                                      transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
                                  ])
    dataset = SVHN('data/svhn_data', transform=transform, download=True)
    dataloader = NextDataLoader(dataset, batch_size=batch_size, num_workers=num_workers)
    return dataloader


def celebdata(batch_size, num_workers):
    print('Loading data...')
    transform = transforms.Compose([
                                        transforms.Resize(128),
                                        transforms.ToTensor(),
                                        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
                                    ])
    dataset = ImageFolder('data/celeba_hq/train', transform=transform)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
    print('Data loaded.')
    return dataloader


In [None]:
class PixelNorm(nn.Module):
    def __init__(self):
        super(PixelNorm, self).__init__()
        self.epsilon = 1e-8

    def forward(self, x):
        return x / torch.sqrt((torch.mean(x ** 2, dim=1, keepdim=True) + self.epsilon))


class InjectNoise(nn.Module):
    def __init__(self):
        super().__init__()
        self.weight = nn.Parameter(torch.zeros(1))

    def forward(self, x, image_noise=None):
        if image_noise is None:
            input_noise = torch.randn(x.shape[0], 1, x.shape[2], x.shape[3], device=x.device)
        noise = self.weight * input_noise
        return x + noise


class MiniBatchSD(nn.Module):
    def __init__(self, group_size=4):
        super(MiniBatchSD, self).__init__()
        self.group_size = group_size

    def forward(self, x):
        s = x.shape
        t = x.view(self.group_size, -1, s[1], s[2], s[3])
        t = t - t.mean(dim=0, keepdim=True)
        t = torch.sqrt((t ** 2).mean(dim=0) + 1e-8)
        t = t.mean(dim=[1, 2, 3], keepdim=True)  # [N/G,1,1,1]
        t = t.repeat(self.group_size, 1, 1, 1).expand(x.shape[0], 1, *x.shape[2:])
        return torch.cat((x, t), dim=1)


class ConvModulated(nn.Conv2d):
    def __init__(self, in_channels, out_channels, kernel_size, latent_size,
                 demodulate=True, bias=True, stride=1, padding=0, dilation=1, **kwargs):
        super().__init__(in_channels, out_channels, kernel_size, stride,
                         padding, dilation, groups=1,
                         bias=bias, padding_mode='zeros')
        self.demodulate = demodulate

        self.style = nn.Linear(latent_size, in_channels)

        self.s_broadcast_view = (-1, 1, self.in_channels, 1, 1)
        self.in_channels_dim = 2

    def convolve(self, x, w, groups):

        return F.conv2d(x, w, None, self.stride, self.padding, self.dilation, groups=groups)

    def forward(self, x, v):
        N, in_channels, H, W = x.shape
        w = self.weight.unsqueeze(0)
        s = self.style(v) + 1
        w = s.view(self.s_broadcast_view) * w

        if self.demodulate:
            sigma = torch.sqrt((w ** 2).sum(dim=[self.in_channels_dim, 3, 4], keepdim=True) + 1e-8)
            w = w / sigma

        x = x.view(1, -1, H, W)
        w = w.view(-1, w.shape[2], w.shape[3], w.shape[4])
        out = self.convolve(x, w, N)
        out = out.view(N, -1, out.shape[2], out.shape[3])

        if not self.bias is None:
            out += self.bias.view(1, self.bias.shape[0], 1, 1)

        return out


class Up_Mod_Conv(ConvModulated):
    def __init__(self, in_channels, out_channels, kernel_size, latent_size,
                 demodulate=True, bias=True, factor=2):
        assert (kernel_size % 2 == 1)
        padding = (max(kernel_size - factor, 0) + 1) // 2
        super().__init__(in_channels, out_channels, kernel_size, latent_size, demodulate, bias,
                         stride=factor, padding=padding)
        self.output_padding = torch.nn.modules.utils._pair(2 * padding - kernel_size + factor)
        # transpose as expected in F.conv_transpose2d
        self.weight = nn.Parameter(self.weight.transpose(0, 1).contiguous())
        self.transposed = True
        # taking into account transposition
        self.s_broadcast_view = (-1, self.in_channels, 1, 1, 1)
        self.in_channels_dim = 1

    def convolve(self, x, w, groups):
        return F.conv_transpose2d(x, w, None, self.stride, self.padding, self.output_padding, groups, self.dilation)


class Down_Conv2d(nn.Conv2d):
    def __init__(self, in_channels, out_channels, kernel_size,
                 bias=True, factor=2):
        assert (kernel_size % 2 == 1)
        padding = kernel_size // 2
        super().__init__(in_channels, out_channels, kernel_size, factor, padding, bias=True)

    def convolve(self, x, w):
        return F.conv2d(x, w, None, self.stride, self.padding, self.dilation, self.groups)


class Equalized_Linear:
    def __init__(self, name):
        self.name = name

    def compute_norm(module, weight):
        mode = 'fan_in'
        if hasattr(module, 'transposed') and module.transposed:
            mode = 'fan_out'
        return torch.nn.init._calculate_correct_fan(weight, mode)

    def scale_weight(self, module, input):
        setattr(module, self.name, module.scale * module.weight_orig)

    def fn(self, module):
        try:
            weight = getattr(module, self.name)
            module.scale = 1 / np.sqrt(Equalized_Linear.compute_norm(module, weight))
            if isinstance(weight, torch.nn.Parameter):
                # register new parameter -- unscaled weight
                module.weight_orig = nn.Parameter(weight.clone() / module.scale)
                # delete old parameter
                del module._parameters[self.name]
            else:
                # register new buffer -- unscaled weight
                module.register_buffer('weight_orig', weight.clone() / module.scale)
                # delete old buffer
                del module._buffers[self.name]
            module.equalize = module.register_forward_pre_hook(self.scale_weight)
        except:
            pass

    def __call__(self, module):
        new_module = deepcopy(module)
        new_module.apply(self.fn)
        return new_module


def parameters_to_buffers(m):
    params = m._parameters.copy()
    m._parameters.clear()
    for n, p in params.items():
        m.register_buffer(n, p.data)


In [None]:
class GeneratorBlock(nn.Module):
    def __init__(self, in_channel, out_channel, kernel_size, latent_size, factor=2, img_channels=3):
        super().__init__()
        inter_fmaps = (in_channel + out_channel)//2
        self.upconv = Up_Mod_Conv(in_channel, inter_fmaps, kernel_size, latent_size,
                                      factor=factor)
        self.conv = ConvModulated(inter_fmaps, out_channel, kernel_size, latent_size,
                                     padding=kernel_size//2)
        self.noise = InjectNoise()
        self.noise2 = InjectNoise()
        self.to_channels = ConvModulated(out_channel, img_channels, kernel_size=1,
                                      latent_size=latent_size, demodulate = False)
        self.upsample = nn.Upsample(scale_factor=factor, mode='bilinear', align_corners=False)
        self.act = nn.LeakyReLU(0.2)

    def forward(self, x, w, y=None, input_noises=None):
        x = self.noise(self.upconv(x, w), None if (input_noises is None) else input_noises[:,0])
        x = self.act(x)
        x = self.noise2(self.conv(x, w), None if (input_noises is None) else input_noises[:,1])
        x = self.act(x)
        if not y is None:
            y = self.upsample(y)
        else:
            y = 0
        y = y + self.to_channels(x,w)
        return x, y


class Generator(nn.Module):
    def __init__(self, in_channel, out_channel, kernel_size, latent_size, style_depth, img_channels,
                 min_res, max_res, blocks, style_mixing_prob=0.8, dlatent_avg_beta=0.995, weights_avg_beta=0.99):
        super(Generator, self).__init__()

        dres = min_res * 2 ** blocks - max_res
        assert dres >= 0
        self.latent_size = latent_size
        layers = [PixelNorm()]
        for i in range(style_depth):
            layer = nn.Linear(latent_size, latent_size)
            self.add_module(str(i), layer)
            layers.append(layer)
            layers.append(nn.LeakyReLU(0.2))
        self.mapping_network = nn.Sequential(*layers)
        self.const = nn.Parameter(torch.randn(out_channel, min_res, min_res))
        fmaps = np.linspace(out_channel, in_channel, blocks + 1).astype('int')
        self.layers = []
        for i in range(blocks):
            layer = GeneratorBlock(fmaps[i], fmaps[i + 1], kernel_size, latent_size, img_channels=img_channels)
            self.add_module(str(i), layer)
            self.layers.append(layer)
        if dres > 0:
            self.crop = torch.nn.ZeroPad2d(-dres // 2)
        self.style_mixing_prob = style_mixing_prob
        self.dlatent_avg_beta = dlatent_avg_beta
        self.register_buffer('dlatent_avg', torch.zeros(latent_size))
        self.weights_avg_beta = weights_avg_beta
        self.Src_Net = deepcopy(self).apply(parameters_to_buffers)
        self.Src_Net.train(False)

    def update_avg_weights(self):
        params = dict(self.named_parameters())
        buffers = dict(self.named_buffers())
        for n, b in self.Src_Net.named_buffers():
            try:
                b.data.copy_(self.weights_avg_beta * b + (1 - self.weights_avg_beta) * params[n])
            except:
                b.data.copy_(buffers[n])

    def load_avg_weights(self):
        buffers = dict(self.Src_Net.named_buffers())
        for n, p in self.named_parameters():
            p.data.copy_(buffers[n])

    def sample_dlatents(self, n):
        z = torch.randn(n, self.latent_size)
        v = self.mapping_network(z)
        if self.training:
            self.dlatent_avg = self.dlatent_avg_beta * self.dlatent_avg + (1 - self.dlatent_avg_beta) * v.data.mean(0)
        if self.training and self.style_mixing_prob > 0:
            v = v.unsqueeze(1).expand(-1, len(self.layers), -1)
            l = len(self.layers)
            cut_off = torch.randint(l - 1, ())
            v2 = torch.randn(n, self.latent_size)
            v2 = self.mapping_network(v2)
            if self.training:
              self.dlatent_avg = self.dlatent_avg_beta * self.dlatent_avg + (1 - self.dlatent_avg_beta) * v2.data.mean(0)
            v2 = v2.unsqueeze(1).expand(-1, len(self.layers), -1)
            mask = torch.empty(n, dtype=torch.bool).bernoulli_(self.style_mixing_prob).view(-1, 1) \
                    * (torch.arange(l) > cut_off)
            v = torch.where(mask.unsqueeze(-1).to(device=v.device), v2, v)
        return v

    def generate(self, v, input_noises=None):
        x = self.const.expand(v.shape[0], *self.const.shape).contiguous()
        input_noises = input_noises if input_noises else [None] * len(self.layers)
        y = None
        if v.ndim < 3:
            v = v.unsqueeze(1).expand(-1, len(self.layers), -1)
        for i, layer in enumerate(self.layers):
            x, y = layer(x, v[:, i], y, input_noises[i])
        if hasattr(self, 'crop'):
            y = self.crop(y)
        return y

    def sample(self, n):
        dlatents = self.sample_dlatents(n)
        x = self.generate(dlatents)
        return x

    def sample_images(self, n, truncation_psi=1):
        with torch.no_grad():
            v = self.Src_Net.sample_dlatents(n)
            if truncation_psi < 1:
                v = self.dlatent_avg + truncation_psi * (v - self.dlatent_avg)
            images = self.Src_Net.generate(v)
        return images


class DiscriminatorBlock(nn.Module):
    def __init__(self, in_channel, out_channel, kernel_size, factor=2):
        super().__init__()
        inter_fmaps = (in_channel + out_channel)//2
        self.conv = nn.Conv2d(in_channel, inter_fmaps, kernel_size, padding=kernel_size//2)
        self.downconv = Down_Conv2d(inter_fmaps, out_channel, kernel_size, factor=factor)
        self.skip = Down_Conv2d(in_channel, out_channel, kernel_size=1, factor=factor)
        self.act = nn.LeakyReLU(0.2)

    def forward(self, x):
        t = x
        x = self.conv(x)
        x = self.act(x)
        x = self.downconv(x)
        x = self.act(x)
        t = self.skip(t)
        return (x + t)/ np.sqrt(2)


class Discriminator(nn.Module):
    def __init__(self, in_channel, out_channel, kernel_size, blocks, img_channels, min_res, max_res, dense_size=128 ):
        super(Discriminator, self).__init__()
        assert min_res * 2 ** blocks >= max_res >= (min_res - 1) * 2 ** blocks
        fmaps = np.linspace(in_channel, out_channel, blocks + 1).astype('int')
        self.from_channels = nn.Conv2d(img_channels, fmaps[0], 1)
        self.layers = []
        for i in range(blocks):
            layer = DiscriminatorBlock(fmaps[i], fmaps[i + 1], kernel_size)
            self.add_module(str(i), layer)
            self.layers.append(layer)
        self.minibatch_sttdev = MiniBatchSD()
        self.conv = nn.Conv2d(fmaps[-1] + 1, fmaps[-1], 3)
        self.dense = nn.Linear(fmaps[-1] * (min_res - 2) ** 2, dense_size)
        self.output = nn.Linear(dense_size, 1)
        self.act = nn.LeakyReLU(0.2)

    def get_score(self, imgs):
        x = self.act(self.from_channels(imgs))
        for layer in self.layers:
            x = layer(x)
        x = self.minibatch_sttdev(x)
        x = self.act(self.conv(x))
        x = x.view(x.shape[0], -1)
        x = self.act(self.dense(x))
        x = self.output(x)
        return x

In [None]:
def device(ngpu):
    dev = torch.device("cuda:0" if (torch.cuda.is_available() and ngpu > 0) else "cpu")
    return dev

In [None]:
ngpu = 0
device = torch.device("cuda:0" if (torch.cuda.is_available() and ngpu>0) else "cpu")


class EvaluationMetric:
    def __init__(self, transform_input=True):
        self.transform_input = transform_input
        self.InceptionV3 = inception_v3(pretrained=True, transform_input=False)
        self.InceptionV3.eval()
        self.k = 3

    def evaluate(self, real_img, generated_img):
        mu1, sigma1, mu2, sigma2, precision, recall  = self.calc_activation_stats(real_img, generated_img)
        fid = self.compute_fid(mu1, sigma1, mu2, sigma2)
        #precision, recall = self.compute_prec_recall()

        print('FID:', fid)
        print('Precision:', precision)
        print('Recall:', recall)

    def build_maps(self, x):
        if list(x.shape[-2:]) != [299, 299]:
            with warnings.catch_warnings():
                warnings.simplefilter("ignore")
                x = F.interpolate(x, size=[299, 299], mode='bilinear')
        if self.transform_input:
            x_ch0 = torch.unsqueeze(x[:, 0], 1) * (0.229 / 0.5) + (0.485 - 0.5) / 0.5
            x_ch1 = torch.unsqueeze(x[:, 1], 1) * (0.224 / 0.5) + (0.456 - 0.5) / 0.5
            x_ch2 = torch.unsqueeze(x[:, 2], 1) * (0.225 / 0.5) + (0.406 - 0.5) / 0.5
            x = torch.cat((x_ch0, x_ch1, x_ch2), 1)
        with torch.no_grad():
            x = self.InceptionV3.Conv2d_1a_3x3(x)
            x = self.InceptionV3.Conv2d_2a_3x3(x)
            x = self.InceptionV3.Conv2d_2b_3x3(x)
            x = F.max_pool2d(x, kernel_size=3, stride=2)
            x = self.InceptionV3.Conv2d_3b_1x1(x)
            x = self.InceptionV3.Conv2d_4a_3x3(x)
            x = F.max_pool2d(x, kernel_size=3, stride=2)
            x = self.InceptionV3.Mixed_5b(x)
            x = self.InceptionV3.Mixed_5c(x)
            x = self.InceptionV3.Mixed_5d(x)
            x = self.InceptionV3.Mixed_6a(x)
            x = self.InceptionV3.Mixed_6b(x)
            x = self.InceptionV3.Mixed_6c(x)
            x = self.InceptionV3.Mixed_6d(x)
            x = self.InceptionV3.Mixed_6e(x)
            x = self.InceptionV3.Mixed_7a(x)
            x = self.InceptionV3.Mixed_7b(x)
            x = self.InceptionV3.Mixed_7c(x)
            x = F.adaptive_avg_pool2d(x, (1, 1))
            return x

    def compute_fid(self, mu1, sigma1, mu2, sigma2, eps=1e-6):
        mu1 = np.atleast_1d(mu1)
        mu2 = np.atleast_1d(mu2)

        sigma1 = np.atleast_2d(sigma1)
        sigma2 = np.atleast_2d(sigma2)

        assert mu1.shape == mu2.shape, \
            'Training and test mean vectors have different lengths'
        assert sigma1.shape == sigma2.shape, \
            'Training and test covariances have different dimensions'

        diff = mu1 - mu2

        covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
        if not np.isfinite(covmean).all():
            msg = ('fid calculation produces singular product; '
                   'adding %s to diagonal of cov estimates') % eps
            print(msg)
            offset = np.eye(sigma1.shape[0]) * eps
            covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))

        if np.iscomplexobj(covmean):
            if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
                m = np.max(np.abs(covmean.imag))
                raise ValueError('Imaginary component {}'.format(m))
            covmean = covmean.real

        tr_covmean = np.trace(covmean)

        return diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * tr_covmean

    def calc_activation_stats(self, real_img, generated_img, batch_size=64):

        assert real_img.shape[0] == generated_img.shape[0]
        real_images = real_img[np.random.permutation(real_img.shape[0])]

        generated_images = generated_img[np.random.permutation(generated_img.shape[0])]
        nearest_k = 3
        real_maps = []
        generated_maps = []
        for s in range(int(math.ceil(real_images.shape[0] / batch_size))):
            sidx = np.arange(batch_size * s, min(batch_size * (s + 1), real_images.shape[0]))
            real_maps.append(self.build_maps(real_images[sidx]).detach().to(device=device))
            generated_maps.append(
                self.build_maps(generated_images[sidx]).detach().to(device=device))

        real_maps = np.squeeze(torch.cat(real_maps).numpy())
        generated_maps = np.squeeze(torch.cat(generated_maps).numpy())

        mu1 = np.mean(generated_maps, axis=0)
        mu2 = np.mean(real_maps, axis=0)
        sigma1 = np.cov(generated_maps, rowvar=False)
        sigma2 = np.cov(real_maps, rowvar=False)
        prec_recall = compute_prdc(real_maps, generated_maps, nearest_k)
        return mu1, sigma1, mu2, sigma2, prec_recall['precision'], prec_recall['recall']

    

In [None]:
# calculate inception score for cifar-10 in Keras
from math import floor
from numpy import ones
from numpy import expand_dims
from numpy import log
from numpy import mean
from numpy import std
from numpy import exp
from numpy.random import shuffle
from keras.applications.inception_v3 import InceptionV3
from keras.applications.inception_v3 import preprocess_input
from keras.datasets import cifar10
from skimage.transform import resize
from numpy import asarray

# scale an array of images to a new size
def scale_images(images, new_shape):
	images_list = list()
	for image in images:
		# resize with nearest neighbor interpolation
		new_image = resize(image, new_shape, 0)
		# store
		images_list.append(new_image)
	return asarray(images_list)

# assumes images have any shape and pixels in [0,255]
def calculate_inception_score(images, n_split=10, eps=1E-16):
	# load inception v3 model
	model = InceptionV3()
	# enumerate splits of images/predictions
	scores = list()
	n_part = floor(images.shape[0] / n_split)
	for i in range(n_split):
		# retrieve images
		ix_start, ix_end = i * n_part, (i+1) * n_part
		subset = images[ix_start:ix_end]
		# convert from uint8 to float32
		subset = subset.astype('float32')
		# scale images to the required size
		subset = scale_images(subset, (299,299,3))
		# pre-process images, scale to [-1,1]
		subset = preprocess_input(subset)
		# predict p(y|x)
		p_yx = model.predict(subset)
		# calculate p(y)
		p_y = expand_dims(p_yx.mean(axis=0), 0)
		# calculate KL divergence using log probabilities
		kl_d = p_yx * (log(p_yx + eps) - log(p_y + eps))
		# sum over classes
		sum_kl_d = kl_d.sum(axis=1)
		# average over images
		avg_kl_d = mean(sum_kl_d)
		# undo the log
		is_score = exp(avg_kl_d)
		# store
		scores.append(is_score)
	# average across images
	is_avg, is_std = mean(scores), std(scores)
	return is_avg, is_std
"""
# load cifar10 images
(images, _), (_, _) = cifar10.load_data()
# shuffle images
shuffle(images)
print('loaded', images.shape)
# calculate inception score

"""

In [None]:
class StyleGANTrainer:
    def __init__(self):
        self.style_depth = 8
        self.in_channel = 256
        self.out_channel = 128
        self.latent_size = 160
        self.blocks = 3
        self.ngpu = 0
        self.epochs = 6
        self.learning_rate = 1e-3
        self.beta = (0.0, 0.99)
        self.batch_size = 64
        self.img_channel = 3
        self.min_res = 4
        self.max_res = 28
        self.kernel_size = 5
        self.num_workers = 3
        self.batch_part = 0.5
        self.r1_interval = 16
        self.pl_weight = 20
        self.D_steps = 1
        self.r1_weight = 8
        self.pl_interval = 4
        self.decay = 0.01
        self.avg = 0
        self.mean = 0
        self.std = 0.0
        self.eval = EvaluationMetric()
        self.pl_batch = int(self.batch_part * self.batch_size)
        self.train_dir = './train_generated'
        self.train_dataloader = cifardata(self.batch_size, self.num_workers)
        self.device = torch.device("cuda:0" if (torch.cuda.is_available() and self.ngpu>0) else "cpu")
        self.Generator = Generator(self.in_channel, self.out_channel, self.kernel_size, self.latent_size,
                                   self.style_depth, self.img_channel, self.min_res, self.max_res, self.blocks)
        self.Generator = Equalized_Linear('weight')(self.Generator)

        self.Discriminator = Discriminator(self.in_channel, self.out_channel, self.kernel_size, self.blocks,
                                           self.img_channel, self.min_res, self.max_res)
        self.Discriminator = Equalized_Linear('weight')(self.Discriminator)
        self.optimizer_g = Adam(self.Generator.parameters(), lr=self.learning_rate, betas=self.beta)
        self.optimizer_d = Adam(self.Discriminator.parameters(), lr=self.learning_rate, betas=self.beta)
        self.writer = SummaryWriter('/content/drive/MyDrive/logs/StyleGAN/CIFAR10/logs')

    def path_reg(self, dlatent, gen_out):
        noise = torch.randn(gen_out.shape, device=gen_out.device) / np.sqrt(np.prod(gen_out.shape[2:]))
        grads = torch.autograd.grad((gen_out * noise).sum(), dlatent, create_graph=True)[0]
        lengths = torch.sqrt((grads ** 2).mean(2).sum(1))
        self.avg = self.decay * torch.mean(lengths.detach()) + (1 - self.decay) * self.avg
        return torch.mean((lengths - self.avg) ** 2)

    def generator_loss(self, i):
        dlatent = self.Generator.sample_dlatents(self.batch_size)
        if i % self.pl_interval == 0:
            dlatent_1, dlatent_2 = dlatent[:self.pl_batch], dlatent[self.pl_batch:]
            fake_imgs = self.Generator.generate(torch.cat((dlatent_1, dlatent_2), 0))
            fake_scores = self.Discriminator.get_score(fake_imgs)
            g_loss = -F.logsigmoid(fake_scores).mean()
            gen_loss = g_loss + self.pl_weight * self.pl_interval * self.path_reg(dlatent_1, fake_imgs[:self.pl_batch])
        else:
            fake_imgs = self.Generator.generate(dlatent)
            fake_scores = self.Discriminator.get_score(fake_imgs)
            gen_loss = -F.logsigmoid(fake_scores).mean()
        self.optimizer_g.zero_grad()
        gen_loss.backward()
        self.optimizer_g.step()
        self.Generator.update_avg_weights()

        return gen_loss

    def discriminator_loss(self, real_data, i, n_epoch):
        real_data.requires_grad = True
        fake_imgs = self.Generator.sample(real_data.shape[0])
        real_scores = self.Discriminator.get_score(real_data)
        fake_scores = self.Discriminator.get_score(fake_imgs)
        disc_loss = torch.mean(-F.logsigmoid(real_scores) + F.softplus(fake_scores))
        grads = torch.autograd.grad(real_scores.sum(), real_data, create_graph=True)[0]
        r1_reg = torch.mean((grads ** 2).sum(dim=[1, 2, 3]))
        if i % self.r1_interval == 0 and n_epoch == self.D_steps - 1:
            disc_loss += self.r1_weight * self.r1_interval * r1_reg
        real_data.requires_grad = False
        self.optimizer_d.zero_grad()
        disc_loss.backward()
        self.optimizer_d.step()

        return disc_loss, real_scores.mean().item(), fake_scores.mean().item()

    def run(self, checkpoint=None):

        if not os.path.exists('/content/drive/MyDrive/Generated_Images'):
            os.mkdir('/content/drive/MyDrive/Generated_Images')

        if not os.path.exists('/content/drive/MyDrive/checkpoints'):
            os.mkdir('/content/drive/MyDrive/checkpoints')


        if checkpoint:
            self.load_checkpoint(checkpoint)
        self.Generator.train()
        self.Discriminator.train()

        for i in tqdm(range(self.epochs)):
            for n_epoch in range(self.D_steps):
                real_data = next(self.train_dataloader)[0].to(self.device)
                disc_loss, real_score, fake_score = self.discriminator_loss(real_data, i, n_epoch)

            gen_loss = self.generator_loss(i)
            gen = self.Generator.sample_images(64)

            if i % 1 == 0:
                print('Epoch:', i, 'Real Score:{:.4f}'.format(real_score), 'Fake Score:{:.4f}'.format(fake_score),
                      'Generator Loss:{:.4f}'.format(gen_loss.item()),
                      'Discriminator Loss:{:.4f}'.format(disc_loss.item()))
                is_avg, is_std = calculate_inception_score(gen)
                print('score', is_avg, is_std)
                self.writer.add_scalar('Real Score', global_step=i, scalar_value=real_score)
                self.writer.add_scalar('Fake Score', global_step=i, scalar_value=fake_score)
                self.writer.add_scalar('Generator Loss', global_step=i, scalar_value=gen_loss.item())
                self.writer.add_scalar('Discriminator Loss', global_step=i, scalar_value=disc_loss.item())
                self.writer.close()

            if i % 100 == 0:
                save_image(gen, '/content/drive/MyDrive/Generated_Images/Iter{}.png'.format(i))
                # save_image(real_data, 'Real_Images/Real{}.png'.format(i))
                self.eval.evaluate(real_data, gen)

            if i % 500 == 0:
                self.save_checkpoint(i)

    def load_checkpoint(self, filename):
        checkpoint = torch.load(filename)
        self.Generator.load_state_dict(checkpoint['generator'])
        self.Discriminator.load_state_dict(checkpoint['discriminator'])
        self.optimizer_g.load_state_dict(checkpoint['generator_optimizer'])
        self.optimizer_d.load_state_dict(checkpoint['discriminator_optimizer'])

    def save_checkpoint(self, epoch):
        torch.save({
            'generator': self.Generator.state_dict(),
            'discriminator': self.Discriminator.state_dict(),
            'generator_optimizer': self.optimizer_g.state_dict(),
            'discriminator_optimizer': self.optimizer_d.state_dict(),
        }, '/content/drive/MyDrive/checkpoints/Model{}.pth'.format(epoch))


In [None]:
%tensorboard --logdir /content/drive/MyDrive/logs/StyleGAN/CIFAR10/logs

In [None]:
def main():
    train = StyleGANTrainer()
    train.run()


if __name__ == '__main__':
    fire.Fire(main)