Google Drive folder link for the whole project: https://drive.google.com/drive/folders/1tKA9agKcUH9rs0-_WzFxVVdeaMZOKx1l?usp=sharing

In [None]:
%matplotlib inline
%env CUBLAS_WORKSPACE_CONFIG=:4096:8
import argparse
import os
import random
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.optim as optim
import torch.utils.data
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import torchvision.utils as vutils
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from IPython.display import HTML
from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader
import pickle
from skimage import io
from pytorch_fid import fid_score, inception
from PIL import Image
import matplotlib
from torchmetrics.image.fid import FrechetInceptionDistance
from torchsummary import summary

random.seed(999)
torch.manual_seed(999)
torch.use_deterministic_algorithms(True) # Needed for reproducible results

device = torch.device("cuda:0" if (torch.cuda.is_available()) else "cpu")

In [None]:
dataroot = "data/flowers102"
batch_size = 128
image_size = 64

n_latent = 100
n_wordemb = 100 # word embedding size
n_textemb = 100 # sentence embedding size
n_vocab = 1004 # vocab size
n_feature = 64 # feature vector size
epochs = 200
learning_rate = 0.0002

In [None]:
class CustomDataset(Dataset):
    def __init__(self, path_to_pickle, root_dir, transform=None):
        assert os.path.exists(path_to_pickle)
        with open(path_to_pickle, "rb") as test_fh:
            self.image_cap_pairs = pickle.load(test_fh)

        self.root_dir = root_dir
        self.transform = transform

    def __len__(self):
        return len(self.image_cap_pairs)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        right_img_path = os.path.join(self.root_dir, self.image_cap_pairs[idx, 0])
        with open(right_img_path, "rb") as fh:
            right_image = Image.open(fh)
            right_image = right_image.convert("RGB")
            right_image = self.transform(right_image)
            
        wrong_img_path = os.path.join(self.root_dir, self.image_cap_pairs[idx, 1])
        with open(wrong_img_path, "rb") as fh:
            wrong_image = Image.open(fh)
            wrong_image = wrong_image.convert("RGB")
            wrong_image = self.transform(wrong_image)

        caption = self.image_cap_pairs[idx, 2]
        caption = torch.tensor(caption).int()
        
        return right_image, wrong_image, caption

In [None]:
dataset = CustomDataset(
    "train.pkl",
    "data/coco/train_64",
    transform=transforms.Compose(
        [
            #transforms.PILToTensor(),
            transforms.Resize(image_size),
            transforms.CenterCrop(image_size),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
        ]
    ),
)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)

real_batch = next(iter(dataloader))
plt.figure(figsize=(8, 8))
plt.axis("off")
plt.title("Training Images")
plt.imshow(
    np.transpose(
        vutils.make_grid(
            real_batch[0].to(device)[:64], padding=2, normalize=True
        ).cpu(),
        (1, 2, 0),
    )
)
print(len(dataset))

In [None]:
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)

In [None]:
class TextEncoder(nn.Module):
    def __init__(self, hidden_dim=100):
        super().__init__()
        
        self.embed = nn.Embedding(n_vocab, n_wordemb)
        self.gru = nn.GRU(n_wordemb, hidden_dim, batch_first=False)
        self.relu = nn.ReLU()
        self.fc = nn.Linear(hidden_dim, n_textemb)
    
    def forward(self, text):
        batch_size = text.shape[0]
        
        embedded = self.embed(text)
        out_gru, _ = self.gru(embedded, None)
        out_relu = self.relu(out_gru[:, -1])
        context_vector = self.fc(out_relu)
        return context_vector.view(batch_size, n_textemb, 1)

In [None]:
class TextEncoder(nn.Module):
    def __init__(self, hidden_dim=100):
        super().__init__()
        
        self.embed = nn.Embedding(n_vocab, n_wordemb)
        self.gru = nn.GRU(n_wordemb, hidden_dim, batch_first=False)
        self.relu = nn.ReLU()
        self.fc = nn.Linear(hidden_dim, n_textemb)
    
    def forward(self, text):
        batch_size = text.shape[0]
        
        embedded = self.embed(text)
        out_gru, _ = self.gru(embedded, None)
        out_relu = self.relu(out_gru[:, -1])
        context_vector = self.fc(out_relu)
        return context_vector.view(batch_size, n_textemb, 1)

