For TPU

In [None]:
!curl https://raw.githubusercontent.com/pytorch/xla/master/contrib/scripts/env-setup.py -o pytorch-xla-env-setup.py
!python pytorch-xla-env-setup.py --apt-packages libomp5 libopenblas-dev

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
import tensorflow as tf
import tensorflow_hub as hub

import torchvision
from torchvision import transforms
from PIL import Image
import tarfile

import os
import shutil
import random
from tqdm import tqdm

def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True
    
seed_everything(42)

In [None]:
images_n_classes = []
with open('../input/images-classes-for-oxford102/images_n_classes.txt') as inf:
    for line in inf:
        line = line.strip().split()
        images_n_classes.append((line[0], line[1]))

For TPU

In [None]:
import torch_xla
import torch_xla.core.xla_model as xm
device = xm.xla_device()
torch.set_default_tensor_type('torch.FloatTensor')

For USE

In [None]:
embs = []
path_to_classes = '../input/cvpr2016/text_c10/'
with tf.device('cpu'):
    module_url = "https://tfhub.dev/google/universal-sentence-encoder/4" 
    model = hub.load(module_url)
    for image_name, class_folder in images_n_classes:
        text_name = image_name[:-4] + '.txt'
        captions = []
        with open(path_to_classes + class_folder + '/' + text_name)as inf:
            for line in inf:
                captions.append(line.strip())
        emb = model(captions).numpy()
        emb = torch.tensor(emb)
        embs.append(emb)

For SkipThougts

In [None]:
embs = []
with open('../input/embs-for-skip/embs_skip_thoughts.txt') as inf:
    for line in tqdm(inf):
        line = line.split(';')[:-1]
        line = [x.split() for x in line]
        line = [[float(x) for x in y] for y in line]
        embs.append(torch.tensor(line))

In [None]:
transform = transforms.Compose([
            transforms.Resize(128),
            transforms.CenterCrop(128),
            transforms.ToTensor()
])

In [None]:
path_to_images = './jpg'

In [None]:
tgz = tarfile.open('../input/flower-dataset-102/102flowers.tgz', 'r:gz')
tgz.extractall()
tgz.close()

In [None]:
class TAC_GAN_Dataset(Dataset):
    def __init__(self, embs, images_classes):
        self.embs = embs
        self.images_classes = images_n_classes
    def __getitem__(self, index):
        emb = self.embs[index]
        image_name, class_folder = self.images_classes[index][0], self.images_classes[index][1]
        
        path = path_to_images + '/' + image_name
        image = Image.open(path)
        image = transform(image)
        
        max_rand = 5
        text_i = random.randint(0, max_rand-1)
        text = emb[text_i]
        one_hot_classes = torch.zeros(102)
        one_hot_classes[int(class_folder[6:]) - 1] = 1.0
        return image, text, one_hot_classes
    
    def __len__(self):
        return len(self.images_classes)

In [None]:
dataset = TAC_GAN_Dataset(embs, images_n_classes)
dataloader = DataLoader(dataset, batch_size=128, shuffle=True, num_workers=2)

In [None]:
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)

In [None]:
class Generator(nn.Module):
    def __init__(self, noise_size = 100, embed_size = 4800, ner_fc1 = 256, ner_fc2 = 64, gen_conv_ch = 64):
        super(Generator, self).__init__()
        self.noise_shape = noise_size
        self.embed_shape = embed_size
        self.ner_fc1 = ner_fc1
        self.ner_fc2 = ner_fc2

        self.FC1 = nn.Linear(self.embed_shape, self.ner_fc1)
        self.emb_leak = nn.LeakyReLU()
        self.FC2 = nn.Linear(self.noise_shape + self.ner_fc1, 8*8*8*self.ner_fc2)
        self.emb_bn = nn.BatchNorm2d(8*self.ner_fc2)
        self.emb_rl = nn.ReLU()

        self.net = nn.Sequential(
            nn.ConvTranspose2d(8*self.ner_fc2, gen_conv_ch*4, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(gen_conv_ch*4),
            nn.ReLU(),
            nn.ConvTranspose2d(gen_conv_ch*4, gen_conv_ch*2, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(gen_conv_ch*2),
            nn.ReLU(),
            nn.ConvTranspose2d(gen_conv_ch*2, gen_conv_ch, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(gen_conv_ch),
            nn.ReLU(),
            nn.ConvTranspose2d(gen_conv_ch, 3, kernel_size=4, stride=2, padding=1),
            nn.Tanh()
        )
        
        self.intialize_weights()

    def forward(self, noise, embed):
        batch_size = noise.shape[0]
        latent_rep = self.emb_leak(self.FC1(embed))
        x = torch.cat((noise, latent_rep), 1)
        repr = self.FC2(x)
        repr = self.emb_rl(self.emb_bn(repr.reshape((batch_size, 8*self.ner_fc2, 8, 8))))

        img_f = self.net(repr)
        return (img_f / 2.0) + 0.5
    
    def intialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
                m.weight.data.normal_(0, 0.02)
                m.bias.data.zero_()
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()


In [None]:
class Discriminator(nn.Module):
    def __init__(self, embed_size = 4800, ner_fc1 = 256, out_net = 384, gen_conv_ch = 64):
        super(Discriminator, self).__init__()
        self.embed_shape = embed_size
        self.ner_fc1 = ner_fc1
        self.out_net = out_net

        self.FC1 = nn.Linear(self.embed_shape, self.ner_fc1)

        self.net = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2),
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2),
            nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2),
            nn.Conv2d(256, self.out_net, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(self.out_net),
            nn.LeakyReLU(0.2)
        )

        self.conv_cat = nn.Conv2d(self.out_net + self.ner_fc1, 512, kernel_size=1, stride=1)
        self.last_fc = nn.Linear(8*8*512, 64)
        self.FC_real_fake = nn.Linear(64, 1)
        self.FC_class = nn.Linear(64, 102)
        self.leak = nn.LeakyReLU()
        self.sig = nn.Sigmoid()
        self.intialize_weights()
        
    def forward(self, img, emb):
        batch_size = img.shape[0]
        x = self.FC1(emb).reshape(batch_size, self.ner_fc1, -1).unsqueeze(2)
        latent_repr = x.repeat(1, 1, 8, 8)

        conved_img = self.net(img)

        repr_cat = torch.cat((conved_img, latent_repr), 1)
        to_fc = self.leak(self.conv_cat(repr_cat)).reshape((batch_size, 8*8*512))

        to_fc = self.leak(self.last_fc(to_fc))

        real_fake_dist = self.FC_real_fake(to_fc)
        class_dist = self.sig(self.FC_class(to_fc))

        return real_fake_dist, class_dist
    
    def intialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
                m.weight.data.normal_(0, 0.02)
                m.bias.data.zero_()
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()

