# Lab5 - Conditional Sequence-to-sequence VAE

In [None]:
import string
import random
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import optim
import matplotlib.pyplot as plt
from tqdm import trange
from nltk.translate.bleu_score import SmoothingFunction, sentence_bleu

%matplotlib inline

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [339]:
# Compute BLEU-4 score
def compute_bleu(output, reference):
    cc = SmoothingFunction()
    if len(reference) == 3:
        weights = (0.33, 0.33, 0.33)
    else:
        weights = (0.25, 0.25, 0.25, 0.25)
    return sentence_bleu([reference], output, weights=weights, smoothing_function=cc.method1)

# Compute Gaussian score
def Gaussian_score(words):
    words_list = []
    score = 0
    yourpath = 'data/train.txt'#should be your directory of train.txt
    with open(yourpath,'r') as fp:
        for line in fp:
            word = line.split(' ')
            word[3] = word[3].strip('\n')
            words_list.extend([word])
        for t in words:
            for i in words_list:
                if t == i:
                    score += 1
    return score/len(words)

In [272]:
# Define characteristic to vector dictionary for embedding
SOS_token = 0
EOS_token = 1
ch_to_ix = {ch: i+2 for i, ch in enumerate(string.ascii_lowercase)}
ix_to_ch = {i+2: ch for i, ch in enumerate(string.ascii_lowercase)}

# Define condition to dictionary
conditions = ['sp', 'tp', 'pg', 'p']
cond_to_ix = {conditions[i]: i for i in range(len(conditions))}

# Load the datasets
def load_data(filename):
    data = []
    with open(filename, 'r') as f:
        for line in f.readlines():
            data.append(line.rstrip().split(' '))
    return data


train_data = load_data('./data/train.txt')
train_cond = ['sp', 'tp', 'pg', 'p']

test_data = load_data('./data/test.txt')
test_cond = [
    ['sp', 'p'],
    ['sp', 'pg'],
    ['sp', 'tp'],
    ['sp', 'tp'],
    ['p', 'tp'],
    ['sp', 'pg'],
    ['p', 'sp'],
    ['pg', 'sp'],
    ['pg', 'p'],
    ['pg', 'tp'],
] # Given

In [252]:
class EncoderRNN(nn.Module):
    def __init__(self, input_size, hidden_size, latent_size, num_conds, cond_size):
        super(EncoderRNN, self).__init__()
        
        self.hidden_size = hidden_size
        self.latent_size = latent_size
        self.cond_size = cond_size
        
        # Embedding the input
        self.word_embedding = nn.Embedding(input_size, hidden_size)
        self.cond_embedding = nn.Embedding(num_conds, cond_size)
        
        # LSTM
        self.lstm = nn.LSTM(hidden_size, hidden_size)
        
        # Fully connected layer to generate mean and var for hidden and cell
        self.hidden_mean = nn.Linear(hidden_size, latent_size)
        self.hidden_logvar = nn.Linear(hidden_size, latent_size)
        self.cell_mean = nn.Linear(hidden_size, latent_size)
        self.cell_logvar = nn.Linear(hidden_size, latent_size)
        
    def forward(self, input, condition, hidden):
        condition = self.cond_embedding(condition).view(1, 1, -1)
        hidden = (torch.cat((hidden[0], condition), dim=2), torch.cat((hidden[1], condition), dim=2))
        output = self.word_embedding(input).view(input.size(0), 1, -1)
        output, hidden = self.lstm(output, hidden)
        
        mean = (self.hidden_mean(hidden[0]), self.cell_mean(hidden[1]))
        logvar = (self.hidden_logvar(hidden[0]), self.cell_logvar(hidden[1]))
        
        # Reparameterize
        std = (torch.exp(logvar[0]), torch.exp(logvar[1]))
        eps = (torch.rand_like(std[0]), torch.rand_like(std[1]))
        latent = (mean[0]+eps[0]*std[0], mean[1]+eps[1]*std[1])
        
        return latent, mean, logvar
        
    def init_hidden(self):
        return (torch.zeros(1, 1, self.hidden_size-self.cond_size, device=device),
                torch.zeros(1, 1, self.hidden_size-self.cond_size, device=device))
    
    def condition_embedding(self, condition):
        return self.cond_embedding(condition)
    
    
