In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
from torchsummary import summary
import os
from PIL import Image

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
class Dataset(torch.utils.data.Dataset):

    def __init__(self, path, image_size):

        self.path = path
        self.image_size = image_size

        self.images = None

    def create_images_data(self):

        data = np.load(self.path)

        for i in range(data.shape[0]):

            data[i] = (data[i]+1.0)*0.5
            data[i] = data[i]*255.0

        data = data.astype('uint8')

        images = []

        for image in data:

            image = Image.fromarray(image)
            image = image.resize(self.image_size)
            array = np.array(image)
            image.close()
            images.append(array)

        images = np.array(images)
        images = np.stack(images, 0)
        images = images.astype('float32')
        
        for i in range(images.shape[0]):

            images[i] = images[i]/127.5 - 1.0

        images = torch.from_numpy(images)
        self.images = images.view(images.size(0), images.size(3), images.size(1), images.size(2))

        del images

        print(f"Images data created with size {self.images.size()} and ready to go!")

    def load_images_data(self, numpy=True):

        if numpy:

            images = np.load(self.path)
            images = images.astype('float32')
        
            for i in range(images.shape[0]):

                images[i] = images[i]/127.5 - 1.0

            images = torch.from_numpy(images)
            self.images = images.view(images.size(0), images.size(3), images.size(1), images.size(2))

            del images

            print(f"Images data loaded with size {self.images.size()} and ready to go!")
        
        else:
            
            self.images = torch.load(self.path)

            print(f"Images data loaded with size {self.images.size()} and ready to go!")

    def __getitem__(self, idx):

        return self.images[idx]

    def __len__(self):

        return len(self.images)

In [21]:
data = Dataset('D:/Python/Projects/filtered_images.npy', (8,8))

In [22]:
data.create_images_data()

Images data created with size torch.Size([6854, 3, 8, 8]) and ready to go!


In [48]:
def transconv2out(input, kernel, stride, padding):
    x = (input-1)*stride
    y = 2*padding
    z = 1*(kernel-1)

    output = x - y + z + 1
    return output

def conv2out(input, kernel, stride, padding):
    x = 2*padding
    y = 1*(kernel-1)
    z = (input + x - y - 1)/stride

    output = z + 1
    return output

In [20]:
class Generator(nn.Module):
    def __init__(self, ):

        super(Generator, self).__init__()

        #self.transconv1 = nn.ConvTranspose2d(100, 3, 4, 1, 0, bias=False)

        # For level 2(generating 8x8 images):
        self.transconv1 = nn.ConvTranspose2d(100, 50, 4, 1, 0, bias=False)
        self.batchnorm1 = nn.BatchNorm2d(50)
        self.transconv2 = nn.ConvTranspose2d(50, 3, 4, 2, 1, bias=False)
        # And so forth...
        
        self.PRelu = nn.PReLU()
        self.tanh = nn.Tanh() # Since our images are between -1 and 1. Consider removing if having problems with vanishing gradients.
    
    # Note: Consider adding LSTMs ----> Check Prototype.py ----> Let's abandon LSTMs in order to prioritize Attention layers...if I can manage to make one.
    # Note²: Consider using LeakyReLU 0.2 just like NVidia did...or a PReLU like it's done in SRGAN

    def forward(self, input):
        # Level 1 ---> 1 transconv only
        x = self.transconv1(input)
        #output = self.tanh(x)
        # Level 2
        x = self.batchnorm1(x)
        x = self.PRelu(x)
        x = self.transconv2(x)
        output = self.tanh(x)
        # And so on...
        return output

In [23]:
netG = Generator().to(device)

In [8]:
def weights_init(net, level):

    for n, p in net.named_parameters():

        if 'conv' and str(level) in n:

            nn.init.normal_(p, 0, 0.02)

        elif 'batchnorm' and 'weight' and str(level) in n:
            nn.init.normal_(p, 1, 0.02)

        elif 'batchnorm' and 'bias' and str(level) in n:
            nn.init.constant_(p, 0.)

In [26]:
weights_init(netG, 2)

In [27]:
# Print the model
print(netG)
summary(netG, (100, 1,1))

