# StackGAN CUB200 DATASETS


In [None]:
import torch
from torchvision import transforms
import os
import numpy as np
import pandas as pd
from PIL import Image
import seaborn as sns
import os
import torch.nn as nn
import torch.optim as optim
import matplotlib
matplotlib.use('agg')
import matplotlib.pyplot as plt
import torchvision.utils as vutils
from torch.utils.data import DataLoader
import warnings
import torch.utils.data as data
import pickle
from torchvision.utils import make_grid
warnings.filterwarnings("ignore") 

In [None]:
imgSize = 64


In [None]:
class LoadDataset(data.Dataset):

    def __init__(self, dataDir, split='train', imgSize= imgSize, transform=None):

        super(LoadDataset,self).__init__()

        self.transform = transform
        # Normalize the image
        self.norm = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
        
        self.imgSize = imgSize
        self.dataDir = dataDir
        self.filenames, self.caps = self.load_info(dataDir, split)
        self.bbox = self.load_bbox()
        self.classes = self.load_class(dataDir, split)
              
    def load_info(self, dataDir, split):
        filenames = self.load_filenames(dataDir, split)
        captionFile = os.path.join(dataDir, 'birds', split, 'char-CNN-RNN-embeddings.pickle')
        with open (captionFile, 'rb') as f:
            captions = np.array(captions)
        
        return filenames, captions
        
    def load_filenames(self, dataDir, split):
        path = os.path.join(dataDir, 'birds', split, 'filenames.pickle')
        with open(path, 'rb') as f:
            filenames = pickle.load(f,encoding='latin1')
        return filenames
    
    def load_bbox(self):
        path = os.path.join(self.dataDir,'birds','CUB_200_2011', 'bounding_boxes.txt')
        bbox_data = pd.read_csv(path, delim_whitespace=True, header=None).astype(int)

        filepath = os.path.join(self.dataDir,'birds', 'CUB_200_2011','images.txt')
        df_filenames = pd.read_csv(filepath, delim_whitespace=True, header=None)
        filenames = sorted( list(df_filenames[1]))
        fname_bbox_dict = {x[:-4]:[] for x in filenames} 
        for i in range(len(filenames)):
            data = list(bbox_data.iloc[1][1:])
            k = filenames[i][:-4]
            fname_bbox_dict[k] = data
        return fname_bbox_dict
    
    def load_class(self, dataDir, split):
        path = os.path.join(dataDir, 'birds', split, 'class_info.pickle')
        if os.path.isfile(path):
            with open(path, 'rb') as f:
                classId = pickle.load(f,encoding='latin1')
        else:
            classId = np.arange(len(self.filenames))
        return classId

    def get_img(self, img_path, bbox=None):
        img = Image.open(img_path).convert('RGB')
        width, height = img.size
        if bbox is not None:
            R = int(np.maximum(bbox[2], bbox[3]) * 0.75)
            center_x = int((2 * bbox[0] + bbox[2]) / 2)
            center_y = int((2 * bbox[1] + bbox[3]) / 2)
            y1 = np.maximum(0, center_y - R)
            y2 = np.minimum(height, center_y + R)
            x1 = np.maximum(0, center_x - R)
            x2 = np.minimum(width, center_x + R)
            img = img.crop([x1, y1, x2, y2])
        load_size = int(self.imgSize * 76 / 64)
        img = img.resize((load_size, load_size), Image.BILINEAR)
        if self.transform is not None:
            img = self.transform(img)
        return img

    
    def __getitem__(self, idx):
        key = self.filenames[idx]
        
        if self.bbox is not None:
            bbox = self.bbox[key]
        else:
            bbox = None
        emb = self.caps[idx, :, :]
        imagePath = os.path.join(self.dataDir,'birds', 'CUB_200_2011', 'images',self.filenames[idx]+'.jpg')
        image = self.get_img(imagePath, bbox)
        
        # random select a sentence
        sample = np.random.randint(0, emb.shape[0]-1)
        cap = emb[sample, :]
        return image, cap
    
    def __len__(self):
        return len(self.filenames)

In [None]:
def initateWeights(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)

# using APPLE M2 Chip       
CUDA = False
cond_dim = 128
df_dim = 128
gf_dim = 128
z_dim = 100
emb_dim = 1024

def Conv_k3(in_p, out_p, stride=1):
    return nn.Conv2d(in_p, out_p, kernel_size=3, stride=stride, padding=1, bias=False)

