In [7]:
import re

import numpy as np

import torch
import torch.nn as nn

from torch.utils.data import DataLoader, Dataset
from torch.optim import *

from torchtext.data import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator

from nltk.tokenize.treebank import TreebankWordDetokenizer

from datasets import load_dataset

from tqdm.notebook import tqdm

In [30]:
class QADataset(Dataset):
    def __init__(self, dataset, tokenizer, vocab, maxlen, special_tokens, device):
        super(QADataset, self).__init__()

        self.maxlen = maxlen
        self.special_tokens = special_tokens

        self.tokenizer = get_tokenizer('spacy') if tokenizer is None else tokenizer
        self.vocab = self.get_vocab(dataset) if vocab is None else vocab

        dataset = dataset.map(self.create_sentence_pairs)
        self.questions = self.get_tensor(dataset['questions'])
        self.answers = self.get_tensor(dataset['answers'])

    def __getitem__(self, i):
        return self.questions[i], self.answers[i]

    def __len__(self):
        return len(self.questions)

    def create_sentence_pairs(self, x):
        sentences = re.split(r'[\s]*#Person\d#: ', x['dialogue'])[1:]

        first_sents = []
        second_sents = []
        for sent1, sent2 in zip(sentences[:-1], sentences[1:]):
            sent1 = self.tokenize_sent(sent1)
            sent2 = self.tokenize_sent(sent2)
            
            first_sents.append(sent1)
            second_sents.append(sent2)

        return {'questions': first_sents, 'answers': second_sents}

    def tokenize_sent(self, sent):
        sent = ' '.join([self.special_tokens['bos_token'], sent, self.special_tokens['eos_token']])
        sent = [tok for tok in self.tokenizer(sent)]
        sent = sent[:self.maxlen]
        sent = sent + [self.special_tokens['eos_token']] * (self.maxlen - len(sent))
        return sent
    
    def get_tensor(self, sents):
        tokens = torch.zeros((len(sent), self.maxlen), dtype=torch.long, device=device)
        for i, sent in enumerate(flatten_tokens):
            tokens[i, :] = torch.tensor(self.vocab(sent))
        return tokens
    
    def get_vocab(self, dataset):
        dataset = dataset.map(self.create_sentence_pairs)
                
        flatten_tokens = dataset['questions'] + dataset['answers']
        vocab = build_vocab_from_iterator(flatten_tokens, min_freq=5,
                                          specials=list(self.special_tokens.values()))
        vocab.set_default_index(vocab['|UNK|'])

        return vocab

lm_train = QADataset(train_dataset, None, None, 100, special_tokens, 'cpu')

  0%|          | 0/12460 [00:00<?, ?ex/s]

TypeError: unhashable type: 'list'

In [18]:
class BiLstmModel(nn.Module):
    def __init__(self):
        pass
    
    
class LstmModel(nn.Module):
    def __init__(self, vocab_size, embed_dim=64, n_layers=32, n_hidden=128, dropout_rate=0.2):
        super(LstmModel, self).__init__()

        self.vocab_size = vocab_size
        self.embed_dim = embed_dim
        self.n_layers = n_layers
        self.n_hidden = n_hidden
        self.dropout_rate = dropout_rate

        self.embed = nn.Embedding(self.vocab_size, embedding_dim=self.embed_dim)
        self.lstm = nn.LSTM(self.embed_dim, self.n_hidden, num_layers=self.n_layers,
                            batch_first=True)
        self.fc = nn.Linear(self.n_hidden, self.vocab_size)
        self.softmax = nn.Softmax(dim=-1)

        self.dropout = nn.Dropout(dropout_rate)

    def forward(self, x, states=None):
        out = self.dropout(self.embed(x))
        out, states = self.lstm(out, states)
        out = self.fc(out)
        out = self.softmax(out)

        return out, states

    def init_states(self, batch_size, device):
        h = torch.zeros((self.n_layers, batch_size, self.n_hidden), device=device)
        c = torch.zeros((self.n_layers, batch_size, self.n_hidden), device=device)

        return (h, c)