class DecoderRNN(nn.Module):
    def __init__(self, hidden_size, output_size):
        super(DecoderRNN, self).__init__()
        
        self.hidden_size = hidden_size
        self.output_size = output_size
        
        self.embedding = nn.Embedding(output_size, hidden_size)
        
        self.lstm = nn.LSTM(hidden_size, hidden_size)
        
        self.out = nn.Linear(hidden_size, output_size)
        self.softmax = nn.LogSoftmax(dim=1)
        
    def forward(self, input, hidden):
        output = self.embedding(input).view(1, 1, -1)
        output = F.relu(output)
        output, hidden = self.lstm(output, hidden)
        output = self.softmax(self.out(output[0]))
        return output, hidden
    
    def init_hidden(self):
        return (torch.zeros(1, 1, self.hidden_size, device=device),
                torch.zeros(1, 1, self.hidden_size, device=device))

In [395]:
def decode(decoder, hidden, target_tensor, teacher_forcing):
    decoder_input = torch.tensor([[SOS_token]], device=device)
    decoder_hidden = hidden
    
    target_length = target_tensor.size(0)
    
    outputs = []
    if teacher_forcing:
        for di in range(target_length):
            decoder_output, decoder_hidden = decoder(decoder_input, decoder_hidden)
            outputs.append(decoder_output)
            decoder_input = target_tensor[di]
    else:
        for di in range(target_length):
            decoder_output, decoder_hidden = decoder(decoder_input, decoder_hidden)
            outputs.append(decoder_output)
            topv, topi = decoder_output.topk(1)
            decoder_input = topi.squeeze().detach()
            
            if decoder_input.item() == EOS_token:
                break
                
    outputs = torch.cat(outputs, dim=0)
    return outputs
    

def train_pair(input_data_tensor, input_cond_tensor, target_data_tensor, target_cond_tensor,
               encoder, encoder_optimizer, decoder, decoder_optimizer, kl_weight,
               criterion, teacher_forcing_ratio=0.5, max_length=MAX_LENGTH):
    
    encoder_hidden = encoder.init_hidden()
    
    encoder_optimizer.zero_grad()
    decoder_optimizer.zero_grad()
    
    input_length = input_data_tensor.size(0)
    target_length = target_data_tensor.size(0)
    
    encoder_outputs = torch.zeros(max_length, encoder.hidden_size, device=device)
    
    loss = 0
    
    #----------sequence to sequence part for encoder----------#
    latent, mean, logvar = encoder(input_data_tensor, input_cond_tensor, encoder_hidden)
    
    #----------sequence to sequence part for decoder----------#
    use_teacher_forcing = True if random.random() < teacher_forcing_ratio else False
    
    decoder_cond = encoder.condition_embedding(target_cond_tensor).view(1, 1, -1)
    decoder_hidden = (torch.cat((latent[0], decoder_cond), dim=2), torch.cat((latent[1], decoder_cond), dim=2))
    output = decode(decoder, decoder_hidden, target_data_tensor, teacher_forcing=use_teacher_forcing)
    
    crossEntropy_loss = criterion(output, target_data_tensor[:output.size(0)].view(-1))
    kl_loss = torch.sum(0.5*(-logvar[0]+(mean[0]**2)+torch.exp(logvar[0])-1)) + torch.sum(0.5*(-logvar[1]+(mean[1]**2)+torch.exp(logvar[1])-1))
    (crossEntropy_loss+kl_weight*kl_loss).backward()
    
    encoder_optimizer.step()
    decoder_optimizer.step()
    
    return crossEntropy_loss.item(), kl_loss.item()
    
    
