#**StackGAN: Text to Photo-realistic Image Synthesis with Stacked Generative Adversarial Networks**

In this project, I have implemented the StackGAN Model

**StackGAN Paper** : [StackGAN](https://arxiv.org/pdf/1612.03242v2.pdf)

**Dataset Download** - [CUB-Dataset Images](http://www.vision.caltech.edu/visipedia/CUB-200-2011.html)
, [CUB-Dataset Text Description](https://drive.google.com/file/d/0B3y_msrWZaXLT1BZdVdycDY5TEE/view?resourcekey=0-sZrhftoEfdvHq6MweAeCjA)

**Special Note**

All file paths used below as specific to my system. Please change them accordingly if you are using the code as it is.

In [1]:
!pip install transformers
!pip install sentence-transformers

import pandas as pd
import numpy as np
import random
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
import torch.optim as optim
from torchvision.utils import save_image
from PIL import Image
import PIL
import pickle
from torch.utils.data import Dataset
from glob import glob
import time
import gc
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
%matplotlib inline
from sentence_transformers import SentenceTransformer, util

In [2]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


## **Downloading the Data**

In [None]:
%cd /content
# !gdown --id 1hbzc_P1FuxMkcabkgn9ZKinBwW683j45      # Uncomment this if you are downloading the dataset for the first time
!tar -xzvf /content/drive/MyDrive/CUB_200_2011.tgz

!gdown --id 0B3y_msrWZaXLT1BZdVdycDY5TEE
!unzip /content/birds.zip

## **Loading the Data**

In [4]:
PATH = '/content/drive/MyDrive/CUB_Dataset'
batch_size = 64
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
''' Dataloader for StageI Training'''

class CUB_Dataset(Dataset):
    def __init__(self, root_dir, mode = 'train', transform = True):
        self.data_dir = root_dir
        self.mode = mode
        self.transforms = transform
        self.bbox = self.load_bbox()     
        self._init_dataset()
        self.filenames = self.load_filenames()
        if transform:
            self._init_transform()

    def _init_dataset(self):
        self.files = []
        self.text_files = []
        dirs = sorted(os.listdir(os.path.join("/content/CUB_200_2011/images")))
        if self.mode == 'train': 
            for dir in range(len(dirs)):
                files = sorted(glob(os.path.join("/content/CUB_200_2011/images", dirs[dir], '*.jpg')))         
                self.files += files
                text_file = sorted(glob(os.path.join("/content/birds/text_c10", dirs[dir], '*.txt')))
                self.text_files += text_file

        else:
            print("No Such Dataset Mode")
            return None

    def load_bbox(self):
        bbox_path = os.path.join('/content/CUB_200_2011/bounding_boxes.txt')
        df_bounding_boxes = pd.read_csv(bbox_path,
                                        delim_whitespace=True,
                                        header=None).astype(int)
        filepath = os.path.join('/content/CUB_200_2011/images.txt')
        df_filenames = pd.read_csv(filepath, delim_whitespace=True, header=None)
        filenames = df_filenames[1].tolist()
        filename_bbox = {img_file[:-4]: [] for img_file in filenames}
        numImgs = len(filenames)
        for i in range(0, numImgs):
            # bbox = [x-left, y-top, width, height]
            bbox = df_bounding_boxes.iloc[i][1:].tolist()

            key = filenames[i][:-4]
            filename_bbox[key] = bbox
        return filename_bbox

    def get_img(self, img_path, bbox):
        img = Image.open(img_path).convert('RGB')
        width, height = img.size
        if bbox is not None:
            R = int(np.maximum(bbox[2], bbox[3]) * 0.75)
            center_x = int((2 * bbox[0] + bbox[2]) / 2)
            center_y = int((2 * bbox[1] + bbox[3]) / 2)
            y1 = np.maximum(0, center_y - R)
            y2 = np.minimum(height, center_y + R)
            x1 = np.maximum(0, center_x - R)
            x2 = np.minimum(width, center_x + R)
            img = img.crop([x1, y1, x2, y2])
        load_size = int(64 * 76 / 64)
        img = img.resize((load_size, load_size), PIL.Image.BILINEAR)
        if self.transform is not None:
            img = self.transform(img)
        return img

    def load_filenames(self):
        filepath = os.path.join('/content/birds/train/filenames.pickle')
        with open(filepath, 'rb') as f:
            filenames1 = pickle.load(f)
        filepath = os.path.join('/content/birds/test/filenames.pickle')
        with open(filepath, 'rb') as f:
            filenames2 = pickle.load(f)
        filenames = filenames1 + filenames2
        return filenames

    def _init_transform(self):
        self.transform = transforms.Compose([
            transforms.RandomCrop((64, 64)),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])
    
    def __getitem__(self, index):
        key = self.filenames[index]
        bbox = self.bbox[key]
        img = self.get_img(self.files[index], bbox)
        text = self.text_files[index]
        return img, text

    def __len__(self):
        return len(self.filenames)

trainset = CUB_Dataset(root_dir='/content/CUB_200_2011/images')
trainloader = DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=2, drop_last=True)

In [5]:
'''Dataloader for StageII Training'''

class CUB_Dataset_2(Dataset):
    def __init__(self, root_dir, mode = 'train', transform = True):
        self.data_dir = root_dir
        self.mode = mode
        self.transforms = transform
        self.bbox = self.load_bbox()     
        self._init_dataset()
        self.filenames = self.load_filenames()
        if transform:
            self._init_transform()

    def _init_dataset(self):
        self.files = []
        self.text_files = []
        self.stage1_img_files = []
        dirs = sorted(os.listdir(os.path.join("/content/CUB_200_2011/images")))
        if self.mode == 'train': 
            for dir in range(len(dirs)):
                files = sorted(glob(os.path.join("/content/CUB_200_2011/images", dirs[dir], '*.jpg')))         
                self.files += files
                text_file = sorted(glob(os.path.join("/content/birds/text_c10", dirs[dir], '*.txt')))
                self.text_files += text_file
                stage_1_files = sorted(glob(os.path.join("/content/drive/MyDrive/StackGAN/StackGAN-1_Images", dirs[dir], '*.png')))
                self.stage1_img_files += stage_1_files

        else:
            print("No Such Dataset Mode")
            return None

    def load_bbox(self):
        bbox_path = os.path.join('/content/CUB_200_2011/bounding_boxes.txt')
        df_bounding_boxes = pd.read_csv(bbox_path,
                                        delim_whitespace=True,
                                        header=None).astype(int)
        filepath = os.path.join('/content/CUB_200_2011/images.txt')
        df_filenames = pd.read_csv(filepath, delim_whitespace=True, header=None)
        filenames = df_filenames[1].tolist()
        filename_bbox = {img_file[:-4]: [] for img_file in filenames}
        numImgs = len(filenames)
        for i in range(0, numImgs):
            # bbox = [x-left, y-top, width, height]
            bbox = df_bounding_boxes.iloc[i][1:].tolist()

            key = filenames[i][:-4]
            filename_bbox[key] = bbox
        return filename_bbox

    def get_img(self, img_path, bbox):
        img = Image.open(img_path).convert('RGB')
        width, height = img.size
        if bbox is not None:
            R = int(np.maximum(bbox[2], bbox[3]) * 0.75)
            center_x = int((2 * bbox[0] + bbox[2]) / 2)
            center_y = int((2 * bbox[1] + bbox[3]) / 2)
            y1 = np.maximum(0, center_y - R)
            y2 = np.minimum(height, center_y + R)
            x1 = np.maximum(0, center_x - R)
            x2 = np.minimum(width, center_x + R)
            img = img.crop([x1, y1, x2, y2])
        load_size = int(256 * 76 / 64)
        img = img.resize((load_size, load_size), PIL.Image.BILINEAR)
        if self.transform is not None:
            img = self.transform(img)
        return img

    def load_filenames(self):
        filepath = os.path.join('/content/birds/train/filenames.pickle')
        with open(filepath, 'rb') as f:
            filenames1 = pickle.load(f)
        filepath = os.path.join('/content/birds/test/filenames.pickle')
        with open(filepath, 'rb') as f:
            filenames2 = pickle.load(f)
        filenames = filenames1 + filenames2
        return filenames

    def _init_transform(self):
        self.transform1 = transforms.Compose([
            #transforms.RandomCrop((64, 64)),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])
        self.transform = transforms.Compose([
            transforms.RandomCrop((256, 256)),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])
    
    def __getitem__(self, index):
        key = self.filenames[index]
        bbox = self.bbox[key]
        img = self.get_img(self.files[index], bbox)
        text = self.text_files[index]
        stageI_img = Image.open(self.stage1_img_files[index]).convert('RGB')
        stageI_img = self.transform1(stageI_img)
        return stageI_img, img, text

    def __len__(self):
        return len(self.stage1_img_files)

trainset2 = CUB_Dataset_2(root_dir='/content/CUB_200_2011/images')
trainloader2 = DataLoader(trainset2, batch_size=batch_size, shuffle=True, num_workers=2, drop_last=True, pin_memory = True)

## **StackGAN Model**




In [None]:
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv2d') != -1:
        m.weight.data.normal_(0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        m.weight.data.normal_(1.0, 0.02)
        m.bias.data.fill_(0)
    elif classname.find('Linear') != -1:
        m.weight.data.normal_(0.0, 0.02)
        if m.bias is not None:
            m.bias.data.fill_(0.0)

class Conditioning_Augmentation_StageI(nn.Module):
    def __init__(self):
        super(Conditioning_Augmentation_StageI, self).__init__()

        self.fc1 = nn.Linear(768, 256)
        self.relu = nn.ReLU()

    def forward(self, x):
        
        x = x.to(device)
        y = self.relu(self.fc1(x))
        u0 = y[:, :128]
        logvar = y[:, 128:]
        sigma0 = torch.exp(logvar/2)
        epsilon = torch.randn((x.shape[0], 128)).to(device)
        out = u0 + sigma0*epsilon
        return out, u0, logvar

class StageI_GAN_Gen(nn.Module):
    def __init__(self, condaug1):
        super(StageI_GAN_Gen, self).__init__()

        # In: [batch_size, 128]
        self.CA1 = condaug1()

        self.fc = nn.Sequential(
            nn.Linear(228, 4*4*128*8),
            nn.BatchNorm1d(4*4*128*8),
            nn.ReLU(True))
        
        self.upsample1 = nn.Upsample(scale_factor=2, mode='nearest')
        self.conv1 = nn.Conv2d(128*8, 64*8, kernel_size=3, stride=1, padding=1, bias = False)
        self.batchnorm1 = nn.BatchNorm2d(64*8)

        self.upsample2 = nn.Upsample(scale_factor=2, mode='nearest')
        self.conv2 = nn.Conv2d(64*8, 32*8, kernel_size=3, stride=1, padding=1, bias = False)
        self.batchnorm2 = nn.BatchNorm2d(32*8)

        self.upsample3 = nn.Upsample(scale_factor=2, mode='nearest')
        self.conv3 = nn.Conv2d(32*8, 16*8, kernel_size=3, stride=1, padding=1, bias = False)
        self.batchnorm3 = nn.BatchNorm2d(16*8)

        self.upsample4 = nn.Upsample(scale_factor=2, mode='nearest')
        self.conv4 = nn.Conv2d(16*8, 8*8, kernel_size=3, stride=1, padding=1, bias = False)
        self.batchnorm4 = nn.BatchNorm2d(8*8)

        self.conv5 = nn.Conv2d(8*8, 3, kernel_size=3, stride=1, padding=1, bias = False)
        self.tanh = nn.Tanh()
        self.relu = nn.ReLU()

    def forward(self, x):

        x = x.to(device)
        x, u0, logvar = self.CA1(x)
        z = torch.randn((x.shape[0], 100)).to(device)
        x = torch.cat((x, z), 1)
        x = self.fc(x)
        x = torch.reshape(x, (-1, 128*8, 4, 4))
        x = self.relu(self.batchnorm1(self.conv1(self.upsample1(x))))
        x = self.relu(self.batchnorm2(self.conv2(self.upsample2(x))))
        x = self.relu(self.batchnorm3(self.conv3(self.upsample3(x))))
        x = self.relu(self.batchnorm4(self.conv4(self.upsample4(x))))
        x = self.tanh(self.conv5(x))

        return x, u0, logvar

class DownSample1(nn.Module):
    def __init__(self):
        super(DownSample1, self).__init__()

        self.conv1 = nn.Conv2d(3, 64, kernel_size=4, stride=2, padding=1, bias = False)
        
        self.conv2 = nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1, bias = False)
        self.batchnorm2 = nn.BatchNorm2d(128)

        self.conv3 = nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1, bias = False)
        self.batchnorm3 = nn.BatchNorm2d(256)

        self.conv4 = nn.Conv2d(256, 512, kernel_size=4, stride=2, padding=1, bias = False)
        self.batchnorm4 = nn.BatchNorm2d(512)

        self.leakyrelu = nn.LeakyReLU(0.2, inplace = True)

    def forward(self, x):
        
        x = x.to(device)
        x = self.leakyrelu(self.conv1(x))
        x = self.leakyrelu(self.batchnorm2(self.conv2(x)))
        x = self.leakyrelu(self.batchnorm3(self.conv3(x)))
        x = self.leakyrelu(self.batchnorm4(self.conv4(x)))

        return x

class StageI_GAN_Dis(nn.Module):
    def __init__(self, downsample):
        super(StageI_GAN_Dis, self).__init__()

        self.fc1 = nn.Linear(768, 128)
        self.downsample = downsample()
        self.conv1 = nn.Conv2d(640, 512, kernel_size=1, stride=1, bias = False)
        self.batchnorm1 = nn.BatchNorm2d(512)
        self.leakyrelu = nn.LeakyReLU(0.2)

        self.conv2 = nn.Conv2d(512, 1, kernel_size = 4, stride = 1)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x, text):
        
        x = x.to(device)
        text = text.to(device)
        x = self.downsample(x)   
        text = self.fc1(text)
        text = text.unsqueeze(2)
        text = text.unsqueeze(3)
        text1 = torch.cat((text, text, text, text), 2)
        text = torch.cat((text1, text1, text1, text1), 3)
        x = torch.cat((x, text), 1)
        x = self.leakyrelu(self.batchnorm1(self.conv1(x))) 
        x = self.conv2(x)
        x = torch.squeeze(x, 3)
        x = torch.squeeze(x, 2)
        x = self.sigmoid(x)

        return x

class Conditioning_Augmentation_StageII(nn.Module):
    def __init__(self):
        super(Conditioning_Augmentation_StageII, self).__init__()

        self.fc1 = nn.Linear(768, 256)
        self.relu = nn.ReLU()

    def forward(self, x):
        
        x = x.to(device)
        #print(x.shape)
        y = self.relu(self.fc1(x))
        u0 = y[:, :128]
        logvar = y[:, 128:]
        sigma0 = torch.exp(logvar/2)
        epsilon = torch.randn((x.shape[0], 128)).to(device)
        out = u0 + sigma0*epsilon
        return out, u0, logvar


class DownSample2(nn.Module):
    def __init__(self):
        super(DownSample2, self).__init__()

        self.conv1 = nn.Conv2d(3, 128, kernel_size=3, stride=1, padding=1, bias = False)
        
        self.conv2 = nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1, bias = False)
        self.batchnorm2 = nn.BatchNorm2d(256)

        self.conv3 = nn.Conv2d(256, 512, kernel_size=4, stride=2, padding=1, bias = False)
        self.batchnorm3 = nn.BatchNorm2d(512)

        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):

        x = x.to(device)
        x = self.relu(self.conv1(x))
        x = self.relu(self.batchnorm2(self.conv2(x)))
        x = self.relu(self.batchnorm3(self.conv3(x)))

        return x

