In [1]:
from math import log2
import os
import numpy as np
import random
from tqdm.auto import tqdm

from scipy.stats import truncnorm

import torch
from torch import nn, optim
from torch.nn import functional
from torch.utils.data import Dataset, DataLoader

import torchvision
from torchvision import transforms, datasets
from torchvision.utils import save_image

torch.backends.cudnn.benchmarks = True

from torch.utils.tensorboard import SummaryWriter

root_path = '/mnt/c/Users/121js/OneDrive/Desktop/TorchImages/'

# Models

In [2]:
class WSConv2d(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1, use_transpose=False, gain=2):
        super().__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding) if not use_transpose else nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride, padding)
        self.scale = (gain / (in_channels * kernel_size**2))**0.5
        self.bias = self.conv.bias
        self.conv.bias = None

        nn.init.normal_(self.conv.weight), nn.init.zeros_(self.bias)

    def forward(self, x):
        return self.conv(x * self.scale) + self.bias.view(1, self.bias.shape[0], 1, 1)


class PixNorm(nn.Module):
    def __init__(self):
        super().__init__()
        self.epsilon = 1e-8

    def forward(self, x):
        return x / torch.sqrt(torch.mean(x**2, dim=1, keepdim=True) + self.epsilon)


class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, use_pixnorm=True):
        super().__init__()
        self.conv1 = WSConv2d(in_channels, out_channels)
        self.conv2 = WSConv2d(out_channels, out_channels)
        self.act = nn.LeakyReLU(0.2)
        self.pn = PixNorm()
        self.use_pn = use_pixnorm
    
    def forward(self, x):
        x = self.act(self.conv1(x))
        x = self.pn(x) if self.pn else x
        x = self.act(self.conv2(x))
        x = self.pn(x) if self.pn else x
        return x


class Generator(nn.Module):
    def __init__(self, z_dim, in_channels, img_channels=3, factors = [1, 1, 1, 1, 1/2, 1/4, 1/8, 1/16, 1/32]):
        super().__init__()
        self.prog_blocks, self.rgb_layers = nn.ModuleList([]), nn.ModuleList([])
        self.initial_block = nn.Sequential(
            PixNorm(),
            nn.ConvTranspose2d(z_dim, in_channels, 4, 1, 0),
            nn.LeakyReLU(0.2),
            WSConv2d(in_channels, in_channels, 3, 1, 1),
            nn.LeakyReLU(0.2),
            PixNorm(),
        )
        self.initial_rgb = WSConv2d(in_channels, img_channels, 1, 1, 0)
        self.rgb_layers.append(self.initial_rgb)
        
        for i in range(len(factors) - 1):
            conv_in_chan = int(in_channels*factors[i])
            conv_out_chan = int(in_channels*factors[i+1])
            self.prog_blocks.append(ConvBlock(conv_in_chan, conv_out_chan))
            self.rgb_layers.append(WSConv2d(conv_out_chan, img_channels, 1, 1, 0))

    def fade_in(self, alpha, upscaled, generated):
        return torch.tanh(alpha*generated + (1-alpha)*upscaled)
    
    def forward(self, x, alpha, steps):
        out = self.initial_block(x)

        if steps == 0:
            return self.initial_rgb(out)
        
        for step in range(steps):
            upscaled = functional.interpolate(out, scale_factor=2, mode='nearest')
            out = self.prog_blocks[step](upscaled)

        final_upsampled = self.rgb_layers[steps-1](upscaled)
        final_out = self.rgb_layers[steps](out)

        return self.fade_in(alpha, final_upsampled, final_out)

