In [1]:
from models.sentence2vec import Sentence2Vec
import torch
import torchvision
import numpy as np
import os
from matplotlib import pyplot as plt
import matplotlib
from skimage.color import lab2rgb, rgb2lab
import pickle

embed_model = Sentence2Vec()

matplotlib.rcParams[u'font.sans-serif'] = ['simhei']

In [2]:
class CA_NET(torch.nn.Module):
    def __init__(self):
        super(CA_NET, self).__init__()
        self.t_dim = 150
        self.c_dim = 150
        self.fc = torch.nn.Linear(self.t_dim, self.c_dim * 2, bias=True)
        self.relu = torch.nn.ReLU()

    def encode(self, text_embedding):
#         x = self.relu(self.fc(text_embedding))
        x = self.fc(text_embedding)
        x = self.relu(x)
        mu = x[:, :, :self.c_dim]
        logvar = x[:, :, self.c_dim:]
        return mu, logvar

    def reparametrize(self, mu, logvar):
        std = logvar.mul(0.5).exp_()
        eps = torch.cuda.FloatTensor(std.size()).normal_(0.0, 1)
        return eps * std + mu

    def forward(self, text_embedding):
        
        mu, logvar = self.encode(text_embedding)
        c_code = self.reparametrize(mu, logvar)
        return c_code, mu, logvar

    
class EncoderRNN(torch.nn.Module):
    def __init__(self, hidden_size, n_layers, dropout_p):
        super(EncoderRNN, self).__init__()

        self.hidden_size = hidden_size
        self.n_layers = n_layers

#         self.embed = Embed(input_size, 300, W_emb, True)
#         self.embed = embed_model
        # 768 is the size of embedding result
        self.gru = torch.nn.GRU(768, hidden_size, n_layers, dropout=dropout_p)
        self.ca_net = CA_NET()

    def forward(self, word_inputs, hidden):
#         embedded = embed_model.embed(word_inputs).transpose(0,1)
        word_inputs = word_inputs.transpose(0, 1).to(device)
        hidden = hidden.to(device)
        output, hidden = self.gru(word_inputs, hidden)
        c_code, mu, logvar = self.ca_net(output)

        return c_code, hidden, mu, logvar

    def init_hidden(self,batch_size):
        hidden = torch.zeros(self.n_layers, batch_size, self.hidden_size)

        return hidden


class Attn(torch.nn.Module):
    def __init__(self, hidden_size, max_length=768):
        super(Attn, self).__init__()
        self.hidden_size = hidden_size
        self.softmax = torch.nn.Softmax(dim=0)
        self.attn_e = torch.nn.Linear(self.hidden_size, self.hidden_size)
        self.attn_h = torch.nn.Linear(self.hidden_size, self.hidden_size)
        self.attn_energy = torch.nn.Linear(self.hidden_size, 1)
        self.sigmoid = torch.nn.Sigmoid()

    def forward(self, hidden, encoder_outputs, each_size):
        seq_len = encoder_outputs.size(0)
        batch_size = encoder_outputs.size(1)
        attn_energies = torch.zeros(seq_len,batch_size,1).cuda()

        for i in range(seq_len):
            attn_energies[i] = self.score(hidden, encoder_outputs[i])

        attn_energies = self.softmax(attn_energies) # (seq_len, batch_size, 1)
        return attn_energies.permute(1,2,0)         # (batch_size, 1, seq_len)

    def score(self, hidden, encoder_output):
        encoder_ = self.attn_e(encoder_output)  # encoder output (batch_size, hidden_size)
        hidden_ = self.attn_h(hidden)           # hidden (batch_size, hidden_size)
        energy = self.attn_energy(self.sigmoid(encoder_ + hidden_))

        return energy

    