class UpperBlock(nn.Module):
    def __init__(self, inp, outp):
        super(UpperBlock, self).__init__()
        self.up = nn.Upsample(scale_factor=2, mode='nearest')
        self.conv = Conv_k3(inp, outp)
        self.batch = nn.BatchNorm2d(outp)
        self.relu = nn.ReLU(True)
        
    def forward(self, x):
        o = self.up(x)
        o = self.relu(self.conv(o))
        o = self.batch(o)
        return o

class DiscriminatorOutput(nn.Module):
    def __init__(self, have_cond = True):
        super(DiscriminatorOutput, self).__init__()
        self.have_cond = have_cond
        self.classifier = nn.Sequential(
            nn.Conv2d(in_channels=1024, out_channels=1, kernel_size=4, stride=4),
            nn.Sigmoid()
        )
        if have_cond:
            cond_part = nn.Sequential(
                Conv_k3(in_p=1024+128, out_p=1024),
                nn.BatchNorm2d(1024),
                nn.LeakyReLU(0.2, inplace=True),
            )
            self.classifier = torch.nn.Sequential(*(list(cond_part)+list(self.classifier)))
        print(self.classifier)
            
    def forward(self, encoded_image, encoded_cond=None):
        if self.have_cond and encoded_cond is not None:
            cond = encoded_cond.view(-1, 128 , 1, 1)
            cond = cond.repeat(1, 1, 4, 4)
            image_with_cond = torch.cat((encoded_image, cond), 1)
        else:
            image_with_cond = encoded_image
        return self.classifier(image_with_cond).view(-1)

class CondArgumentModel(nn.Module):
    def __init__(self):
        super(CondArgumentModel,self).__init__()
        self.fc = nn.Linear(in_features=emb_dim, out_features=cond_dim*2)
        self.relu = nn.ReLU(True)
        
    def convert(self, embed):
        x = self.relu(self.fc(embed))
        mean, sigma = x[:, : cond_dim], x[:, cond_dim:]
        return mean, sigma
    
    def forward(self, x):
        mean, sigma = self.convert(x)
        diag = torch.exp(sigma*0.5)
        if CUDA:
            normal_dis = (torch.FloatTensor(diag.size()).normal_()).cuda()
        else:
            normal_dis = (torch.FloatTensor(diag.size()).normal_())
        condition = (diag*normal_dis)+mean
        return condition, mean, sigma

In [None]:
class Generate_Stage_1(nn.Module):
    def __init__(self):
        super(Generate_Stage_1, self).__init__()
        self.CA = CondArgumentModel()
        self.fc = nn.Sequential(
            nn.Linear(in_features=228, out_features=128*8*4*4, bias=False),
            nn.BatchNorm1d(128*8*4*4),
            nn.ReLU(inplace=True)
        )
        self.img = nn.Sequential(
            UpperBlock(128*8,64*8),
            UpperBlock(64*8,32*8),
            UpperBlock(32*8,16*8),
            UpperBlock(16*8,8*8),
            Conv_k3(8*8, 3),
            nn.Tanh()
        )
        
    def forward(self, noise, emb):
        cond, mean, sigma = self.CA(emb)
        cond = cond.view(noise.size(0), cond_dim, 1, 1)
        x = torch.cat((noise, cond),1)
        x = x.view(-1, 228)
        o = self.fc(x)
        h_code = o.view(-1, 128*8, 4, 4)
        fake_img = self.img(h_code)
        return fake_img, mean, sigma

### Building  Stage - 1 Discriminator