In [None]:
device = torch.device('cuda')
criterion_class = nn.BCELoss()
criterion_real_fake = nn.BCEWithLogitsLoss()

netG = Generator()
netG.to(device)
#netG.apply(weights_init)

netD = Discriminator()
netD.to(device)
#netD.apply(weights_init)

optG = torch.optim.Adam(netG.parameters(), lr=0.0002, betas=(0.5, 0.999))
optD = torch.optim.Adam(netD.parameters(), lr=0.0002, betas=(0.5, 0.999))

In [None]:
N_EPOCHS = 300

In [None]:
from IPython.display import clear_output

In [None]:
G_losses = []
D_losses = []
for i in range(N_EPOCHS):
    netG.train()
    netD.train()
    for i, (images, texts, labels) in enumerate(dataloader):
        netG.train()
        netD.train()
        batch_size = images.shape[0]
        images = images.to(device)
        texts = texts.to(device)
        labels = labels.to(device)
        
        real_label = torch.ones((batch_size, 1)).to(device)
        fake_label = torch.zeros((batch_size, 1)).to(device)
        #create noise for generator and normalize
        noise = torch.randn(batch_size, 100)
        noise.data.normal_(0,1)
        noise = noise.to(device)
        #create permutations for different case of learning
        
        rand_1 = torch.randperm(batch_size)
        rand_1 = rand_1.to(device)
        rand_2 = torch.randperm(batch_size)
        rand_2 = rand_2.to(device)
        rand_3 = torch.randperm(batch_size)
        rand_3 = rand_3.to(device)
        rand_4 = torch.randperm(batch_size)
        rand_4 = rand_4.to(device)
        
        ############### Train D ###################
        netD.zero_grad()
        
        #train on real images, real classes, real captions
        outS_real, outC_real, _ = netD(images, texts)
        lossS_real = criterion_real_fake(outS_real, real_label)
        lossC_real = criterion_class(outC_real, labels) 
        
        #train on wrong images, wrong classes, real captions
        outS_wrong, outC_wrong, _ = netD(images[rand_1], texts[rand_2])
        lossS_wrong = criterion_real_fake(outS_wrong, fake_label)
        lossC_wrong = criterion_class(outC_wrong, labels[rand_1])
        
        #train on fake images, real classes, real captions
        fake_images = netG(noise, texts)
        outS_fake, outC_fake, _ = netD(fake_images.detach(), texts[rand_3])
        lossS_fake = criterion_real_fake(outS_fake, fake_label)
        lossC_fake = criterion_class(outC_fake, labels[rand_3])
        
        #sum all losses
        loss_D = (lossS_real + lossS_wrong + lossS_fake) + (lossC_real + lossC_wrong + lossC_fake)
        loss_D.backward()
        optD.step()
        
        ############### Train G ###################
        netG.zero_grad()
        noise.data.normal_(0,1)
        fake_images = netG(noise, texts[rand_4])
        S_fake, C_fake, _ = netD(fake_images, texts[rand_4])
        lossS_G = criterion_real_fake(S_fake, real_label)
        lossC_G = criterion_class(C_fake, labels[rand_4])
        
        loss_G = lossS_G + lossC_G
        loss_G.backward()
        optG.step()
        ###########################################
        if i % 5 == 0:
            G_losses.append(loss_G.detach().cpu())
            D_losses.append(loss_D.detach().cpu())
            torch.save(netG.state_dict(), './netG.pt')
            torch.save(netD.state_dict(), './netD.pt')
            clear_output(True)
            plt.figure(figsize=(10,5))
            plt.title("Generator and DiscrimSinator Loss During Training")
            plt.plot(G_losses,label="G")
            plt.plot(D_losses,label="D")
            plt.xlabel("iterations")
            plt.ylabel("Loss")
            plt.legend()
            plt.show()
            netG.eval()
            with tf.device('cpu'):
                with torch.no_grad():
                    plt.imshow(fake_images[0].detach().cpu().permute(1, 2, 0))

If using USE

In [None]:
netG.eval()
with tf.device('cpu'):
    with torch.no_grad():
        texts = ["this flower is yellow in color, with petals that are very skinny"]
        emb = torch.tensor(model(text).numpy(), device=device)
        emb = emb.view(1, -1)
        images = netG(torch.randn(1, 100).to(device), emb)
        images = images.squeeze().detach().cpu()
        plt.imshow(images.permute(1, 2, 0))