def predict(input_data_tensor, input_cond_tensor, target_data_tensor, target_cond_tensor, 
            encoder, decoder, criterion, max_length=MAX_LENGTH):
    
    encoder_hidden = encoder.init_hidden()
    input_data_length = input_data_tensor.size(0)
    
    encoder_outputs = torch.zeros(max_length, encoder.hidden_size, device=device)
    
    with torch.no_grad():
        
        latent, mean, logvar = encoder(input_data_tensor, input_cond_tensor, encoder_hidden)
        decoder_input = torch.tensor([[SOS_token]], device=device)
        decoder_cond = encoder.cond_embedding(target_cond_tensor).view(1, 1, -1)
        decoder_hidden = (torch.cat((latent[0], decoder_cond), dim=2), torch.cat((latent[1], decoder_cond), dim=2))
        
        crossEntropy_loss = 0
        kl_loss = torch.sum(0.5*(-logvar[0]+(mean[0]**2)+torch.exp(logvar[0])-1)) + torch.sum(0.5*(-logvar[1]+(mean[1]**2)+torch.exp(logvar[1])-1))
        
        pred = ''
        for di in range(target_data_tensor.size(0)):
            decoder_output, decoder_hidden = decoder(decoder_input, decoder_hidden)
            topv, topi = decoder_output.topk(1)
            decoder_input = topi.squeeze().detach()
            
            crossEntropy_loss += criterion(decoder_output, target_data_tensor[di])
            if decoder_input == EOS_token:
                break
            
            pred += ix_to_ch[decoder_input.item()]
    
    return pred, crossEntropy_loss, kl_loss
    

def train_data_combination(n=4):
    for i in range(n):
        for j in range(i+1, n):
            yield (i, j)
            yield (j, i)
            

def get_kl_anealing_func(method, slope=0.0002, max_val=1, period=None):
    if method == 'monotonic':
        return lambda iteration: min(iteration*slope, max_val)
    elif method == 'cyclical' and period:
        return lambda iteration: min((iteration%period)*slope, max_val)
    else:
        return None
    

