In [None]:
#!pip install torchsummary
from torchsummary import summary

In [None]:
import warnings
warnings.filterwarnings('ignore')
import os
import random
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.utils as vutils
from torchvision.utils import make_grid
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
import PIL
from tqdm import tqdm
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader

In [None]:
img_size=256
n_channels=1

latent_size=100
batch_size=64
step_conv_channels=64

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

beta1 = 0.5
beta2 = 0.999

lr=0.0002

num_workers=4

DATA_PATH = '...'
EPOCH_START=0
UPLOADED=False
LOAD_FILENAME_PATH_GENERATOR=('weights/generator_epoch_%d.pth' % EPOCH_START)
LOAD_FILENAME_PATH_DISCRIMINATOR=('weights/discriminator_epoch_%d.pth' % EPOCH_START)

In [None]:
device

In [None]:
class Split(object):
    def __call__(self, image):
        return transforms.Grayscale(num_output_channels=n_channels)(image[1,:,:].view(n_channels,img_size,img_size))

dataset = ImageFolder(DATA_PATH, transform=transforms.Compose([
        transforms.Resize(img_size,interpolation=transforms.InterpolationMode.BICUBIC),
        #transforms.RandomHorizontalFlip(p=0.5),
        #transforms.Resize(upsample_transform, interpolation=transforms.InterpolationMode.BICUBIC),
        #transforms.RandomCrop((IMG_WIDTH,IMG_HEIGHT)),
        transforms.ToTensor(),
        transforms.Normalize(mean=(0.5,0.5,0.5), std=(0.5,0.5,0.5)),
        Split()
        ]))
dataloader = DataLoader(dataset, batch_size, shuffle=True, num_workers=num_workers)

In [None]:
batch, _ = next(iter(dataloader))

plt.figure(figsize=(12, 12))
plt.axis("off")
plt.title("Training Images")
plt.imshow(np.transpose(make_grid(batch.to(device), padding=2, normalize=True).cpu() ,(1,2,0)))

In [None]:
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.l1 = nn.Sequential(nn.Linear(latent_size, step_conv_channels*16 * 4 * 4))

        self.conv_blocks = nn.Sequential(
            nn.Conv2d(step_conv_channels*16, step_conv_channels*16, 3, 1, 1),
            nn.BatchNorm2d(step_conv_channels*16),
            nn.ReLU(),
            nn.Upsample(scale_factor=2), # 8x8 
            nn.Conv2d(step_conv_channels*16, step_conv_channels*16, 3, stride=1, padding=1),
            nn.BatchNorm2d(step_conv_channels*16),
            nn.ReLU(),
            nn.Upsample(scale_factor=2), # 16x16
            nn.Conv2d(step_conv_channels*16, step_conv_channels*8, 3, stride=1, padding=1),
            nn.BatchNorm2d(step_conv_channels*8),
            nn.ReLU(),
            nn.Upsample(scale_factor=2), # 32x32
            nn.Conv2d(step_conv_channels*8, step_conv_channels*4, 3, stride=1, padding=1),
            nn.BatchNorm2d(step_conv_channels*4),
            nn.ReLU(),
            nn.Upsample(scale_factor=2), # 64x64
            nn.Conv2d(step_conv_channels*4, step_conv_channels*2, 3, stride=1, padding=1),
            nn.BatchNorm2d(step_conv_channels*2),
            nn.ReLU(),
            nn.Upsample(scale_factor=2), # 128x128
            nn.Conv2d(step_conv_channels*2, step_conv_channels, 3, stride=1, padding=1),
            nn.BatchNorm2d(step_conv_channels),
            nn.ReLU(),
            nn.Upsample(scale_factor=2), # 256x256
            nn.Conv2d(step_conv_channels, n_channels, 3, stride=1, padding=1),
            nn.Tanh(),
        )
    
    def forward(self, z):
        z = z.view(z.shape[0],-1)
        out = self.l1(z)
        out = out.view(out.shape[0], step_conv_channels*16, 4, 4)
        img = self.conv_blocks(out)
        return img

In [None]:
g = Generator()
summary(g,(latent_size,1), device = 'cpu')

