# Light GAN 1024

## Import

In [1]:
import argparse
import os
import numpy as np
import math
import cv2 as cv

import torchvision.transforms as transforms
from torchvision.utils import save_image

from torch.utils.data import Dataset, DataLoader
from torchvision import datasets
from torch.autograd import Variable
from torch.cuda.amp import autocast, GradScaler

import torch.nn as nn
import torch.nn.functional as F
import torch
import torchvision

## Hyperparameters

In [2]:
n_epochs = 100 # type=int, "number of epochs of training"
batch_size = 10 # type=int, "size of the batches"

lr = 0.0005 # type=float "adam: learning rate"
b1 = 0.9 # type=float "adam: decay of first order momentum of gradient"
b2 = 0.999 # type=float "adam: decay of first order momentum of gradient"

num_gpu = 2 
cuda = torch.cuda.is_available()

latent_dim = 4 # type=int "dimensionality of the latent space"
img_size = 1024 # type=int "size of each image dimension"
channels = 1 # type=int "number of image channels"
sample_interval = 10000 # int "interval betwen image samples"

dataset_dir = r"C:\Users\Leo's PC\Documents\SSTP Tests\stylegan2-ada-pytorch\Font1024"

## Datasets

In [3]:
class Dataset(Dataset):

    def __init__(self, file_dir, transform=None):

        self.dir = file_dir
        self.transform = transform
        self.diction = {}
        
        idx = 0
        for filename in os.listdir(self.dir):
            if filename.endswith('png'):
                self.diction[idx] = filename
                idx += 1
                        
    def __len__(self):
        return len(self.diction)

    
    def __getitem__(self, idx):
        img_name = self.diction[idx]
        directory = self.dir + "\\" + str(img_name)
        image = cv.imread(directory, cv.IMREAD_GRAYSCALE)
        if self.transform:
            image = self.transform(image)
        return image
    

dataset = Dataset(file_dir=dataset_dir)

## Dataloaders

In [4]:
loader = DataLoader(dataset=dataset, batch_size=batch_size, shuffle=True, drop_last=True)

## Model classes

In [5]:
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        
        #activation functions
        self.leakyrelu = nn.LeakyReLU(0.2)
        self.tanh = nn.Tanh()
        
        #upsampler
        self.upsamplerx4 = nn.Upsample(scale_factor=4)
        self.upsamplerx2 = nn.Upsample(scale_factor=2)
        self.pool = nn.AdaptiveMaxPool2d(output_size = 1024)
        
        #L1
        self.conv1 = torch.nn.ConvTranspose2d(in_channels=512, out_channels=512, kernel_size=1, stride=1, padding=0, bias=True)
        self.norm1 = nn.BatchNorm2d(512)
        
        #L2
        self.conv2 = torch.nn.ConvTranspose2d(in_channels=512, out_channels=256, kernel_size=5, stride=2, padding=2, bias=True)
        self.norm2 = nn.BatchNorm2d(256)
        
        #L3
        self.conv3 = torch.nn.ConvTranspose2d(in_channels=256, out_channels=128, kernel_size=5, stride=2, padding=2, bias=True)
        self.norm3 = nn.BatchNorm2d(128)
        
        #L4
        self.conv4 = torch.nn.ConvTranspose2d(in_channels=128, out_channels=64, kernel_size=5, stride=2, padding=2, bias=True)
        self.norm4 = nn.BatchNorm2d(64)
        
        #L5
        self.conv5 = torch.nn.ConvTranspose2d(in_channels=64, out_channels=32, kernel_size=7, stride=2, padding=2, bias=True)
        self.norm5 = nn.BatchNorm2d(32)
        
        #L6
        self.conv6 = torch.nn.ConvTranspose2d(in_channels=32, out_channels=channels, kernel_size=7, stride=2, padding=1, bias=True)
        self.norm6 = nn.BatchNorm2d(channels)


    @autocast()
    def forward(self, x):
        
        #L1
        x = self.conv1(x)
        x = self.upsamplerx2(x)
        x = self.norm1(x)
        x = self.leakyrelu(x)
        
        #print(x.shape)
        
        #L2
        x = self.conv2(x)
        x = self.upsamplerx2(x)
        x = self.norm2(x)
        x = self.leakyrelu(x)
        
        #print(x.shape)
        
        #L3
        x = self.conv3(x)
        x = self.upsamplerx2(x)
        x = self.norm3(x)
        x = self.leakyrelu(x)
        
        #print(x.shape)
        
        #L4
        x = self.conv4(x)
        x = self.norm4(x)
        x = self.leakyrelu(x)
        
        #print(x.shape)
        
        #L5
        x = self.conv5(x)
        x = self.norm5(x)
        x = self.leakyrelu(x)
        
        #print(x.shape)
        
        #L6
        x = self.conv6(x)
        #x = self.pool(x)
        x = self.norm6(x)
        x = self.tanh(x)
        
        return x
    
    
    def name(self):
        return "Generator"