In [None]:
# Generator Code
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.textEncoder = TextEncoder()
        self.layer1 = nn.Sequential(
            nn.ConvTranspose2d(
                n_latent + n_textemb, n_feature * 8, 4, 1, 0, bias=False
            ),
            nn.BatchNorm2d(n_feature * 8),
            nn.ReLU(True),
        )
        self.layer2 = nn.Sequential(
            nn.ConvTranspose2d(n_feature * 8, n_feature * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(n_feature * 4),
            nn.ReLU(True),
        )
        self.layer3 = nn.Sequential(
            nn.ConvTranspose2d(n_feature * 4, n_feature * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(n_feature * 2),
            nn.ReLU(True),
        )
        self.layer4 = nn.Sequential(
            nn.ConvTranspose2d(n_feature * 2, n_feature, 4, 2, 1, bias=False),
            nn.BatchNorm2d(n_feature),
            nn.ReLU(True),
        )
        self.output = nn.Sequential(
            nn.ConvTranspose2d(n_feature, 3, 4, 2, 1, bias=False), nn.Tanh()
        )

    def forward(self, input_noise, caption=None):
        batch_size = input_noise.size(0)
        if caption is None:
            caption = torch.zeros(batch_size, 17, dtype=int, device=device)

        caption_encoding = self.textEncoder(caption).unsqueeze(-1)
        feature_vector = torch.cat((input_noise, caption_encoding), 1)

        x1 = self.layer1(feature_vector)
        x2 = self.layer2(x1)
        x3 = self.layer3(x2)
        x4 = self.layer4(x3)
        return self.output(x4)

In [None]:
# Create the generator
netG = Generator().to(device)
netG.apply(weights_init)
summary(netG, (100, 1, 1))

In [None]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()

        self.textEncoder = TextEncoder()
        self.layer1 = nn.Sequential(
            nn.Conv2d(3, n_feature, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
        )
        self.layer2 = nn.Sequential(
            nn.Conv2d(n_feature, n_feature * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(n_feature * 2),
            nn.LeakyReLU(0.2, inplace=True),
        )
        self.layer3 = nn.Sequential(
            nn.Conv2d(n_feature * 2, n_feature * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(n_feature * 4),
            nn.LeakyReLU(0.2, inplace=True),
        )
        self.layer4 = nn.Sequential(
            nn.Conv2d(n_feature * 4, n_feature * 8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(n_feature * 8),
            nn.LeakyReLU(0.2, inplace=True),
        )

        self.output = nn.Sequential(
            nn.Conv2d(n_feature * 8 + n_textemb, 1, 4, 1, 0, bias=False), nn.Sigmoid()
        )

    def forward(self, input_image, caption=None):
        batch_size = input_image.size(0)
        if caption is None:
            caption = torch.zeros(batch_size, 17, dtype=int, device=device)

        caption_encoding = self.textEncoder(caption).unsqueeze(-1).repeat(1, 1, 4, 4)

        x1 = self.layer1(input_image)
        x2 = self.layer2(x1)
        x3 = self.layer3(x2)
        x4 = self.layer4(x3)

        x4_ = torch.cat((x4, caption_encoding), 1)

        x5 = self.output(x4_)
        return x5

In [None]:
# Create the Discriminator
netD = Discriminator().to(device)
netD.apply(weights_init)
summary(netD, (3, 128, 128))

In [None]:
criterion = nn.BCELoss()

fixed_noise = torch.randn(64, n_latent, 1, 1, device=device)

real_label = 1.
wrong_label = 0.
fake_label = 0.

optimizerD = optim.Adam(netD.parameters(), lr=learning_rate, betas=(0.5, 0.999))
optimizerG = optim.Adam(netG.parameters(), lr=learning_rate, betas=(0.5, 0.999))

In [None]:
torch.backends.cudnn.benchmark = True
# Training Loop

# Lists to keep track of progress
img_list = []
G_losses = []
D_losses = []
fids = []

fid = FrechetInceptionDistance(feature=64, normalize=True).to(device)

# For each epoch
for epoch in range(1, epochs + 1):
    # For each batch in the dataloader
    tqdm_dataloader = tqdm(dataloader)
    for i, data in enumerate(tqdm_dataloader, 1):
        tqdm_dataloader.set_description_str(f"Epoch: {epoch}/{epochs}")

        caption = data[2].to(device) if np.random.random() > 0.1 else None
        
        ############################
        # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
        ###########################
        ## Train with all-real batch
        netD.zero_grad()
        # Format batch
        real_image = data[0].to(device)
        b_size = real_image.size(0)
        label = torch.full((b_size,), real_label, dtype=torch.float, device=device)
        # Forward pass real batch through D
        output = netD(real_image, caption).view(-1)
        # Calculate loss on all-real batch
        errD_real = criterion(output, label)
        # Calculate gradients for D in backward pass
        errD_real.backward()
        D_x = output.mean().item()

        ## Train with all-wrong batch
        netD.zero_grad()
        # Format batch
        wrong_image = data[1].to(device)
        # Forward pass real batch through D
        output = netD(wrong_image, caption).view(-1)
        # Calculate loss on all-real batch
        errD_wrong = criterion(output, label)
        # Calculate gradients for D in backward pass
        errD_wrong.backward()
        D_x += output.mean().item()

        ## Train with all-fake batch
        # Generate batch of latent vectors
        noise = torch.randn(b_size, n_latent, 1, 1, device=device)
        # Generate fake image batch with G
        fake = netG(noise, caption)
        label.fill_(fake_label)
        # Classify all fake batch with D
        output = netD(fake.detach(), caption).view(-1)
        # Calculate D's loss on the all-fake batch
        errD_fake = criterion(output, label)
        # Calculate the gradients for this batch, accumulated (summed) with previous gradients
        errD_fake.backward()
        D_G_z1 = output.mean().item()
        # Compute error of D as sum over the fake and the real batches
        errD = errD_real + errD_wrong + errD_fake
        
        # Update D
        optimizerD.step()

        ############################
        # (2) Update G network: maximize log(D(G(z)))
        ###########################
        netG.zero_grad()
        label.fill_(real_label)  # fake labels are real for generator cost
        # Since we just updated D, perform another forward pass of all-fake batch through D
        output = netD(fake, caption).view(-1)
        # Calculate G's loss based on this output
        errG = criterion(output, label)
        # Calculate gradients for G
        errG.backward()
        D_G_z2 = output.mean().item()
        # Update G
        optimizerG.step()
        
        fid.update(real_image, real=True)
        fid.update(fake, real=False)
        
        
        tqdm_dataloader.set_postfix_str(
            f"Loss_D: {errD.item():.4f}, Loss_G: {errG.item():.4f}, D(x): {D_x:.4f}, D(G(z)): {D_G_z1:.4f} / {D_G_z2:.4f}"
        )

        # Save Losses for plotting later
        G_losses.append(errG.item())
        D_losses.append(errD.item())
    
    fids.append(fid.compute().item())
    fid.reset()
        
    tqdm_dataloader.set_postfix_str(
        f"Loss_D: {errD.item():.4f}, Loss_G: {errG.item():.4f}, D(x): {D_x:.4f}, D(G(z)): {D_G_z1:.4f} / {D_G_z2:.4f}, fid: {fids[-1]}"
    )
    
    # Check how the generator is doing by saving G's output on fixed_noise
    with torch.no_grad():
        fake = netG(fixed_noise).detach().cpu()
    img_list.append(vutils.make_grid(fake, padding=2, normalize=True))
    
    torch.save(netG.state_dict(), f"model_saves/generator_E{epoch}.pth")
    torch.save(netD.state_dict(), f"model_saves/discriminator_E{epoch}.pth")
    if fids[-1] < min([*fids[:-1],np.inf]):
        best_generator = f"model_saves/generator_E{epoch}.pth"
        best_discriminator = f"model_saves/discriminator_E{epoch}.pth"

In [None]:
print(fids)

plt.figure(figsize=(10,5))
plt.title("Generator and Discriminator Loss vs Iterations")
plt.plot(G_losses,label="Generator")
plt.plot(D_losses,label="Discriminator")
plt.xlabel("iterations")
plt.ylabel("Loss")
plt.legend()
plt.show()


plt.figure(figsize=(10,5))
plt.title("FID Score vs Iterations")
plt.plot(fids)
plt.xlabel("iterations")
plt.ylabel("Loss")
plt.show()

In [None]:
fig = plt.figure(figsize=(8,8))
plt.axis("off")
ims = [[plt.imshow(np.transpose(i,(1,2,0)), animated=True)] for i in img_list]
matplotlib.rcParams["animation.embed_limit"] = 500
ani = animation.ArtistAnimation(fig, ims, interval=1000, repeat_delay=1000, blit=True)

HTML(ani.to_jshtml())

In [None]:
netG_best = Generator().to(device)
netG_best.load_state_dict(torch.load(best_generator))
netG_best.eval()

netD_best = Discriminator().to(device)
netD_best.load_state_dict(torch.load(best_discriminator))
netG_best.eval()

real_batch = next(iter(dataloader))


plt.figure(figsize=(15,15))
plt.subplot(1,2,1)
plt.axis("off")
plt.title("Real Images")
plt.imshow(np.transpose(vutils.make_grid(real_batch[0].to(device)[:64], padding=2, normalize=True).cpu(),(1,2,0)))

plt.subplot(1,2,2)
plt.axis("off")
plt.title("Fake Images")
plt.imshow(np.transpose(img_list[-1],(1,2,0)))
plt.show()

In [None]:
netG_current = Generator().to(device)
netG_current.load_state_dict(torch.load("model_saves/generator_E200.pth"))
netG_current.eval()

netD_current = Discriminator().to(device)
netD_current.load_state_dict(torch.load("model_saves/discriminator_E200.pth"))
netG_current.eval()

In [None]:
caption = torch.zeros((1, 17), dtype=int, device=device)
caption[0, 0:8] = torch.tensor([1, 4, 166, 77, 54, 48, 78, 2])

noise = torch.randn(9, n_latent, 1, 1, device=device)
with torch.no_grad():
    fake_images = netG(noise).detach().cpu() 
plt.figure(figsize=(10, 10))
plt.imshow(np.transpose(vutils.make_grid(fake_images, nrow=3, padding=5, normalize=True, pad_value=1),(1,2,0)))
#plt.title("Caption:\nPetals of this flower are pink with a yellow centre\n")
plt.axis("off")