### Import and Setup

In [1]:
!pip install unidecode

Collecting unidecode
  Downloading Unidecode-1.3.2-py3-none-any.whl (235 kB)
[K     |████████████████████████████████| 235 kB 2.1 MB/s 
[?25hInstalling collected packages: unidecode
Successfully installed unidecode-1.3.2


In [None]:
import json
import re
from unidecode import unidecode
from nltk.tokenize import word_tokenize, sent_tokenize
import nltk
nltk.download('punkt')
from collections import Counter
import pickle

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from torch.utils import data
import operator

device = "cuda" if torch.cuda.is_available() else "cpu"
assert device == "cuda"  

seed = 42
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)

from google.colab import drive
drive.mount('/content/drive')

[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Unzipping tokenizers/punkt.zip.
Mounted at /content/drive


### Data

In [None]:
# # These IDs are reserved.
MAX_SENT_LENGTH = 15
MAX_SENT_LENGTH_PLUS_SOS_EOS = MAX_SENT_LENGTH + 2
PAD_INDEX = 0
UNK_INDEX = 1
SOS_INDEX = 2
EOS_INDEX = 3
RARE_WORD_TRESHOLD = 0

class TSTDataset(data.Dataset):
    def __init__(self, taylor_sentences, drake_sentences, vocab, vocab_counts, sampling=1.):
        self.taylor_sentences = taylor_sentences[:int(len(taylor_sentences) * sampling)]
        self.drake_sentences = drake_sentences[:int(len(drake_sentences) * sampling)]

        self.max_seq_length = MAX_SENT_LENGTH_PLUS_SOS_EOS
        self.vocab = vocab
        self.vocab_counts = vocab_counts

        self.v2id = {v : i for i, v in enumerate(self.vocab)}
        self.id2v = {val : key for key, val in self.v2id.items()}
    
    def __len__(self):
        return min(len(self.taylor_sentences), len(self.drake_sentences))
    
    def __getitem__(self, index):
        taylor_sent = self.taylor_sentences[index]
        taylor_len = len(taylor_sent) + 2   # add <s> and </s> to each sentence
        taylor_id = []
        for w in taylor_sent:
            if w not in self.vocab:
                w = '<unk>'
            if vocab_counts[w] <= RARE_WORD_TRESHOLD:
                w = '<unk>'
            taylor_id.append(self.v2id[w])

        taylor_id = ([SOS_INDEX] + taylor_id + [EOS_INDEX] + [PAD_INDEX] *
                  (self.max_seq_length - taylor_len))

        drake_sent = self.drake_sentences[index]
        drake_len = len(drake_sent) + 2   # add <s> and </s> to each sentence
        drake_id = []
        for w in drake_sent:
            if w not in self.vocab:
                w = '<unk>'
            if vocab_counts[w] <= RARE_WORD_TRESHOLD:
                w = '<unk>'
            drake_id.append(self.v2id[w])

        drake_id = ([SOS_INDEX] + drake_id + [EOS_INDEX] + [PAD_INDEX] *
                  (self.max_seq_length - drake_len))

        return torch.tensor(taylor_id), taylor_len, torch.tensor(drake_id), drake_len

train_dataset = torch.load('/content/drive/Shareddrives/MIT NLP 8.864/Data/train.pt')
valid_dataset = torch.load('/content/drive/Shareddrives/MIT NLP 8.864/Data/valid.pt')
test_dataset = torch.load('/content/drive/Shareddrives/MIT NLP 8.864/Data/test.pt')
vocab_file = open('/content/drive/Shareddrives/MIT NLP 8.864/Data/vocab.pkl', "rb")
vocab = pickle.load(vocab_file)
vocab_file.close()

vocab_counts = Counter(vocab)
vocab_counts['<pad>'] = RARE_WORD_TRESHOLD + 1
vocab_counts['<unk>'] = RARE_WORD_TRESHOLD + 1
vocab_counts['<s>'] = RARE_WORD_TRESHOLD + 1
vocab_counts['</s>'] = RARE_WORD_TRESHOLD + 1

# f = open('/content/drive/Shareddrives/MIT NLP 8.864/Data/drake.json')
# drake = json.load(f)
# f.close()

# f = open('/content/drive/Shareddrives/MIT NLP 8.864/Data/tswift.json')
# taylor = json.load(f)
# f.close()

# drake = [drake['songs'][i]['lyrics'] for i in range(len(drake['songs']))]
# taylor = [taylor['songs'][i]['lyrics'] for i in range(len(taylor['songs']))]

# taylor_lyrics = [re.sub('\u2005', ' ', re.sub(r'[\(\[].*?[\)\]]', '', taylor[i])).split('\n') for i in range(len(taylor))]
# taylor_lyrics = [[unidecode(i) for i in taylor_lyrics[j]] for j in range(len(taylor_lyrics))]
# taylor_lyrics = [[re.sub('\d+EmbedShare URLCopyEmbedCopy', '', i) for i in taylor_lyrics[j]] for j in range(len(taylor_lyrics))]
# taylor_lyrics = [[re.sub('\d+.EmbedShare URLCopyEmbedCopy', '', i) for i in taylor_lyrics[j]] for j in range(len(taylor_lyrics))]
# taylor_lyrics = [[re.sub('EmbedShare URLCopyEmbedCopy', '', i) for i in taylor_lyrics[j]] for j in range(len(taylor_lyrics))]
# taylor_lyrics = [[i for i in taylor_lyrics[j] if i != ''] for j in range(len(taylor_lyrics))]

# drake_lyrics = [re.sub('\u2005', ' ', re.sub(r'[\(\[].*?[\)\]]', '', drake[i])).split('\n') for i in range(len(drake))]
# drake_lyrics = [[unidecode(i) for i in drake_lyrics[j]] for j in range(len(drake_lyrics))]
# drake_lyrics = [[re.sub('\d+EmbedShare URLCopyEmbedCopy', '', i) for i in drake_lyrics[j]] for j in range(len(drake_lyrics))]
# drake_lyrics = [[re.sub('\d+.EmbedShare URLCopyEmbedCopy', '', i) for i in drake_lyrics[j]] for j in range(len(drake_lyrics))]
# drake_lyrics = [[re.sub('EmbedShare URLCopyEmbedCopy', '', i) for i in drake_lyrics[j]] for j in range(len(drake_lyrics))]
# drake_lyrics = [[i for i in drake_lyrics[j] if i != ''] for j in range(len(drake_lyrics))]

# # taylor_lyrics = [[line1 + ', ' + line2 for line1,line2 in zip(song[0::2], song[1::2])] for song in taylor_lyrics]
# # drake_lyrics = [[line1 + ', ' + line2 for line1,line2 in zip(song[0::2], song[1::2])] for song in drake_lyrics]

# drake_tokenized = [[word_tokenize(drake_lyrics[i][j]) for j in range(len(drake_lyrics[i]))] for i in range(len(drake_lyrics))]
# taylor_tokenized = [[word_tokenize(taylor_lyrics[i][j]) for j in range(len(taylor_lyrics[i]))] for i in range(len(taylor_lyrics))]

# drake_tokenized = [[[word.lower() for word in line] for line in song] for song in drake_tokenized]
# taylor_tokenized = [[[word.lower() for word in line] for line in song] for song in taylor_tokenized]

# drake_length = sum([[len(sent) for sent in song] for song in drake_tokenized], [])
# taylor_length = sum([[len(sent) for sent in song] for song in taylor_tokenized], [])

# # drake_lyrics = sum([[sent for sent in song if (len(sent) >= 10 and len (sent) <= 30)] for song in drake_tokenized], [])
# # taylor_lyrics = sum([[sent for sent in song if (len(sent) >= 10 and len (sent) <= 30)] for song in taylor_tokenized], [])

# drake_lyrics = sum([[sent for sent in song if (len(sent) >= 5 and len (sent) <= 15)] for song in drake_tokenized], [])
# taylor_lyrics = sum([[sent for sent in song if (len(sent) >= 5 and len (sent) <= 15)] for song in taylor_tokenized], [])

# taylor_vocab = sum(taylor_lyrics,[])
# drake_vocab = sum(drake_lyrics,[])

# def unique(list1):
     
#     # insert the list to the set
#     list_set = set(list1)
#     # convert the set to the list
#     unique_list = (list(list_set))
#     return unique_list

# vocab = taylor_vocab + drake_vocab
# vocab_counts = Counter(vocab)
# vocab = unique(vocab)
# vocab = ['<pad>','<unk>','<s>', '</s>'] + vocab

# vocab_counts['<pad>'] = RARE_WORD_TRESHOLD + 1
# vocab_counts['<unk>'] = RARE_WORD_TRESHOLD + 1
# vocab_counts['<s>'] = RARE_WORD_TRESHOLD + 1
# vocab_counts['</s>'] = RARE_WORD_TRESHOLD + 1

# class TSTDataset(data.Dataset):
#     def __init__(self, taylor_sentences, drake_sentences, vocab, vocab_counts, sampling=1.):
#         self.taylor_sentences = taylor_sentences[:int(len(taylor_sentences) * sampling)]
#         self.drake_sentences = drake_sentences[:int(len(drake_sentences) * sampling)]

#         self.max_seq_length = MAX_SENT_LENGTH_PLUS_SOS_EOS
#         self.vocab = vocab
#         self.vocab_counts = vocab_counts

#         self.v2id = {v : i for i, v in enumerate(self.vocab)}
#         self.id2v = {val : key for key, val in self.v2id.items()}
    
#     def __len__(self):
#         return min(len(self.taylor_sentences), len(self.drake_sentences))
    
#     def __getitem__(self, index):
#         taylor_sent = self.taylor_sentences[index]
#         taylor_len = len(taylor_sent) + 2   # add <s> and </s> to each sentence
#         taylor_id = []
#         for w in taylor_sent:
#             if w not in self.vocab:
#                 w = '<unk>'
#             if vocab_counts[w] <= RARE_WORD_TRESHOLD:
#                 w = '<unk>'
#             taylor_id.append(self.v2id[w])

#         taylor_id = ([SOS_INDEX] + taylor_id + [EOS_INDEX] + [PAD_INDEX] *
#                   (self.max_seq_length - taylor_len))

#         drake_sent = self.drake_sentences[index]
#         drake_len = len(drake_sent) + 2   # add <s> and </s> to each sentence
#         drake_id = []
#         for w in drake_sent:
#             if w not in self.vocab:
#                 w = '<unk>'
#             if vocab_counts[w] <= RARE_WORD_TRESHOLD:
#                 w = '<unk>'
#             drake_id.append(self.v2id[w])

#         drake_id = ([SOS_INDEX] + drake_id + [EOS_INDEX] + [PAD_INDEX] *
#                   (self.max_seq_length - drake_len))

#         return torch.tensor(taylor_id), taylor_len, torch.tensor(drake_id), drake_len

# dataset = TSTDataset(taylor_lyrics, drake_lyrics, vocab, vocab_counts)

# test_pct = 0.2
# valid_pct = 0.1

# train_dataset, test_dataset = torch.utils.data.random_split(dataset, [int(len(dataset)*(1-test_pct)),len(dataset)-int(len(dataset)*(1-test_pct))])
# valid_dataset, train_dataset = torch.utils.data.random_split(train_dataset, [int(len(dataset)*valid_pct),len(train_dataset)-int(len(dataset)*valid_pct)])

# train_dataset = torch.save(train_dataset, '/content/drive/Shareddrives/MIT NLP 8.864/Data/train.pt')
# valid_dataset = torch.save(valid_dataset, '/content/drive/Shareddrives/MIT NLP 8.864/Data/valid.pt')
# test_dataset = torch.save(test_dataset, '/content/drive/Shareddrives/MIT NLP 8.864/Data/test.pt')

# vocab_file = open('/content/drive/Shareddrives/MIT NLP 8.864/Data/vocab.pkl', "wb")
# pickle.dump(vocab, vocab_file)
# vocab_file.close()

# vocab_file = open('/content/drive/Shareddrives/MIT NLP 8.864/Data/vocab.pkl', "wb")
# pickle.dump(vocab, vocab_file)
# vocab_file.close()

In [None]:
len(train_dataset) + len(valid_dataset) + len(test_dataset)

17997

In [None]:
len(train_dataset)

12598

### Encoder

In [None]:
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence

class Encoder(nn.Module):
  def __init__(self, input_size, hidden_size, dropout=0.):
    """
    Inputs: 
      - `input_size`: an int representing the RNN input size.
      - `hidden_size`: an int representing the RNN hidden size.
      - `dropout`: a float representing the dropout rate during training. Note
          that for 1-layer RNN this has no effect since dropout only applies to
          outputs of intermediate layers.
    """
    super(Encoder, self).__init__()
    self.rnn = nn.GRU(input_size, hidden_size, num_layers=1, batch_first=True,
                      dropout=dropout, bidirectional=False)

  def forward(self, inputs, lengths, init_state=None):
    """
    Inputs:
      - `inputs`: a 3d-tensor of shape (batch_size, max_seq_length, embed_size)
          representing a batch of padded embedded word vectors of source
          sentences.
      - `lengths`: a 1d-tensor of shape (batch_size,) representing the sequence
          lengths of `inputs`.

    Returns:
      - `outputs`: a 3d-tensor of shape
        (batch_size, max_seq_length, hidden_size).
      - `finals`: a 3d-tensor of shape (num_layers, batch_size, hidden_size).
      Hint: `outputs` and `finals` are both standard GRU outputs. Check:
      https://pytorch.org/docs/stable/nn.html#gru
    """
    # Our variable-length inputs are padded to the same length for batching
    # Here we "pack" them for computational efficiency (see note below)
    packed = pack_padded_sequence(inputs, lengths.cpu(), batch_first=True,
                                  enforce_sorted=False)
    outputs, finals = self.rnn(packed, init_state)
    outputs, _ = pad_packed_sequence(outputs, batch_first=True,
                                     total_length=MAX_SENT_LENGTH_PLUS_SOS_EOS)
    return outputs, finals

### Decoder

#### Generator

In [None]:
class GeneratorTransferredSampled(nn.Module):
  """Define standard linear + softmax generation step."""
  def __init__(self, hidden_size, vocab_size, src_embed, gamma=0.001):
    """
    Inputs:
      - `src_embed`: a 2d-tensor of shape (vocab_size, embed_size )
    """
    super(GeneratorTransferredSampled, self).__init__()
    self.proj = nn.Linear(hidden_size, vocab_size, bias=True)
    self.gamma = gamma
    self.logsoftmax = nn.LogSoftmax(dim = 2)
    self.softmax = nn.Softmax(dim = 2)
    self.src_embed = src_embed

  def embedding(self,x):
    return torch.matmul(x,self.src_embed.weight)
    
  def gumbel_softmax(self,logits, eps=1e-20):
    U = torch.rand(logits.shape).to(device)
    G = -torch.log(-torch.log(U + eps) + eps).to(device)
    return self.logsoftmax((logits + G) / self.gamma)

  def gumbel(self,logits, eps=1e-20):
    U = torch.rand(logits.shape).to(device)
    G = -torch.log(-torch.log(U + eps) + eps).to(device)
    return (logits + G) / self.gamma

  def forward(self, x):
    logits = self.proj(x)
    logprob = self.logsoftmax(logits)
    prob = self.softmax(logits)
    output = self.embedding(prob)
    word  = logits.argmax(dim = 2, keepdim = False)

    return output, logprob, word

  def forward_gumbel(self, x):
    logits = self.proj(x)
    prob = self.softmax(self.gumbel(logits))
    logprob = self.logsoftmax(self.gumbel(logits))
    output = self.embedding(prob)
    word  = logits.argmax(dim = 2, keepdim = False)

    return output, logprob, word

#### Basic Decoder

In [None]:
class Decoder(nn.Module):
  """An RNN decoder + generator with GRU"""

  def __init__(self, input_size, hidden_size, max_len,generator, num_layers = 1, dropout=0.):
    """
      Inputs:
        - `input_size` , `hidden_size`, and `dropout` the same as in Encoder.
    """
    super(Decoder, self).__init__()
    self.input_size = input_size
    self.hidden_size = hidden_size
    self.rnn = nn.GRU(input_size, hidden_size, num_layers=num_layers, batch_first=True,
                      dropout=dropout, bidirectional=False)
    self.generator = generator
    self.max_len = max_len
    self.dropout_layer = nn.Dropout(p=dropout)
    self.rnn_to_pre = nn.Linear(input_size + hidden_size,
                                hidden_size, bias=False)

  def forward_step(self, prev_embed, hidden):
    """Helper function for forward below:
       Perform a single decoder step (1 word).

       Inputs:
      - `prev_embed`: a 3d-tensor of shape (batch_size, 1, embed_size = vocab_size)
          representing the padded embedded word vectors at this step in training
      - `hidden`: a 3d-tensor of shape (1, batch_size, hidden_size) representing
          the current hidden state.

      Returns:
      - `hidden`: a 3d-tensor of shape (1, batch_size, hidden_size)
          representing the current decoder hidden state.
      - `output`: a 3d-tensor of shape (batch_size, max_len, vocab_size)
          representing the total generated outputs.
    """
    pre_output, hidden = self.rnn(prev_embed, hidden)
    pre_output = torch.cat([prev_embed, pre_output], dim=2)
    pre_output = self.dropout_layer(pre_output)
    pre_output = self.rnn_to_pre(pre_output)
    pre_output = torch.tanh(pre_output)

    return hidden, pre_output

    ### Your code here!
    pre_output, hidden = self.rnn(prev_embed, hidden)
    pre_output = torch.cat([prev_embed, pre_output], dim=2)
    pre_output = self.dropout_layer(pre_output)
    pre_output = self.rnn_to_pre(pre_output)
    pre_output = self.pre_activation(pre_output)
    
  def forward_step_beam(self, prev_embed, encoder_hidden, 
                   src_mask, proj_key, hidden):
    """Beam Search only: Unroll the decoder one step at a time.
    Inputs:
      - `input`: a 3d-tensor of shape (batch_size, 1, embed_size)
          representing the padded embedded word vectors at this step in training
      - `hidden`: a 3d-tensor of shape (1, batch_size, hidden_size) representing
          the current hidden state.
    Returns:
      - `hidden`: a 3d-tensor of shape (1, batch_size, hidden_size)
          representing the current decoder hidden state.
      - `output`: a 3d-tensor of shape (batch_size, 1, vocab_size)
          representing the total generated outputs.    
      - `gumbel_logits`: a 3d-tensor of shape
          (batch_size, 1, trg_vocab_size) representing the mapped decoder
          outputs from gumbel softmax.
      - `output_word`: a 2d-tensor of shape
          (batch_size, 1) representing output sentence and
          the corresponding word index (can be used for embedding)    
      - `logits`: a 2d-tensor of shape
          (batch_size, 1, trg_vocab_size) representing the mapped decoder
          outputs from log softmax 
          """
    temp_hidden, pre_output = self.forward_step(prev_embed,encoder_hidden, 
                   src_mask, proj_key, hidden)
    output, logits, output_word = self.generator.forward_gumbel(pre_output)
    return  temp_hidden, output, logits, output_word

  def forward(self, input, encoder_finals,max_len, hidden=None):
    """Unroll the decoder one step at a time.

    Inputs:
      - `inputs`: a 3d-tensor of shape (batch_size, 1, embed_size)
          representing a batch of padded embedded word vectors of SOS . 
          If size is (batch_size,max_len, embed_size), then it is teacher forcing.
      - `encoder_finals`: a 3d-tensor of shape
          (num_enc_layers, batch_size, hidden_size) representing the final
          encoder hidden states used to initialize the initial decoder hidden
          states.
      - `hidden`: a 3d-tensor of shape (1, batch_size, hidden_size) representing
          the value to be used to initialize the initial decoder hidden states.
          If None, then use `encoder_finals`.
      - `max_len`: an int representing the maximum decoding length.
      - `style`: TAYLOR_STYLE or DRAKE_STYLE

    Returns:
      - `hidden`: a 3d-tensor of shape
          (num_layers, batch_size, hidden_size) representing the final hidden
          state for each element in the batch.
      - `outputs`: a 3d-tensor of shape
          (batch_size, max_len, hidden_size) representing the raw decoder
          outputs (before mapping to a `trg_vocab_size`-dim vector).
      - `logits_vectors`: a 3d-tensor of shape
          (batch_size, max_len, trg_vocab_size) representing the mapped decoder
          outputs.
      - `words`: a 3d-tensor of shape
          (batch_size, max_len, 1) representing output sentence and
          the corresponding word index (can be used for embedding)  
    """

    # Initialize decoder hidden state.
    if hidden is None:
      hidden = self.init_hidden(encoder_finals)
    output_vectors = []
    logits_vectors = []
    words = []
    hidden_states = []
    hidden_states.append(hidden[-1][:,None,:])
    for i in range(max_len-1) :
      hidden, prev_output = self.forward_step(input,hidden)
      input, logits, output_word = self.generator.forward_gumbel(prev_output)
      # input, logits, output_word = self.generator(prev_output)

      # input = torch.concat([input,torch.full(input.shape,style)], axis = -1)
      logits_vectors.append(logits)
      output_vectors.append(input)
      words.append(output_word)
      hidden_states.append(prev_output)

    outputs = torch.cat(output_vectors, dim =1)
    logits_vectors = torch.cat(logits_vectors,dim = 1)
    words = torch.cat(words, axis = -1)
    hidden_states = torch.cat(hidden_states, axis = 1)
    return hidden, outputs , logits_vectors, words, hidden_states

  def forward_teacher(self, input, encoder_finals, max_len=None, hidden=None):
    """Unroll the decoder one step at a time.

    Inputs:
      - `inputs`: a 3d-tensor of shape (batch_size,max_len, embed_size)
          representing a batch of padded embedded word vectors of original 
          sentence and acts as  teacher forcing.

    Returns:
      - `hidden`: a 3d-tensor of shape
          (num_layers, batch_size, hidden_size) representing the final hidden
          state for each element in the batch.
      - `outputs`: a 3d-tensor of shape
          (batch_size, max_len, hidden_size) representing the raw decoder
          outputs (before mapping to a `trg_vocab_size`-dim vector).
      - `logits_vectors`: a 3d-tensor of shape
          (batch_size, max_len, trg_vocab_size) representing the mapped decoder
          outputs each represents the probability? 
      - `words`: a 3d-tensor of shape
          (batch_size, max_len, 1) representing output sentence and
          the corresponding word index (can be used for embedding)      
    """

    # Initialize decoder hidden state.
    if max_len is None:
      max_len = input.shape[1]
    if hidden is None:
      hidden = self.init_hidden(encoder_finals)
    output_vectors = []
    logits_vectors = []
    words = []
    hidden_states = []
    hidden_states.append(hidden[-1][:,None,:])
    for i in range(max_len):
      hidden, prev_output = self.forward_step(input[:,i:i+1,:],hidden)
      # output, logits, output_word = self.generator(prev_output)
      output, logits, output_word = self.generator.forward_gumbel(prev_output)
      
      logits_vectors.append(logits)
      output_vectors.append(output)
      words.append(output_word)
      hidden_states.append(prev_output)

    outputs = torch.cat(output_vectors, dim =1)
    logits_vectors = torch.cat(logits_vectors,dim = 1)
    words = torch.cat(words, axis = -1)
    hidden_states = torch.cat(hidden_states, axis = 1)
    return hidden, outputs , logits_vectors, words, hidden_states

  def init_hidden(self, encoder_finals):
    """Use encoder final hidden state to initialize decoder's first hidden
       state.

       Input: `encoder_finals` is same as in forward()

       Returns: 
         - `decoder_init_hiddens`: a 3d-tensor of shape 
              (num_layers, batch_size, hidden_size) representing the initial
              hidden state of the decoder for each element in the batch 
    """
    decoder_init_hiddens = torch.tanh(encoder_finals)
    return decoder_init_hiddens

#### Beam Search

In [None]:
from queue import PriorityQueue
class BeamSearchNode:
  def __init__(self, hiddenstate, previousNode, cur_embed, wordId, 
               logProb,  length ):
    self.h = hiddenstate
    self.prevNode = previousNode
    self.cur_embed = cur_embed
    self.wordid = wordId
    self.logp = logProb
    self.leng = length
    
  def __lt__(self,other):
    return self.logp < other.logp

  def eval(self, alpha=1.0):
    return self.logp 
    # Add here a function for shaping a reward
    # reward = 0
    # return self.logp / float(self.leng - 1 + 1e-6) + alpha * reward'

class BeamSearch:
  def __init__(self,decoder, beam_width, topk, line_embed,max_len,
               max_iter=2000):
    """Use Beam Search to generate a full sentence with the given decoder model
  Inputs:
      - `decoder`: decoder module with forward_step_beam function
      - `beam_width` : the length of the beam 
      - `max_len`: an int representing the maximum decoding length.
      - `max_iter`: The maximum decoding iteration
    """
    self.decoder = decoder
    self.beam_width = beam_width
    self.topk = topk
    self.line_embed = line_embed
    self.max_len = max_len
    self.max_iter = max_iter

  def beam_decode(self, inputs, encoder_hidden, encoder_finals, src_mask, max_len,
                  hidden = None):
                  # inputs, encoder_finals,src_mask, proj_key, hidden):
    """Use Beam Search to generate a full sentence with the given decoder model
    Inputs:
        - `inputs`: a 3d-tensor of shape (batch_size, 1, embed_size)
            representing a batch of padded embedded word vectors of SOS . 
        - `encoder_finals`: a 3d-tensor of shape
            (num_enc_layers, batch_size, hidden_size) representing the final
            encoder hidden states used to initialize the initial decoder hidden
            states.

    Returns:
        - `final_logp_batch`: a 2d-tensor of shape
            (batch_size, sentences_num) representing the probability of generating 
            the sentence.
        - `final_hidden_batch`: a 4d-tensor of shape
            (sentences_num, num_layers, batch_size, hidden_size) representing 
            the final hidden layer
        - `decoded_batch`: a 3d-tensor of shape
            (batch_size,sentences_num,  max_len) representing output sentence and
            the corresponding word index (can be used for embedding)  
    
    """
    
    decoded_batch = []
    final_hidden_batch, final_logp_batch = [],[]
    # print("shape of encoder_finals:",encoder_finals.shape)
    for i in range(inputs.shape[0]):
      if hidden is None:
        hidden = self.decoder.init_hidden(encoder_finals[:,i:i+1,:])
      decoder_input = inputs[i:i+1,:,:]
      # Number of sentence to generate
      endnodes = []
      number_required = self.topk
      proj_key = self.decoder.attention.key_layer(encoder_hidden[i:i+1,:,:])

      # starting node -  hidden vector, previous node, cur_embed, word id , logp, length
      node = BeamSearchNode(self.decoder.init_hidden(encoder_finals[:,i:i+1,:]), 
                            None, decoder_input, [[SOS_INDEX]], 0, 1)
      nodes = PriorityQueue()

      nodes.put((-node.eval(), node))
      qsize = 1

      while qsize<=self.max_iter:
        tocheck = min(nodes.qsize(), self.beam_width)
        new_nodes = PriorityQueue()
        while tocheck>0:
          score, n = nodes.get()
          decoder_input = n.cur_embed
          decoder_hidden = n.h
          if n.leng > self.max_len:
            endnodes.append((score, n))
            # if we reached maximum # of sentences required
            if len(endnodes) >= number_required:
                break

          # decode for one step using decoder
          hidden, _, logsoftmax_logits, wordId = decoder.forward_step_beam(decoder_input, 
                                                                           encoder_hidden[i:i+1,:,:],
                                                                          src_mask[i:i+1,:,:], 
                                                                          proj_key, 
                                                                          decoder_hidden)
          tocheck -= 1
          # PUT HERE REAL BEAM SEARCH OF TOP
          log_prob, indexes = torch.topk(logsoftmax_logits, self.beam_width)
          for new_k in range(self.beam_width):
            
            decoded_t = indexes[0][0][new_k].view(1, -1)
            log_p = log_prob[0][0][new_k]
            prev_embed = self.line_embed(decoded_t)

            node = BeamSearchNode(decoder_hidden, n,prev_embed, decoded_t, 
                                  n.logp + log_p,n.leng + 1)
            score = -node.eval()
            new_nodes.put((score, node))
          qsize += self.beam_width - 1
        nodes = new_nodes

        if len(endnodes) >= number_required:
            break
        

      # choose nbest paths, back trace them
      if len(endnodes) == 0:
          endnodes = [nodes.get() for _ in range(self.topk)]

      utterances = []
      final_logps = []
      final_hiddens = []
      # final_gumbel_logits = []
      for score, n in sorted(endnodes, key=operator.itemgetter(0)):
          end_node = n
          utterance = []
          # gumbel_logits = []
          utterance.append(n.wordid[0][0])
          # gumbel_logits =gumbel_logits + [n.gumbel_logits]
          # back trace
          while n.prevNode != None:
              n = n.prevNode
              utterance.append(n.wordid[0][0])
              # if n.gumbel_logits is not None:
              #   gumbel_logits =gumbel_logits + [n.gumbel_logits]

          utterance = torch.unsqueeze(torch.tensor(utterance[::-1][:self.max_len]), axis = 0)
          utterances.append(utterance)
          # gumbel_logits = torch.cat(gumbel_logits, dim = 1)
          final_logp = end_node.logp 
          final_hidden = end_node.h
          final_logps.append(final_logp)
          final_hiddens.append(torch.unsqueeze(torch.tensor(final_hidden), axis = 0))
          # final_gumbel_logits.append(gumbel_logits)
        
      utterances = torch.cat(utterances, axis = 0 )
      decoded_batch.append(torch.unsqueeze(utterances, axis = 0))
      final_logp_batch.append(torch.unsqueeze(torch.tensor(final_logps), axis = 0))
      final_hiddens = torch.cat(final_hiddens, axis = 0)
      final_hidden_batch.append(final_hiddens)
      # final_gumbel_logits = torch.cat(final_gumbel_logits, axis = 0)
      # final_gumbel_logits_batch.append(torch.unsqueeze(final_gumbel_logits, axis = 0))

    # decoded_batch size = (batch, topk, sentence_len, 1)
    final_logp_batch = torch.cat(final_logp_batch, axis = 0)
    final_hidden_batch = torch.cat(final_hidden_batch, axis = 2)
    # final_gumbel_logits_batch = torch.cat(final_gumbel_logits_batch, axis = 0)
    decoded_batch = torch.cat(decoded_batch, axis = 0)
    print("final_logp_batch.shape:",final_logp_batch.shape)
    print("final_hidden_batch.shape:",final_hidden_batch.shape)
    # print("final_gumbel_logits_batch.shape:",final_gumbel_logits_batch.shape)
    print("decoded_batch.shape:",decoded_batch.shape)
    print("max_len:", self.max_len)
    print("topk:", self.topk)
    print("batch size:", inputs.shape[0])
    return final_hidden_batch, final_logp_batch,  decoded_batch




### Attention Decoder

In [None]:
class BahdanauAttention(nn.Module):
    """Implements Bahdanau (MLP) attention"""
    
    def __init__(self, hidden_size, key_size=None, query_size=None):
        super(BahdanauAttention, self).__init__()
        
        # We assume a bi-directional encoder so key_size is 2*hidden_size
        key_size = 2 * hidden_size if key_size is None else key_size
        query_size = hidden_size if query_size is None else query_size

        self.key_layer = nn.Linear(key_size, hidden_size, bias=False)
        self.query_layer = nn.Linear(query_size, hidden_size, bias=False)
        self.energy_layer = nn.Linear(hidden_size, 1, bias=False)
        
        # to store attention scores
        self.alphas = None
        
    def forward(self, query=None, proj_key=None, value=None, mask=None):
        assert mask is not None, "mask is required"

        # We first project the query (the decoder state).
        # The projected keys (the encoder states) were already pre-computated.
        query = self.query_layer(query)
        
        # Calculate scores.
        scores = self.energy_layer(torch.tanh(query + proj_key))
        scores = scores.squeeze(2).unsqueeze(1)
        
        # Mask out invalid positions.
        # The mask marks valid positions so we invert it using `mask & 0`.
        scores.data.masked_fill_(mask == 0, -float('inf'))
        
        # Turn scores to probabilities.
        alphas = F.softmax(scores, dim=-1)
        self.alphas = alphas        
        
        # The context vector is the weighted sum of the values.
        context = torch.bmm(alphas, value)
        
        # context shape: [B, 1, 2D], alphas shape: [B, 1, M]
        return context, alphas

In [None]:
class AttentionDecoder(nn.Module):
  """An RNN decoder + generator with GRU"""

  def __init__(self, input_size, hidden_size, attention, max_len,generator, num_layers = 1, dropout=0.):
    """
      Inputs:
        - `input_size` , `hidden_size`, and `dropout` the same as in Encoder.
    """
    super(AttentionDecoder, self).__init__()
    self.input_size = input_size
    self.hidden_size = hidden_size
    # self.rnn = nn.GRU(input_size, hidden_size, num_layers=num_layers, batch_first=True,
    #                   dropout=dropout, bidirectional=False)
    self.rnn = nn.GRU(input_size + hidden_size, hidden_size, num_layers,
                          batch_first=True, dropout=dropout)
    
    self.generator = generator
    self.max_len = max_len
    self.dropout_layer = nn.Dropout(p=dropout)
    # self.rnn_to_pre = nn.Linear(input_size + hidden_size,
    #                             hidden_size, bias=False)
    self.rnn_to_pre = nn.Linear(hidden_size + hidden_size + input_size,
                                hidden_size, bias=False)
    self.attention = attention

  def forward_step(self, prev_embed, encoder_hidden, src_mask, proj_key, hidden):
    """Helper function for forward below:
       Perform a single decoder step (1 word).

       Inputs:
      - `prev_embed`: a 3d-tensor of shape (batch_size, 1, embed_size = vocab_size)
          representing the padded embedded word vectors at this step in training
      - `hidden`: a 3d-tensor of shape (1, batch_size, hidden_size) representing
          the current hidden state.

      Returns:
      - `hidden`: a 3d-tensor of shape (1, batch_size, hidden_size)
          representing the current decoder hidden state.
      - `output`: a 3d-tensor of shape (batch_size, max_len, vocab_size)
          representing the total generated outputs.
    """

    # compute context vector using attention mechanism
    query = hidden[-1].unsqueeze(1)  # [#layers, B, D] -> [B, 1, D]
    context, attn_probs = self.attention(
        query=query, proj_key=proj_key,
        value=encoder_hidden, mask=src_mask)
    
    # RNN
    rnn_input = torch.cat([prev_embed, context], dim=2)

    output, hidden = self.rnn(rnn_input, hidden)
    
    pre_output = torch.cat([prev_embed, output, context], dim=2)
    pre_output = self.dropout_layer(pre_output)
    pre_output = self.rnn_to_pre(pre_output)
    
    return hidden, pre_output

  def forward_step_beam(self, prev_embed, encoder_hidden, 
                   src_mask, proj_key, hidden):
    """Beam Search only: Unroll the decoder one step at a time."""
    temp_hidden, pre_output = self.forward_step(prev_embed,encoder_hidden, 
                   src_mask, proj_key, hidden)
    output, logits, output_word = self.generator.forward_gumbel(pre_output)
    return  temp_hidden, output, logits, output_word

  def forward(self, input, encoder_hidden, encoder_finals, src_mask, max_len, hidden=None):
    """Unroll the decoder one step at a time.

    Inputs:
      - `inputs`: a 3d-tensor of shape (batch_size, 1, embed_size)
          representing a batch of padded embedded word vectors of SOS . 
          If size is (batch_size,max_len, embed_size), then it is teacher forcing.
      - `encoder_finals`: a 3d-tensor of shape
          (num_enc_layers, batch_size, hidden_size) representing the final
          encoder hidden states used to initialize the initial decoder hidden
          states.
      - `hidden`: a 3d-tensor of shape (1, batch_size, hidden_size) representing
          the value to be used to initialize the initial decoder hidden states.
          If None, then use `encoder_finals`.
      - `max_len`: an int representing the maximum decoding length.
      - `style`: TAYLOR_STYLE or DRAKE_STYLE

    Returns:
      - `hidden`: a 3d-tensor of shape
          (num_layers, batch_size, hidden_size) representing the final hidden
          state for each element in the batch.
      - `outputs`: a 3d-tensor of shape
          (batch_size, max_len, hidden_size) representing the raw decoder
          outputs (before mapping to a `trg_vocab_size`-dim vector).
      - `logits_vectors`: a 3d-tensor of shape
          (batch_size, max_len, trg_vocab_size) representing the mapped decoder
          outputs.
      - `words`: a 3d-tensor of shape
          (batch_size, max_len, 1) representing output sentence and
          the corresponding word index (can be used for embedding)  
    """

    # Initialize decoder hidden state.
    if hidden is None:
      hidden = self.init_hidden(encoder_finals)

    proj_key = self.attention.key_layer(encoder_hidden)

    output_vectors = []
    logits_vectors = []
    words = []
    hidden_states = []
    hidden_states.append(hidden[-1][:,None,:])
    
    for i in range(max_len-1) :
      
      hidden, prev_output = self.forward_step(input,encoder_hidden, src_mask, proj_key, hidden)
      input, logits, output_word = self.generator.forward_gumbel(prev_output)

      logits_vectors.append(logits)
      output_vectors.append(input)
      words.append(output_word)
      hidden_states.append(prev_output)

    outputs = torch.cat(output_vectors, dim =1)
    logits_vectors = torch.cat(logits_vectors,dim = 1)
    words = torch.cat(words, axis = -1)
    hidden_states = torch.cat(hidden_states, axis = 1)

    return hidden, outputs , logits_vectors, words, hidden_states
  
  def forward_teacher(self, input, encoder_hidden, encoder_finals, src_mask, max_len=None, hidden=None):
    """Unroll the decoder one step at a time.

    Inputs:
      - `inputs`: a 3d-tensor of shape (batch_size,max_len, embed_size)
          representing a batch of padded embedded word vectors of original 
          sentence and acts as  teacher forcing.

    Returns:
      - `hidden`: a 3d-tensor of shape
          (num_layers, batch_size, hidden_size) representing the final hidden
          state for each element in the batch.
      - `outputs`: a 3d-tensor of shape
          (batch_size, max_len, hidden_size) representing the raw decoder
          outputs (before mapping to a `trg_vocab_size`-dim vector).
      - `logits_vectors`: a 3d-tensor of shape
          (batch_size, max_len, trg_vocab_size) representing the mapped decoder
          outputs each represents the probability? 
      - `words`: a 3d-tensor of shape
          (batch_size, max_len, 1) representing output sentence and
          the corresponding word index (can be used for embedding)      
    """

    # Initialize decoder hidden state.
    if max_len is None:
      max_len = input.shape[1]
    if hidden is None:
      hidden = self.init_hidden(encoder_finals)
    
    proj_key = self.attention.key_layer(encoder_hidden)

    output_vectors = []
    logits_vectors = []
    words = []
    hidden_states = []
    hidden_states.append(hidden[-1][:,None,:])

    for i in range(max_len):
      hidden, prev_output = self.forward_step(input[:,i:i+1,:], encoder_hidden, src_mask, proj_key, hidden)
      
      output, logits, output_word = self.generator.forward_gumbel(prev_output)
      logits_vectors.append(logits)
      output_vectors.append(output)
      words.append(output_word)
      hidden_states.append(prev_output)

    outputs = torch.cat(output_vectors, dim =1)
    logits_vectors = torch.cat(logits_vectors,dim = 1)
    words = torch.cat(words, axis = -1)
    hidden_states = torch.cat(hidden_states, axis = 1)
    return hidden, outputs , logits_vectors, words, hidden_states

  def init_hidden(self, encoder_finals):
    """Use encoder final hidden state to initialize decoder's first hidden
       state.

       Input: `encoder_finals` is same as in forward()

       Returns: 
         - `decoder_init_hiddens`: a 3d-tensor of shape 
              (num_layers, batch_size, hidden_size) representing the initial
              hidden state of the decoder for each element in the batch 
    """
    decoder_init_hiddens = torch.tanh(encoder_finals)
    return decoder_init_hiddens

#### EncoderDecoderAttention (in progress)

In [None]:
### Work in progress
class EncoderDecoderAttention(nn.Module):
  def __init__(self, encoder, decoder, line_embed, generator):
    super(EncoderDecoderAttention, self).__init__()

    self.encoder = encoder
    self.decoder = decoder
    self.line_embed = line_embed
    self.generator = generator

  def forward(self, lines, line_lens):
    encoder_hidden, encoder_finals = self.encode(lines, line_lens)
    src_mask = (lines != PAD_INDEX).unsqueeze(-2)
    # return self.reconstruct(encoder_hidden, encoder_finals, lines[:, :-1], src_mask)
    return self.reconstruct(encoder_hidden, encoder_finals, lines[:, :-1], src_mask), self.decode(encoder_hidden, encoder_finals, src_mask)

  def encode(self, lines, line_lens):
    return self.encoder(self.line_embed(lines), line_lens)
    
  def reconstruct(self, encoder_hidden, h0, lines, src_mask):
    original = self.line_embed(lines)
    return self.decoder.forward_teacher(original,encoder_hidden, h0, src_mask)

  def decode(self, encoder_hidden, h0, src_mask):
    target = self.line_embed(torch.tensor([SOS_INDEX]).repeat(h0.size()[1],1).to(device))
    return self.decoder.forward(target, encoder_hidden, h0, src_mask, max_len)

epochs = 3
lr = 1e-3
batch_size = 32
print_every = 100
max_len = dataset.max_seq_length
vocab_size = len(vocab)
embed_size = 256
hidden_size = 256
dropout = 0.2
gamma = 0.001

train_loader = data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
valid_loader = data.DataLoader(valid_dataset, batch_size=batch_size, shuffle=True)
test_loader = data.DataLoader(test_dataset, batch_size=batch_size, shuffle=True)

line_embed = nn.Embedding(vocab_size, embed_size)
encoder = Encoder(embed_size,hidden_size)
generator = GeneratorTransferredSampled(hidden_size,vocab_size, line_embed, gamma = gamma)
attention = BahdanauAttention(hidden_size, key_size=hidden_size)
decoder = AttentionDecoder(embed_size, hidden_size, attention=attention, max_len=vocab_size, generator = generator,dropout=dropout)
model = EncoderDecoderAttention(encoder, decoder, line_embed, generator).to(device)
optimizer_model = torch.optim.Adam(model.parameters(), lr=lr)
rec_loss = nn.NLLLoss(reduction="mean",ignore_index = PAD_INDEX)

for epoch in range(epochs):
  epoch_rec_loss = 0
  epoch_tokens = 0
  model.train()
  for i,(taylor_lines, taylor_len, drake_lines, drake_len) in enumerate(train_loader):
    lines = torch.cat((taylor_lines, drake_lines), 0).to(device)    
    line_lens = torch.cat((taylor_len, drake_len), 0).to(device)

    # Train model
    rec_orig,dec_orig  = model(lines, line_lens)   
    # rec_orig  = model(lines, line_lens)    
    loss_rec = rec_loss(input=dec_orig[2].permute(0,2,1), target=lines[:, 1:])

    optimizer_model.zero_grad()
    loss_rec.backward()
    optimizer_model.step()
    
    epoch_rec_loss += loss_rec.item() * line_lens.sum().item()
    epoch_tokens += line_lens.sum().item()

  print("Finished Training Epoch ", epoch)
  print("Training PPL", np.exp(epoch_rec_loss / float(epoch_tokens)))
  val_loss = 0
  val_tokens = 0

  for i,(taylor_lines, taylor_len, drake_lines, drake_len) in enumerate(valid_loader):
    lines = torch.cat((taylor_lines, drake_lines), 0).to(device)    
    line_lens = torch.cat((taylor_len, drake_len), 0).to(device)

    rec_orig,dec_orig  = model(lines, line_lens)  
    # rec_orig  = model(lines, line_lens)    
    loss_rec = rec_loss(input=dec_orig[2].permute(0,2,1), target=lines[:, 1:])

    val_loss += loss_rec.item() * line_lens.sum().item()
    val_tokens += line_lens.sum().item()
  
  print("Valid PPL", np.exp(val_loss / float(val_tokens)))

  "num_layers={}".format(dropout, num_layers))


Finished Training Epoch  0
Training PPL 186.3012933666054
Valid PPL 36.84848284237294
Finished Training Epoch  1
Training PPL 13.585992878438688
Valid PPL 7.053865752802458
Finished Training Epoch  2
Training PPL 4.822962307204506
Valid PPL 4.504301768389754


In [None]:
dec_orig[3]

tensor([[ 4619,  9239,  9239,  ...,  9268, 12570,  9268],
        [ 6947,  2453,  6338,  ...,  7785, 10651, 10651],
        [ 7670, 12181,  6029,  ..., 12186,  3635, 12570],
        ...,
        [ 7372,  3998, 10130,  ...,  8616,  2051,  7983],
        [ 9077, 10213,  9280,  ...,  1119,  3461,  7149],
        [ 5248,  9280, 12234,  ..., 11901,  6566,  9320]], device='cuda:0')

In [None]:
def lookup_words(x, vocab):
  return [vocab[i] for i in x]

idx=17
print(lookup_words(lines[idx], vocab))
print(lookup_words(rec_orig[3][idx], vocab))
print(lookup_words(dec_orig[3][idx], vocab))

['<s>', 'i', 'just', ',', 'i', 'ca', "n't", ',', 'i', 'just', 'ca', "n't", 'be', 'lovin', "'", 'you', 'no', 'more', ',', 'i', 'love', 'you', 'more', 'than', 'i', 'love', 'myself', '</s>', '<pad>', '<pad>', '<pad>', '<pad>']
['i', 'just', ',', 'i', 'ca', "n't", ',', 'i', 'just', 'ca', "n't", 'be', 'lovin', "'", 'no', 'no', 'more', ',', 'i', 'love', 'more', 'more', 'than', 'i', '</s>', 'myself', '</s>', '</s>', '</s>', '</s>', '</s>']
['i', 'just', ',', 'i', 'ca', "n't", ',', 'i', 'just', 'ca', "n't", 'be', 'lovin', "'", 'you', 'more', 'more', ',', 'i', 'love', 'more', 'more', 'than', 'i', 'myself', '</s>', '</s>', '</s>', '</s>', '</s>', '</s>']


### Encoder-Decoder for Testing

In [None]:
# ### Work in progress
# class EncoderDecoder(nn.Module):
#   def __init__(self, encoder, decoder, line_embed, generator):
#     super(EncoderDecoder, self).__init__()

#     self.encoder = encoder
#     self.decoder = decoder
#     self.line_embed = line_embed
#     self.generator = generator

#   def forward(self, lines, line_lens):
#     encoder_hiddens, encoder_finals = self.encode(lines, line_lens)
#     del encoder_hiddens
#     return self.reconstruct(encoder_finals, lines[:, :-1]), self.decode(encoder_finals)

#   def encode(self, lines, line_lens):
#     return self.encoder(self.line_embed(lines), line_lens)
    
#   def reconstruct(self, h0, lines):
#     original = self.line_embed(lines)
#     return self.decoder.forward_teacher(original,h0)

#   def decode(self, h0):
#     target = self.line_embed(torch.tensor([SOS_INDEX]).repeat(h0.size()[1],1).to(device))
#     return self.decoder.forward(target,h0,max_len)

# epochs = 3
# lr = 1e-3
# batch_size = 32
# print_every = 100
# max_len = dataset.max_seq_length
# vocab_size = len(vocab)
# embed_size = 256
# hidden_size = 256
# dropout = 0.2
# gamma = 0.001

# train_loader = data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
# valid_loader = data.DataLoader(valid_dataset, batch_size=batch_size, shuffle=True)
# test_loader = data.DataLoader(test_dataset, batch_size=batch_size, shuffle=True)

# line_embed = nn.Embedding(vocab_size, embed_size)
# encoder = Encoder(embed_size,hidden_size)
# generator = GeneratorTransferredSampled(hidden_size,vocab_size, line_embed, gamma = gamma)
# decoder = Decoder(embed_size, hidden_size, max_len=vocab_size, generator = generator,dropout=dropout)
# model = EncoderDecoder(encoder, decoder, line_embed, generator).to(device)
# optimizer_model = torch.optim.Adam(model.parameters(), lr=lr)
# rec_loss = nn.NLLLoss(reduction="mean",ignore_index = PAD_INDEX)

# for epoch in range(epochs):
#   epoch_rec_loss = 0
#   epoch_tokens = 0
#   model.train()
#   for i,(taylor_lines, taylor_len, drake_lines, drake_len) in enumerate(train_loader):
#     lines = torch.cat((taylor_lines, drake_lines), 0).to(device)    
#     line_lens = torch.cat((taylor_len, drake_len), 0).to(device)

#     # Train model
#     rec_orig,dec_orig  = model(lines, line_lens)    
#     loss_rec = rec_loss(input=rec_orig[2].permute(0,2,1), target=lines[:, 1:])

#     optimizer_model.zero_grad()
#     loss_rec.backward()
#     optimizer_model.step()
    
#     epoch_rec_loss += loss_rec.item() * line_lens.sum().item()
#     epoch_tokens += line_lens.sum().item()

#   print("Finished Training Epoch ", epoch)
#   print("Training PPL", np.exp(epoch_rec_loss / float(epoch_tokens)))
#   val_loss = 0
#   val_tokens = 0

#   for i,(taylor_lines, taylor_len, drake_lines, drake_len) in enumerate(valid_loader):
#     lines = torch.cat((taylor_lines, drake_lines), 0).to(device)    
#     line_lens = torch.cat((taylor_len, drake_len), 0).to(device)

#     rec_orig,dec_orig  = model(lines, line_lens)    
#     loss_rec = rec_loss(input=rec_orig[2].permute(0,2,1), target=lines[:, 1:])

#     val_loss += loss_rec.item() * line_lens.sum().item()
#     val_tokens += line_lens.sum().item()
  
#   print("Valid PPL", np.exp(val_loss / float(val_tokens)))

### Classifier

In [None]:
class LSTMDiscriminator(nn.Module):
  def __init__(self, input_size, hidden_size, LSTMlayers=1, dropout = 0.5):
    super(LSTMDiscriminator, self).__init__()

    self.lstm = nn.LSTM(input_size=input_size, hidden_size=hidden_size, num_layers=LSTMlayers, 
                        batch_first=True, bidirectional=True)
    self.drop = nn.Dropout(p=dropout)
    self.fc = nn.Linear(2*hidden_size, 1)
    self.hidden_size = hidden_size

  def forward(self, text_emb, text_len):
    text_len[text_len==0] += 1

    packed_input = pack_padded_sequence(text_emb, text_len.cpu(), batch_first=True, enforce_sorted=False)
    packed_output, _ = self.lstm(packed_input)
    output, _ = pad_packed_sequence(packed_output, batch_first=True)

    out_forward = output[range(len(output)), text_len - 1, :self.hidden_size]
    out_reverse = output[:, 0, self.hidden_size:]
    out_reduced = torch.cat((out_forward, out_reverse), 1)
    text_fea = self.drop(out_reduced)

    text_fea = self.fc(text_fea)
    text_fea = torch.squeeze(text_fea, 1)
    text_out = torch.sigmoid(text_fea)

    return text_out
    
class LSTMClassifier(nn.Module):

    def __init__(self, dimension=128):
        super(LSTMClassifier, self).__init__()

        self.embedding = nn.Linear(len(vocab), 300)
        self.dimension = dimension
        self.lstm = nn.LSTM(input_size=300,
                            hidden_size=dimension,
                            num_layers=1,
                            batch_first=True,
                            bidirectional=True)
        self.drop = nn.Dropout(p=0.5)

        self.fc = nn.Linear(2*dimension, 1)

    def forward(self, text, text_len):

        text_emb = self.embedding(text)
        text_len[text_len==0] += 1

        packed_input = pack_padded_sequence(text_emb, text_len.cpu(), batch_first=True, enforce_sorted=False)
        packed_output, _ = self.lstm(packed_input)
        output, _ = pad_packed_sequence(packed_output, batch_first=True)

        out_forward = output[range(len(output)), text_len - 1, :self.dimension]
        out_reverse = output[:, 0, self.dimension:]
        out_reduced = torch.cat((out_forward, out_reverse), 1)
        text_fea = self.drop(out_reduced)

        text_fea = self.fc(text_fea)
        text_fea = torch.squeeze(text_fea, 1)
        text_out = torch.sigmoid(text_fea)

        return text_out

### Training

#### TSTModel

In [None]:
class TSTModel(nn.Module):
  def __init__(self, max_len, vocab_size, embed_size, hidden_size_z, hidden_size_y, line_embed, encoder, generator, decoder, classifier):
    super(TSTModel, self).__init__()

    self.hidden_size = hidden_size_y + hidden_size_z

    self.encoder = encoder
    self.generator = generator
    self.decoder = decoder
    self.classifier = classifier

    self.line_embed = line_embed
    self.y_embed_enc = nn.Embedding(2,hidden_size_y)
    self.y_embed_gen = nn.Embedding(2,hidden_size_y)

    self.max_len = max_len
    self.vocab_size = vocab_size
    self.embed_size = embed_size
    self.hidden_size_z = hidden_size_z
    self.hidden_size_y = hidden_size_y

  def forward(self, lines, line_lens, labels):

    encoder_hidden, encoded_lines = self.encode(lines, line_lens, labels)
    z = encoded_lines[-1][:,self.hidden_size_y:]

    h0_orig = torch.cat((self.y_embed_gen(labels),z), 1)[None,:]
    h0_tsf = torch.cat((self.y_embed_gen(1-labels),z), 1)[None,:]

    # Decode back into original form for reconstruction
    rec_orig = self.reconstruct(h0_orig, lines[:, :-1])

    # Decode into original and transferred forms for classification
 
    decode_orig = self.decode(h0_orig)
    decode_tsf = self.decode(h0_tsf)
    
    half = int(lines.size(0) / 2)

    discrim1_input = torch.cat((rec_orig[4][:half], decode_tsf[4][half:]))
    discrim0_input = torch.cat((rec_orig[4][half:], decode_tsf[4][:half]))

    classifier_lines = torch.cat((torch.exp(decode_orig[2]), torch.exp(decode_tsf[2]), F.one_hot(lines[:,1:], self.vocab_size).to(torch.float)), 0)
    
    rec_orig_len = first_eos(rec_orig[3]) + 1
    decode_orig_len = first_eos(decode_orig[3]) + 1
    decode_tsf_len = first_eos(decode_tsf[3]) + 1

    classifier_line_lens = torch.cat((decode_orig_len, decode_tsf_len, line_lens),0)
    # classifier_line_lens = torch.cat((line_lens, line_lens, line_lens),0)
    discrim0_lens = torch.cat((rec_orig_len[half:], decode_tsf_len[:half]))
    discrim1_lens = torch.cat((rec_orig_len[:half], decode_tsf_len[half:]))

    pred_class = self.classifier(classifier_lines, classifier_line_lens-1)

    # return rec_orig, decode_orig
    return rec_orig, pred_class, decode_orig, decode_tsf, (discrim0_input, discrim0_lens), (discrim1_input, discrim1_lens)

  def encode(self, lines, line_lens, labels):
    init_state = torch.cat((self.y_embed_enc(labels), torch.zeros((len(lines),self.hidden_size_z), device=device)), 1)[None,:].to(device)
    return self.encoder(self.line_embed(lines), line_lens, init_state)

  def reconstruct(self, h0, lines):
    original = self.line_embed(lines)
    return self.decoder.forward_teacher(original,h0)

  def decode(self, h0):
    target = self.line_embed(torch.tensor([SOS_INDEX]).repeat(h0.size()[1],1).to(device))
    return self.decoder.forward(target,h0,self.max_len)

def first_eos(x):
  eos_pos = (x == EOS_INDEX)
  found, indices = ((eos_pos.cumsum(1) == 1) & eos_pos).max(1)
  indices = indices + (~found*x.size(1))
  return indices

#### TSTModelAttention

In [None]:
class TSTModelAttention(nn.Module):
  def __init__(self, max_len, vocab_size, embed_size, hidden_size_z, 
               hidden_size_y, line_embed, encoder, generator, decoder, 
               classifier,beamSeasrch):
    super(TSTModelAttention, self).__init__()

    self.hidden_size = hidden_size_y + hidden_size_z

    self.encoder = encoder
    self.generator = generator
    self.decoder = decoder
    self.classifier = classifier

    self.beamSeasrch = beamSeasrch

    self.line_embed = line_embed
    self.y_embed_enc = nn.Embedding(2,hidden_size_y)
    self.y_embed_gen = nn.Embedding(2,hidden_size_y)

    self.max_len = max_len
    self.vocab_size = vocab_size
    self.embed_size = embed_size
    self.hidden_size_z = hidden_size_z
    self.hidden_size_y = hidden_size_y

  def forward(self, lines, line_lens, labels):

    src_mask = (lines != PAD_INDEX).unsqueeze(-2)
    encoder_hidden, encoder_finals = self.encode(lines, line_lens, labels)

    z = encoder_finals[-1][:,self.hidden_size_y:]

    h0_orig = torch.cat((self.y_embed_gen(labels),z), 1)[None,:]
    h0_tsf = torch.cat((self.y_embed_gen(1-labels),z), 1)[None,:]

    # Decode back into original form for reconstruction
    rec_orig = self.reconstruct(encoder_hidden, h0_orig, lines[:, :-1], src_mask)

    # Decode into original and transferred forms for classification
    decode_orig = self.decode(encoder_hidden, h0_orig, src_mask)
    decode_tsf = self.decode(encoder_hidden, h0_tsf, src_mask)
    
    half = int(lines.size(0) / 2)

    discrim1_input = torch.cat((rec_orig[4][:half], decode_tsf[4][half:]))
    discrim0_input = torch.cat((rec_orig[4][half:], decode_tsf[4][:half]))

    classifier_lines = torch.cat((torch.exp(decode_orig[2]), torch.exp(decode_tsf[2]), F.one_hot(lines[:,1:], self.vocab_size).to(torch.float)), 0)

    rec_orig_len = first_eos(rec_orig[3]) + 1
    decode_orig_len = first_eos(decode_orig[3]) + 1
    decode_tsf_len = first_eos(decode_tsf[3]) + 1

    classifier_line_lens = torch.cat((decode_orig_len, decode_tsf_len, line_lens),0)
    # classifier_line_lens = torch.cat((line_lens, line_lens, line_lens),0)
    discrim0_lens = torch.cat((rec_orig_len[half:], decode_tsf_len[:half]))
    discrim1_lens = torch.cat((rec_orig_len[:half], decode_tsf_len[half:]))

    pred_class = self.classifier(classifier_lines, classifier_line_lens-1)
    
    # return rec_orig, pred_class, decode_orig, decode_tsf, (discrim0_input, line_lens), (discrim1_input, line_lens)
    return rec_orig, pred_class, decode_orig, decode_tsf, (discrim0_input, discrim0_lens), (discrim1_input, discrim1_lens)

  def forward_beam(self,lines, line_lens, labels):
    src_mask = (lines != PAD_INDEX).unsqueeze(-2)
    encoder_hidden, encoder_finals = self.encode(lines, line_lens, labels)
    z = encoder_finals[-1][:,self.hidden_size_y:]

    h0_orig = torch.cat((self.y_embed_gen(labels),z), 1)[None,:]
    h0_tsf = torch.cat((self.y_embed_gen(1-labels),z), 1)[None,:]

    decode_orig = self.decode_beam(encoder_hidden, h0_orig, src_mask)
    decode_tsf = self.decode_beam(encoder_hidden, h0_tsf, src_mask)

    return decode_orig, decode_tsf #rec_orig, ,pred_class,  (discrim0_input, line_lens), (discrim1_input, line_lens)

  
  def decode_beam(self,encoder_hidden,h0,src_mask):
    target = self.line_embed(torch.tensor([SOS_INDEX]).repeat(h0.size()[1],1).to(device))
    return self.beamSeasrch.beam_decode(target, encoder_hidden,h0,src_mask, self.max_len)

  def encode(self, lines, line_lens, labels):
    init_state = torch.cat((self.y_embed_enc(labels), torch.zeros((len(lines),self.hidden_size_z), device=device)), 1)[None,:].to(device)
    return self.encoder(self.line_embed(lines), line_lens, init_state)

  def reconstruct(self, encoder_hidden, h0, lines, src_mask):
    original = self.line_embed(lines)
    return self.decoder.forward_teacher(original,encoder_hidden, h0, src_mask)

  def decode(self, encoder_hidden, h0, src_mask):
    target = self.line_embed(torch.tensor([SOS_INDEX]).repeat(h0.size()[1],1).to(device))
    return self.decoder.forward(target, encoder_hidden, h0, src_mask, self.max_len)

#### Training

In [None]:
# Changing learning rate can affect things
# Focusing on reconstruction first, then later building in discriminator or classifier seems to help
# Without attention, autoencoder needs at least 7-8 epochs to get reasonable reconstructions

attention = True
classify = False
discriminate = True

pre_train_classifier = False

epochs = 10
class_epochs = 2
lr = 1e-3
batch_size = 32
print_every = 100

max_len = MAX_SENT_LENGTH_PLUS_SOS_EOS
vocab_size = len(vocab)
embed_size = 100
hidden_size_z = 500
hidden_size_y = 200
hidden_size = hidden_size_z + hidden_size_y
dropout = 0.2
gamma = 0.1

TAYLOR_STYLE=1 # for information only, don't change
DRAKE_STYLE=0  # for information only, don't change
train_loader = data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
valid_loader = data.DataLoader(valid_dataset, batch_size=batch_size, shuffle=True)
test_loader = data.DataLoader(test_dataset, batch_size=batch_size, shuffle=True)

line_embed = nn.Embedding(vocab_size, embed_size)
encoder = Encoder(embed_size,hidden_size)
generator = GeneratorTransferredSampled(hidden_size,vocab_size, line_embed, gamma = gamma)
classifier = LSTMClassifier()
discriminator0 = LSTMDiscriminator(hidden_size, hidden_size).to(device)
discriminator1 = LSTMDiscriminator(hidden_size, hidden_size).to(device)

stop

if pre_train_classifier:
  classifier = classifier.to(device)
  optimizer = torch.optim.Adam(classifier.parameters(), lr=1e-3) 
  class_loss = nn.BCELoss()

  for epoch in range(class_epochs):
    correct = 0
    classifier.train()
    for i,(taylor_lines, taylor_len, drake_lines, drake_len) in enumerate(train_loader):
      lines = torch.cat((taylor_lines, drake_lines), 0).to(device)  
      classifier_lines = F.one_hot(lines[:,1:], len(vocab)).to(torch.float).to(device)

      line_lens = torch.cat((taylor_len, drake_len), 0).to(device)
      labels = torch.cat((torch.ones(size=(len(taylor_lines),), dtype=torch.int32),torch.zeros(size=(len(drake_lines),),dtype=torch.int32))).to(device)

      pred_class = classifier(classifier_lines, line_lens-1)
      loss_class = class_loss(input=pred_class, target=labels.to(torch.float))

      optimizer.zero_grad()
      loss_class.backward()
      optimizer.step()

      correct += torch.sum((pred_class >= 0.5) == labels)
    print("Pre-Training Accuracy: ", correct / float(2*len(train_dataset)))
    classifier.eval()
    correct = 0
    for i,(taylor_lines, taylor_len, drake_lines, drake_len) in enumerate(valid_loader):
      lines = torch.cat((taylor_lines, drake_lines), 0).to(device)  
      classifier_lines = F.one_hot(lines[:,1:], len(vocab)).to(torch.float).to(device)

      line_lens = torch.cat((taylor_len, drake_len), 0).to(device)
      labels = torch.cat((torch.ones(size=(len(taylor_lines),), dtype=torch.int32),torch.zeros(size=(len(drake_lines),),dtype=torch.int32))).to(device)

      pred_class = classifier(classifier_lines, line_lens-1)
      correct += torch.sum((pred_class >= 0.5) == labels)
    print("Pre-Valid Accuracy: ", correct / float(2*len(valid_dataset)))

if attention:
  attention_mech = BahdanauAttention(hidden_size, key_size=hidden_size)
  decoder = AttentionDecoder(embed_size, hidden_size, attention=attention_mech, max_len=vocab_size, generator = generator,dropout=dropout)
  # BEAM SEARCH
  beamSeasrch = BeamSearch(decoder, 3,3,line_embed,max_len)
  model = TSTModelAttention(max_len, vocab_size, embed_size, hidden_size_z, hidden_size_y, line_embed, encoder, generator, decoder, classifier,beamSeasrch).to(device)
else:
  decoder = Decoder(embed_size, hidden_size, max_len=vocab_size, generator = generator, dropout=dropout)
  # BEAM SEARCH
  beamSeasrch = BeamSearch(decoder, 3,3,line_embed,max_len)
  model = TSTModel(max_len, vocab_size, embed_size, hidden_size_z, hidden_size_y, line_embed, encoder, generator, decoder, classifier).to(device)


optimizer_model = torch.optim.Adam(model.parameters(), lr=lr) 
optimizer_discr = torch.optim.Adam(list(discriminator0.parameters()) + list(discriminator1.parameters()), lr=lr) 

rec_loss = nn.NLLLoss(reduction="mean",ignore_index = PAD_INDEX)
class_loss = nn.BCELoss()
discr_loss = nn.BCELoss()

epoch_losses = []
for epoch in range(epochs):
  epoch_loss = 0
  epoch_class_loss = 0
  epoch_rec_loss = 0
  epoch_adv_loss = 0
  epoch_loss_d = 0
  epoch_tokens = 0
  model.train()
  for i,(taylor_lines, taylor_len, drake_lines, drake_len) in enumerate(train_loader):
    lines = torch.cat((taylor_lines, drake_lines), 0).to(device)    
    line_lens = torch.cat((taylor_len, drake_len), 0).to(device)
    labels = torch.cat((torch.ones(size=(len(taylor_lines),), dtype=torch.int32),torch.zeros(size=(len(drake_lines),),dtype=torch.int32))).to(device)
    classifier_labels = torch.cat((labels,1-labels, labels))
    
    fake_labels = torch.cat((torch.zeros(size=(len(taylor_lines),), dtype=torch.int32),torch.ones(size=(len(taylor_lines),),dtype=torch.int32))).to(device)
    fake_labels = fake_labels

    # Train discriminator

    if discriminate:
      rec_orig, pred_class, decode_orig, decode_tsf, pred_fake0, pred_fake1 = model(lines, line_lens, labels)
      
      pred_fake0 = discriminator0(pred_fake0[0], pred_fake0[1])
      pred_fake1 = discriminator1(pred_fake1[0], pred_fake1[1])

      loss_d0 = discr_loss(pred_fake0, fake_labels.to(torch.float))
      loss_d1 = discr_loss(pred_fake1, fake_labels.to(torch.float))
      loss_d = loss_d0 + loss_d1

      optimizer_discr.zero_grad()
      loss_d.backward()
      optimizer_discr.step()

    # Train model

    rec_orig, pred_class, decode_orig, decode_tsf, pred_fake0, pred_fake1 = model(lines, line_lens, labels)

    loss_rec = rec_loss(input=rec_orig[2].permute(0,2,1), target=lines[:, 1:])

    loss = loss_rec

    if attention:
      rec_treshold  = 10
    else:
      rec_treshold = 10
    
    if discriminate:
      pred_fake0 = discriminator0(pred_fake0[0], pred_fake0[1])
      pred_fake1 = discriminator1(pred_fake1[0], pred_fake1[1])
      loss_adv0 = class_loss(pred_fake0[len(drake_lines):], fake_labels[len(drake_lines):].to(torch.float))
      loss_adv1 = class_loss(pred_fake1[len(taylor_lines):], fake_labels[len(taylor_lines):].to(torch.float))

      if loss_adv0 < 1.2 and loss_adv1 < 1.2 and loss_rec < rec_treshold:
      # Don't use adversarial training unless discriminator and reconstruction are both good
        loss -= (loss_adv0 + loss_adv1)
    
    if classify:
      loss_class = class_loss(pred_class, classifier_labels.to(torch.float))
      loss_class_generated = class_loss(pred_class[:-len(labels)], classifier_labels[:-len(labels)].to(torch.float))

      # if loss_class_generated > 0.5 and loss_rec < rec_treshold:
      # # If generated examples are too similar and reconstruction is good, only focus on achieving better style (i.e. classifier)
      # # Else, focus on both
      # # Note: play around with these, it probably affects performance
      #   loss = loss_class
      # else:
      #   loss += loss_class
      if loss_rec < rec_treshold:
        loss += loss_class

    optimizer_model.zero_grad()
    loss.backward()
    optimizer_model.step()
    
    epoch_loss += loss.item()
    epoch_rec_loss += loss_rec.item() * line_lens.sum().item()
    epoch_tokens += line_lens.sum().item()

    if discriminate:
      epoch_loss_d += loss_d.item()
      epoch_adv_loss += (loss_adv0.item() + loss_adv1.item())
      
    if classify:
      epoch_class_loss += loss_class.item()

    if model.training and i % print_every == 0:
      print("Epoch Step: %d Loss: %f" % (i, loss.item()))
      if classify:
        print("Epoch Step: %d Class Loss: %f" % (i, loss_class.item()))
  
  epoch_losses.append(epoch_loss)
  print("Finished Training Epoch ", epoch)
  print("Training PPL", np.exp(epoch_rec_loss / float(epoch_tokens)))

  if discriminate:
    print("Adversarial Loss", epoch_adv_loss)
    print("Discriminator Loss", epoch_loss_d)
  
  if classify:
    print("Classification Loss", epoch_class_loss)

  val_loss = 0
  val_tokens = 0
  val_class_loss = 0
  correct_pred = 0
  correct_pred_all = 0
  correct_pred_drake = 0
  correct_pred_tay = 0

  for i,(taylor_lines, taylor_len, drake_lines, drake_len) in enumerate(valid_loader):
    lines = torch.cat((taylor_lines, drake_lines), 0).to(device)    
    line_lens = torch.cat((taylor_len, drake_len), 0).to(device)
    labels = torch.cat((torch.ones(size=(len(taylor_lines),), dtype=torch.int32),torch.zeros(size=(len(drake_lines),),dtype=torch.int32))).to(device)
    classifier_labels = torch.cat((labels,1-labels, labels))
    
    fake_labels = torch.cat((torch.zeros(size=(len(taylor_lines),), dtype=torch.int32),torch.ones(size=(len(taylor_lines),),dtype=torch.int32))).to(device)

    rec_orig, pred_class, decode_orig, decode_tsf, pred_fake0, pred_fake1 = model(lines, line_lens, labels)
    loss_rec = rec_loss(input=rec_orig[2].permute(0,2,1), target=lines[:, 1:])
    # loss_class = class_loss(pred_class, classifier_labels.to(torch.float))

    val_loss += loss_rec.item() * line_lens.sum().item()
    val_tokens += line_lens.sum().item()
    # val_class_loss += loss_class.item()*classifier_labels.size(0)

    if classify:
      correct_pred += torch.sum((pred_class[-len(lines):] >= 0.5) == classifier_labels[-len(lines):])
      correct_pred_all += torch.sum((pred_class >= 0.5) == classifier_labels)

    if discriminate:
      pred_fake0 = discriminator0(pred_fake0[0], pred_fake0[1])
      pred_fake1 = discriminator1(pred_fake1[0], pred_fake1[1])

      correct_pred_drake += torch.sum((pred_fake0 >= 0.5) == fake_labels) 
      correct_pred_tay += torch.sum((pred_fake1 >= 0.5) == fake_labels)
  
  print("Valid PPL", np.exp(val_loss / float(val_tokens)))
  if classify:
    print("Valid Classification Accuracy on True", correct_pred / (2.*len(valid_dataset)))
    print("Valid Classification Accuracy on All", correct_pred_all / (3.*2.*len(valid_dataset)))

  if discriminate:
    print("Valid Classification Accuracy on Drake", correct_pred_drake / (2.*len(valid_dataset)))
    print("Valid Classification Accuracy on Taylor", correct_pred_tay / (2.*len(valid_dataset)))

NameError: ignored

In [None]:
# Quick assessment
def lookup_words(x, vocab):
  return [vocab[i] for i in x]

idx=0
print(lookup_words(lines[idx], vocab))
print(lookup_words(rec_orig[3][idx], vocab))
print(lookup_words(decode_orig[3][idx], vocab))
print(lookup_words(decode_tsf[3][idx], vocab))

idx=9
print(lookup_words(lines[idx], vocab))
print(lookup_words(rec_orig[3][idx], vocab))
print(lookup_words(decode_orig[3][idx], vocab))
print(lookup_words(decode_tsf[3][idx], vocab))

['<s>', 'you', "'re", 'being', 'too', 'loud', '</s>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>']
['you', "'re", 'being', 'too', 'loud', '</s>', '</s>', 'too', 'loud', '</s>', '</s>', 'too', 'loud', '</s>', '</s>', 'too']
['you', "'re", 'being', 'too', 'loud', '</s>', '</s>', 'too', 'loud', '</s>', '</s>', 'too', 'loud', '</s>', '</s>', 'too']
['you', "'re", 'being', 'too', 'loud', '</s>', 'being', 'too', 'loud', '</s>', '</s>', 'too', 'loud', '</s>', '</s>', 'too']
['<s>', 'love', 'certain', 'ones', 'but', 'never', 'get', 'attached', 'to', "'em", '</s>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>', '<pad>']
['love', 'certain', 'ones', 'but', 'never', 'get', 'attached', 'to', "'em", '</s>', 'tales', "'em", '</s>', '</s>', "'em", '</s>']
['love', 'certain', 'ones', 'but', 'never', 'get', 'attached', 'to', "'em", '</s>', 'closed', "'em", '</s>', '</s>', "'em", '</s>']
['love', 'certain', 'ones', 'but', 'never', 'get', 'attached', 'to', "'em",

#### Save Model

In [None]:
from datetime import datetime
import pytz
import os
now = datetime.now()
now.astimezone(pytz.timezone('America/New_York'))

model_dir = '/content/drive/Shareddrives/MIT NLP 8.864/model'

In [None]:
def current_time():
  now = datetime.now()
  now = now.astimezone(pytz.timezone('America/New_York'))
  return now.strftime("%Y%m%d_%H%M%S")
def lookup_words(x, vocab):
  return [vocab[i] for i in x]

##### Pre-trained Classifier

In [None]:
if pre_train_classifier:
  classifier_path = os.path.join(model_dir,
                                 f'classifier/classifier_{current_time()}.pt')
  torch.save(classifier.state_dict(), classifier_path)


In [None]:
# # Reload model Example
# model = LSTMClassifier()
# model.load_state_dict(torch.load(classifier_path))
# model = model.to(device)
# model.eval()

##### Style-Transfer Model

###### Set model path

In [None]:
model_name = 'model_'
if not attention and not classify and not discriminate:
  # Dana
  model_name += '1'
elif attention and not classify and not discriminate:
  # Dana
  model_name += '2'
elif not attention and classify and not discriminate:
  # Sirui
  model_name += '3'
elif attention and classify and not discriminate:
  # Sirui
  model_name += '4'
elif not attention and classify and discriminate:
  # Chenwei
  model_name += '5'
elif attention and classify and discriminate:
  # Chenwei
  model_name += '6'
elif not attention and not classify and discriminate:
  # Hammaad
  model_name += '7'
elif attention and not classify and discriminate:
  # Hammaad
  model_name += '8'


###### Save model

In [None]:
model_init_args = {"Embedding_num_embeddings":vocab_size,
           'Embedding_embedding_dim':embed_size,
           'Encoder_input_size':embed_size, 
           'Encoder_hidden_size':hidden_size,
           'GeneratorTransferredSampled_hidden_size':hidden_size, 
           'GeneratorTransferredSampled_vocab_size':vocab_size, 
           'GeneratorTransferredSampled_gamma':gamma,
           'LSTMClassifier_dimension':classifier.dimension,
           'LSTMDiscriminator_input_size':hidden_size, 
           'LSTMDiscriminator_hidden_size':hidden_size}
if attention: 
  model_init_args.update({
           'BahdanauAttention_hidden_size':hidden_size,
           'BahdanauAttention_key_size':hidden_size,
           'AttentionDecoder_input_size':embed_size,
            'AttentionDecoder_hidden_size':hidden_size, 
            'AttentionDecoder_max_len':vocab_size,  
            'AttentionDecoder_dropout':dropout,
            'TSTModelAttention_max_len': max_len, 
            'TSTModelAttention_vocab_size':vocab_size, 
            'TSTModelAttention_embed_size':embed_size, 
            'TSTModelAttention_hidden_size_z':hidden_size_z, 
            'TSTModelAttention_hidden_size_y':hidden_size_y})
else:
  model_init_args.update({
            'Decoder_':embed_size, 
            'Decoder_':hidden_size, 
            'Decoder_max_len':vocab_size, 
            'Decoder_dropout':dropout,
            'TSTModel_max_len': max_len, 
            'TSTModel_vocab_size':vocab_size, 
            'TSTModel_embed_size':embed_size, 
            'TSTModel_hidden_size_z':hidden_size_z, 
            'TSTModel_hidden_size_y':hidden_size_y})
           


In [None]:
model_dir_current = os.path.join(model_dir,model_name)
model_dir_current = os.path.join(model_dir_current,
                                 f'{model_name}_{current_time()}')
if not os.path.isdir(model_dir_current):
  os.mkdir(model_dir_current)
model_path = os.path.join(model_dir_current,
                          f'{model_name}_{current_time()}.pt')
torch.save(model.state_dict(), model_path)
model_args_filename = model_path[:-3]+"_params.txt"
import json
with open(model_args_filename,'w') as f:
  json.dump(model_init_args, f)

In [None]:
model_path

'/content/drive/Shareddrives/MIT NLP 8.864/model/model_8/model_8_20211205_213744/model_8_20211205_213744.pt'

In [None]:
line_embed = nn.Embedding(vocab_size, embed_size)
encoder = Encoder(embed_size,hidden_size)
generator = GeneratorTransferredSampled(hidden_size,vocab_size, 
                                        line_embed, gamma = gamma)
classifier = LSTMClassifier()
discriminator0 = LSTMDiscriminator(hidden_size, hidden_size).to(device)
discriminator1 = LSTMDiscriminator(hidden_size, hidden_size).to(device)
if attention:
  attention_mech = BahdanauAttention(hidden_size, key_size=hidden_size)
  decoder = AttentionDecoder(embed_size, hidden_size, attention=attention_mech, max_len=vocab_size, generator = generator,dropout=dropout)
  beamSeasrch = BeamSearch(decoder, 3,3,line_embed,max_len)
  model = TSTModelAttention(max_len, vocab_size, embed_size, hidden_size_z, hidden_size_y, line_embed, encoder, generator, decoder, classifier, beamSeasrch).to(device)
else:
  decoder = Decoder(embed_size, hidden_size, max_len=vocab_size, generator = generator, dropout=dropout)
  model = TSTModel(max_len, vocab_size, embed_size, hidden_size_z, hidden_size_y, line_embed, encoder, generator, decoder, classifier).to(device)


  "num_layers={}".format(dropout, num_layers))


In [None]:
model_path_load = '/content/drive/Shareddrives/MIT NLP 8.864/model/model_8/model_8_20211205_192248/model_8_20211205_192248.pt'
model.load_state_dict(torch.load(model_path_load))
model = model.to(device)
model.eval()

TSTModelAttention(
  (encoder): Encoder(
    (rnn): GRU(100, 700, batch_first=True)
  )
  (generator): GeneratorTransferredSampled(
    (proj): Linear(in_features=700, out_features=12458, bias=True)
    (logsoftmax): LogSoftmax(dim=2)
    (softmax): Softmax(dim=2)
    (src_embed): Embedding(12458, 100)
  )
  (decoder): AttentionDecoder(
    (rnn): GRU(800, 700, batch_first=True, dropout=0.2)
    (generator): GeneratorTransferredSampled(
      (proj): Linear(in_features=700, out_features=12458, bias=True)
      (logsoftmax): LogSoftmax(dim=2)
      (softmax): Softmax(dim=2)
      (src_embed): Embedding(12458, 100)
    )
    (dropout_layer): Dropout(p=0.2, inplace=False)
    (rnn_to_pre): Linear(in_features=1500, out_features=700, bias=False)
    (attention): BahdanauAttention(
      (key_layer): Linear(in_features=700, out_features=700, bias=False)
      (query_layer): Linear(in_features=700, out_features=700, bias=False)
      (energy_layer): Linear(in_features=700, out_features=1, bia

#### Save Test Result

In [None]:
if os.path.exists(os.path.join(model_dir_current,"raw.txt")):
  os.remove(os.path.join(model_dir_current,"raw.txt"))
if os.path.exists(os.path.join(model_dir_current,"orig.txt")):
  os.remove(os.path.join(model_dir_current,"orig.txt"))
if os.path.exists(os.path.join(model_dir_current,"tsf.txt")):
  os.remove(os.path.join(model_dir_current,"tsf.txt"))
for i,(taylor_lines, taylor_len, drake_lines, drake_len) in enumerate(test_loader):
    lines = torch.cat((taylor_lines, drake_lines), 0).to(device)    
    line_lens = torch.cat((taylor_len, drake_len), 0).to(device)
    labels = torch.cat((torch.ones(size=(len(taylor_lines),), dtype=torch.int32),torch.zeros(size=(len(drake_lines),),dtype=torch.int32))).to(device)
    classifier_labels = torch.cat((labels,1-labels, labels))
    
    fake_labels = torch.cat((torch.zeros(size=(len(taylor_lines),), dtype=torch.int32),torch.ones(size=(len(taylor_lines),),dtype=torch.int32))).to(device)

    rec_orig, pred_class, decode_orig, decode_tsf, _, _ = model(lines, line_lens, labels)
    for i in range(lines.shape[0]):
      with open(os.path.join(model_dir_current,"raw.txt"),'a') as f:
        f.write(str(lookup_words(lines[i],vocab))+"\n")
      with open(os.path.join(model_dir_current,"orig.txt"),'a') as f:
        f.write(str(lookup_words(decode_orig[3][i],vocab))+"\n")
      with open(os.path.join(model_dir_current,"tsf.txt"),'a') as f:
        f.write(str(lookup_words(decode_tsf[3][i],vocab))+"\n")
      
    

#### BeamSearch Validation (only for the first validation sample)

In [None]:
for i,(taylor_lines, taylor_len, drake_lines, drake_len) in enumerate(valid_loader):
  lines = torch.cat((taylor_lines, drake_lines), 0).to(device)    
  line_lens = torch.cat((taylor_len, drake_len), 0).to(device)
  
  labels = torch.cat((torch.ones(size=(len(taylor_lines),), dtype=torch.int32),torch.zeros(size=(len(drake_lines),),dtype=torch.int32))).to(device)
  classifier_labels = torch.cat((labels,1-labels, labels))

  lines = lines[0:1,:]
  line_lens = line_lens[0:1]
  labels = labels[0:1]
  decode_orig_beam, decode_tsf_beam = model.forward_beam(lines, line_lens, labels)
  
  print("3 Original Sentences decoded by beam search:")
  print(lookup_words(decode_orig[-1][0], vocab))
  print(lookup_words(decode_orig[-1][1], vocab))
  print(lookup_words(decode_orig[-1][2], vocab))
  print("3 Transferred Sentences decoded by beam search:")
  print(lookup_words(decode_tsf[-1][0], vocab))
  print(lookup_words(decode_tsf[-1][1], vocab))
  print(lookup_words(decode_tsf[-1][2], vocab))
  break

AttributeError: ignored

### Pre-Trained Classifier for Evaluation

In [None]:
# from torch.nn.utils.rnn import pad_packed_sequence, pack_padded_sequence

# # Train classifier used in eval
# class_epochs = 1
# lr_class = 1e-3
# batch_size = 32
# vocab_size = len(vocab)

# train_loader = data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
# valid_loader = data.DataLoader(valid_dataset, batch_size=batch_size, shuffle=True)
# test_loader = data.DataLoader(test_dataset, batch_size=batch_size, shuffle=True)

# classifier_for_eval = LSTMClassifier().to(device)
# optimizer = torch.optim.Adam(classifier_for_eval.parameters(), lr=lr_class) 
# class_loss = nn.BCELoss()

# for epoch in range(class_epochs):
#   correct = 0
#   classifier_for_eval.train()
#   for i,(taylor_lines, taylor_len, drake_lines, drake_len) in enumerate(train_loader):
#     lines = torch.cat((taylor_lines, drake_lines), 0).to(device)  
#     classifier_lines = F.one_hot(lines[:,1:], len(vocab)).to(torch.float).to(device)

#     line_lens = torch.cat((taylor_len, drake_len), 0).to(device)
#     labels = torch.cat((torch.ones(size=(len(taylor_lines),), dtype=torch.int32),torch.zeros(size=(len(drake_lines),),dtype=torch.int32))).to(device)

#     pred_class = classifier_for_eval(classifier_lines, line_lens-1)
#     loss_class = class_loss(input=pred_class, target=labels.to(torch.float))

#     optimizer.zero_grad()
#     loss_class.backward()
#     optimizer.step()

#     correct += torch.sum((pred_class >= 0.5) == labels)
#   print("Pre-Training Accuracy: ", correct / float(2*len(train_dataset)))
#   classifier_for_eval.eval()
#   correct = 0
#   for i,(taylor_lines, taylor_len, drake_lines, drake_len) in enumerate(valid_loader):
#     lines = torch.cat((taylor_lines, drake_lines), 0).to(device)  
#     classifier_lines = F.one_hot(lines[:,1:], len(vocab)).to(torch.float).to(device)

#     line_lens = torch.cat((taylor_len, drake_len), 0).to(device)
#     labels = torch.cat((torch.ones(size=(len(taylor_lines),), dtype=torch.int32),torch.zeros(size=(len(drake_lines),),dtype=torch.int32))).to(device)

#     pred_class = classifier_for_eval(classifier_lines, line_lens-1)
#     correct += torch.sum((pred_class >= 0.5) == labels)
#   print("Pre-Valid Accuracy: ", correct / float(2*len(valid_dataset)))


Pre-Training Accuracy:  tensor(0.7757, device='cuda:0')
Pre-Valid Accuracy:  tensor(0.8396, device='cuda:0')


In [None]:
classifier_for_eval = LSTMClassifier().to(device)
classifier_path = '/content/drive/Shareddrives/MIT NLP 8.864/model/classifier/classifier_for_eval_20211205_130209.pt'
classifier_for_eval.load_state_dict(torch.load(classifier_path))
classifier_for_eval = classifier_for_eval.to(device)
classifier_for_eval.eval()

LSTMClassifier(
  (embedding): Linear(in_features=12458, out_features=300, bias=True)
  (lstm): LSTM(300, 128, batch_first=True, bidirectional=True)
  (drop): Dropout(p=0.5, inplace=False)
  (fc): Linear(in_features=256, out_features=1, bias=True)
)

In [None]:
# Evaluate TST using classifier
model.eval()
classifier_for_eval.eval()

pred_orig=[]
pred_tsf=[]
pred_real=[]

y_orig=[]
y_tsf=[]
y_real=[]

for i,(taylor_lines, taylor_len, drake_lines, drake_len) in enumerate(test_loader):
  lines = torch.cat((taylor_lines, drake_lines), 0).to(device)    
  line_lens = torch.cat((taylor_len, drake_len), 0).to(device)
  labels = torch.cat((torch.ones(size=(len(taylor_lines),), dtype=torch.int32),torch.zeros(size=(len(drake_lines),),dtype=torch.int32))).to(device)
  classifier_labels = torch.cat((labels,1-labels, labels))
  
  rec_orig, pred_class, decode_orig, decode_tsf, pred_fake0, pred_fake1 = model(lines, line_lens, labels)

  classifier_lines = torch.cat((torch.exp(decode_orig[2]), torch.exp(decode_tsf[2]), F.one_hot(lines[:,1:], vocab_size).to(torch.float)), 0)

  rec_orig_len = first_eos(rec_orig[3]) + 1
  decode_orig_len = first_eos(decode_orig[3]) + 1
  decode_tsf_len = first_eos(decode_tsf[3]) + 1
  classifier_line_lens = torch.cat((decode_orig_len, decode_tsf_len, line_lens),0)
  
  pred_class = classifier_for_eval(classifier_lines, classifier_line_lens-1)
  pred_class = pred_class.cpu().detach().numpy()
  classifier_labels = classifier_labels.cpu().detach().numpy()

  pred_orig.append(pred_class[:len(lines)])
  pred_tsf.append(pred_class[len(lines):-len(lines)])
  pred_real.append(pred_class[-len(lines):])

  y_orig.append(classifier_labels[:len(lines)])
  y_tsf.append(classifier_labels[len(lines):-len(lines)])
  y_real.append(classifier_labels[-len(lines):])

from sklearn.metrics import roc_auc_score
print("AUC on Real Sentences: ", roc_auc_score(np.concatenate(y_real), np.concatenate(pred_real)))
print("AUC on Reconstructed Sentences: ", roc_auc_score(np.concatenate(y_orig), np.concatenate(pred_orig)))
print("AUC on Transferred Sentences: ", roc_auc_score(np.concatenate(y_tsf), np.concatenate(pred_tsf)))


AUC on Real Sentences:  0.9229819830246914
AUC on Reconstructed Sentences:  0.9188062885802468
AUC on Transferred Sentences:  0.10500956790123459