In [19]:
def train_step(model, loader, optimizer, criterion, device):
    model.train()

    pbar = tqdm(loader)
    batch_losses = []

    for i, batch in enumerate(pbar):
        batch = batch.to(device)

        states = model.init_states(batch.size(0), device=device)

        optimizer.zero_grad()
        states = [state.detach() for state in states]
        y_pred, states = model(batch[:, :-1], states)
        loss = criterion(y_pred.moveaxis(1, -1), batch[:, 1:])
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
        optimizer.step()

        batch_losses.append(loss.item())
        pbar.set_description(f'Batch Loss: {loss.item():.3f} Train Loss: {np.mean(batch_losses):.3f}')

    return np.mean(batch_losses)


def eval_step(model, loader, criterion, device):
    model.eval()

    pbar = tqdm(loader)
    batch_losses = []

    for i, batch in enumerate(pbar):
        batch = batch.to(device)

        y_pred, _ = model(batch[:, :-1])
        loss = criterion(y_pred.moveaxis(1, -1), batch[:, 1:])

        batch_losses.append(loss.item())
        pbar.set_description(f'Batch Loss: {loss.item():.3f} Validation Loss: {np.mean(batch_losses):.3f}')

    return np.mean(batch_losses)


def answer(model, sent, tokenizer, vocab, maxlen, special_tokens, device):
    model.eval()

    detokenizer = TreebankWordDetokenizer()

    sent = ' '.join([special_tokens['bos_token'], sent, special_tokens['pad_token']])
    sent = vocab(tokenizer(sent))
    sent = torch.tensor(sent, device=device)

    with torch.no_grad():
        y_pred, states = model(sent)

        pred_tokens = y_pred.argmax(dim=-1, keepdim=True)
        # sent_preds = vocab.lookup_tokens(list(pred_tokens))

        answer = []
        for j in range(maxlen - len(sent)):
            last_idx = pred_tokens[-1]
            answer.append(vocab.lookup_token(last_idx))

            if answer[-1] == special_tokens['eos_token']:
                break

            y_pred, states = model(last_idx, states)
            pred_tokens = y_pred.argmax(dim=-1, keepdim=True)

        return detokenizer.detokenize(answer)

In [6]:
def train(max_len, epochs, bs, lr, embed_dim, n_layers, n_hidden, device):
    special_tokens = {'bos_token': '|BOS|',
                      'pad_token': '|PAD|',
                      'eos_token': '|EOS|',
                      'unk_token': '|UNK|'}

    train_dataset = load_dataset('knkarthick/dialogsum', split='train')
    val_dataset = load_dataset('knkarthick/dialogsum', split='validation')

    lm_train = LMDataset(train_dataset, None, None, max_len, special_tokens, device)
    lm_valid = LMDataset(val_dataset, lm_train.tokenizer, lm_train.vocab, max_len, special_tokens, device)

    train_loader = DataLoader(lm_train, batch_size=bs, shuffle=True)
    val_loader = DataLoader(lm_valid, batch_size=bs)

    model = LstmModel(len(lm_train.vocab), embed_dim=embed_dim, n_layers=n_layers, n_hidden=n_hidden, dropout_rate=0.1)
    model = model.to(device)

    criterion = nn.CrossEntropyLoss()
    optimizer = Adam(model.parameters(), lr=lr)

    for e in range(epochs):
        train_loss = train_step(model, train_loader, optimizer, criterion, device)
        val_loss = eval_step(model, val_loader, criterion, device)
    print(answer(model, 'Hi, how are you?', lm_train.tokenizer, lm_train.vocab, max_len, special_tokens, device))

    return val_loss

In [28]:
special_tokens = {'bos_token': '|BOS|',
                  'pad_token': '|PAD|',
                  'eos_token': '|EOS|',
                  'unk_token': '|UNK|'}

train_dataset = load_dataset('knkarthick/dialogsum', split='train')
lm_train = QADataset(train_dataset, None, None, 100, special_tokens, 'cpu')


