In [4]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
import torch.nn.functional as F
import cv2
from math import log2


In [5]:
from datetime import datetime
now = datetime.now()


In [6]:
factors = [1, 1, 1, 1, 1/2, 1/4, 1/8, 1/16, 1/32]


In [7]:
class WSConv2d(nn.Module):
    def __init__(self, in_channels, out_channels, kernal_size=3, stride=1, padding=1, gain=2):
        super().__init__()
        self.conv = nn.Conv2d(in_channels, out_channels,
                              kernel_size=kernal_size, stride=stride, padding=padding)
        self.scale = (gain/(in_channels*kernal_size**2))**0.5
        self.bias = self.conv.bias
        self.conv.bias = None

        nn.init.normal_(self.conv.weight)
        nn.init.zeros_(self.bias)

    def forward(self, x):
        return self.conv(x*self.scale)+self.bias.view(1, self.bias.shape[0], 1, 1)


In [8]:
class PixelNorm(nn.Module):
    def __init__(self):
        super().__init__()
        self.epsilon = 1e-8

    def forward(self, x):
        return x/torch.sqrt(torch.mean(x**2, dim=1, keepdim=True) + self.epsilon)


In [9]:
class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, pixelnorm=True):
        super().__init__()
        self.conv1 = WSConv2d(in_channels, out_channels)
        self.conv2 = WSConv2d(out_channels, out_channels)
        self.leaky = nn.LeakyReLU(0.2)
        self.pn = PixelNorm()
        self.use_pn = pixelnorm

    def forward(self, x):
        x = self.leaky(self.conv1(x))
        x = self.pn(x) if self.use_pn else x
        x = self.leaky(self.conv2(x))
        x = self.pn(x) if self.use_pn else x
        return x


In [10]:
class Generator(nn.Module):
    def __init__(self, z_dim, in_channels, img_channels=3):
        super().__init__()
        self.initial = nn.Sequential(
            PixelNorm(),
            nn.ConvTranspose2d(z_dim, in_channels, 4, 1, 0),
            nn.LeakyReLU(0, 2),
            WSConv2d(in_channels, in_channels,
                     kernal_size=3, stride=1, padding=1),
            nn.LeakyReLU(0.2),
            PixelNorm()
        )
        self.rgb = WSConv2d(in_channels, img_channels,
                            kernal_size=1, stride=1, padding=0)
        self.prog_blocks, self.rgb_layers = nn.ModuleList(), nn.ModuleList([
            self.rgb])

        for i in range(len(factors)-1):
            conv_in_c = int(in_channels*factors[i])
            conv_out_c = int(in_channels*factors[i+1])
            self.prog_blocks.append(ConvBlock(conv_in_c, conv_out_c))
            self.rgb_layers.append(
                WSConv2d(conv_out_c, img_channels, kernal_size=1, stride=1, padding=0))

    def fade_in(self, alpha, upscaled, generated):
        return torch.tanh(alpha*generated+(1-alpha)*upscaled)

    def forward(self, x, alpha, steps):
        out = self.initial(x)

        if steps == 0:
            return self.rgb(out)
        for step in range(steps):
            upscaled = F.interpolate(out, scale_factor=2, mode="nearest")
            out = self.prog_blocks[step](upscaled)
        final_upscaled = self.rgb_layers[steps-1](upscaled)
        final_out = self.rgb_layers[steps](out)
        return self.fade_in(alpha, final_upscaled, final_out)