In [None]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.main = nn.Sequential(
            nn.Conv2d(n_channels, step_conv_channels, 3, 1, 1, bias=True),
            nn.LeakyReLU(0.2, inplace=True),
            nn.AvgPool2d(2), #128x128
            nn.Conv2d(step_conv_channels, step_conv_channels * 2, 3, 1, 1, bias=True),
            nn.LeakyReLU(0.2, inplace=True),
            nn.AvgPool2d(2), #64x64
            nn.BatchNorm2d(step_conv_channels * 2),
            nn.Conv2d(step_conv_channels * 2, step_conv_channels * 4, 3, 1, 1, bias=True),
            nn.LeakyReLU(0.2, inplace=True),
            nn.AvgPool2d(2), #32x32
            nn.BatchNorm2d(step_conv_channels * 4),
            nn.Conv2d(step_conv_channels * 4, step_conv_channels * 8, 3, 1, 1, bias=True),
            nn.LeakyReLU(0.2, inplace=True),
            nn.AvgPool2d(2), #16x16
            nn.BatchNorm2d(step_conv_channels * 8),
            nn.Conv2d(step_conv_channels * 8, step_conv_channels * 16, 3, 1, 1, bias=True),
            nn.LeakyReLU(0.2, inplace=True),
            nn.AvgPool2d(2), #8x8
            nn.BatchNorm2d(step_conv_channels * 16),
            nn.Conv2d(step_conv_channels * 16, step_conv_channels * 16, 3, 1, 1, bias=True),
            nn.LeakyReLU(0.2, inplace=True),
            nn.AvgPool2d(2), #4x4
            nn.BatchNorm2d(step_conv_channels * 16)
        )
        
        self.adv_layer = nn.Sequential(nn.Linear(step_conv_channels * 16 * 4 * 4, 1), nn.Sigmoid())
        
        
    def forward(self, img):
        out = self.main(img)
        out = out.view(out.shape[0], -1)
        out = self.adv_layer(out)

        return out

In [None]:
d=Discriminator()
summary(d,(1,img_size,img_size), device='cpu')