class ResidualBlock(nn.Module):
    def __init__(self):
        super(ResidualBlock, self).__init__()

        self.conv1 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias = False)
        self.batchnorm1 = nn.BatchNorm2d(512)

        self.conv2 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias = False)
        self.batchnorm2 = nn.BatchNorm2d(512)

        self.relu = nn.ReLU()

    def forward(self, x):

        x = x.to(device)
        identity = x
        x = self.relu(self.batchnorm1(self.conv1(x)))
        x = self.batchnorm2(self.conv2(x))
        x = x + identity
        x = self.relu(x)

        return x

class UpSampling2(nn.Module):
    def __init__(self):
        super(UpSampling2, self).__init__()

        self.upsample = nn.Upsample(scale_factor=2, mode='nearest')
        self.conv1 = nn.Conv2d(512, 256, kernel_size=3, stride=1, padding=1, bias = False)
        self.batchnorm1 = nn.BatchNorm2d(256)

        self.conv2 = nn.Conv2d(256, 128, kernel_size=3, stride=1, padding=1, bias = False)
        self.batchnorm2 = nn.BatchNorm2d(128)

        self.conv3 = nn.Conv2d(128, 64, kernel_size=3, stride=1, padding=1, bias = False)
        self.batchnorm3 = nn.BatchNorm2d(64)

        self.conv4 = nn.Conv2d(64, 32, kernel_size=3, stride=1, padding=1, bias = False)
        self.batchnorm4 = nn.BatchNorm2d(32)

        self.conv5 = nn.Conv2d(32, 3, kernel_size=3, stride=1, padding=1, bias = False)
        self.relu = nn.ReLU()

    def forward(self, x):

        x = x.to(device)
        x = self.relu(self.batchnorm1(self.conv1(self.upsample(x))))
        x = self.relu(self.batchnorm2(self.conv2(self.upsample(x))))
        x = self.relu(self.batchnorm3(self.conv3(self.upsample(x))))
        x = self.relu(self.batchnorm4(self.conv4(self.upsample(x))))
        x = self.conv5(x)

        return x