Generator(
  (transconv1): ConvTranspose2d(100, 50, kernel_size=(4, 4), stride=(1, 1), bias=False)
  (batchnorm1): BatchNorm2d(50, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (PRelu): PReLU(num_parameters=1)
  (transconv2): ConvTranspose2d(50, 3, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
  (tanh): Tanh()
)
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
   ConvTranspose2d-1             [-1, 50, 4, 4]          80,000
       BatchNorm2d-2             [-1, 50, 4, 4]             100
             PReLU-3             [-1, 50, 4, 4]               1
   ConvTranspose2d-4              [-1, 3, 8, 8]           2,400
              Tanh-5              [-1, 3, 8, 8]               0
Total params: 82,501
Trainable params: 82,501
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.00
Forward/backward pass size (MB): 0.02
Param

In [28]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()

        #self.conv1 = nn.Conv2d(3, 1, 4, 1, 0, bias=False)
        # Level 2:
        self.conv1 = nn.Conv2d(3, 100, 3, 1, 1, bias=False)
        self.conv2 = nn.Conv2d(100, 1, 4, 1, 0, bias=False)
        # And so on
        '''self.batchnorm2 = nn.BatchNorm2d(ndf * 2)
        self.conv3 = nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False)
        self.batchnorm3 = nn.BatchNorm2d(ndf * 4)
        self.conv4 = nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False)
        self.batchnorm4 = nn.BatchNorm2d(ndf * 8)
        self.conv5 = nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False)'''
        #self.sigmoid = nn.Sigmoid() ---> Included in BCEWithLogits (in log version)

        self.pool2x2 = nn.AvgPool2d(2, 2)
        self.LRelu = nn.LeakyReLU(0.2)
        self.dropout = nn.Dropout(0.4)
    
    def forward(self, input):
        # Level 1 ---> Conv
        #x = self.conv1(input)
        # Level 2:
        x = self.conv1(input)
        x = self.dropout(x)
        x = self.pool2x2(x)
        x = self.LRelu(x)
        x = self.conv2(x)
        # Further and further
        '''#x = torch.randn(x.size()).to(device) + x
        x = self.batchnorm2(x)
        x = self.LeakyRelu(x)
        x = self.dropout(x)
        x = self.conv3(x)
        #x = torch.randn(x.size()).to(device) + x
        x = self.batchnorm3(x)
        x = self.LeakyRelu(x)
        x = self.dropout(x)
        x = self.conv4(x)
        #x = torch.randn(x.size()).to(device) + x
        x = self.batchnorm4(x)
        x = self.LeakyRelu(x)
        x = self.dropout(x)
        x = self.conv5(x)'''
        output = x

        return output

In [29]:
netD = Discriminator().to(device)

In [30]:
weights_init(netD, 2)

In [32]:
# Print the model
print(netD)
summary(netD, (3, 8, 8))

Discriminator(
  (conv1): Conv2d(3, 100, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (conv2): Conv2d(100, 1, kernel_size=(4, 4), stride=(1, 1), bias=False)
  (pool2x2): AvgPool2d(kernel_size=2, stride=2, padding=0)
  (LRelu): LeakyReLU(negative_slope=0.2)
  (dropout): Dropout(p=0.4, inplace=False)
)
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1            [-1, 100, 8, 8]           2,700
           Dropout-2            [-1, 100, 8, 8]               0
         AvgPool2d-3            [-1, 100, 4, 4]               0
         LeakyReLU-4            [-1, 100, 4, 4]               0
            Conv2d-5              [-1, 1, 1, 1]           1,600
Total params: 4,300
Trainable params: 4,300
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.00
Forward/backward pass size (MB): 0.12
Params size (MB): 0.02
Estimated Tot

In [33]:
# From level 2 and beyond: load weights from previous level into new Generator. Manipulating the tensor shape is necessary.

previous_models = torch.load('Saves/default_Cocogoat.tar')

In [39]:
previous_generator = previous_models['generator_params']
previous_discriminator = previous_models['discriminator_params']

current_generator = netG.state_dict()
current_discriminator = netD.state_dict()

In [44]:
for layer in previous_generator.keys():
    print(layer)

for layer in previous_discriminator.keys():
    print(layer)

transconv1.weight
conv1.weight


In [81]:
# From level 2 and beyond: load weights from previous level into new Generator. Manipulating the tensor shape is necessary.

previous_generator = previous_models['generator_params']
previous_discriminator = previous_models['discriminator_params']

#print(previous_generator['transconv1.weight'].size()) # (100, 3, 4, 4)
#print(current_generator['transconv1.weight'].size()) # (100, 50, 4, 4)

weights = previous_generator['transconv1.weight']

weights = torch.cat([weights]*16, dim=1)

print(weights.size()) # (100, 3, 4, 4)

zeros = torch.zeros(weights.size(0), 2, weights.size(2), weights.size(3)).to(device)
weights = torch.cat((weights, zeros), dim=1)
print(weights.size()) # (100, 50, 4, 4)

torch.Size([100, 48, 4, 4])
torch.Size([100, 50, 4, 4])


In [76]:
# Applying weights from previous level. Again, just for visualization. The real thing will happen inside the training function.

#print(previous_discriminator['conv1.weight'].size()) # (1, 3, 4, 4)
#print(current_discriminator['conv1.weight'].size()) # (100, 3, 3, 3)

desired_shape = current_discriminator['conv1.weight'].size()
previous_shape = previous_discriminator['conv1.weight'].size()

weights = previous_discriminator['conv1.weight']

weights = torch.cat([weights]*100, 0)

print(weights.size()) # (100,3,4,4)

upsampler = torch.nn.Upsample((3,3))
weights = upsampler(weights)
print(weights.size()) # (30, 3, 6, 6)

torch.Size([100, 3, 4, 4])
torch.Size([100, 3, 3, 3])


In [16]:
# Establish convention for real and fake labels during training --> Using 0.9 and 0. instead of 1 and 0 ---> One-sided label smoothing
# https://arxiv.org/pdf/1606.03498.pdf
real_label = 0.9 
fake_label = 0.

In [84]:
# Setup Adam optimizers for both G and D ---> NVidia Progressive Grow: Same optimizer parameters with Adam - lr = 0.001, b1 = 0, b2 = 0.99
optimizerD = optim.Adam(netD.parameters(), lr=1e-3, betas=(0, 0.99)) # Consider changing learning rate or even the optimizer
optimizerG = optim.Adam(netG.parameters(), lr=1e-3, betas=(0, 0.99)) # If learning rate is too aggressive --> might fail to converge or might collapse

# Setting up schedulers - Maybe it's better if we use the same decay and steps for both generator and discriminator
schedulerD = optim.lr_scheduler.StepLR(optimizerD, 10000, gamma=0.1)
schedulerG = optim.lr_scheduler.StepLR(optimizerG, 10000, gamma=0.1)

In [None]:
# Lists to keep track of progress

costsD = []
costsG = []
#content_losses = [] # I really liked the idea in SRGAN of using content loss. We could use something like this to improve diversity somehow.
#adversarial_losses = []

In [85]:
def update_weights(model_name=None, generator=netG, discriminator=netD):

    previous_models = torch.load(f'Saves/{model_name}.tar')

    previous_generator = previous_models['generator_params']
    previous_discriminator = previous_models['discriminator_params']

    current_generator = generator.state_dict()
    current_discriminator = discriminator.state_dict()

    # Updating Generator's weights
    
    weights = previous_generator['transconv1.weight'] # (100, 3, 4, 4)
    weights = torch.cat([weights]*16, dim=1) # (100, 48, 4, 4)
    zeros = torch.zeros(weights.size(0), 2, weights.size(2), weights.size(3), device=device)
    weights = torch.cat((weights, zeros), dim=1) # (100, 50, 4, 4)

    current_generator['transconv1.weight'] = weights

    # Updating Discriminator's weights

    weights = previous_discriminator['conv1.weight'] # (1, 3, 4, 4)
    weights = torch.cat([weights]*100, dim=0) # (100, 3, 4, 4)
    upsampler = torch.nn.Upsample((3,3))
    weights = upsampler(weights) # (100, 3, 3, 3)
    
    current_discriminator['conv1.weight'] = weights

    del weights, upsampler
            
    print("Weights Updated!")

In [86]:
def train(
    data=None,
    generator=netG,
    discriminator=netD,
    epochs=50000,
    batch_size=2048,
    loss=nn.BCEWithLogitsLoss(),
    optimizerD=optimizerD,
    optimizerG=optimizerG,
    save_point=1000,
    checkpoint=1000,
    model_name='default_Cocogoat',
    keep_going="no"):


    print("Starting Training Loop...")

    dataloader = torch.utils.data.DataLoader(data, batch_size=batch_size, shuffle=True)

    if keep_going == "no":
        update_weights(model_name=model_name)

        start_epoch = 0

    if keep_going == "yes":

        save = torch.load(f'Saves/{model_name}.pth')

        start_epoch = save['epoch']
        discriminator.load_state_dict(save['discriminator_params'])
        generator.load_state_dict(save['generator_params'])

        print("Continuing from last save")
        
    for epoch in range(start_epoch, epochs):

        for item, image in enumerate(dataloader):

            discriminator.zero_grad()

            real_images = image.to(device)
            label = torch.full((real_images.size(0),), real_label, dtype=torch.float, device=device)

            # Forward pass real batch through D

            output = discriminator(real_images).view(-1) # (batch_size, )

            # Calculate loss on all-real batch

            errD_real = loss(output, label) # Target, Input

            # Calculate gradients for D in backward pass

            errD_real.backward()

            ## Train with all-fake batch

            # Generate batch of latent vectors

            noise = torch.randn((real_images.size(0), 100, 1, 1), device=device)

            # Generate fake image batch with G

            fake = generator(noise)
            label.fill_(fake_label)

            # Classify all fake batch with D

            output = discriminator(fake.detach()).view(-1) # Using detach to avoid backpropagation through generator on this step

            # Calculate D's loss on the all-fake batch

            errD_fake = loss(output, label)

            # Calculate the gradients for this batch, accumulated (summed) with previous gradients

            errD_fake.backward()

            # Compute error of D as sum over the fake and the real batches

            errD = errD_real.item() + errD_fake.item()

            # Checking Discriminator's gradients to see if we're getting vanishing/exploding gradients

            Dgrads_avg = torch.mean(netD.conv1.weight.grad)

            # Update D

            optimizerD.step()

            ############################
            # (2) Update G network
            ###########################

            generator.zero_grad()
            label.fill_(real_label)  # fake labels are real for generator cost --> What the Discriminator predicted correctly is "wrong" --> backpropagation trick

            # Since we just updated D, perform another forward pass of all-fake batch through D

            output = discriminator(fake).view(-1)

            # Calculate G's loss based on this output

            errG = loss(output, label)

            # Calculate gradients for G

            errG.backward()

            # Checking Generator's gradients to see if we're getting vanishing/exploding gradients

            Ggrads_avg = torch.mean(netG.transconv1.weight.grad)

            # Update G

            optimizerG.step()

            # Output training stats
            if item % checkpoint == 0:
                print(f"{epoch}|{epochs}")
                print(f"Discriminator Loss: {errD}\tGradients Average: {Dgrads_avg}")
                print(f"Generator Loss: {errG.item()}\tGradients Average: {Ggrads_avg}")

                torch.save({
                    'epoch': epoch,
                    'discriminator_params': netD.state_dict(),
                    'generator_params': netG.state_dict(),
                }, f"Saves/{model_name}.tar")

                print("Models saved!")

        if epoch % save_point == 0:

            with torch.no_grad():

                saving_image = generator(noise).cpu()

            saving_image = saving_image.view(saving_image.size(0), saving_image.size(2), saving_image.size(3), saving_image.size(1))
            saving_image = saving_image.numpy()
            saving_image = (saving_image+1.0)*0.5
            np.save(f'Cocogoat/Image_{model_name}_{epoch}.npy', saving_image, allow_pickle=True)
            print(f'Fake image saved!')
            
            _, ax = plt.subplots(2,2)

            for x in range(ax.shape[0]):
                for y in range(ax.shape[1]):
                    ax[x,y].axis('off')

            ax[0,0].imshow(saving_image[0])
            ax[1,0].imshow(saving_image[1])
            ax[0,1].imshow(saving_image[2])
            ax[1,1].imshow(saving_image[3])

            plt.show()
        
        schedulerD.step()
        schedulerG.step()

In [None]:
train(data=data, epochs=50001, batch_size=64, checkpoint=6000, save_point=5000, model_name='default_Cocogoat')