class Discrimiator(nn.Module):
    def __init__(self, in_channels, img_channels=3, factors = [1, 1, 1, 1, 1/2, 1/4, 1/8, 1/16, 1/32]):
        super().__init__()
        self.prog_blocks, self.rgb_layers = nn.ModuleList([]), nn.ModuleList([])
        self.leaky = nn.LeakyReLU(0.2)

        for i in range(len(factors)-1, 0, -1):
            conv_in_chan = int(in_channels*factors[i])
            conv_out_chan = int(in_channels*factors[i-1])
            self.prog_blocks.append(ConvBlock(conv_in_chan, conv_out_chan, use_pixnorm=False))
            self.rgb_layers.append(WSConv2d(img_channels, conv_in_chan, 1, 1, 0))
        
        self.final_rgb = WSConv2d(img_channels, in_channels, 1, 1, 0)
        self.rgb_layers.append(self.final_rgb)
        self.avg_pool = nn.AvgPool2d(2, 2)

        self.final_block = nn.Sequential(
            WSConv2d(in_channels+1, in_channels, 3, 1, 1),
            nn.LeakyReLU(0.2),
            WSConv2d(in_channels, in_channels, 4, 1, 0),
            nn.LeakyReLU(0.2),
            WSConv2d(in_channels, 1, 1, 1, 0),
        )
        
    def fade_in(self, alpha, downscaled, out):
        return alpha*out + (1-alpha)*downscaled

    def minibatch_std(self, x):
        batch_stats = torch.std(x, dim=0).mean().repeat(x.shape[0], 1, x.shape[2], x.shape[3])
        return torch.cat([x, batch_stats], dim=1)
    
    def forward(self, x, alpha, steps):
        curr_steps = len(self.prog_blocks) - steps
        out = self.leaky(self.rgb_layers[curr_steps](x))

        if steps == 0:
            out = self.minibatch_std(out)
            return self.final_block(out).view(out.shape[0], -1)
        
        downscaled = self.leaky(self.rgb_layers[curr_steps+1](self.avg_pool(x)))
        out = self.avg_pool(self.prog_blocks[curr_steps](out))
        out = self.fade_in(alpha, downscaled, out)

        for curr_step in range(curr_steps+1, len(self.prog_blocks)):
            out = self.prog_blocks[curr_step](out)
            out = self.avg_pool(out)

        out = self.minibatch_std(out)
        return self.final_block(out).view(out.shape[0], -1)

In [3]:
# z_dim = 50
# in_channels = 256
# factors = [1, 1, 1, 1, 1/2, 1/4, 1/8, 1/16, 1/32]
# img_sizes = [4, 8, 16, 32, 64, 128, 256, 512, 1024]
# gen = Generator(z_dim, in_channels, img_channels=3, factors=factors)
# critic = Discrimiator(in_channels, img_channels=3, factors=factors)

# for img_size in img_sizes:
#     num_steps = int(log2(img_size/4))
#     x = torch.randn((1, z_dim, 1, 1))
#     z = gen(x, 0.5, steps=num_steps)
#     assert z.shape == (1, 3, img_size, img_size)
    
#     out = critic(z, alpha=0.5, steps=num_steps)
#     assert out.shape == (1, 1)

#     print(f'Success at image size: {img_size}')

# Config

In [4]:
STARTING_IMG_SIZE = 512
DATASET_NAME = 'spiral_galaxies'
CHECKPOINT_GEN = 'generator.pth'
CHECKPOINT_CRIT = 'critic.pth'
SAVE_IMG_FOLDER = 'results'
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
SAVE_MODEL = True
LOAD_MODEL = True
BATCH_SIZES = [256, 256, 256, 64, 64, 16, 8, 4, 2]
LR = 1e-3
IMG_SIZE = 256
IMG_CHANNELS = 3
IN_CHANNELS = 128
NOISE_DIM = 128
LAMBDA_GP = 10
NUM_STEPS = int(log2(IMG_SIZE/4)) + 1

PROGRESSIVE_EPOCHS = [200, 200, 200, 150, 150, 150, 150, 100, 100]
NUM_IMG_TO_SHOW = 6

# Utils