In [11]:
class Discriminator(nn.Module):
    def __init__(self, z_dim, in_channels, img_channels=3):
        super().__init__()
        self.prog_blocks, self.rgb_layers = nn.ModuleList(), nn.ModuleList()
        self.leaky = nn.LeakyReLU(0.2)

        for i in range(len(factors)-1, 0, -1):
            conv_in_c = int(in_channels*factors[i])
            conv_out_c = int(in_channels*factors[i-1])
            self.prog_blocks.append(
                ConvBlock(conv_in_c, conv_out_c, pixelnorm=False))
            self.rgb_layers.append(
                WSConv2d(img_channels, conv_in_c, kernal_size=1, stride=1, padding=0))

        self.initial_rgb = WSConv2d(
            img_channels, in_channels, kernal_size=1, stride=1, padding=0)
        self.rgb_layers.append(self.initial_rgb)
        self.avg_pool = nn.AvgPool2d(kernel_size=2, stride=2)

        self.final_block = nn.Sequential(
            WSConv2d(in_channels+1, in_channels,
                     kernal_size=3, stride=1, padding=1),
            nn.LeakyReLU(0.2),
            WSConv2d(in_channels, in_channels,
                     kernal_size=4, stride=1, padding=0),
            nn.LeakyReLU(0.2),
            WSConv2d(in_channels, 1, kernal_size=1, stride=1, padding=0),
        )

    def fade_in(self, alpha, downscale, out):
        return alpha*out+(1-alpha)*downscale

    def minibatch_std(self, x):
        batch_stat = torch.std(x, dim=0).mean().repeat(
            x.shape[0], 1, x.shape[2], x.shape[3])
        return torch.cat([x, batch_stat], dim=1)

    def forward(self, x, alpha, steps):
        cur_step = len(self.prog_blocks) - steps
        out = self.leaky(self.rgb_layers[cur_step](x))

        if steps == 0:
            out = self.minibatch_std(out)
            return self.final_block(out).view(out.shape[0], -1)

        downscaled = self.leaky(self.rgb_layers[cur_step+1](self.avg_pool(x)))
        out = self.avg_pool(self.prog_blocks[cur_step](out))
        out = self.fade_in(alpha, downscaled, out)

        for step in range(cur_step+1, len(self.prog_blocks)):
            out = self.prog_blocks[step](out)
            out = self.avg_pool(out)
        out = self.minibatch_std(out)
        return self.final_block(out).view(out.shape[0], -1)


In [12]:
Z_DIM = 50
IN_CHANNELS = 256
gen = Generator(Z_DIM, IN_CHANNELS, img_channels=3)
disc = Discriminator(Z_DIM, IN_CHANNELS, img_channels=3)

for resolution in [4, 8, 16, 32, 64, 128, 256, 512, 1024]:
    num_steps = int(log2(resolution/4))
    x = torch.randn((1, Z_DIM, 1, 1))
    z = gen(x, 0.5, steps=num_steps)
    assert z.shape == (1, 3, resolution, resolution)
    out = disc(z, alpha=0.5, steps=num_steps)
    assert out.shape == (1, 1)
    print("OK", resolution)


OK 4
OK 8
OK 16
OK 32
OK 64
OK 128
OK 256
OK 512
OK 1024


In [13]:
torch.backends.cudnn.benchmarks = True
INIT_IMG_SIZE = 4
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
LR = 1e-3
BATCH_SIZE = [32, 32, 32, 16, 16, 16, 16, 8, 4]
CHANNELS_IMG = 3
Z_DIM = 512
IN_CHANNELS = 512
DISC_ITERATIONS = 1
LAMBDA_GP = 10
PROGAN_EPOCHS = [10]*len(BATCH_SIZE)
FIXED_NOISE = torch.randn(8, Z_DIM, 1, 1).to(DEVICE)
NUM_WORKERS = 4


In [14]:
print("TIME: ", now.strftime("%Y%m%d-%H%M%S"))


TIME:  20230205-094156


In [23]:
def plotTensorBoard(writer, loss_critic, loss_gen, real, fake, tensorboard_step):
    writer.add_scalar("Loss Critic", loss_critic, global_step=tensorboard_step)

    with torch.no_grad():
        # take out (up to) 8 examples to plot
        img_grid_real = torchvision.utils.make_grid(real[:8], normalize=True)
        img_grid_fake = torchvision.utils.make_grid(fake[:8], normalize=True)
        writer.add_image("Real", img_grid_real, global_step=tensorboard_step)
        writer.add_image("Fake", img_grid_fake, global_step=tensorboard_step)


def gradient_penalty(critic, real, fake, alpha, train_step, device="cpu"):
    BATCH_SIZE, C, H, W = real.shape
    beta = torch.rand((BATCH_SIZE, 1, 1, 1)).repeat(1, C, H, W).to(device)
    interpolated_images = real * beta + fake.detach() * (1 - beta)
    interpolated_images.requires_grad_(True)

    # Calculate critic scores
    mixed_scores = critic(interpolated_images, alpha, train_step)

    # Take the gradient of the scores with respect to the images
    gradient = torch.autograd.grad(
        inputs=interpolated_images,
        outputs=mixed_scores,
        grad_outputs=torch.ones_like(mixed_scores),
        create_graph=True,
        retain_graph=True,
    )[0]
    gradient = gradient.view(gradient.shape[0], -1)
    gradient_norm = gradient.norm(2, dim=1)
    gradient_penalty = torch.mean((gradient_norm - 1) ** 2)
    return gradient_penalty