class AttnDecoderRNN(torch.nn.Module):
    def __init__(self, hidden_size, n_layers=1, dropout_p=0.1):
        super(AttnDecoderRNN, self).__init__()
        self.attn = Attn(hidden_size)
        self.hidden_size = hidden_size
        self.n_layers = n_layers
        self.dropout_p = dropout_p
        self.palette_dim = 3

        self.gru = torch.nn.GRUCell(self.hidden_size + self.palette_dim, hidden_size)

        self.out = torch.nn.Sequential(
                        torch.nn.Linear(hidden_size, hidden_size),
                        torch.nn.ReLU(inplace=True),
                        torch.nn.BatchNorm1d(hidden_size),
                        torch.nn.Linear(hidden_size,self.palette_dim)
                   )
    def forward(self, last_palette, last_decoder_hidden, encoder_outputs, each_input_size, i):

        # Compute context vector.
        if i == 0:
            context = torch.mean(encoder_outputs, dim=0, keepdim=True)
            # Compute gru output.
            gru_input = torch.cat((last_palette, context.squeeze(0)), 1)    
            
            gru_hidden = self.gru(gru_input, last_decoder_hidden)

            # Generate palette color.
            #palette = self.out(gru_hidden.squeeze(0))
            palette = self.out(gru_hidden.squeeze(1))
            return palette, context.unsqueeze(0), gru_hidden
        else:
            attn_weights = self.attn(last_decoder_hidden.squeeze(0), encoder_outputs, each_input_size)
            context = torch.bmm(attn_weights, encoder_outputs.transpose(0,1))
            
            
            # Compute gru output.
            gru_input = torch.cat((last_palette, context.squeeze(1)), 1)
            
            gru_hidden = self.gru(gru_input, last_decoder_hidden)

            # Generate palette color.
            #palette = self.out(gru_hidden.squeeze(0))
            palette = self.out(gru_hidden.squeeze(1))
            return palette, context.unsqueeze(0), gru_hidden
    
    
class DecoderRNN(torch.nn.Module):
    def __init__(self, hidden_size, n_layers=1, dropout_p=0.1):
        super(DecoderRNN, self).__init__()
        self.hidden_size = hidden_size
        self.n_layers = n_layers
        self.dropout_p = dropout_p
        self.palette_dim = 3
        
        self.encoder = torch.nn.Linear()

        self.gru = torch.nn.GRUCell(self.hidden_size + self.palette_dim, hidden_size)

        self.out = torch.nn.Sequential(
                        torch.nn.Linear(hidden_size, hidden_size),
                        torch.nn.ReLU(inplace=True),
                        torch.nn.BatchNorm1d(hidden_size),
                        torch.nn.Linear(hidden_size,self.palette_dim)
        )

    def forward(self, last_palette, last_decoder_hidden, encoder_output):
        context = encoder_output

        # Compute gru output.
        gru_input = torch.cat((last_palette, context.squeeze(1)), 1)
        gru_hidden = self.gru(gru_input, last_decoder_hidden)

        # Generate palette color.
        # palette = self.out(gru_hidden.squeeze(0))
        palette = self.out(gru_hidden.squeeze(1))
        return palette, context.unsqueeze(0), gru_hidden



class Discriminator(torch.nn.Module):
    def __init__(self, color_size=15, hidden_dim=150):
        super(Discriminator, self).__init__()
        curr_dim = color_size + hidden_dim

        layers = []
        layers.append(torch.nn.Linear(curr_dim, int(curr_dim/2)))
        layers.append(torch.nn.ReLU(inplace=True))
        layers.append(torch.nn.Linear(int(curr_dim/2), int(curr_dim/4)))
        layers.append(torch.nn.ReLU(inplace=True))
        layers.append(torch.nn.Linear(int(curr_dim/4), int(curr_dim/8)))
        layers.append(torch.nn.ReLU(inplace=True))
        layers.append(torch.nn.Linear(int(curr_dim/8), 1)) # 9 -> 1
        layers.append(torch.nn.Sigmoid())

        self.main = torch.nn.Sequential(*layers)

    def forward(self, color, text):
        out = torch.cat([color, text.squeeze(1)], dim=1) # color: batch x 15, text: batch x 768
        out2 = self.main(out)
        return out2.squeeze(1)


In [3]:
class Text2ColorDataset(torch.utils.data.Dataset):
    def __init__(self, text_path, palette_path):
        self.text_path = text_path
        self.palette_path = palette_path

        self.text_data = pickle.load(open(text_path, 'rb'))
        palette_data = pickle.load(open(palette_path, 'rb'))

        palette_data = torch.FloatTensor(palette_data) / 255.

        self.palette_list = []
        for index, palettes in enumerate(palette_data):
            temp = []
            for palette in palettes:
                rgb = np.array([palette[0], palette[1], palette[2]])
                lab = rgb2lab(rgb[np.newaxis, np.newaxis, :], illuminant='D50').flatten()
                temp.append(lab[0])
                temp.append(lab[1])
                temp.append(lab[2])
            self.palette_list.append(temp)
        
        self.palette_list = torch.FloatTensor(self.palette_list)


        self.embed_model = Sentence2Vec()

    def __len__(self):
        return len(self.text_data)

    def __getitem__(self, idx):
        text_item = self.embed_model.embed(' '.join(self.text_data[idx]))['pooler_output'].squeeze(1)
        palette_item = self.palette_list[idx]

        return text_item, palette_item

In [4]:
# parameters
text_path = './data_for_training/palette_gen/words421.pkl'
palette_path = './data/rgb421.pkl'