In [5]:
def find_gradient_penalty(critic, real, fake, alpha, train_step, device='cuda'):
    B, C, H, W = real.shape
    beta = torch.rand((B, 1, 1, 1)).repeat(1, C, H, W).to(device)
    interpolated_images = real*beta + fake.detach()*(1-beta)
    interpolated_images.requires_grad_(True)

    mixed_scores = critic(interpolated_images, alpha, train_step)
    gradient = torch.autograd.grad(
        outputs=mixed_scores,
        inputs=interpolated_images,
        grad_outputs=torch.ones_like(mixed_scores),
        create_graph=True,
        retain_graph=True,
    )[0]
    gradient = gradient.view(gradient.shape[0], -1)
    gradient_norm = gradient.norm(2, dim=1)
    gradient_penalty = torch.mean((gradient_norm-1)**2)
    return gradient_penalty

def save_checkpoint(model, optimizer, filename='my_checkpoint.pth.tar'):
    print('==> Saving Checkpoint <==')
    checkpoint = {
        'model': model.state_dict(),
        'optimizer': optimizer.state_dict(),
    }
    torch.save(checkpoint, filename)

def load_checkpoint(checkpoint_file, model, optimizer, lr):
    print('==> Loading Checkpoint <==')
    checkpoint = torch.load(checkpoint_file, map_location=DEVICE)
    model.load_state_dict(checkpoint['model'])
    optimizer.load_state_dict(checkpoint['optimizer'])

    for param_group in optimizer.param_groups:
        param_group['lr'] = lr

def show_in_tensorboard(writer, loss_crit, loss_gen, real, fake, tensorboard_step):
    writer.add_scalar('Loss Gen', loss_gen, global_step=tensorboard_step)
    writer.add_scalar('Loss Critic', loss_crit, global_step=tensorboard_step)

    with torch.no_grad():
        fake_imgs = torchvision.utils.make_grid(fake[:NUM_IMG_TO_SHOW], nrow=NUM_IMG_TO_SHOW, normalize=True)
        real_imgs = torchvision.utils.make_grid(real[:NUM_IMG_TO_SHOW], nrow=NUM_IMG_TO_SHOW, normalize=True)
        writer.add_image('Fake', fake_imgs, global_step=tensorboard_step)
        writer.add_image('Real', real_imgs, global_step=tensorboard_step)

# def save_some_examples(gen, steps, gen_size, folder='progan_results'):
#     gen.eval()
#     alpha=1
#     with torch.no_grad():
#         to_save = torchvision.utils.make_grid(gen(FIXED_NOISE, alpha, steps)*0.5 + 0.5, nrow=NUM_IMG_TO_SHOW)
#         save_image(to_save, folder+f'/img_{gen_size}x{gen_size}_epoch.png')
#     gen.train()

# Training

In [6]:
def get_loader(img_size):
    transformations = transforms.Compose(
        transforms=[
            transforms.Resize((img_size, img_size)),
            transforms.ToTensor(),
            # transforms.RandomHorizontalFlip(p=0.5),
            transforms.Normalize(
                [0.5 for _ in range(IMG_CHANNELS)], [0.5 for _ in range(IMG_CHANNELS)],
            ),
        ]
    )
    images = datasets.ImageFolder(root=root_path+DATASET_NAME, transform=transformations)
    images_loader = DataLoader(dataset=images, batch_size=BATCH_SIZES[int(log2(img_size/4))], shuffle=True)
    return images, images_loader