class StageII_GAN_Gen(nn.Module):
    def __init__(self, downsample, resblock, upsample, condaug2):
        super(StageII_GAN_Gen, self).__init__()

        self.downsample = downsample()
        self.resblock = resblock()
        self.upsample = upsample()
        self.CA2 = condaug2()
        self.conv = nn.Conv2d(640, 512, kernel_size=3, stride=1, padding=1, bias = False)
        self.batchnorm = nn.BatchNorm2d(512)
        self.relu = nn.ReLU(True)
        self.tanh = nn.Tanh()
    def forward(self, x, text):

        x = x.to(device)
        text = text.to(device)
        text, u0, logvar = self.CA2(text)
        text = text.unsqueeze(2)
        text = text.unsqueeze(3)
        text = text.repeat(1, 1, 16, 16)
        x = self.downsample(x)
        x = torch.cat((x, text), 1)
        x = self.relu(self.batchnorm(self.conv(x)))
        x = self.resblock(self.resblock(self.resblock(self.resblock(x))))
        x = self.upsample(x)
        x = self.tanh(x)

        return x, u0, logvar

class DownSample3(nn.Module):
    def __init__(self):
        super(DownSample3, self).__init__()

        self.conv1 = nn.Conv2d(3, 64, kernel_size = 4, stride = 2, padding = 1, bias = False)
        
        self.conv2 = nn.Conv2d(64, 128, kernel_size = 4, stride = 2, padding = 1, bias = False)
        self.batchnorm2 = nn.BatchNorm2d(128)

        self.conv3 = nn.Conv2d(128, 256, kernel_size = 4, stride = 2, padding = 1, bias = False)
        self.batchnorm3 = nn.BatchNorm2d(256)

        self.conv4 = nn.Conv2d(256, 512, kernel_size = 4, stride = 2, padding = 1, bias = False)
        self.batchnorm4 = nn.BatchNorm2d(512)

        self.conv5 = nn.Conv2d(512, 1024, kernel_size = 4, stride = 2, padding = 1, bias = False)
        self.batchnorm5 = nn.BatchNorm2d(1024)

        self.conv6 = nn.Conv2d(1024, 2048, kernel_size = 4, stride = 2, padding = 1, bias = False)
        self.batchnorm6 = nn.BatchNorm2d(2048)

        self.conv7 = nn.Conv2d(2048, 1024, kernel_size = 3, stride = 1, padding = 1, bias = False)
        self.batchnorm7 = nn.BatchNorm2d(1024)

        self.conv8 = nn.Conv2d(1024, 512, kernel_size = 3, stride = 1, padding = 1, bias = False)
        self.batchnorm8 = nn.BatchNorm2d(512)
        self.leakyrelu = nn.LeakyReLU(0.2)

    def forward(self, x):

        x = x.to(device)
        x = self.leakyrelu(self.conv1(x))
        x = self.leakyrelu(self.batchnorm2(self.conv2(x)))
        x = self.leakyrelu(self.batchnorm3(self.conv3(x)))
        x = self.leakyrelu(self.batchnorm4(self.conv4(x)))
        x = self.leakyrelu(self.batchnorm5(self.conv5(x)))
        x = self.leakyrelu(self.batchnorm6(self.conv6(x)))
        x = self.leakyrelu(self.batchnorm7(self.conv7(x)))
        x = self.leakyrelu(self.batchnorm8(self.conv8(x)))

        return x