def train(encoder, decoder, train_dataset, train_condition, test_dataset, test_condition,
          kl_anealing_func, teacher_forcing_ratio, num_epochs=20, learning_rate=1e-02):
    
    encoder_optimizer = optim.SGD(encoder.parameters(), lr=learning_rate)
    decoder_optimizer = optim.SGD(decoder.parameters(), lr=learning_rate)
    
    criterion = nn.CrossEntropyLoss()
    
    train_result_list = {'cross_entropy': [], 'kl': []}
    test_result_list = {'cross_entropy': [], 'kl': [], 'score': []}
    iteration = 0
    for epoch in range(num_epochs):

        # Train
        encoder.train()
        decoder.train()
        
        shuffled_idx = np.random.permutation(len(train_dataset))
        train_loss = {'cross_entropy': 0, 'kl': 0}
        for i in range(len(train_dataset)):
            data = train_dataset[shuffled_idx[i]]
            cond = train_condition
            for a, b in train_data_combination(4):
                input_data_tensor = torch.tensor([ch_to_ix[ch] for ch in data[a]], device=device).view(-1, 1)
                input_cond_tensor = torch.tensor(cond_to_ix[train_cond[a]], device=device)
                target_data_tensor = torch.tensor([ch_to_ix[ch] for ch in data[b]]+[EOS_token], device=device).view(-1, 1)
                target_cond_tensor = torch.tensor(cond_to_ix[train_cond[b]], device=device)
                loss = train_pair(input_data_tensor, input_cond_tensor, target_data_tensor, target_cond_tensor,
                           encoder, encoder_optimizer, decoder, decoder_optimizer, kl_anealing_func(iteration),
                           criterion, teacher_forcing_ratio)
                train_loss['cross_entropy'] += loss[0]
                train_loss['kl'] += loss[1]
                
        train_loss['cross_entropy'] /= (len(train_dataset)*12)
        train_loss['kl'] /= (len(train_dataset)*12)
        train_result_list['cross_entropy'].append(train_loss['cross_entropy'])
        train_result_list['kl'].append(train_loss['kl'])
                
        # Test
        encoder.eval()
        decoder.eval()
        
        test_loss = {'cross_entropy': 0, 'kl': 0}
        test_score = 0
        for i in range(len(test_dataset)):
            input_data_tensor = torch.tensor([ch_to_ix[ch] for ch in test_dataset[i][0]], device=device).view(-1, 1)
            input_cond_tensor = torch.tensor(cond_to_ix[test_condition[i][0]], device=device)
            target_data_tensor = torch.tensor([ch_to_ix[ch] for ch in test_dataset[i][1]], device=device).view(-1, 1)
            target_cond_tensor = torch.tensor(cond_to_ix[test_condition[i][1]], device=device)
            pred, crossEntropy_loss, kl_loss = predict(input_data_tensor, input_cond_tensor, target_data_tensor, target_cond_tensor, encoder, decoder, criterion)
            test_loss['cross_entropy'] += crossEntropy_loss
            test_loss['kl'] += kl_loss
            score = compute_bleu(pred, test_dataset[i][1])
            test_score += score
        
        test_loss['cross_entropy'] /= len(test_dataset)
        test_loss['kl'] /= len(test_dataset)
        test_result_list['cross_entropy'].append(test_loss['cross_entropy'])
        test_result_list['kl'].append(test_loss['kl'])
        test_score /= len(test_dataset)
        test_result_list['score'].append(test_score)
        
        # Print the result
        print('=====================================================')
        print('Epoch: {} / {}'.format(epoch+1, num_epochs))
        print()
        print('Train Cross Entropy Loss: {}'.format(train_loss['cross_entropy']))
        print('Train KL Loss: {}'.format(train_loss['kl']))
        print()
        print('Test Cross Entropy Loss: {}'.format(test_loss['cross_entropy']))
        print('Test KL Loss: {}'.format(test_loss['kl']))
        print('Test BLEU-4 Score: {}'.format(test_score))
        print()
        print()
         
    return encoder, decoder, train_result_list, test_result_list


def evaluate(test_dataset, test_condition, encoder, decoder, max_length=MAX_LENGTH):
    
    score = 0
    for i in range(len(test_dataset)):
        input_data_tensor = torch.tensor([ch_to_ix[ch] for ch in test_dataset[i][0]], device=device).view(-1, 1)
        input_cond_tensor = torch.tensor(cond_to_ix[test_condition[i][0]], device=device)
        target_cond_tensor = torch.tensor(cond_to_ix[test_condition[i][1]], device=device)
            
        encoder_hidden = encoder.init_hidden()
        input_data_length = input_data_tensor.size(0)
    
        encoder_outputs = torch.zeros(max_length, encoder.hidden_size, device=device)
    
        with torch.no_grad():
        
            latent, mean, logvar = encoder(input_data_tensor, input_cond_tensor, encoder_hidden)
            decoder_input = torch.tensor([[SOS_token]], device=device)
            decoder_cond = encoder.cond_embedding(target_cond_tensor).view(1, 1, -1)
            decoder_hidden = (torch.cat((latent[0], decoder_cond), dim=2), torch.cat((latent[1], decoder_cond), dim=2))

        
            pred = ''
            while True:
                decoder_output, decoder_hidden = decoder(decoder_input, decoder_hidden)
                topv, topi = decoder_output.topk(1)
                decoder_input = topi.squeeze().detach()

                if decoder_input == EOS_token:
                    break
            
                pred += ix_to_ch[decoder_input.item()]
    
        score += compute_bleu(pred, test_dataset[i][1])
        
        print('Input:      {}'.format(test_dataset[i][0]))
        print('Target:     {}'.format(test_dataset[i][1]))
        print('Prediction: {}'.format(pred))
        print()
    print('Average BLEU-4 Score: {}'.format(score/len(test_dataset)))