In [7]:
def train_func(critic, gen, loader, dataset, step, alpha, opt_crit, opt_gen, tensorboard_step, writer, scaler_crit, scaler_gen):
    # loop = tqdm(loader, leave=True)
    loop = loader
    for batch_idx, (real, _) in enumerate(loop):
        real = real.to(DEVICE)
        noise = torch.randn(real.shape[0], NOISE_DIM, 1, 1).to(DEVICE)

        with torch.cuda.amp.autocast():
            fake = gen(noise, alpha, step)
            crit_real = critic(real, alpha, step)
            crit_fake = critic(fake.detach(), alpha, step)
            gp = find_gradient_penalty(critic, real, fake, alpha, step, DEVICE)
            loss_crit = (
                - (torch.mean(crit_real) - torch.mean(crit_fake))
                + LAMBDA_GP*gp
                + (1e-3*torch.mean(crit_real**2))
            )
        critic.zero_grad()
        scaler_crit.scale(loss_crit).backward()
        scaler_crit.step(opt_crit)
        scaler_crit.update()

        with torch.cuda.amp.autocast():
            gen_fake = critic(fake, alpha, step)
            loss_gen = -torch.mean(gen_fake)
        gen.zero_grad()
        scaler_gen.scale(loss_gen).backward()
        scaler_gen.step(opt_gen)
        scaler_gen.update()

        alpha += real.shape[0] / ((PROGRESSIVE_EPOCHS[step]*0.5)*len(dataset))
        alpha = min(alpha, 1)

        if batch_idx % int(np.sqrt(len(loader))) == 0:
            with torch.no_grad():
                FIXED_NOISE = torch.randn(NUM_IMG_TO_SHOW, NOISE_DIM, 1, 1).to(DEVICE)
                fixed_fakes = gen(FIXED_NOISE, alpha, step)*0.5 + 0.5
                to_save = torchvision.utils.make_grid(fixed_fakes, nrow=NUM_IMG_TO_SHOW)
                save_image(to_save, SAVE_IMG_FOLDER+f'/img_{4*2**step}x{4*2**step}_{tensorboard_step}.png')
                show_in_tensorboard(
                    writer,
                    loss_crit.item(),
                    loss_gen.item(),
                    real.detach(),
                    fixed_fakes.detach(),
                    tensorboard_step
                )
                tensorboard_step += 1
                
    return alpha, tensorboard_step

In [8]:
def main_func():
    gen = Generator(NOISE_DIM, IN_CHANNELS, IMG_CHANNELS).to(DEVICE)
    critic = Discrimiator(IN_CHANNELS, IMG_CHANNELS).to(DEVICE)
    opt_gen = optim.Adam(gen.parameters(), lr=LR, betas=(0, 0.99))
    opt_crit = optim.Adam(critic.parameters(), lr=LR, betas=(0, 0.99))
    scaler_crit, scaler_gen = torch.cuda.amp.GradScaler(), torch.cuda.amp.GradScaler()

    writer = SummaryWriter('logs/progan')
    if LOAD_MODEL:
        load_checkpoint(CHECKPOINT_GEN, gen, opt_gen, LR)
        load_checkpoint(CHECKPOINT_CRIT, critic, opt_crit, LR)
    
    gen.train(), critic.train()

    tensorboard_step = 1
    step = int(log2(STARTING_IMG_SIZE/4))

    for num_epochs in PROGRESSIVE_EPOCHS[step:]:
        alpha = 1e-5
        images, images_loader = get_loader(4*2**step)
        print(f'Current Image size: {4*2**step}')

        for epoch in tqdm(range(num_epochs)):
            # print(f'Epoch [{epoch+1}/{num_epochs}]')
            alpha, tensorboard_step = train_func(
                critic, gen, images_loader, images, step, alpha, opt_crit, opt_gen, tensorboard_step, writer,
                scaler_crit, scaler_gen
                )
            if SAVE_MODEL:
                save_checkpoint(gen, opt_gen, CHECKPOINT_GEN)
                save_checkpoint(critic, opt_crit, CHECKPOINT_CRIT)
    
        step += 1
        if 4*2**step == STARTING_IMG_SIZE*2:
            print('Ending')
            break

In [9]:
main_func()

==> Loading Checkpoint <==
==> Loading Checkpoint <==
Current Image size: 512


  0%|          | 0/100 [00:00<?, ?it/s]

==> Saving Checkpoint <==
==> Saving Checkpoint <==
==> Saving Checkpoint <==
==> Saving Checkpoint <==
==> Saving Checkpoint <==
==> Saving Checkpoint <==
==> Saving Checkpoint <==
==> Saving Checkpoint <==
==> Saving Checkpoint <==
==> Saving Checkpoint <==
==> Saving Checkpoint <==
==> Saving Checkpoint <==
==> Saving Checkpoint <==
==> Saving Checkpoint <==
==> Saving Checkpoint <==
==> Saving Checkpoint <==


In [None]:
# start from 65 epoch and size=256