Using custom data configuration knkarthick--dialogsum-caf2f3e75d9073aa
Found cached dataset csv (/Users/bugrahamzagundog/.cache/huggingface/datasets/knkarthick___csv/knkarthick--dialogsum-caf2f3e75d9073aa/0.0.0/6b34fb8fcf56f7c8ba51dc895bfa2bfbe43546f190a60fcf74bb5e8afdcc2317)


  0%|          | 0/12460 [00:00<?, ?ex/s]

  0%|          | 0/12460 [00:00<?, ?ex/s]

IOPub data rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_data_rate_limit`.

Current values:
ServerApp.iopub_data_rate_limit=1000000.0 (bytes/sec)
ServerApp.rate_limit_window=3.0 (secs)



TypeError: unhashable type: 'list'

In [23]:
lm_train

tensor([   0,  285,    5,  207,  565,    4,    6,   34, 1899,    3,    4,  139,
          23,    7,   78,  173,    9,    1,    6,  511,   13,   69,   25,   11,
          68,  218,   10,   53,   11,  250,   56,   85,    4,    2,    2,    2,
           2,    2,    2,    2,    2,    2,    2,    2,    2,    2,    2,    2,
           2,    2,    2,    2,    2,    2,    2,    2,    2,    2,    2,    2,
           2,    2,    2,    2,    2,    2,    2,    2,    2,    2,    2,    2,
           2,    2,    2,    2,    2,    2,    2,    2,    2,    2,    2,    2,
           2,    2,    2,    2,    2,    2,    2,    2,    2,    2,    2,    2,
           2,    2,    2,    2])

In [5]:
DEVICE = 'cuda'
MAXLEN = 50
EPOCHS = 10
BS = 64
LR = 0.01

N_EMBED = 256
N_LAYERS = 2
N_HIDDEN = 128
DROPOUT_RATE = 0.2

special_tokens = {'bos_token': '|BOS|', 
                  'eos_token': '|EOS|',
                  'pad_token': '|PAD|',
                  'unk_token': '|UNK|',
                  'mask_token': '|MASK|'}

def create_tokenized_pairs(x):
    sentences = re.split(r'[\s]*#Person\d#: ', x['dialogue'])[1:]
    
    sentences1, sentences2 = [], []
    for sent1, sent2 in zip(sentences[:-1], sentences[1:]):
        sent1 = ' '.join([special_tokens['bos_token'], sent1, special_tokens['eos_token']])
        sent2 = ' '.join([special_tokens['bos_token'], sent2, special_tokens['eos_token']])
        
        sent1 = tokenizer(sent1)[:MAXLEN]
        sent2 = tokenizer(sent2)[:MAXLEN]
        
        sent1 += [special_tokens['pad_token']] * (MAXLEN - len(sent1))
        sent2 += [special_tokens['pad_token']] * (MAXLEN - len(sent2))
        
        sentences1.append(sent1)
        sentences2.append(sent2)
        
    return {'sent1': sentences1, 'sent2': sentences2}

tokenizer = get_tokenizer('spacy')

train_dataset = load_dataset('knkarthick/dialogsum', split='train')
val_dataset = load_dataset('knkarthick/dialogsum', split='validation')

Using custom data configuration knkarthick--dialogsum-b0174fca0a26ed84
Found cached dataset csv (/home/sefa/.cache/huggingface/datasets/knkarthick___csv/knkarthick--dialogsum-b0174fca0a26ed84/0.0.0/6b34fb8fcf56f7c8ba51dc895bfa2bfbe43546f190a60fcf74bb5e8afdcc2317)
Using custom data configuration knkarthick--dialogsum-b0174fca0a26ed84
Found cached dataset csv (/home/sefa/.cache/huggingface/datasets/knkarthick___csv/knkarthick--dialogsum-b0174fca0a26ed84/0.0.0/6b34fb8fcf56f7c8ba51dc895bfa2bfbe43546f190a60fcf74bb5e8afdcc2317)


In [6]:
## vocab from iterator
tokenized_dataset = train_dataset.map(create_tokenized_pairs)

flatten_sent1 = [sent for sents in tokenized_dataset['sent1'] for sent in sents]
flatten_sent2 = [sents[-1] for sents in tokenized_dataset['sent2']]
flatten_sents = flatten_sent1 + flatten_sent2

vocab = build_vocab_from_iterator(flatten_sents, min_freq=5, specials=list(special_tokens.values()))
vocab.set_default_index(vocab[special_tokens['unk_token']])
vocab.set_default_index(vocab[special_tokens['mask_token']])
del flatten_sent1
del flatten_sent2
del flatten_sents

  0%|          | 0/12460 [00:00<?, ?ex/s]

In [7]:
class LMDataset(Dataset):
    def __init__(self, dataset):
        super(LMDataset, self).__init__()

        dataset = dataset.map(create_tokenized_pairs)
        self.flatten_sent1 = [vocab(sent) for sents in dataset['sent1'] for sent in sents]
        self.flatten_sent1 = torch.tensor(self.flatten_sent1)
        
    def __getitem__(self, i):
        return self.flatten_sent1[i]
                
    def __len__(self):
        return len(self.flatten_sent1)
    

# Masked Language Modeling with BiLSTM

In [8]:
lm_train = LMDataset(train_dataset)
lm_val = LMDataset(val_dataset)

train_loader = DataLoader(lm_train, batch_size=BS, shuffle=True)
val_loader = DataLoader(lm_val, batch_size=BS, shuffle=True)

  0%|          | 0/12460 [00:00<?, ?ex/s]

  0%|          | 0/500 [00:00<?, ?ex/s]

In [9]:
class BiLSTMEncoder(nn.Module):
    def __init__(self, vocab_size, embed_dim=64, n_layers=32, n_hidden=128):
        super(BiLSTMEncoder, self).__init__()
        
        self.embed = nn.Embedding(vocab_size, embed_dim, padding_idx=vocab[special_tokens['pad_token']])
        self.bilstm = nn.LSTM(embed_dim, n_hidden, num_layers=n_layers, 
                            dropout=DROPOUT_RATE, bidirectional=True, batch_first=True)
        
        self.dropout = nn.Dropout(DROPOUT_RATE)
        
        
    def forward(self, x):
        x = self.dropout(self.embed(x))
        x, (hx, cx) = self.bilstm(x)
        h_fold = hx[:2, :, :] + torch.flip(hx[2:, :, :], dims=[-1])
        c_fold = cx[:2, :, :] + torch.flip(cx[2:, :, :], dims=[-1])
        return x, (h_fold, c_fold)

    
class MLM(nn.Module):
    def __init__(self, vocab_size, embed_dim=64, n_layers=32, n_hidden=128):
        super(MLM, self).__init__()
        
        self.encoder = BiLSTMEncoder(vocab_size, embed_dim=embed_dim, 
                                     n_layers=n_layers, n_hidden=n_hidden)
        self.fc = nn.Linear(2*n_hidden, vocab_size)
        
    def forward(self, x):
        x, (hx, cx) = self.encoder(x)
        x = x.mean(dim=1)
        x = self.fc(x)
        
        return x


In [11]:
def random_mask(ids):
    pad_id = vocab.lookup_indices([special_tokens['pad_token']])[0]
    mask_id = vocab.lookup_indices([special_tokens['mask_token']])[0]
    
    rand_ids = torch.zeros_like(ids, dtype=bool)
    new_ids = torch.empty_like(ids).copy_(ids)
    
    for i, sent in enumerate(ids):
        max_index = sent.tolist().index(pad_id)+1 if pad_id in sent else MAXLEN
        rand_id = torch.randint(1, max_index, (1,)).item()
        rand_ids[i, rand_id] = 1
        new_ids[i, rand_id] = mask_id

    return new_ids, rand_ids


In [12]:
input_sent = ' '.join(['Hi, how', special_tokens['mask_token'],  'you?'])
input_sent = ' '.join([special_tokens['bos_token'], input_sent, special_tokens['eos_token']])
input_sent = vocab(tokenizer(input_sent))
input_sent = torch.tensor([input_sent], device=DEVICE)

lm_head = MLM(len(vocab), embed_dim=N_EMBED, n_layers=N_LAYERS, n_hidden=N_HIDDEN).to(DEVICE)

criterion = nn.CrossEntropyLoss()
optimizer = Adam(lm_head.parameters(), lr=LR)

for epoch in range(EPOCHS):
    lm_head.train()
    pbar = tqdm(train_loader)
    train_losses = []
    for batch_idx, x in enumerate(pbar):
        x = x.to(DEVICE)
        y = x.to(DEVICE)
        
        x, ids = random_mask(x)
        
        optimizer.zero_grad()        
        y_pred = lm_head(x)
        loss = criterion(y_pred, y[ids])
        loss.backward()
        torch.nn.utils.clip_grad_norm_(lm_head.parameters(), 0.9)
        optimizer.step()
        
        train_losses.append(loss.item())
        pbar.set_description(f'Loss: {np.mean(train_losses):.3f}')
    
        #wandb.log({'epochs': epoch,
        #           'learning_rate': LR,
        #           'loss': loss.item()})
    
    lm_head.eval()
    pbar = tqdm(val_loader)
    val_losses = []
    for batch_idx, x in enumerate(pbar):
        x = x.to(DEVICE)
        y = x.to(DEVICE)
        
        x, ids = random_mask(x)
        
        y_pred = lm_head(x)
        loss = criterion(y_pred, y[ids])
        
        val_losses.append(loss.item())
        pbar.set_description(f'Loss: {np.mean(val_losses):.3f}')
    
    # on epoch end
    out = lm_head(input_sent)
    tokens = vocab.lookup_tokens(out.argmax(dim=-1).tolist())
    print(tokens)


  0%|          | 0/1654 [00:00<?, ?it/s]

  0%|          | 0/66 [00:00<?, ?it/s]

['are']


  0%|          | 0/1654 [00:00<?, ?it/s]

  0%|          | 0/66 [00:00<?, ?it/s]

['are']


  0%|          | 0/1654 [00:00<?, ?it/s]

  0%|          | 0/66 [00:00<?, ?it/s]

['are']


  0%|          | 0/1654 [00:00<?, ?it/s]

  0%|          | 0/66 [00:00<?, ?it/s]

['about']


  0%|          | 0/1654 [00:00<?, ?it/s]

  0%|          | 0/66 [00:00<?, ?it/s]

['about']


  0%|          | 0/1654 [00:00<?, ?it/s]

  0%|          | 0/66 [00:00<?, ?it/s]

['are']


  0%|          | 0/1654 [00:00<?, ?it/s]

  0%|          | 0/66 [00:00<?, ?it/s]

['about']


  0%|          | 0/1654 [00:00<?, ?it/s]

  0%|          | 0/66 [00:00<?, ?it/s]

['about']


  0%|          | 0/1654 [00:00<?, ?it/s]

  0%|          | 0/66 [00:00<?, ?it/s]

['about']


  0%|          | 0/1654 [00:00<?, ?it/s]

  0%|          | 0/66 [00:00<?, ?it/s]

['about']


In [13]:
encoder = lm_head.encoder
torch.save(encoder, 'bilstm_encoder.pt')
# lm_head = torch.load('bilstm_encoder.pt')

# Dialogue Generation Head

In [14]:
class DialogueDataset(Dataset):
    def __init__(self, dataset):
        super(DialogueDataset, self).__init__()
        
        dataset = dataset.map(create_tokenized_pairs)
        self.flatten_sent1 = [vocab(sent) for sents in dataset['sent1'] for sent in sents]
        self.flatten_sent2 = [vocab(sent) for sents in dataset['sent2'] for sent in sents]
        
        self.flatten_sent1 = torch.tensor(self.flatten_sent1)
        self.flatten_sent2 = torch.tensor(self.flatten_sent2)
    
    def __getitem__(self, i):
        return self.flatten_sent1[i], self.flatten_sent2[i]
                
    def __len__(self):
        return len(self.flatten_sent1)


In [15]:
dialog_train = DialogueDataset(train_dataset)
dialog_valid = DialogueDataset(val_dataset)

train_loader = DataLoader(dialog_train, batch_size=BS, shuffle=True)
val_loader = DataLoader(dialog_valid, batch_size=BS, shuffle=True)

  0%|          | 0/12460 [00:00<?, ?ex/s]

  0%|          | 0/500 [00:00<?, ?ex/s]

In [17]:
class LSTMDecoder(nn.Module):
    def __init__(self, vocab_size, embed_dim=64, n_layers=32, n_hidden=128):
        super(LSTMDecoder, self).__init__()

        self.embed = nn.Embedding(vocab_size, embed_dim, padding_idx=vocab[special_tokens['pad_token']])
        self.lstm = nn.LSTM(embed_dim, n_hidden, num_layers=n_layers, 
                            dropout=DROPOUT_RATE, batch_first=True)
        self.fc = nn.Linear(n_hidden, vocab_size)
        self.dropout = nn.Dropout(DROPOUT_RATE)
        
    def forward(self, x, states):
        x = self.dropout(self.embed(x))
        out, states = self.lstm(x, states)
        out = self.fc(out)
        return out, states


In [20]:
## training loop
#wandb.init('lstm_tuner', project='chat2learn', config={'batch_size': BS, 
#                                                       'learning_rate': LR, 
#                                                       'epochs': EPOCHS})
input_sent = 'Hi, how are you?'
input_sent = ' '.join([special_tokens['bos_token'], input_sent, MAXLEN * special_tokens['eos_token']])
input_sent = vocab(tokenizer(input_sent)[:MAXLEN])
input_sent = torch.tensor([input_sent], device=DEVICE)

encoder = torch.load('bilstm_encoder.pt')
dialog_head = LSTMDecoder(len(vocab), embed_dim=N_EMBED, n_layers=N_LAYERS, 
                          n_hidden=N_HIDDEN).to(DEVICE)

criterion = nn.CrossEntropyLoss(ignore_index=vocab[special_tokens['pad_token']])
optimizer = Adam(dialog_head.parameters(), lr=LR)

#encoder.eval()

min_val_loss = None

for epoch in range(EPOCHS):
    pbar = tqdm(train_loader)
    train_losses = []
    for batch_idx, (x, y) in enumerate(pbar):
        encoder.eval(), dialog_head.train()
        x = x.to(DEVICE)
        y = y.to(DEVICE)

        optimizer.zero_grad()        
        _, states = encoder(x)
        states = [state.detach() for state in states]
        y_pred, states = dialog_head(y, states)
        loss = criterion(y_pred.moveaxis(1, -1), y)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(dialog_head.parameters(), 0.9)
        optimizer.step()
        
        train_losses.append(loss.item())
        pbar.set_description(f'Loss: {np.mean(train_losses):.3f}')
    
        #wandb.log({'epochs': epoch,
        #           'learning_rate': LR,
        #           'loss': loss.item()})
        
        if batch_idx % 200 == 0:
            encoder.eval(), dialog_head.eval()
            out_word = vocab(tokenizer(special_tokens['bos_token']))
            out_word = torch.tensor([out_word], device=DEVICE)
            
            _, states = encoder(input_sent)

            answer = [special_tokens['bos_token']]
            for i in range(MAXLEN):
                pred, states = dialog_head(out_word, states)
                out_word = pred.argmax(dim=-1)
                out_token = vocab.lookup_token(out_word)
                
                answer.append(out_token)
                if answer[-1] == special_tokens['eos_token']:
                    break
                    
            print('ANSWER: ', ' '.join(answer))

    val_losses = []
    encoder.eval(), dialog_head.eval()
    pbar = tqdm(val_loader)
    for x, y in pbar:
        x = x.to(DEVICE)
        y = y.to(DEVICE)
        
        _, states = encoder(x)
        y_pred, states = dialog_head(y, states=states)
        loss = criterion(y_pred.moveaxis(1, -1), y)
        val_losses.append(loss.item())
        
        pbar.set_description(f'Validation Loss: {np.mean(val_losses):.3f}')
        
        # curr_lr = scheduler.get_last_lr()[-1]
        #wandb.log({'epochs': epoch, 'learning_rate': LR, 'val_loss': loss.item()})
    
    #if min_val_loss is None or np.mean(val_losses) < min_val_loss:
    #    min_val_loss = np.mean(val_losses)
    #    torch.save(model, 'lstm_model.pt')

    # scheduler.step()

  0%|          | 0/1654 [00:00<?, ?it/s]

ANSWER:  |BOS| This 's 's . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . .
ANSWER:  |BOS| |BOS| |BOS| |BOS| so so so so so so so so so so so so so so so so so so so so so so so so so so so so so so so so so so so so so so so so so so so so so so so
ANSWER:  |BOS| |BOS| |BOS| |BOS| so so so so so so so so so so so so so so so so so so so so so so so so so so so so so so so so so so so so so so so so so so so so so so so
ANSWER:  |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS|
ANSWER:  |BOS| |BOS| |BOS| |BOS| The The The The The The The The The The The The The The The The The The The The The The The The The The The The The The The The The The The The The The The The The The The The The The 

  0%|          | 0/66 [00:00<?, ?it/s]

  0%|          | 0/1654 [00:00<?, ?it/s]

ANSWER:  |BOS| |BOS| |BOS| |BOS| so so so so so so so so so so so so so so so so so so so so so so so so so so so so so so so so so so so so so so so so so so so so so so so
ANSWER:  |BOS| |BOS| |BOS| so so so so so so so so so so so so so so so so so so so so so so so so so so so so so so so so so so so so so so so so so so so so so so so so
ANSWER:  |BOS| |BOS| |BOS| |BOS| so so so so so so so so so so so so so so so so so so so so so so so so so so so so so so so so so so so so so so so so so so so so so so so
ANSWER:  |BOS| |BOS| |BOS| |BOS| so so so so so so so so so so so so so so so so so so so so so so so so so so so so so so so so so so so so so so so so so so so so so so so
ANSWER:  |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS

  0%|          | 0/66 [00:00<?, ?it/s]

  0%|          | 0/1654 [00:00<?, ?it/s]

ANSWER:  |BOS| |BOS| |BOS| so so so so so so so so so so so so so so so so so so so so so so so so so so so so so so so so so so so so so so so so so so so so so so so so
ANSWER:  |BOS| |BOS| |BOS| at at at at at at at at at at at at at at at at at at at at at at at at at at at at at at at at at at at at at at at at at at at at at at at at
ANSWER:  |BOS| |BOS| |BOS| at at at at at at at at at at at at at at at at at at at at at at at at at at at at at at at at at at at at at at at at at at at at at at at at
ANSWER:  |BOS| |BOS| at at at at at at at at at at at at at at at at at at at at at at at at at at at at at at at at at at at at at at at at at at at at at at at at at
ANSWER:  |BOS| |BOS| |BOS| |BOS| at at at at at at at at at at at at at at at at at at at at at at at at at at at at at at at at at at at at at at at at at at at at at at at
ANSWER:  |BOS| |BOS| |BOS| |BOS| |BOS| new new new new new new new new new new new new new new new new new new new new new new new new new new ne

  0%|          | 0/66 [00:00<?, ?it/s]

  0%|          | 0/1654 [00:00<?, ?it/s]

ANSWER:  |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS|
ANSWER:  |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS|
ANSWER:  |BOS| |BOS| |BOS| one one one one one one one one one one one one one one one one one one one one one one one one one one one one one one one one one one one one one one one one one one one one one one one one
ANSWER:  |BOS| |BOS| |BOS| anything anything anything anything anything anything anything anything anything anything anything anything anything anythin

  0%|          | 0/66 [00:00<?, ?it/s]

  0%|          | 0/1654 [00:00<?, ?it/s]

ANSWER:  |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| universities universities universities universities universities universities universities universities universities universities universities universities universities universities universities universities universities universities universities universities universities universities universities universities universities universities universities universities universities universities universities universities universities universities universities universities universities universities universities universities universities universities
ANSWER:  |BOS| |BOS| |BOS| |BOS| |BOS| universities universities universities universities universities universities universities universities universities universities universities universities universities universities universities universities universities universities universities universities universities universities universities universities universities universities universities u

  0%|          | 0/66 [00:00<?, ?it/s]

  0%|          | 0/1654 [00:00<?, ?it/s]

ANSWER:  |BOS| |BOS| |BOS| |BOS| Others Others Others Others Others Others Others Others Others Others Others Others Others Others Others Others Others Others Others Others Others Others Others Others Others Others Others Others Others Others Others Others Others Others Others Others Others Others Others Others Others Others Others Others Others Others Others
ANSWER:  |BOS| |BOS| |BOS| |BOS| |BOS| booked booked booked booked booked booked booked booked booked booked booked booked booked booked booked booked booked booked booked booked booked booked booked booked booked booked booked booked booked booked booked booked booked booked booked booked booked booked booked booked booked booked booked booked booked booked
ANSWER:  |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS

  0%|          | 0/66 [00:00<?, ?it/s]

  0%|          | 0/1654 [00:00<?, ?it/s]

ANSWER:  |BOS| |BOS| |BOS| |BOS| |BOS| bells bells bells bells bells bells bells bells bells bells bells bells bells bells bells bells bells bells bells bells bells bells bells bells bells bells bells bells bells bells bells bells bells bells bells bells bells bells bells bells bells bells bells bells bells bells
ANSWER:  |BOS| |BOS| Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks
ANSWER:  |BOS| |BOS| |BOS| bells bells bells bells bells bells bells bells bells bells bells bells bells bells bells bells bells bells bells bells bells bells bells bells bells bells bells bells bells bells bells bells bells bells bells bells bells bells bells bells bells bells bells bells bells bells bells bells
ANSWER

  0%|          | 0/66 [00:00<?, ?it/s]

  0%|          | 0/1654 [00:00<?, ?it/s]

ANSWER:  |BOS| |BOS| |BOS| Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks
ANSWER:  |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS|
ANSWER:  |BOS| |BOS| |BOS| |BOS| |BOS| Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Thanks Tha

  0%|          | 0/66 [00:00<?, ?it/s]

  0%|          | 0/1654 [00:00<?, ?it/s]

ANSWER:  |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS|
ANSWER:  |BOS| |BOS| |BOS| |BOS| Brian Brian Brian Brian Brian Brian Brian Brian Brian Brian Brian Brian Brian Brian Brian Brian Brian Brian Brian Brian Brian Brian Brian Brian Brian Brian Brian Brian Brian Brian Brian Brian Brian Brian Brian Brian Brian Brian Brian Brian Brian Brian Brian Brian Brian Brian Brian
ANSWER:  |BOS| |BOS| |BOS| |BOS| Brian Brian Brian Brian Brian Brian Brian Brian Brian Brian Brian Brian Brian Brian Brian Brian Brian Brian Brian Brian Brian Brian Brian Brian Brian Brian Brian Brian Brian Brian Brian Brian Brian Brian Brian Brian Brian Brian Brian Brian Brian Brian Brian Brian Brian Brian Brian
ANSWER:  |BOS| |BOS| |BOS| , , , , , , , , , , , , , , 

  0%|          | 0/66 [00:00<?, ?it/s]

  0%|          | 0/1654 [00:00<?, ?it/s]

ANSWER:  |BOS| |BOS| |BOS| Sweden Sweden Sweden Sweden Sweden Sweden Sweden Sweden Sweden Sweden Sweden Sweden Sweden Sweden Sweden Sweden Sweden Sweden Sweden Sweden Sweden Sweden Sweden Sweden Sweden Sweden Sweden Sweden Sweden Sweden Sweden Sweden Sweden Sweden Sweden Sweden Sweden Sweden Sweden Sweden Sweden Sweden Sweden Sweden Sweden Sweden Sweden Sweden
ANSWER:  |BOS| |BOS| |BOS| |BOS| |BOS| more more more more more more more more more more more more more more more more more more more more more more more more more more more more more more more more more more more more more more more more more more more more more more
ANSWER:  |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS| |BOS|
ANSWER:  |BOS| |BOS| |BOS| |BOS| |BOS| patients patie

  0%|          | 0/66 [00:00<?, ?it/s]

In [None]:
#torch.save(dialog_head, 'lstm_decoder.pt')

In [None]:
' '.join(vocab.lookup_tokens(y[0].tolist()))

In [None]:
' '.join(vocab.lookup_tokens(y_pred.argmax(dim=-1)[0].tolist()))

In [34]:
np.mean([1.306264, 1.341686, 1.316051])

1.3213336666666666