In [396]:
# Define hyperparameters
hidden_size = 512
latent_size = 64
vocab_size = 28
num_conds = 4
cond_size = 32
teacher_forcing_ratio = 0.5
learning_rate = 0.001
MAX_LENGTH = 40
num_epochs = 5

# Encoder & Decoder
#encoder = EncoderRNN(vocab_size, hidden_size, latent_size, num_conds, cond_size).to(device)
#decoder = DecoderRNN(latent_size+cond_size, vocab_size).to(device)

# Train
encoder, decoder, train_loss_list, test_score_list = train(encoder, decoder,
                                                           train_data, train_cond, test_data, test_cond,
                                                           kl_anealing_func=get_kl_anealing_func('cyclical', period=10000),
                                                           teacher_forcing_ratio=teacher_forcing_ratio,
                                                           num_epochs=num_epochs)

Epoch: 1 / 5

Train Cross Entropy Loss: 0.0011314563325108148
Train KL Loss: 1107148.3614365729

Test Cross Entropy Loss: 0.008216426707804203
Test KL Loss: 1031.5662841796875
Test BLEU-4 Score: 1.0


Epoch: 2 / 5

Train Cross Entropy Loss: 0.0010978779801758902
Train KL Loss: 1112188.0331284632

Test Cross Entropy Loss: 0.007876241579651833
Test KL Loss: 1029.3585205078125
Test BLEU-4 Score: 1.0


Epoch: 3 / 5

Train Cross Entropy Loss: 0.0010828311358603232
Train KL Loss: 895540.7413091745

Test Cross Entropy Loss: 0.00800984725356102
Test KL Loss: 1026.9705810546875
Test BLEU-4 Score: 1.0


Epoch: 4 / 5

Train Cross Entropy Loss: 0.0010611051219330929
Train KL Loss: 913191.0349916626

Test Cross Entropy Loss: 0.007418683264404535
Test KL Loss: 1034.4884033203125
Test BLEU-4 Score: 1.0


Epoch: 5 / 5

Train Cross Entropy Loss: 0.001034282368021017
Train KL Loss: 892065.9049988104

Test Cross Entropy Loss: 0.007953687570989132
Test KL Loss: 1031.2242431640625
Test BLEU-4 Score: 1.0




In [397]:
def generate_words(encoder, decoder, num=100):
    
    encoder.eval()
    decoder.eval()
    
    cond_size = encoder.cond_size
    embedded_conditions = {}
    for cond in conditions:
        cond_tensor = torch.tensor(cond_to_ix[cond], device=device)
        embedded_conditions[cond] = encoder.condition_embedding(cond_tensor)
    
    words_list = []
    with torch.no_grad():
        for _ in range(num):
            noise = decoder.init_hidden()

            # Generate Gaussian noise
            noise = (noise[0].normal_(std=1), noise[1].normal_(std=1))

            # Generate words with 4 different tenses
            words = []
            for embedded_cond in embedded_conditions.values():
                
                noise[0][:, :, -cond_size:] = embedded_cond
                noise[1][:, :, -cond_size:] = embedded_cond
                
                decoder_input = torch.tensor([[SOS_token]], device=device)
                decoder_hidden = noise
            
                word = ''
                while True:
                    decoder_output, decoder_hidden = decoder(decoder_input, decoder_hidden)
                    topv, topi = decoder_output.topk(1)
                    decoder_input = topi.squeeze().detach()
                    
                    if decoder_input == EOS_token:
                        break
                        
                    word += ix_to_ch[decoder_input.item()]
                words.append(word)
            words_list.append(words)
    return words_list