In [None]:
import torch
from torch import nn

In [None]:
# def pixel_norm(x):
#     eps= 10e-8
#     return x/torch.sqrt(torch.mean(x,dim=1)+eps)

class Pixel_norm(nn.Module):
    def __init__(self):
        super().__init__()
        self.eps=10e-8

    def forward(self,x):
            return x/torch.sqrt(torch.mean(x**2,dim=1,keepdim=True)+self.eps)


class WConv(nn.Module):
    def __init__(self,in_chan,out_chan,kernel=3,stride=1,padding=1):
        super().__init__()
        self.conv=nn.Conv2d(in_chan,out_chan,kernel,stride,padding,)
        self.equalized_weights =(2/(in_chan*kernel**2))**0.5

    def forward(self,x):
        return self.conv(x*self.equalized_weights)


class WConvTrans(nn.Module):
    def __init__(self,in_chan,out_chan,kernel=3,stride=1,padding=1):
        super().__init__()
        self.conv=nn.ConvTranspose2d(in_chan,out_chan,kernel,stride,padding)
        self.equalized_weights =(2/(out_chan*kernel**2))**0.5

    def forward(self,x):
        return self.conv(x*self.equalized_weights)


class ConvBlock(nn.Module):
    def __init__(self,in_chan,out_chan):
        super().__init__()
        self.conv = WConv(in_chan,out_chan)
        self.conv1 = WConv(out_chan,out_chan)
        self.pix = Pixel_norm()
        self.leaky = nn.LeakyReLU(0.2)

    def forward(self,x):
        return self.pix(self.leaky(self.conv1(self.pix(self.leaky(self.conv(x))))))

class DisBlock(nn.Module):
    def __init__(self,in_chan,out_chan):
        super().__init__()
        self.conv = WConv(in_chan,in_chan)
        self.conv1 = WConv(in_chan,out_chan)
        self.leaky = nn.LeakyReLU(0.2)

    def forward(self,x):
        return self.leaky(self.conv1(self.leaky(self.conv(x))))