class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        
        #activation functions
        self.leakyrelu = nn.LeakyReLU(0.2)
        self.sigmoid = nn.Sigmoid()
        self.softmax = nn.Softmax(dim=1)
        
        #L1
        self.conv1 = nn.Conv2d(in_channels=channels, out_channels=32, kernel_size=3, stride=1, padding=1, bias=True)
        self.norm1 = nn.BatchNorm2d(32)
        self.pool1 = nn.AdaptiveMaxPool2d(output_size=512)
        
        #L2
        self.conv2 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, stride=1, padding=1, bias=True)
        self.norm2 = nn.BatchNorm2d(64)
        self.pool2 = nn.AdaptiveMaxPool2d(output_size=256)
        
        #L3
        self.conv3 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1, bias=True)
        self.norm3 = nn.BatchNorm2d(128)
        self.pool3 = nn.AdaptiveMaxPool2d(output_size=128)
        
        #L4
        self.conv4 = torch.nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, stride=1, padding=1, bias=True)
        self.norm4 = nn.BatchNorm2d(256)
        self.pool4 = nn.AdaptiveMaxPool2d(output_size=64)
        
        #L5
        self.conv5 = nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3, stride=1, padding=1, bias=True)
        self.norm5 = nn.BatchNorm2d(512)
        self.pool5 = nn.AdaptiveMaxPool2d(output_size = 32)
        
        #L6
        self.conv6 = nn.Conv2d(in_channels=512, out_channels=1024, kernel_size=3, stride=1, padding=1, bias=True)
        self.norm6 = nn.BatchNorm2d(1024)
        self.pool6 = nn.AdaptiveMaxPool2d(output_size=1)
        
        #L7
        self.fc1 = nn.Linear(in_features=1024, out_features=512, bias=True)
        self.norm7 = nn.BatchNorm1d(512)
        self.dropout1 = nn.Dropout(p=0.5)
        
        #L8
        self.fc2 = nn.Linear(in_features=512, out_features=2, bias=True)


    @autocast()
    def forward(self, x):
        #L1
        x = self.conv1(x)
        x = self.norm1(x)
        x = self.leakyrelu(x)
        x = self.pool1(x)
        
        #print(x.shape)
        
        #L2
        x = self.conv2(x)
        x = self.norm2(x)
        x = self.leakyrelu(x)
        x = self.pool2(x)
        
        #print(x.shape)
        
        #L3
        x = self.conv3(x)
        x = self.norm3(x)
        x = self.leakyrelu(x)
        x = self.pool3(x)
        
        #print(x.shape)
        
        #L4
        x = self.conv4(x)
        x = self.norm4(x)
        x = self.leakyrelu(x)
        x = self.pool4(x)
        
        #print(x.shape)
        
        #L5
        x = self.conv5(x)
        x = self.norm5(x)
        x = self.leakyrelu(x)
        x = self.pool5(x)
        
        #print(x.shape)
        
        #L6
        x = self.conv6(x)
        x = self.norm6(x)
        x = self.leakyrelu(x)
        x = self.pool6(x)
        
        x = x.view(x.shape[0], -1)
         
        #print(x.shape)
        
        #L7
        x = self.fc1(x)
        x = self.norm7(x)
        x = self.dropout1(x)
        x = self.sigmoid(x)
        
        #print(x.shape)
        
        #L8
        x = self.fc2(x)
        x = self.softmax(x)
        
        #print(x.shape)

        return x
    
    
    def name(self):
        return "Discriminator"
    
    