class StageII_GAN_Dis(nn.Module):
    def __init__(self, downsample):
        super(StageII_GAN_Dis, self).__init__()
        
        self.fc0 = nn.Linear(768, 128)
        self.downsample = downsample()
        self.conv1 = nn.Conv2d(640, 512, kernel_size=3, stride=1, padding = 1, bias = False)
        self.batchnorm1 = nn.BatchNorm2d(512)
        self.leakyrelu = nn.LeakyReLU(0.2)
        self.conv2 = nn.Conv2d(512, 1, kernel_size = 4, stride = 4)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x, text):
        
        x = x.to(device)
        text = text.to(device)
        x = self.downsample(x)  
        text = self.fc0(text)
        text = text.unsqueeze(2)
        text = text.unsqueeze(3)
        text = text.repeat(1, 1, 4, 4)
        x = torch.cat((x, text), 1)
        x = self.leakyrelu(self.batchnorm1(self.conv1(x))) 
        x = self.sigmoid(self.conv2(x))  
        x = x.squeeze(3)
        x = x.squeeze(2)

        return x

StageI_Gen = StageI_GAN_Gen(Conditioning_Augmentation_StageI).to(device)
StageI_Gen = StageI_Gen.apply(weights_init)
StageI_Dis = StageI_GAN_Dis(DownSample1).to(device)
StageI_Dis = StageI_Dis.apply(weights_init)
StageII_Gen = StageII_GAN_Gen(DownSample2, ResidualBlock, UpSampling2, Conditioning_Augmentation_StageII).to(device)
StageII_Gen = StageII_Gen.apply(weights_init)
StageII_Dis = StageII_GAN_Dis(DownSample3).to(device)
StageII_Dis = StageII_Dis.apply(weights_init)
sbert_model = SentenceTransformer('paraphrase-mpnet-base-v2')