def get_loader(image_size):
    transform = transforms.Compose(
        [
            transforms.Resize((image_size, image_size)),
            transforms.ToTensor(),
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.Normalize(
                [0.5 for _ in range(CHANNELS_IMG)],
                [0.5 for _ in range(CHANNELS_IMG)],
            ),
        ]
    )
    batch_size = BATCH_SIZE[int(log2(image_size / 4))]
    dataset = datasets.ImageFolder(root="./datasets", transform=transform)
    loader = DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=NUM_WORKERS,
        pin_memory=True,
    )
    return loader, dataset

def trainFunc(disc,gen,loader,dataset,step,alpha,opt_disc,opt_gen,tb_step,writer,scaler_gen,scaler_disc):
    loop=tqdm(loader,leave=True)
    for batch_indx,(real,_) in enumerate(loop):
        real=real.to(DEVICE)
        cur_batch_size=real.shape[0]

        noise=torch.randn(cur_batch_size,Z_DIM,1,1).to(DEVICE)
        with torch.cuda.amp.autocast():
            fake=gen(noise,alpha,step)
            critic_real=disc(real,alpha,step)
            critic_fake=disc(fake.detach(),alpha,step)
            gp=gradient_penalty(disc,real,fake,alpha,step,DEVICE)
            loss_critic = (
                -(torch.mean(critic_real)-torch.mean(critic_fake))
                + LAMBDA_GP*gp
                + (0.001 * torch.mean(critic_real**2))
            )
        opt_disc.zero_grad()
        scaler_disc.scale(loss_critic).backward()
        scaler_disc.step(opt_disc)
        scaler_disc.update()

        with torch.cuda.amp.autocast():
            gen_fake=disc(fake,alpha,step)
            loss_gen=-torch.mean(gen_fake)
        opt_gen.zero_grad()
        scaler_gen.scale(loss_gen).backward()
        scaler_gen.step(opt_gen)
        scaler_gen.update()

        alpha+=cur_batch_size/(len(dataset)*PROGAN_EPOCHS[step]*0.5)
        alpha=min(alpha,1)

        if batch_indx%500 ==0:
            with torch.no_grad():
                fixed_fakes=gen(FIXED_NOISE,alpha,step)*0.5 + 0.5
            plotTensorBoard(
                writer,
                loss_critic.item(),
                loss_gen.item(),
                real.detach(),
                fixed_fakes.detach(),
                tb_step
            )
            tb_step+=1
        loop.set_postfix(
            gp=gp.item(),
            loss_critic=loss_critic.item(),
        )

    return tb_step, alpha

def train_wrapper():
    gen=Generator(Z_DIM,IN_CHANNELS,CHANNELS_IMG).to(DEVICE)
    disc=Discriminator(Z_DIM,IN_CHANNELS,CHANNELS_IMG).to(DEVICE)

    opt_gen=optim.Adam(gen.parameters(),lr=LR,betas=(0.0,0.99))
    opt_disc=optim.Adam(disc.parameters(),lr=LR,betas=(0.0,0.99))

    scaler_disc=torch.cuda.amp.GradScaler()
    scaler_gen=torch.cuda.amp.GradScaler()

    writer=SummaryWriter(f'logs/progan/'+ now.strftime("%Y%m%d-%H%M%S") + "/")

    gen.train()
    disc.train()
    
    tb_step=0
    step=int(log2(INIT_IMG_SIZE/4))
    for num_epochs in PROGAN_EPOCHS[step:]:
        alpha=1e-5
        loader,dataset = get_loader(4*2**step)
        print("IMAGE SIZE",4*2**step)
        for epoch in range(num_epochs):
            print("Range",4*2**step) 
            tb_step,alpha=trainFunc(
                disc,gen,loader,dataset,step,alpha,opt_disc,opt_gen,tb_step,writer,scaler_gen,scaler_disc
            )
        torch.save(gen,'./gen_'+str(step)+'.pt')
        torch.save(disc,'./disc_'+str(step)+'.pt')
        step += 1

    

In [24]:
DEVICE

'cuda'

In [25]:
train_wrapper()


IMAGE SIZE 4
Range 4


100%|██████████| 1987/1987 [00:47<00:00, 41.49it/s, gp=0.00745, loss_critic=0.00799] 


Range 4


100%|██████████| 1987/1987 [00:48<00:00, 41.07it/s, gp=0.00899, loss_critic=0.0298] 


Range 4