In [None]:
def fit(model, criterion, epochs, lr, epochs_start=0, uploaded=False):
    if epochs_start!=0 and not uploaded:
        model['discriminator'].load_state_dict(torch.load(LOAD_FILENAME_PATH_DISCRIMINATOR))
        model['generator'].load_state_dict(torch.load(LOAD_FILENAME_PATH_GENERATOR))
        print('Model uploaded')
        
    model["discriminator"].to(device)
    model["generator"].to(device)
    model["discriminator"].train()
    model["generator"].train()
    torch.cuda.empty_cache()
    
    losses_g, losses_d, real_scores, fake_scores = [], [], [], []
    
    loss_g_per_batch, loss_d_per_batch, real_score_per_batch, fake_score_per_batch = [], [], [], []
    
    optimizer = {
        "discriminator": torch.optim.Adam(model["discriminator"].parameters(), 
                                          lr=lr, betas=(beta1, beta2)),
        "generator": torch.optim.Adam(model["generator"].parameters(),
                                      lr=lr, betas=(beta1, beta2))
    }
    
    for epoch in tqdm(range(epochs)):
        loss_d_per_epoch = []
        loss_g_per_epoch = []
        real_score_per_epoch = []
        fake_score_per_epoch = []
        for real_images, _ in dataloader:
            # discriminator step
            real_images = real_images.to(device)
            optimizer["discriminator"].zero_grad()

            # real images to discriminator
            real_preds = model["discriminator"](real_images)
            real_targets = torch.FloatTensor(real_images.size(0), 1,).uniform_(0.95, 1.0).to(device)
            real_loss = criterion["discriminator"](real_preds, real_targets)
            
            # generating images
            latent = torch.randn(real_images.size(0), latent_size, device=device)
            fake_images = model["generator"](latent)

            # generated images to discriminator
            fake_targets = torch.FloatTensor(fake_images.size(0), 1,).uniform_(0.0, 0.05).to(device)
            fake_preds = model["discriminator"](fake_images)
            fake_loss = criterion["discriminator"](fake_preds, fake_targets)
            
            # logs
            cur_real_score = torch.mean(real_preds).item()
            cur_fake_score = torch.mean(fake_preds).item()
            real_score_per_epoch.append(cur_real_score)
            real_score_per_batch.append(cur_real_score)
            fake_score_per_epoch.append(cur_fake_score)
            fake_score_per_batch.append(cur_fake_score)
            
            # backward pass
            loss_d = real_loss + fake_loss
            loss_d.backward()
            optimizer["discriminator"].step()
            
            #logs
            loss_d_per_epoch.append(loss_d.item())
            loss_d_per_batch.append(loss_d.item())


            # generator step
            optimizer["generator"].zero_grad()
            
            # generating images
            latent = torch.randn(real_images.size(0), latent_size, device=device)
            fake_images = model["generator"](latent)
            
            # generated images to discriminator
            preds = model["discriminator"](fake_images)
            targets = torch.FloatTensor(real_images.size(0), 1).uniform_(0.95, 1.0).to(device)
            loss_g = criterion["generator"](preds, targets)
            
            # backward pass
            loss_g.backward()
            optimizer["generator"].step()
            
            #logs
            loss_g_per_epoch.append(loss_g.item())
            loss_g_per_batch.append(loss_g.item())
            
        # logs
        losses_g.append(np.mean(loss_g_per_epoch))
        losses_d.append(np.mean(loss_d_per_epoch))
        real_scores.append(np.mean(real_score_per_epoch))
        fake_scores.append(np.mean(fake_score_per_epoch))
        
        # logs
        print("Epoch [{}/{}], loss_g: {:.4f}, loss_d: {:.4f}, real_score: {:.4f}, fake_score: {:.4f}".format(
            epoch+1+epochs_start, epochs+epochs_start, 
            losses_g[-1], losses_d[-1], real_scores[-1], fake_scores[-1]))

        # examples
        plt.figure(figsize=(12,12))
        plt.axis("off")
        plt.title("Generated Images")
        plt.imshow(np.transpose(make_grid(fake_images.to(device)[:8], padding=2, normalize=True).cpu(),(1,2,0)))
        plt.show()
        
        if (epoch+1+epochs_start)%10==0:
            torch.save(model['generator'].state_dict(),'generator_epoch_%d.pth' % (epoch+1+epochs_start))
            torch.save(model['discriminator'].state_dict(),'discriminator_epoch_%d.pth' % (epoch+1+epochs_start))
            print('Model Saved! Epoch: %d' % (epoch+1+epochs_start))
            
    
    return [losses_g, loss_g_per_batch, losses_d, loss_d_per_batch, real_scores, real_score_per_batch, fake_scores, fake_score_per_batch]

In [None]:
# custom weights initialization called on netG and netD
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)

In [None]:
# Create the Discriminator
netD = Discriminator().to(device)

# Apply the weights_init function to randomly initialize all weights
#  to mean=0, stdev=0.2.
netD.apply(weights_init)

# Create the generator
netG = Generator().to(device)

# Apply the weights_init function to randomly initialize all weights
#  to mean=0, stdev=0.02.
netG.apply(weights_init)

model = {
    "discriminator": netD,
    "generator": netG
}

criterion = {
    "discriminator": nn.BCELoss(),
    "generator": nn.BCELoss()
}

In [None]:
epochs=70
lr=0.0002

In [None]:
logs = fit(model,criterion,epochs,lr, EPOCH_START, UPLOADED)

In [None]:
txts = ['losses_g.txt', 'loss_g_per_batch.txt', 'losses_d.txt', 'loss_d_per_batch.txt', 
                'real_scores.txt', 'real_score_per_batch.txt', 'fake_scores.txt', 'fake_score_per_batch.txt']

for i in range(len(txts)):
    with open(txts[i], 'w') as f:
        for e in logs[i]:
            f.write(str(e)+' ')

In [None]:
with torch.no_grad():
    latent = torch.randn(batch_size, latent_size, 1, 1, device=device)
    fake_images = model["generator"](latent)

plt.figure(figsize=(12,12))
plt.axis("off")
plt.title("Generated Images")
plt.imshow(np.transpose(make_grid(fake_images.to(device), padding=2, normalize=True).cpu(),(1,2,0)))
plt.show()