## **Training**

In [8]:
epoch_D1losses = []             
epoch_G1losses = []
epoch_D2losses = []             
epoch_G2losses = []
epoch_Real_Score = []
epoch_Fake_Score = []
epoch_Generator_Score = []

In [None]:
epochs = 600
lrG = 0.0002
lrD = 0.0002

optimizerD1 = torch.optim.Adam(StageI_Dis.parameters(), lr=lrD, betas=(0.5,0.999))
optimizerG1 = torch.optim.Adam(StageI_Gen.parameters(), lr=lrG, betas=(0.5,0.999))

BCEloss = nn.BCELoss()

def KL_loss(mu, logvar):
    # -0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
    KLD_element = mu.pow(2).add_(logvar.exp()).mul_(-1).add_(1).add_(logvar)
    KLD = torch.mean(KLD_element).mul_(-0.5)
    return KLD

def train_StageI_Dis(real_images, wrong_images, text, optimizer):

    optimizer.zero_grad()

    real_images = real_images.to(device)
    text = text.to(device)
    real_pred = StageI_Dis(real_images, text)
    real_targets = torch.ones(real_images.size(0),1)
    real_pred = real_pred.to(device)
    real_targets = real_targets.to(device)
    real_loss = BCEloss(real_pred, real_targets)
    real_score = torch.mean(real_pred).item()

    fake_images, mu, logvar = StageI_Gen(text)
    fake_pred1 = StageI_Dis(fake_images, text)
    fake_targets1 = torch.zeros(fake_images.size(0),1)
    fake_pred1 = fake_pred1.to(device)
    fake_targets1 = fake_targets1.to(device)
    fake_loss1 = BCEloss(fake_pred1, fake_targets1)
    fake_score1 = torch.mean(fake_pred1).item()

    wrong_images = wrong_images.to(device)
    fake_pred2 = StageI_Dis(wrong_images, text)
    fake_targets2 = torch.zeros(wrong_images.size(0),1)
    fake_pred2 = fake_pred2.to(device)
    fake_targets2 = fake_targets2.to(device)
    fake_loss2 = BCEloss(fake_pred2, fake_targets2)
    fake_score2 = torch.mean(fake_pred2).item()

    discriminator_loss = real_loss + (fake_loss1 + fake_loss2)/2
    discriminator_loss.backward()
    optimizer.step()

    return discriminator_loss.item(), real_score, (fake_score1+fake_score2)/2


def train_StageI_Gen(text, optimizer):

    optimizer.zero_grad()

    text = text.to(device)
    generator_images, mu, logvar = StageI_Gen(text)
    generator_pred = StageI_Dis(generator_images, text)
    generator_targets = torch.ones(batch_size, 1)
    generator_pred = generator_pred.to(device)
    generator_targets = generator_targets.to(device)
    gen_bin_loss = BCEloss(generator_pred, generator_targets)
    generator_score = torch.mean(generator_pred).item()
    kl_loss = KL_loss(mu, logvar)

    generator_loss = gen_bin_loss + 2*kl_loss
    generator_loss.backward()
    optimizer.step()

    return generator_loss.item(), gen_bin_loss, kl_loss, generator_score

def epoch_time(start_time, end_time):
    elapsed_time = end_time - start_time
    elapsed_mins = int(elapsed_time / 60)
    elapsed_secs = int(elapsed_time - (elapsed_mins * 60))
    return elapsed_mins, elapsed_secs