In [None]:
class Generator(nn.Module):
    def __init__(self,in_dim,img_channel):
        super(Generator,self).__init__()
        self.blocks =nn.ModuleList()
        self.rgb_layers = nn.ModuleList()
        self.rgb_layers.append(WConv(in_dim,img_channel,1,1,0))
        self.up = nn.Upsample(scale_factor=2,mode='nearest')
        self.blocks.append(nn.Sequential(
            WConvTrans(in_dim,in_dim,4,1,0),
            nn.LeakyReLU(0.2),
            Pixel_norm(),
            WConv(in_dim,in_dim),
            nn.LeakyReLU(0.2),
            Pixel_norm(),
        ))

        for i in range(3):
            self.blocks.append(ConvBlock(in_dim,in_dim))
            self.rgb_layers.append(WConv(in_dim,img_channel,1,1,0))

        n=0
        while in_dim//2**n>16:
            self.blocks.append(ConvBlock(in_dim//2**n,in_dim//2**(n+1)))
            self.rgb_layers.append(WConv(in_dim//2**(n),img_channel,1,1,0))
            n+=1
        self.rgb_layers.append(WConv(in_dim//2**(n),img_channel,1,1,0))

    def forward(self,x, out_size,alpha):
        i=0
        while 2**i<out_size:
            i+=1
        x=self.blocks[0](x)
        x_up=x
        for j in range(1,i-1):
            x_up=self.up(x)
            x=self.blocks[j](x_up)
        x_rgb = self.rgb_layers[i-2](x_up)
        x_rgb_up = self.rgb_layers[i-1](x)
        return torch.tanh((1-alpha)*x_rgb+alpha*x_rgb_up)

In [None]:
class Discriminator(nn.Module):
    def __init__(self,in_chan,img_channel):
        super(Discriminator,self).__init__()
        self.pool = nn.AvgPool2d(2,2)
        self.blocks =nn.ModuleList()
        self.rgb_layers = nn.ModuleList()
        # self.rgb_layers.append(WConv(img_channel,img_channel,1,1,0))
        n=16
        while n<in_chan:
            self.blocks.append(DisBlock(n,n*2))
            self.rgb_layers.append(WConv(img_channel,n,1,1,0))
            n=n*2

        for i in range(3):
            self.blocks.append(DisBlock(in_chan,in_chan))
            self.rgb_layers.append(WConv(img_channel,in_chan,1,1,0))
        self.blocks.append(nn.Sequential(
            WConv(in_chan+1,in_chan),
            nn.LeakyReLU(0.2),
            WConv(in_chan,in_chan,4,1,0),
            nn.LeakyReLU(0.2),
            WConv(in_chan,1,1,1,0),
            nn.Sigmoid()
        ))
        self.rgb_layers.append(WConv(img_channel,in_chan,1,1,0))
    def minibatch_std(self, x):
        batch_statistics = (
            torch.std(x, dim=0).mean().repeat(x.shape[0], 1, x.shape[2], x.shape[3])
        )
        return torch.cat([x, batch_statistics], dim=1)

    def forward(self,x,size,alpha=0.5):
        i=0
        while 2**i<size:
            i+=1
        x_rgb=self.rgb_layers[-(i-1)](x)
        if i==2:
            x=self.minibatch_std(x_rgb)
            return (self.blocks[-1](x)).view(x.shape[0], -1)

        x_down= self.pool(self.blocks[-(i-1)](x_rgb))
        x=self.rgb_layers[-(i-2)](self.pool(x))
        x=(1-alpha)*x_down+alpha*x

        for j in range(i-2,1,-1):
            x=self.blocks[-j](x)
            x=self.pool(x)
        x=self.minibatch_std(x)
        return (self.blocks[-1](x)).view(x.shape[0], -1)




In [None]:
# gen = Generator(512,3)
# x= torch.randn((32,3,4,4))
# dis = Discriminator(512,3)
# dis(x,4,1e-5)
# for i in range(2,11):
#     im=gen(x,2**i,0.5)
#     print(im.shape)
#     print(dis(im,2**i))

In [None]:
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DATASET = 'dataset/'
LEARNING_RATE = 1e-3
BATCH_SIZES = [128, 128, 64, 32, 16, 16, 8, 8, 4]
IMG_START_SIZE = 4
CHANNELS_IMG = 3
IN_CHANNELS = 512
DISC_ITERATIONS = 1
LAMBDA_GP = 10
epochs = [10] * len(BATCH_SIZES)
FIXED_NOISE = torch.randn(16, IN_CHANNELS, 1, 1).to(DEVICE)
NUM_WORKERS = 4
CHECKPOINT_GEN = "gen.pth"
CHECKPOINT_DISC = "disc.pth"
SAVE_MODEL = True
LOAD_MODEL = False

In [None]:
import random
import numpy as np
import os
import torchvision
import torch.nn as nn
from torchvision.utils import save_image
from scipy.stats import truncnorm

def plot_to_tensorboard(
    writer, loss_disc, loss_gen, real, fake, tensorboard_step
):
    writer.add_scalar("Loss Critic", loss_disc, global_step=tensorboard_step)

    with torch.no_grad():
        img_grid_real = torchvision.utils.make_grid(real[:16], normalize=True)
        img_grid_fake = torchvision.utils.make_grid(fake[:16], normalize=True)
        writer.add_image("Real", img_grid_real, global_step=tensorboard_step)
        writer.add_image("Fake", img_grid_fake, global_step=tensorboard_step)

def save_checkpoint(model, optimizer, filename="my_checkpoint.pth.tar"):
    print("=> Saving checkpoint")
    checkpoint = {
        "state_dict": model.state_dict(),
        "optimizer": optimizer.state_dict(),
    }
    torch.save(checkpoint, filename)


def load_checkpoint(checkpoint_file, model, optimizer, lr):
    print("=> Loading checkpoint")
    checkpoint = torch.load(checkpoint_file, map_location="cuda")
    model.load_state_dict(checkpoint["state_dict"])
    optimizer.load_state_dict(checkpoint["optimizer"])

    for param_group in optimizer.param_groups:
        param_group["lr"] = lr



In [None]:
import torch
import torch.optim as optim
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from math import log2
from tqdm import tqdm

torch.backends.cudnn.benchmarks = True

def gradient_penalty(disc, real, fake, alpha, size, device=DEVICE):
    BATCH_SIZE, C, H, W = real.shape
    beta = torch.rand((BATCH_SIZE, 1, 1, 1)).repeat(1, C, H, W).to(device)
    interpolated_images = real * beta + fake.detach() * (1 - beta)
    interpolated_images.requires_grad_(True)
    mixed_scores = disc(interpolated_images, size, alpha)

    gradient = torch.autograd.grad(
        inputs=interpolated_images,
        outputs=mixed_scores,
        grad_outputs=torch.ones_like(mixed_scores),
        create_graph=True,
        retain_graph=True,
    )[0]
    gradient = gradient.view(gradient.shape[0], -1)
    gradient_norm = gradient.norm(2, dim=1)
    gradient_penalty = torch.mean((gradient_norm - 1) ** 2)
    return gradient_penalty

def get_loader(image_size):
    transform = transforms.Compose(
        [
            transforms.Resize((image_size, image_size)),
            transforms.ToTensor(),
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
        ]
    )
    batch_size = BATCH_SIZES[int(log2(image_size / 4))]
    dataset = datasets.ImageFolder(
        root=DATASET, transform=transform,
    )
    loader = DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=NUM_WORKERS,
        pin_memory=True,
    )
    return loader, dataset

In [None]:
def train_fn(size,disc,gen,loader,dataset,alpha,opt_disc,opt_gen,tensorboard_step,writer):
    loop = tqdm(loader, leave=True)
    for batch_idx, (real, _) in enumerate(loop):
        real = real.to(DEVICE)
        cur_batch_size = real.shape[0]
        noise = torch.randn(cur_batch_size, IN_CHANNELS, 1, 1).to(DEVICE)
        fake = gen(noise, size, alpha)

        #Discriminator loss
        critic_real = disc(real, size, alpha)
        critic_fake = disc(fake.detach(),  size, alpha)
        gp = gradient_penalty(disc, real, fake, alpha, size, device=DEVICE)

        loss_disc = (
            -(torch.mean(critic_real) - torch.mean(critic_fake))
            + LAMBDA_GP * gp
            + (0.001 * torch.mean(critic_real ** 2))
        )
        opt_disc.zero_grad()
        loss_disc.backward()
        opt_disc.step()

        # Generator loss
        gen_fake = disc(fake,size, alpha)
        loss_gen = -torch.mean(gen_fake)

        opt_gen.zero_grad()
        loss_gen.backward()
        opt_gen.step()

        alpha += cur_batch_size / (
            (epochs[size//4] * 0.5) * len(dataset)
        )
        alpha = min(alpha, 1)

        if batch_idx % 500 == 0:
            with torch.no_grad():
                fixed_fakes = gen(FIXED_NOISE,size,alpha) * 0.5 + 0.5
            plot_to_tensorboard(
                writer,
                loss_disc.item(),
                loss_gen.item(),
                real.detach(),
                fixed_fakes.detach(),
                tensorboard_step,
            )
            tensorboard_step += 1

        loop.set_postfix(
            gp=gp.item(),
            loss_critic=loss_disc.item(),
        )

    return tensorboard_step, alpha


def main():

    gen = Generator( IN_CHANNELS, CHANNELS_IMG ).to(DEVICE)
    disc = Discriminator( IN_CHANNELS, CHANNELS_IMG).to(DEVICE)
    opt_gen = optim.Adam(gen.parameters(), lr=LEARNING_RATE, betas=(0.0, 0.99))
    opt_critic = optim.Adam(disc.parameters(), lr=LEARNING_RATE, betas=(0.0, 0.99))

    if LOAD_MODEL:
        load_checkpoint(
            CHECKPOINT_GEN, gen, opt_gen, LEARNING_RATE,
        )
        load_checkpoint(
            CHECKPOINT_DISC, disc, opt_critic, LEARNING_RATE,
        )

    gen.train()
    disc.train()
    writer = SummaryWriter(f"logs/gan")
    tensorboard_step = 0
    size = IMG_START_SIZE


    for num_epochs in epochs[size//4:]:
        alpha = 1e-5
        loader, dataset = get_loader(size)
        print(f"Current image size: {size}")

        for epoch in range(num_epochs):
            print(f"Epoch [{epoch+1}/{num_epochs}]")
            tensorboard_step, alpha = train_fn(
                size,
                disc,
                gen,
                loader,
                dataset,
                alpha,
                opt_critic,
                opt_gen,
                tensorboard_step,
                writer,
            )

            if SAVE_MODEL:
                save_checkpoint(gen, opt_gen, filename=CHECKPOINT_GEN)
                save_checkpoint(disc, opt_critic, filename=CHECKPOINT_DISC)

        size=size*2


if __name__ == "__main__":
    main()