class Discriminator_Res(nn.Module):
    def __init__(self):
        super(Discriminator_Res, self).__init__()
        
        self.prepool = nn.AdaptiveAveragePool2d(512, 512)
        self.ResNet = torchvision.models.resnet18(pretrained=True)
        self.ResNet.fc = nn.Linear(in_features=512, out_features=1, bias=True)

    @autocast()
    def forward(self, x):
        
        x = self.ResNet(x)

        return x
    
    
    def name(self):
        return "Discriminator_Res"

## Loss, Optimizer, Training setup

In [6]:
# Loss function
adversarial_loss = torch.nn.BCEWithLogitsLoss()


# Initialize generator and discriminator
G = Generator()
D = Discriminator()


def init_weights(m):
    if type(m) == nn.Linear or type(m) == nn.Conv2d or type(m) == nn.ConvTranspose2d:
        torch.nn.init.xavier_uniform_(m.weight)
        m.bias.data.fill_(0.01)

G.apply(init_weights)
D.apply(init_weights)


device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")


G.cuda()
D.cuda()
adversarial_loss.cuda()


G = torch.nn.DataParallel(G)
D = torch.nn.DataParallel(D)


optimizer_G = torch.optim.Adam(G.parameters(), lr=lr, betas=(b1, b2))
optimizer_D = torch.optim.Adam(D.parameters(), lr=lr, betas=(b1, b2))

scaler = GradScaler()

In [None]:
Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor


for epoch in range(n_epochs):
    
    for idx, imgs in enumerate(loader):

        # Adversarial ground truths
        valid = Variable(Tensor(imgs.shape[0], 2).fill_(1.0), requires_grad=False).cuda()
        fake = Variable(Tensor(imgs.shape[0], 2).fill_(0.0), requires_grad=False).cuda()

        # Configure input
        real_imgs = Variable(imgs.type(Tensor)).cuda().to(device)

        # -----------------
        #  Train Generator
        # -----------------

        optimizer_G.zero_grad()

        # Sample noise as generator input
        latent_vector = Variable(Tensor(np.random.randn(imgs.shape[0], 512, latent_dim, latent_dim))).cuda()
        
        G.train()
        D.eval()
        
        with autocast():
            gen_imgs = G(latent_vector) # Generate a batch of images
            g_loss = adversarial_loss(D(gen_imgs), valid) # Loss measures generator's ability to fool the discriminator

        scaler.scale(g_loss).backward() #back propagation with calculated loss
        scaler.step(optimizer_G) 
        scaler.update()
        
        g_loss_avg = g_loss.item() if idx==0 else (g_loss_avg + g_loss.item()) / 2
        
        # ---------------------
        #  Train Discriminator
        # ---------------------
        
        D.train()
        optimizer_D.zero_grad()
        
        real_imgs.unsqueeze_(1)
        
        with autocast():
            # Measure discriminator's ability to classify real from generated samples
            real_loss = adversarial_loss(D(real_imgs), valid)
            fake_loss = adversarial_loss(D(gen_imgs.detach()), fake)
            d_loss = (real_loss + fake_loss) / 2

        scaler.scale(d_loss).backward() #back propagation with calculated loss
        scaler.step(optimizer_D) 
        scaler.update()

       
        batches_done = epoch * len(loader) + idx
        
        d_loss_avg = d_loss.item() if idx==0 else (d_loss_avg + d_loss.item()) / 2
        
    save_image(gen_imgs.data[:25], r"C:/Users/Leo's PC/Documents/SSTP Tests/Chinese Characters/LightGAN out/%d.png" % batches_done, nrow=5, normalize=True)
    print("[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]" % (epoch, n_epochs, idx, len(loader), d_loss_avg, g_loss_avg))



[Epoch 0/100] [Batch 2850/2851] [D loss: 0.724098] [G loss: 0.474077]
[Epoch 1/100] [Batch 2850/2851] [D loss: 0.724077] [G loss: 0.474077]


In [None]:

checkpoint_file = open(r"C:/Users/Leo's PC/Documents/SSTP Tests/Chinese Characters/LightGAN out/G.tar", 'wb')
torch.save({'model': G.state_dict()}, checkpoint_file)
checkpoint_file.close()
'''

checkpoint = torch.load(open("C:/Users/Leo's PC/Documents/SSTP Tests/Chinese Characters/LightGAN out/G.tar", 'rb'))
G.load_state_dict(checkpoint['model'])

'''