'''Load the weights if you have trained already'''

# if os.path.isfile("/content/drive/MyDrive/StackGAN/checkpoints/StageI_Dis_GAN.pt"):
#     checkpointD1 = torch.load('/content/drive/MyDrive/StackGAN/checkpoints/StageI_Dis_GAN.pt')
#     StageI_Dis.load_state_dict(checkpointD1['model_state_dict'])
#     StageI_Dis.to(device)
#     optimizerD1.load_state_dict(checkpointD1['optimizer_state_dict'])
#     epoch = checkpointD1['epoch']
#     best_D1loss = checkpointD1['loss']

# if os.path.isfile("/content/drive/MyDrive/StackGAN/checkpoints/StageI_Gen_GAN.pt"):
#     checkpointG1 = torch.load('/content/drive/MyDrive/StackGAN/checkpoints/StageI_Gen_GAN.pt')
#     StageI_Gen.load_state_dict(checkpointG1['model_state_dict'])
#     StageI_Gen.to(device)
#     optimizerG1.load_state_dict(checkpointG1['optimizer_state_dict'])
#     epoch = checkpointG1['epoch']
#     best_G1loss = checkpointG1['loss']

def save_samples(index1, text, show=True):
    fake_images, a, b = StageI_Gen(text)
    fake_images = fake_images[0:16,:,:,:]
    fake_fname = 'generated-images-{}.png'.format(index1)
    save_image((fake_images), os.path.join("/content/drive/MyDrive/GAN Images/Birds/birds-5", fake_fname), nrow=4)
    if show:
        fig, ax = plt.subplots(figsize=(8, 8))
        ax.set_xticks([]); ax.set_yticks([])
        ax.imshow(make_grid(fake_images.cpu().detach(), nrow=8).permute(1, 2, 0))

''' Stage1 Training'''

for epoch in range(epochs):
    
    start_time = time.monotonic()
    
    print(f"Epoch: {epoch + 1}")
    train_D1Loss_batch = []
    train_G1Loss_batch = []
    train_real_score = []
    train_fake_score = []
    train_generator_score = []

    # model.eval()
    for idx,(real_images, text) in enumerate(trainloader):

        '''If you want to use HuggingFace Sentence Transformer'''
        text = list(text)
        embedding = []
        for i in range(len(text)):
            my_file = open(text[i], "r")
            content = my_file.read()
            embedding.append(content)
        emb = sbert_model.encode(embedding)
        emb = torch.from_numpy(emb)

        wrong_images = torch.flip(real_images, [0])
        discriminator_loss, real_score, fake_score = train_StageI_Dis(real_images, wrong_images, emb, optimizerD1)
        generator_loss, a, b, generator_score = train_StageI_Gen(emb, optimizerG1)
        train_D1Loss_batch.append(discriminator_loss)
        train_G1Loss_batch.append(generator_loss)
        train_real_score.append(real_score)
        train_fake_score.append(fake_score)
        train_generator_score.append(generator_score)

        if (idx+1)%180 == 0:
            emb_save = emb[0:16, :]
            save_samples(epoch+1, emb_save, show=False)

        torch.cuda.empty_cache()
        torch.cuda.reset_peak_memory_stats()
        torch.cuda.empty_cache()
        gc.collect()

    if (epoch+1)%50 == 0:
        lrG = lrG/2
        lrD = lrD/2
        print(f"Learning Rate Halved: {lrG} {lrD}")

    epoch_D1losses.append(sum(train_D1Loss_batch)/len(trainloader))
    epoch_G1losses.append(sum(train_G1Loss_batch)/len(trainloader))
    epoch_Real_Score.append(sum(train_real_score)/len(trainloader))
    epoch_Fake_Score.append(sum(train_fake_score)/len(trainloader))
    epoch_Generator_Score.append(sum(train_generator_score)/len(trainloader))

    torch.save({
        'epoch': epoch,
        'model_state_dict': StageI_Dis.state_dict(),
        'optimizer_state_dict': optimizerD1.state_dict(),
        'loss': epoch_D1losses[-1],
        }, '/content/drive/MyDrive/StackGAN/checkpoints/StageI_Dis_GAN.pt')

    torch.save({
        'epoch': epoch,
        'model_state_dict': StageI_Gen.state_dict(),
        'optimizer_state_dict': optimizerG1.state_dict(),
        'loss': epoch_G1losses[-1],
        }, '/content/drive/MyDrive/StackGAN/checkpoints/StageI_Gen_GAN.pt')

    print(f"Discriminator Epoch Loss: {epoch_D1losses[-1]:.5f}   Generator Epoch Loss: {epoch_G1losses[-1]:.5f}   Real Score: {epoch_Real_Score[-1]:.5f}   Fake Score: {epoch_Fake_Score[-1]:.5f}   Generator Score: {epoch_Generator_Score[-1]:.5f}")

    end_time = time.monotonic()
    epoch_mins, epoch_secs = epoch_time(start_time, end_time)
    torch.cuda.empty_cache()
    torch.cuda.reset_peak_memory_stats()
    torch.cuda.empty_cache()
    gc.collect()
    print("\n\n\n TIME TAKEN FOR THE EPOCH: {} mins and {} seconds".format(epoch_mins, epoch_secs))
    