batch_size = 16
lr = 1e-4
weight_decay = 5e-5
beta1 = 0.5
beta2 = 0.99
hidden_dim = 768

max_iter_cnt = 1e5

print_every_iter = 1
save_every_epoch = 3

ckpt_dir = './palette_gen_ckpt/date_time'
output_dir = './palette_gen_ckpt/output/date_time'

device = torch.device('cuda')

if not os.path.exists(ckpt_dir):
    os.makedirs(ckpt_dir)

if not os.path.exists(output_dir):
    os.makedirs(output_dir)


# utils
def lab2rgb_1d(in_lab, clip=True):
    tmp_rgb = lab2rgb(in_lab[np.newaxis, np.newaxis, :], illuminant='D50').flatten()
    if clip:
        tmp_rgb = np.clip(tmp_rgb, 0, 1)
    return tmp_rgb


# define data_loader
dataset = Text2ColorDataset(text_path, palette_path)
data_loader = torch.utils.data.DataLoader(dataset, batch_size=16, shuffle=True, drop_last=True)

In [5]:
# define models
encoder = EncoderRNN(hidden_size=150, n_layers=1, dropout_p=0.2).to(device)
decoder = AttnDecoderRNN(hidden_size=150, n_layers=1, dropout_p=0.2).to(device)
discriminator = Discriminator(color_size=15, hidden_dim=150).to(device)

  "num_layers={}".format(dropout, num_layers))


In [6]:
# define loss function
criterion_GAN = torch.nn.BCELoss()
criterion_smoothL1 = torch.nn.SmoothL1Loss()

def KL_loss(mu, logvar):
    KLD_element = mu.pow(2).add_(logvar.exp()).mul_(-1).add_(1).add_(logvar)
    KLD = torch.mean(KLD_element).mul_(-0.5)
    return KLD

In [7]:
# define optimizer
G_parameters = list(encoder.parameters()) + list(decoder.parameters())

g_optimizer = torch.optim.Adam(G_parameters, lr=lr, weight_decay=weight_decay)
d_optimizer = torch.optim.Adam(discriminator.parameters(), lr=lr, betas=(beta1, beta2))

In [8]:
# define training step
def train_step(text, palettes):
    batch_size = text.size(0)
    nonzero_indices = list(torch.nonzero(text)[:, 0])
    each_input_size = [nonzero_indices.count(j) for j in range(batch_size)]
    
    real_labels = torch.ones(batch_size).to(device)
    fake_labels = torch.zeros(batch_size).to(device)

    palette = torch.FloatTensor(batch_size, 3).zero_().to(device)
    fake_palettes = torch.FloatTensor(batch_size, 15).zero_().to(device)

    encoder_hidden = encoder.init_hidden(batch_size).to(device)
    
    encoder_outputs, decoder_hidden, mu, logvar = encoder(text, encoder_hidden)
    
    decoder_hidden = decoder_hidden.squeeze(0)

    for i in range(5):
        palette, decoder_context, decoder_hidden = decoder(palette,
                                                           decoder_hidden,
                                                           encoder_outputs, 
                                                           each_input_size, 
                                                           i)
        fake_palettes[:, 3 * i:3 * (i+1)] = palette

    
    # Condition for the discriminator.
    each_input_size = torch.FloatTensor(each_input_size).to(device)
    each_input_size = each_input_size.unsqueeze(1).expand(batch_size, 150)
    encoder_outputs = torch.sum(encoder_outputs, 0)
    encoder_outputs = torch.div(encoder_outputs, each_input_size)
    
    # train discriminator
    palettes = palettes.to(device)
    
    real = discriminator(palettes, encoder_outputs)
    d_loss_real = criterion_GAN(real, real_labels)

    fake = discriminator(fake_palettes, encoder_outputs)
    d_loss_fake = criterion_GAN(fake, fake_labels)

    d_loss = d_loss_real + d_loss_fake

    d_optimizer.zero_grad()
    d_loss.backward(retain_graph=True)
    d_optimizer.step()


    # train generator
    fake = discriminator(fake_palettes, encoder_outputs)
    g_loss_GAN = criterion_GAN(fake, real_labels)

    g_loss_smoothL1 = criterion_smoothL1(fake_palettes, palettes)

    kl_loss = KL_loss(mu, logvar)

    # g_loss = g_loss_GAN + g_loss_smoothL1 * self.args.lambda_sL1 + kl_loss * self.args.lambda_KL
    g_loss = g_loss_GAN + g_loss_smoothL1 + kl_loss

    # Backprop and optimize.
    g_optimizer.zero_grad()
    g_loss.backward()
    g_optimizer.step()
    
    return g_loss, d_loss

