In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import time
import numpy as np
import gc
from torch.utils.data import DataLoader
from discriminator_utils import COCO2017

In [2]:
class CGenerator(nn.Module):
    # define constructor for the generator
    def __init__(self, noise_in_dimension, noise_out_dimension, vocab_size, embedding_length=128):
        # inherit nn.Module class methods
        super().__init__()
        # set class attributes:
        # model branch at start of model structure for noise (image)
        self.initial_noise_processing = nn.Sequential(nn.Linear(in_features=noise_in_dimension, out_features=4*4*noise_out_dimension, bias=True), nn.CELU(alpha=1, inplace=True))
        # moddel branch at start of model structure (parallel to noise branch) for condition (text)
        self.initial_condition_processing = nn.Sequential(nn.Embedding(num_embeddings=vocab_size, embedding_dim=embedding_length, padding_idx=1), nn.Linear(in_features=embedding_length, out_features=16), nn.LeakyReLU(negative_slope=0.1, inplace=True))
        self.generator_model = nn.Sequential(
            nn.ConvTranspose2d(in_channels=59+noise_out_dimension, out_channels=64*8, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(num_features=64*8), # momentum parameter
            nn.ReLU(inplace=True), # 64x512x8x8
            nn.ConvTranspose2d(in_channels=64*8, out_channels=64*4, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(num_features=64*4),
            nn.ReLU(inplace=True), # 64x256x16x16
            nn.ConvTranspose2d(in_channels=64*4, out_channels=64*2, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(num_features=64*2),
            nn.ReLU(inplace=True), # 64x128x32x32
            # nn.ConvTranspose2d(in_channels=64*2, out_channels=64, kernel_size=4, stride=2, padding=1, bias=False),
            # nn.BatchNorm2d(num_features=64),
            # nn.ReLU(inplace=True),
            nn.ConvTranspose2d(in_channels=128, out_channels=3, kernel_size=4, stride=2, padding=1, bias=False),
            nn.Tanh() #64x3x64x64
            )
        
    def forward(self, noise, text):
        preconcat_noise = self.initial_noise_processing(noise)
        # reshape noise data
        preconcat_noise = torch.reshape(preconcat_noise, (-1, 512, 4, 4))
        preconcat_text = self.initial_condition_processing(text)
        # reshape text data
        preconcat_text = torch.reshape(preconcat_text, (64, 944))
        preconcat_text = torch.reshape(preconcat_text, (-1, 59, 4, 4))
        # concatenate the image data and text data
        image_and_text = torch.concat((preconcat_noise, preconcat_text), dim=1)
        generative_result = self.generator_model(image_and_text)
        return generative_result

In [3]:
class Discriminator(nn.Module):
    def __init__(self, nc, ndf, opt):
        super(Discriminator, self).__init__()
        self.opt = opt
        self.nc = nc
        self.ndf = ndf
        self.convD = nn.Sequential(
            nn.Conv2d(nc, ndf, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(ndf, ndf * 2, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(ndf * 2),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(ndf * 2, ndf * 4, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(ndf * 4),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(ndf * 4, ndf * 8, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(ndf * 8),
            nn.LeakyReLU(0.2, inplace=True)
        )

        self.fcD = nn.Sequential(
            nn.Linear(opt['txtSize'], opt['nt']),
            nn.BatchNorm1d(opt['nt']),
            nn.LeakyReLU(0.2, inplace=True),
        )

        self.netD = nn.Sequential(
            nn.Conv2d(ndf * 8 + opt['nt'], ndf * 8, kernel_size=1),
            nn.BatchNorm2d(ndf * 8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(ndf * 8, 1, kernel_size=4),
            nn.Sigmoid()
        )

    def forward(self, img, txt):
        conv_output = self.convD(img)
        replicated_txt = self.fcD(txt.float())
        concatenated_input = torch.cat((conv_output, replicated_txt.unsqueeze(2).unsqueeze(3).repeat(1, 1, conv_output.size(2), conv_output.size(3))), dim=1)
        output = self.netD(concatenated_input)
        return output.view(-1, 1).squeeze(1)


In [4]:
dataset = COCO2017('./train_annotations.json', r'C:\Users\danie\Documents\Data\COMP652\FinalProject\train2017', 64000)
data_loader = DataLoader(dataset, batch_size=64) # shuffle=True

# dictionary for discriminator model parameters
opt = {
    'txtSize': 59, # Determined by max_length in tokenizing
    'nt': 128,
    'ndf': 64,
    'nc': 3, # number of input channels, for RGB images in COCO is 3x64x64
}
# assign an instance of the Discriminator Class to a variable
discriminator = Discriminator(opt['nc'], opt['ndf'], opt).to('cuda')
print(discriminator)
# assign an instance of the CGenerator Class to a variable
generator = CGenerator(noise_in_dimension=100, noise_out_dimension=512, vocab_size=24784).to('cuda')
print(generator)
# define the loss function
loss = nn.BCELoss()
# define optimizers with learning rate and momentum (Beta1) defined in research paper for the generator
discriminator_optimizer = optim.Adam(params=discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))
generator_optimizer = optim.Adam(params=generator.parameters(), lr=0.0002, betas=(0.5, 0.999))

Discriminator(
  (convD): Sequential(
    (0): Conv2d(3, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (1): LeakyReLU(negative_slope=0.2, inplace=True)
    (2): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (3): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (4): LeakyReLU(negative_slope=0.2, inplace=True)
    (5): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (6): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (7): LeakyReLU(negative_slope=0.2, inplace=True)
    (8): Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (9): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (10): LeakyReLU(negative_slope=0.2, inplace=True)
  )
  (fcD): Sequential(
    (0): Linear(in_features=59, out_features=128, bias=True)
    (1): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)

In [5]:
epoch_count = 50
batch_size = 64
# discriminator_batch_loss_real_list = []
# discriminator_batch_loss_synthetic_list = []
discriminator_batch_loss_list = []
generator_batch_loss_list = []
discriminator_epoch_loss_list = []
generator_epoch_loss_list = []
best_generator_loss = np.inf
best_discriminator_loss = np.inf
for epoch in range(epoch_count): # train GAN epoch_count number of times
    epoch_start = time.time()
    for idx, (image, text) in enumerate(data_loader, 0): # iterate through the training data loader (get batched data)
        # batch_start = time.time()
        image = image.to('cuda')
        text = text.to('cuda')
        text_4g = torch.unsqueeze(text, 1).long()
        # set gradient for all discriminator parametrs to None
        discriminator.zero_grad()
        # get discriminator predictions
        predictions_real = discriminator(image, text)
        # vector of ones
        labels_real = torch.ones(predictions_real.size(0)).to('cuda')
        # get discriminator loss on real images
        d_batch_loss_real = loss(predictions_real, labels_real)
        # backwards propogation for discriminator (these gardients will be accumulated with the gradients from the synthetic image results)
        d_batch_loss_real.backward()
        # get noise from normal distribution for this batch round
        noise = torch.randn(batch_size, 100, device='cuda')
        # generate synthetic images
        synthetic = generator(noise, text_4g)
        # get discriminator predictions on synthetic images
        predictions_synthetic = discriminator(synthetic.detach(), text)
        # vector of zeros
        labels_synthetic = torch.zeros(predictions_synthetic.size(0)).to('cuda')
        # get discriminator loss on synthetic images
        d_batch_loss_synthetic = loss(predictions_synthetic, labels_synthetic)
        # backwards propogation for discriminator (these gardients are accumulated with the gradients from the real image results)
        d_batch_loss_synthetic.backward()
        # update discriminator
        discriminator_optimizer.step()  
        # set gradient for all generator parametrs to None
        generator.zero_grad(set_to_none=True)
        # get discriminator predictions on synthetic images
        predictions_synthetic_updated = discriminator(synthetic, text)
        # generator loss
        generator_batch_loss = loss(predictions_synthetic_updated, labels_real)
        # backwards propogation for generator
        generator_batch_loss.backward()
        # update generator
        generator_optimizer.step()

        # batch_end = time.time()
        # append discriminator batch loss
        discriminator_batch_loss_list.append(d_batch_loss_synthetic.item() + d_batch_loss_real.item()) 
        # append generator batch loss
        generator_batch_loss_list.append(generator_batch_loss.item())

        # print update every 300 batches
        # if idx % 300 == 0:
        #     print(f"Batch # {idx}==> Batch Loss (Discriminator): {discriminator_batch_loss_list[idx]}. Batch Loss (Generator): {generator_batch_loss}.") # Cumulative Average Batch Loss (Generator): {np.mean(generator_batch_loss_list)}
    epoch_end = time.time()
    print(f"Epoch {epoch}==> Duration (sec): {round((epoch_end - epoch_start),1)}. Average Batch Loss (Discriminator): {np.mean(discriminator_batch_loss_list)}. Average Batch Loss (Generator): {np.mean(generator_batch_loss_list)}.")
    generator_epoch_loss_list.append(np.mean(generator_batch_loss_list))
    discriminator_epoch_loss_list.append(np.mean(discriminator_batch_loss_list))


Epoch 0==> Duration (sec): 14.2. Average Batch Loss (Discriminator): 0.8571754950318635. Average Batch Loss (Generator): 3.3201395372748377.
Epoch 1==> Duration (sec): 13.0. Average Batch Loss (Discriminator): 0.8166059298200375. Average Batch Loss (Generator): 3.1149997739046813.
Epoch 2==> Duration (sec): 12.9. Average Batch Loss (Discriminator): 0.7617775377637175. Average Batch Loss (Generator): 3.1803221253504357.
Epoch 3==> Duration (sec): 13.0. Average Batch Loss (Discriminator): 0.7207302550051988. Average Batch Loss (Generator): 3.2325179138686506.
Epoch 4==> Duration (sec): 13.0. Average Batch Loss (Discriminator): 0.6867798178899086. Average Batch Loss (Generator): 3.298916923199594.
Epoch 5==> Duration (sec): 13.4. Average Batch Loss (Discriminator): 0.6569176209905914. Average Batch Loss (Generator): 3.369518345706165.
Epoch 6==> Duration (sec): 12.9. Average Batch Loss (Discriminator): 0.6307448239297907. Average Batch Loss (Generator): 3.46123771426401.
Epoch 7==> Durati

In [6]:
del generator
del discriminator
gc.collect()
torch.cuda.empty_cache()