print("OVERALL TRAINING COMPLETE")

In [None]:
epochs = 600
lrG = 0.0002
lrD = 0.0002

optimizerD2 = torch.optim.Adam(StageII_Dis.parameters(), lr=lrD, betas=(0.5,0.999))
optimizerG2 = torch.optim.Adam(StageII_Gen.parameters(), lr=lrG, betas=(0.5,0.999))

BCEloss = nn.BCELoss()

def KL_loss(mu, logvar):
    # -0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
    KLD_element = mu.pow(2).add_(logvar.exp()).mul_(-1).add_(1).add_(logvar)
    KLD = torch.mean(KLD_element).mul_(-0.5)
    return KLD

def train_StageII_Dis(real_images, wrong_images, stageI_img, text, optimizer):

    optimizer.zero_grad()

    real_images = real_images.to(device)
    text = text.to(device)
    real_pred = StageII_Dis(real_images, text)
    real_targets = torch.ones(real_images.size(0),1)
    real_pred = real_pred.to(device)
    real_targets = real_targets.to(device)
    real_loss = BCEloss(real_pred, real_targets)
    real_score = torch.mean(real_pred).item()

    fake_images, mu, logvar = StageII_Gen(stageI_img, text)

    fake_pred1 = StageII_Dis(fake_images, text)
    fake_targets1 = torch.zeros(fake_images.size(0),1)
    fake_pred1 = fake_pred1.to(device)
    fake_targets1 = fake_targets1.to(device)
    fake_loss1 = BCEloss(fake_pred1, fake_targets1)
    fake_score1 = torch.mean(fake_pred1).item()

    wrong_images = wrong_images.to(device)
    fake_pred2 = StageII_Dis(wrong_images, text)
    fake_targets2 = torch.zeros(wrong_images.size(0),1)
    fake_pred2 = fake_pred2.to(device)
    fake_targets2 = fake_targets2.to(device)
    fake_loss2 = BCEloss(fake_pred2, fake_targets2)
    fake_score2 = torch.mean(fake_pred2).item()

    discriminator_loss = (fake_loss1 + fake_loss2)/2 + real_loss
    discriminator_loss.backward()
    optimizer.step()

    return discriminator_loss.item(), real_score, (fake_score1 + fake_score2)/2


def train_StageII_Gen(gen1_image, text, optimizer):

    optimizer.zero_grad()
    
    gen1_image = gen1_image.to(device)
    text = text.to(device)
    generator_images, mu, logvar = StageII_Gen(gen1_image, text)
    generator_pred = StageII_Dis(generator_images, text)
    generator_targets = torch.ones(batch_size, 1)
    generator_pred = generator_pred.to(device)
    generator_targets = generator_targets.to(device)
    gen_bin_loss = BCEloss(generator_pred, generator_targets)
    generator_score = torch.mean(generator_pred).item()
    kl_loss = KL_loss(mu, logvar)

    generator_loss = gen_bin_loss + 2*kl_loss
    generator_loss.backward()
    optimizer.step()

    return generator_loss.item(), generator_score

def epoch_time(start_time, end_time):
    elapsed_time = end_time - start_time
    elapsed_mins = int(elapsed_time / 60)
    elapsed_secs = int(elapsed_time - (elapsed_mins * 60))
    return elapsed_mins, elapsed_secs

# if os.path.isfile("/content/drive/MyDrive/StackGAN/checkpoints/StageII_Dis_GAN.pt"):
#     checkpointD2 = torch.load('/content/drive/MyDrive/StackGAN/checkpoints/StageII_Dis_GAN.pt')
#     StageII_Dis.load_state_dict(checkpointD2['model_state_dict'])
#     StageII_Dis.to(device)
#     optimizerD2.load_state_dict(checkpointD2['optimizer_state_dict'])
#     epoch = checkpointD2['epoch']
#     best_D2loss = checkpointD2['loss']

# if os.path.isfile("/content/drive/MyDrive/StackGAN/checkpoints/StageII_Gen_GAN.pt"):
#     checkpointG2 = torch.load('/content/drive/MyDrive/StackGAN/checkpoints/StageII_Gen_GAN.pt')
#     StageII_Gen.load_state_dict(checkpointG2['model_state_dict'])
#     StageII_Gen.to(device)
#     optimizerG2.load_state_dict(checkpointG2['optimizer_state_dict'])
#     epoch = checkpointG2['epoch']
#     best_G2loss = checkpointG2['loss']

def save_samples(index1, stageI_img, text, show=True):
    fake_images, a, b = StageII_Gen(stageI_img, text)
    fake_images = fake_images[0:4,:,:,:]
    fake_fname = 'generated-images-{}.png'.format(index1)
    save_image((fake_images), os.path.join("/content/drive/MyDrive/GAN Images/Birds/birds-6", fake_fname), nrow=2)
    if show:
        fig, ax = plt.subplots(figsize=(8, 8))
        ax.set_xticks([]); ax.set_yticks([])
        ax.imshow(make_grid(fake_images.cpu().detach(), nrow=8).permute(1, 2, 0))

''' Stage2 Training'''