In [9]:
# sample some results in training process
def sample(prefix, input_text='快乐 悲伤 浪漫'):
    decoder.eval()
    text = embed_model.embed(input_text)['pooler_output'].unsqueeze(0).to(device)
    for x in range(2):  # saving 5 samples
        fig1, axs1 = plt.subplots(nrows=1, ncols=5)
        
        batch_size = text.size(0)
        nonzero_indices = list(torch.nonzero(text)[:, 0])
        each_input_size = [nonzero_indices.count(j) for j in range(batch_size)]
        
        palette = torch.FloatTensor(batch_size, 3).zero_().to(device)
        fake_palettes = torch.FloatTensor(batch_size, 15).zero_().to(device)
        encoder_hidden = encoder.init_hidden(batch_size).to(device)

        encoder_outputs, decoder_hidden, mu, logvar = encoder(text, encoder_hidden)
        
        decoder_hidden = decoder_hidden.squeeze(0)
        
        for i in range(5):
            palette, decoder_context, decoder_hidden = decoder(palette,
                                                               decoder_hidden,
                                                               encoder_outputs,
                                                               each_input_size,
                                                               i)

            fake_palettes[:, 3 * i:3 * (i + 1)] = palette
        
        axs1[0].set_title(input_text)
#         print(fake_palettes.size())
        fake_palettes = fake_palettes.squeeze(0)
        for k in range(5):
            lab = np.array([fake_palettes.data[3*k],
                            fake_palettes.data[3*k+1],
                            fake_palettes.data[3*k+2]], dtype='float64')
            rgb = lab2rgb_1d(lab)
            axs1[k].imshow([[rgb]])
            axs1[k].axis('off')

        fig1.savefig(os.path.join(output_dir,
                                    'epoch{}_sample{}.jpg'.format(prefix, x+1)))
        plt.close()
    print('Saved train sample...')

In [10]:

# the whole train function
def train():
    print('Start training Loop...')

    epoch = 0
    iter_cnt = 0
    iter_every_epoch = len(data_loader)

    decoder.train()
    discriminator.train()

    
    while iter_cnt < max_iter_cnt:
        for idx, batch in enumerate(data_loader):
            text, palettes = batch
            
            g_loss, d_loss = train_step(text, palettes)
            
            exit(0)

            iter_cnt += 1

            if iter_cnt % print_every_iter == 0:
                print('Epoch: {:.2f}, Iteration: {:6d}. G_Loss: {:.4f}, D_Loss: {:.4f}'.format(
                        iter_cnt / iter_every_epoch, iter_cnt, g_loss, d_loss))

        epoch += 1

        if epoch % save_every_epoch == 0:
            # save checkpoint 
            torch.save({
                    'epoch': epoch,
                    'encoder': encoder.state_dict(),
                    'decoder_state_dict': decoder.state_dict(),
                    'discriminator_state_dict': discriminator.state_dict(),
                    'g_opt_state_dict': g_optimizer.state_dict(),
                    'd_opt_state_dict': d_optimizer.state_dict(),
                    'loss': 0,
                    }, os.path.join(ckpt_dir, 'ckpt_%d.pt' % epoch))

            # sample
            sample('%d_1' % epoch, '快乐 悲伤')
            sample('%d_2' % epoch, '快乐 悲伤')
            sample('%d_3' % epoch, '快乐 悲伤')
            decoder.train()

In [34]:
def test():
    encoder.load_state_dict(torch.load('./ckpt/20121103/ckpt_666.pt', map_location=lambda storage, loc: storage)['encoder'])
    decoder.load_state_dict(torch.load('./ckpt/20121103/ckpt_666.pt', map_location=lambda storage, loc: storage)['decoder_state_dict'])
    decoder.eval()
    
    sample(1, '美好')
    sample(2, '生命')
    sample(3, '夜晚')
    sample(4, '放肆')
    sample(5, '童年')
    sample(6, '秋天')
    
    sample(7, '美好 秋天')
    sample(8, '生命 秋天')
    sample(9, '夜晚 秋天')
    sample(10, '放肆 秋天')
    sample(11, '童年 秋天')
    
    
    sample(12, '上班')
    sample(13, '疼痛')
    sample(14, '喊叫')
    
    
    

In [35]:
test()

Saved train sample...
Saved train sample...
Saved train sample...
Saved train sample...
Saved train sample...
Saved train sample...
Saved train sample...
Saved train sample...
Saved train sample...
Saved train sample...
Saved train sample...
Saved train sample...
Saved train sample...
Saved train sample...