100%|██████████| 1987/1987 [00:47<00:00, 42.14it/s, gp=0.00261, loss_critic=-.0746]  


Range 4


 88%|████████▊ | 1744/1987 [00:40<00:05, 45.14it/s, gp=0.00629, loss_critic=0.217]   IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

100%|██████████| 1987/1987 [00:46<00:00, 42.43it/s, gp=0.0132, loss_critic=0.0596]   


IMAGE SIZE 8
Range 8


 43%|████▎     | 846/1987 [00:33<00:46, 24.62it/s, gp=0.00864, loss_critic=0.0877]  IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

100%|██████████| 1987/1987 [01:14<00:00, 26.54it/s, gp=0.00265, loss_critic=0.15]    


Range 8


 21%|██        | 408/1987 [00:15<01:00, 26.06it/s, gp=0.00207, loss_critic=-.0255]  IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

100%|██████████| 1987/1987 [01:13<00:00, 26.95it/s, gp=0.0065, loss_critic=0.296]    


IMAGE SIZE 16
Range 16


 28%|██▊       | 551/1987 [00:32<01:25, 16.81it/s, gp=0.00347, loss_critic=-.0948] IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

100%|██████████| 1987/1987 [01:55<00:00, 17.21it/s, gp=0.00751, loss_critic=-.623]  


Range 16


 77%|███████▋  | 1521/1987 [01:28<00:27, 16.73it/s, gp=0.00301, loss_critic=-.083]   IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

100%|██████████| 1987/1987 [01:55<00:00, 17.22it/s, gp=0.00102, loss_critic=0.0654] 


Range 16


 97%|█████████▋| 1929/1987 [01:51<00:03, 17.04it/s, gp=0.00253, loss_critic=-.377]   IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

100%|██████████| 3973/3973 [06:58<00:00,  9.49it/s, gp=0.000965, loss_critic=0.331]


Range 32


 26%|██▌       | 1030/3973 [01:48<05:10,  9.49it/s, gp=0.00153, loss_critic=0.396]  IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

 91%|█████████ | 3596/3973 [06:19<00:39,  9.48it/s, gp=0.0215, loss_critic=2.11]    IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

 56%|█████▌    | 2209/3973 [03:53<03:05,  9.51it/s, gp=0.01, loss_critic=-.515]     IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the clien

Range 32


 10%|▉         | 383/3973 [00:40<06:19,  9.46it/s, gp=0.0065, loss_critic=-.943]   IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

 73%|███████▎  | 2908/3973 [05:07<01:51,  9.51it/s, gp=0.0106, loss_critic=-.335]    IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

100%|██████████| 3973/3973 [06:58<00:00,  9.49it/s, gp=0.00234, loss_critic=-.197]  


Range 32


 26%|██▌       | 1017/3973 [01:47<05:10,  9.51it/s, gp=0.00726, loss_critic=-.532] IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

 79%|███████▉  | 3155/3973 [15:27<03:59,  3.42it/s, gp=0.007, loss_critic=-1.97]     IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

 39%|███▉      | 1553/3973 [07:36<11:54,  3.39it/s, gp=0.00597, loss_critic=-1.32]   IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the clie

Range 64


 13%|█▎        | 498/3973 [02:27<17:07,  3.38it/s, gp=0.0149, loss_critic=-.715]   IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

 73%|███████▎  | 2914/3973 [14:20<05:15,  3.36it/s, gp=0.00863, loss_critic=-1.39]  IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

100%|██████████| 3973/3973 [19:40<00:00,  3.37it/s, gp=0.00515, loss_critic=0.565]   


IMAGE SIZE 128
Range 128


100%|██████████| 3973/3973 [38:51<00:00,  1.70it/s, gp=0.00467, loss_critic=-3.12] 


Range 128


100%|██████████| 3973/3973 [38:52<00:00,  1.70it/s, gp=0.0111, loss_critic=-1.27]    


Range 128


 30%|███       | 1210/3973 [11:48<26:51,  1.71it/s, gp=0.017, loss_critic=-2.02]  IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

 93%|█████████▎| 3678/3973 [35:53<02:53,  1.70it/s, gp=0.0144, loss_critic=-.763]   IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

 89%|████████▊ | 3522/3973 [34:17<04:23,  1.71it/s, gp=0.0203, loss_critic=-.969] 


KeyboardInterrupt: 

In [27]:
torch.save(gen,'./gen_latest'+'.pt')
torch.save(disc,'./disc_latest'+'.pt')