for epoch in range(epochs):
    
    start_time = time.monotonic()
    
    print(f"Epoch: {epoch + 1}")
    train_D2Loss_batch = []
    train_G2Loss_batch = []
    train_real_score = []
    train_fake_score = []
    train_generator_score = []

    for idx,(stageI_img, real_images, text) in enumerate(trainloader2):
        '''If you want to use HuggingFace Sentence Transformer'''
        text = list(text)
        embedding = []
        for i in range(len(text)):
            my_file = open(text[i], "r")
            content = my_file.read()
            embedding.append(content)
        emb = sbert_model.encode(embedding)
        emb = torch.from_numpy(emb)

        wrong_images = torch.flip(real_images, [0])
        discriminator_loss, real_score, fake_score = train_StageII_Dis(real_images, wrong_images, stageI_img, emb, optimizerD2)
        generator_loss, generator_score = train_StageII_Gen(stageI_img, emb, optimizerG2)
        train_D2Loss_batch.append(discriminator_loss)
        train_G2Loss_batch.append(generator_loss)
        train_real_score.append(real_score)
        train_fake_score.append(fake_score)
        train_generator_score.append(generator_score)

        if (idx+1)%180 == 0:
            emb_save = emb[0:4, :]
            stageI_img_save = stageI_img[0:4, :, :, :]
            save_samples(epoch+1+372, stageI_img_save, emb_save, show=False)

        torch.cuda.empty_cache()
        torch.cuda.reset_peak_memory_stats()
        torch.cuda.empty_cache()
        gc.collect()

    if (epoch+1)%50 == 0:
        lrG = lrG/2
        lrD = lrD/2
        print(f"Learning Rate Halved: {lrG} {lrD}")

    epoch_D2losses.append(sum(train_D2Loss_batch)/len(trainloader2))
    epoch_G2losses.append(sum(train_G2Loss_batch)/len(trainloader2))
    epoch_Real_Score.append(sum(train_real_score)/len(trainloader2))
    epoch_Fake_Score.append(sum(train_fake_score)/len(trainloader2))
    epoch_Generator_Score.append(sum(train_generator_score)/len(trainloader2))

    torch.save({
        'epoch': epoch,
        'model_state_dict': StageII_Dis.state_dict(),
        'optimizer_state_dict': optimizerD2.state_dict(),
        'loss': epoch_D2losses[-1],
        }, '/content/drive/MyDrive/StackGAN/checkpoints/StageII_Dis_GAN.pt')

    torch.save({
        'epoch': epoch,
        'model_state_dict': StageII_Gen.state_dict(),
        'optimizer_state_dict': optimizerG2.state_dict(),
        'loss': epoch_G2losses[-1],
        }, '/content/drive/MyDrive/StackGAN/checkpoints/StageII_Gen_GAN.pt')

    print(f"Discriminator Epoch Loss: {epoch_D2losses[-1]:.5f}   Generator Epoch Loss: {epoch_G2losses[-1]:.5f}   Real Score: {epoch_Real_Score[-1]:.5f}   Fake Score: {epoch_Fake_Score[-1]:.5f}   Generator Score: {epoch_Generator_Score[-1]:.5f}")

    end_time = time.monotonic()
    epoch_mins, epoch_secs = epoch_time(start_time, end_time)
    torch.cuda.empty_cache()
    torch.cuda.reset_peak_memory_stats()
    torch.cuda.empty_cache()
    gc.collect()
    print("\n\n\n TIME TAKEN FOR THE EPOCH: {} mins and {} seconds".format(epoch_mins, epoch_secs))
    

print("OVERALL TRAINING COMPLETE")

## Testing

In [None]:
lrD = 0.0002
lrG = 0.0002
optimizerD1 = torch.optim.Adam(StageI_Dis.parameters(), lr=lrD, betas=(0.5,0.999))
optimizerG1 = torch.optim.Adam(StageI_Gen.parameters(), lr=lrG, betas=(0.5,0.999))
optimizerD2 = torch.optim.Adam(StageII_Dis.parameters(), lr=lrD, betas=(0.5,0.999))
optimizerG2 = torch.optim.Adam(StageII_Gen.parameters(), lr=lrG, betas=(0.5,0.999))

checkpointG1 = torch.load('/content/drive/MyDrive/StackGAN/checkpoints/StageI_Gen_GAN.pt')
StageI_Gen.load_state_dict(checkpointG1['model_state_dict'])
StageI_Gen.to(device)
optimizerG1.load_state_dict(checkpointG1['optimizer_state_dict'])
epoch = checkpointG1['epoch']
best_G1loss = checkpointG1['loss']
print(best_G1loss)

checkpointG2 = torch.load('/content/drive/MyDrive/StackGAN/checkpoints/StageII_Gen_GAN.pt')
StageII_Gen.load_state_dict(checkpointG2['model_state_dict'])
StageII_Gen.to(device)
optimizerG2.load_state_dict(checkpointG2['optimizer_state_dict'])
epoch = checkpointG2['epoch']
best_G2loss = checkpointG2['loss']
print(best_G2loss)

In [123]:
StageI_Gen.eval()
StageII_Gen.eval()
emb = sbert_model.encode('''The bird is black in colour with white belly''')
emb = torch.from_numpy(emb)
emb = emb.unsqueeze(0)
generator_images, mu, logvar = StageI_Gen(emb)
generator_images2, mu, logvar = StageII_Gen(generator_images, emb)
generator_images2 = generator_images2.squeeze(0)
fake_fname = '0.png'
save_image((generator_images2), os.path.join("/content", fake_fname), nrow=1)