In [None]:
class DiscriminatorStage1(nn.Module):
    def __init__(self):
        super(DiscriminatorStage1, self).__init__()
        self.encoder = nn.Sequential(
            #c alucalation output size = [(input_size −Kernal +2Padding )/Stride ]+1
            # input is image 3 x 64 x 64  
            nn.Conv2d(in_channels=3, out_channels=128, kernel_size=4, stride=2, padding=1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),# => 128 x 32 x 32 
            
            nn.Conv2d(in_channels=128, out_channels=256, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),# => 256 x 16 x 16
            
            nn.Conv2d(in_channels=256, out_channels=512, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),# => 512 x 8 x 8
            
            nn.Conv2d(in_channels=512, out_channels=1024, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(1024),
            nn.LeakyReLU(0.2, inplace=True)# => 1024 x 4 x 4
        )
        self.condition_classifier = DiscriminatorOutput()
        self.uncondition_classifier = None
        
    def forward(self, image):
        return self.encoder(image)

In [None]:
def KL_loss(mean, sigma):
        temp = 1+sigma+((-1)*((mean*mean)+sigma))
        return torch.mean(temp)*(-0.5)

def CalculateGeneratorLoss(net_dim, fake_images, y_labels, condition):
    criterion = nn.BCELoss()
    cond = cond.detach()
    fake_f = net_dim(fake_images)

    fake_cond_ouput = net_dim.condition_classifier(fake_f, condition)
    error_fakes = criterion(fake_cond_ouput, y_labels)
    if net_dim.uncondition_classifier is not None:
        fake_uncond_output = net_dim.uncondition_classifier(fake_f)
        uncond_errD_fake = criterion(fake_uncond_output, y_labels)
        errD_fake += uncond_errD_fake
    return error_fakes

def calculate_discriminator_loss(net_dis, real_images, fake_images, y_real, y_fake, condition):
    criterion = nn.BCELoss()
    batch_size = real_images.size(0)
    cond = cond.detach()
    fake = fake_images.detach()

    real_img_feature = net_dis(real_images)
    fake_img_feature = net_dis(fake)

    real_output = net_dis.condition_classifier(real_img_feature, cond)
    errD_real  = criterion(real_output, y_real)
    wrong_output = net_dis.condition_classifier(real_img_feature[:(batch_size-1)], cond[1:])
    errD_wrong = criterion(wrong_output, y_fake[1:])

    fake_output = net_dis.condition_classifier(fake_img_feature, cond)
    errD_fake= criterion(fake_output, y_fake)

    if net_dis.uncondition_classifier is not None:
        real_uncond_output = net_dis.uncondition_classifier(real_img_feature)
        errD_real_uncond = criterion(real_uncond_output, y_real)

        fake_uncond_output = net_dis.uncondition_classifier(fake_img_feature)
        errD_fake_uncond = criterion(fake_uncond_output, y_fake)

        error_detected = (errD_real+errD_real_uncond)/2. + (errD_fake+errD_wrong+errD_fake_uncond)/3.
        errD_real =  (errD_real+errD_real_uncond)/2
        errD_fake = (errD_fake+errD_fake_uncond)/2.
    else:
        errD = errD_real + (errD_fake+errD_wrong)*0.5
    return errD, errD_real.item(), errD_wrong.item(), errD_fake.item()

### Training Stage - 1 GAN

In [None]:
#https://github.com/donnaphat-ut/StackGAN
    
def main():
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    batch_size = 64
    transform = transforms.Compose([
                transforms.RandomCrop(64),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
    
    dataset = LoadDataset(dataDir = 'Datasets/', split='train', transform=transform)
    loaded_datasets = DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=True)
    netG = Generate_Stage_1().to(device)
    netG.apply(initateWeights)
    netD = DiscriminatorStage1().to(device)
    netD.apply(initateWeights)
    lr = 0.0002
    optD = optim.Adam(netD.parameters(), lr=lr, betas=(0.5, 0.999))
    optG = optim.Adam(netG.parameters(), lr=lr, betas=(0.5, 0.999))

    fixed_noise = torch.rand(batch_size, z_dim, 1, 1).to(device)

    real_labels = (torch.FloatTensor(batch_size).fill_(1)).to(device)
    fake_labels = (torch.FloatTensor(batch_size).fill_(0)).to(device)
    
    num_epoch = 600
    iters = 0
    for epoch in range(num_epoch):
        if epoch % 100 == 0 and epoch > 0:
            lr = lr*0.5
            for param_group in optG.param_groups:
                param_group['lr'] = lr
            for param_group in optD.param_groups:
                param_group['lr'] = lr
        for i, data in enumerate(loaded_datasets,0):
            real_imgs, encoded_caps = data
            real_imgs = real_imgs.to(device)
            encoded_caps = encoded_caps.to(device)

            ##update discriminator
            netD.zero_grad()
            # generate fake image
            noise = torch.rand(batch_size, z_dim, 1, 1).to(device)
            fake_imgs, m, s = netG(noise, encoded_caps)
            errD, errD_real, errD_wrong, errD_fake = calculate_discriminator_loss(netD, real_imgs, fake_imgs, real_labels, fake_labels, m)
            errD.backward()
            optD.step()

            ##update generator
            netG.zero_grad()
            errG = CalculateGeneratorLoss(netD, fake_imgs, real_labels, m)
            errG += errG + KL_loss(m,s)
            errG.backward()
            optG.step()
            
            iters+=1

            if i%50 == 0:
                 print('[%d/%d][%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f\tLoss_D_R: %.4f\tLoss_D_W: %.4f\tLoss_D_F %.4f'
                      % (epoch, num_epoch, i, len(loaded_datasets),
                         errD.item(), errG.item(), errD_real, errD_wrong, errD_fake))
        if epoch%50==0:
            with torch.no_grad():
                fake, _, _  = netG(fixed_noise, encoded_caps)
                # fig = plt.figure(figsize=(10,10))
                grid = make_grid(fake.detach().cpu(), nrow=8, normalize=True).permute(1,2,0).numpy()
                plt.imshow(grid)
                # fig.savefig('results1/epch-{}.png'.format(epoch))

    
if __name__ == '__main__':
    main()

In [None]:
class ResBlock(nn.Module):
    def __init__(self, plane):
        super(ResBlock, self).__init__()
        self.block = nn.Sequential(
            Conv_k3(plane, plane),
            nn.BatchNorm2d(plane),
            nn.ReLU(True),
            Conv_k3(plane, plane),
            nn.BatchNorm2d(plane)
        )
        self.relu = nn.ReLU(True)
        
    def forward(self, x):
        tmp = x
        o = self.block(x)
        o = o + tmp
        return self.relu(o)
    
class G_Stage2(nn.Module):
    def __init__(self, G_Stage1):
        super(G_Stage2, self).__init__()
        self.G1 = G_Stage1
        self.CA = CondAugment_Model()
        for p in self.G1.parameters():
            p.requires_grad = False
        self.encoder = nn.Sequential(
            Conv_k3(3, 128),
            nn.ReLU(True),
            nn.Conv2d(128, 128 * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(128 * 2),
            nn.ReLU(True),
            nn.Conv2d(128 * 2, 128 * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(128 * 4),
            nn.ReLU(True))
        self.combine = nn.Sequential(
            Conv_k3(640, 512),
            nn.BatchNorm2d(512),
            nn.ReLU(True)
        )
        self.residual = nn.Sequential(
            ResBlock(512),
            ResBlock(512),
            ResBlock(512),
            ResBlock(512)
        )
        self.decoder = nn.Sequential(
            Upblock(512,256),
            Upblock(256,128),
            Upblock(128,64),
            Upblock(64,32),
            Conv_k3(32,3),
            nn.Tanh()
        )
        
    def forward(self, noise, emb):
        init_image, _, _ = self.G1(noise, emb)
        encoded = self.encoder(init_image)
        
        cond, m, s = self.CA(emb)
        cond = cond.view(-1, 128, 1, 1)
        cond = cond.repeat(1, 1, 16, 16)
        
        encoded_cond = torch.cat([encoded, cond],1)
        img_feature = self.combine(encoded_cond)
        img_feature = self.residual(img_feature)
        img = self.decoder(img_feature)
        
        return init_image, img, m, s

In [None]:
class D_Stage2(nn.Module):
    def __init__(self):
        super(D_Stage2, self).__init__()
        self.img_encoder = nn.Sequential(
            # start 3 x 256 x 256
            nn.Conv2d(3, 128, 4, 2, 1, bias=False), #=> 128 x 128 x 128
            nn.LeakyReLU(0.2, True),
            
            nn.Conv2d(128, 256, 4, 2, 1, bias=False), #=> 256 x 64 x 64
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, True),
            
            nn.Conv2d(256, 512, 4, 2, 1, bias=False), #=> 512 x 32 x 32
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, True),
            
            nn.Conv2d(512, 1024, 4, 2, 1, bias=False), #=> 1024 x 16 x 16
            nn.BatchNorm2d(1024),
            nn.LeakyReLU(0.2, True),
            
            nn.Conv2d(1024, 2048, 4, 2, 1, bias=False), #=> 2048 x 8 x 8
            nn.BatchNorm2d(2048),
            nn.LeakyReLU(0.2, True),
            
            nn.Conv2d(2048, 4096, 4, 2, 1, bias=False), #=> 4096 x 4 x 4
            nn.BatchNorm2d(4096),
            nn.LeakyReLU(0.2, True),
            
            Conv_k3(4096, 2048), # 2048 x 4 x 4
            nn.BatchNorm2d(2048),
            nn.LeakyReLU(0.2, True),
            Conv_k3(2048, 1024), # 1024 x 4 x 4
            nn.BatchNorm2d(1024),
            nn.LeakyReLU(0.2, True)
        )
        
        self.condition_classifier = DiscriminatorOutput()
        self.uncondition_classifier = DiscriminatorOutput(have_cond=False)
        
    def forward(self, img):
        img_feature = self.img_encoder(img)
        return img_feature

In [None]:
def main():
    device = torch.device('cuda:0')
    # load dataset with size 256x256
    batch_size = 55
    transform = transforms.Compose([
                transforms.RandomCrop(256),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
    
    
    dataset = LoadDataset(dataDir = 'Datasets/', split='train', transform=transform, imgSize=256)
    tr_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=True)
    #load model Stage-I generator and put it into Stage-II generator
    G1 = Generate_Stage_1()
    G1.eval()
    netG = G_Stage2(G1).to(device)
    netG.apply(initateWeights)
    netD = D_Stage2().to(device)
    netD.apply(initateWeights)

    lr = 0.0002
    optD = optim.Adam(netD.parameters(), lr=lr, betas=(0.5, 0.999))
    # remove the parameter from Stage-I generator
    netG_param = []
    for p in netG.parameters():
        if p.requires_grad:
            netG_param.append(p)
    optG = optim.Adam(netG_param, lr=lr, betas=(0.5, 0.999))

    fixed_noise = torch.rand(batch_size, 100, 1, 1).to(device)

    real_labels = (torch.FloatTensor(batch_size).fill_(1)).to(device)
    fake_labels = (torch.FloatTensor(batch_size).fill_(0)).to(device)
    
    num_epoch = 45
    iters = 0
    for epoch in range(num_epoch+1):
        if epoch % 100 == 0 and epoch > 0:
            lr = lr*0.5
            for param_group in optG.param_groups:
                param_group['lr'] = lr
            for param_group in optD.param_groups:
                param_group['lr'] = lr
        for i, data in enumerate(tr_loader,0):
            real_imgs, encoded_caps = data
            real_imgs = real_imgs.to(device)
            encoded_caps = encoded_caps.to(device)

            ##update discriminator
            netD.zero_grad()
            # generate fake image
            noise = torch.rand(batch_size, 100, 1, 1).to(device)
            init_img ,fake_imgs, m, s = netG(noise, encoded_caps)
            errD, errD_real, errD_wrong, errD_fake = calculate_discriminator_loss(netD, real_imgs, fake_imgs, real_labels, fake_labels, m)
            errD.backward()
            optD.step()

            ##update generator
            netG.zero_grad()
            errG = CalculateGeneratorLoss(netD, fake_imgs, real_labels, m)
            errG += errG + KL_loss(m,s)
            errG.backward()
            optG.step()     
            
            if i%50 == 0:
                 print('[%d/%d][%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f\tLoss_D_R: %.4f\tLoss_D_W: %.4f\tLoss_D_F %.4f'
                      % (epoch, num_epoch, i, len(tr_loader),
                         errD.item(), errG.item(), errD_real, errD_wrong, errD_fake))
        if epoch%10==0:
            with torch.no_grad():
                _, fake, _, _  = netG(fixed_noise, encoded_caps)
                fig = plt.figure(figsize=(10,10))
                grid = make_grid(fake.detach().cpu(), nrow=8, normalize=True).permute(1,2,0).numpy()
                plt.imshow(grid)
                
        if epoch%15==0:
            torch.save(netG.state_dict(), 'results2/netG2_epoch_{}.pth'.format(epoch))
    torch.save(netD.state_dict(), 'results2/netD2_epoch_last.pth')

if __name__ == '__main__':
    main()

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
G1 = Generate_Stage_1().to(device)
G1checkpoint = torch.load('results1/netG_epoch_600.pth')
G1.load_state_dict(G1checkpoint)
G1.eval()

netG = G_Stage2(G1).to(device)
G2checkpoint = torch.load('results2/netG2_epoch_450.pth')
netG.load_state_dict(G2checkpoint)
netG.eval()

In [None]:
transform = transforms.Compose([
                transforms.RandomCrop(256),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
    

    
dataset = LoadDataset(dataDir = 'Datasets/', split='test', transform=transform, imgSize=256)
te_loader = DataLoader(dataset, batch_size= 55, shuffle=True, drop_last=True)

In [None]:
%matplotlib inline
te_loader = iter(te_loader)
imgTensor, captions = next(te_loader)

batch_size = 55
 
with torch.no_grad():# Generate image grid
    grid1 = make_grid(imgTensor[:10], padding = 4, nrow=10)
    grid1 = grid1.permute(1, 2, 0)
    plt.figure(figsize=(20, 8), dpi=300)
    plt.imshow(grid1)

with torch.no_grad():
    noise = torch.rand(batch_size, z_dim, 1, 1).to(device)
    cap = captions.to(device)
    init_imgs, fake_imgs, m, s = netG(noise, cap)

    grid2 = make_grid(fake_imgs.detach().cpu()[:10], padding = 4, nrow=10).permute(1, 2, 0).numpy()

    plt.figure(figsize=(20, 8), dpi=300)
    plt.axis('off')
    plt.imshow(grid2)