# Retrieve Poetry
## Poetry Retriever using the Poly-encoder Transformer architecture (Humeau et al., 2019) for retrieval

In [1]:
# This notebook is based on :
# https://aritter.github.io/CS-7650/
# This Project was developed at the Georgia Institute of Technology by Ashutosh Baheti (ashutosh.baheti@cc.gatech.edu), 
# borrowing  from the Neural Machine Translation Project (Project 2) 
# of the UC Berkeley NLP course https://cal-cs288.github.io/sp20/

In [2]:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals

import torch
from torch.jit import script, trace
import torch.nn as nn
from torch import optim
import torch.nn.functional as F
import numpy as np
import csv
import random
import re
import os
import unicodedata
import codecs
from io import open
import itertools
import math
import pickle
import statistics
import sys
from functools import partial

from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence
import tqdm
import nltk
#from google.colab import files

In [3]:
# General util functions
def make_dir_if_not_exists(directory):
	if not os.path.exists(directory):
		logging.info("Creating new directory: {}".format(directory))
		os.makedirs(directory)

def print_list(l, K=None):
	# If K is given then only print first K
	for i, e in enumerate(l):
		if i == K:
			break
		print(e)
	print()

def remove_multiple_spaces(string):
	return re.sub(r'\s+', ' ', string).strip()

def save_in_pickle(save_object, save_file):
	with open(save_file, "wb") as pickle_out:
		pickle.dump(save_object, pickle_out)

def load_from_pickle(pickle_file):
	with open(pickle_file, "rb") as pickle_in:
		return pickle.load(pickle_in)

def save_in_txt(list_of_strings, save_file):
	with open(save_file, "w") as writer:
		for line in list_of_strings:
			line = line.strip()
			writer.write(f"{line}\n")

def load_from_txt(txt_file):
	with open(txt_file, "r") as reader:
		all_lines = list()
		for line in reader:
			line = line.strip()
			all_lines.append(line)
		return all_lines

In [4]:
import pandas as pd

print(torch.cuda.is_available())
if torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")
print("Using device:", device)

True
Using device: cuda


In [17]:
bert_model_name = 'distilbert-base-uncased' 
# Bert Imports
from transformers import DistilBertTokenizer, DistilBertModel
#bert_model = DistilBertModel.from_pretrained(bert_model_name)
tokenizer = DistilBertTokenizer.from_pretrained(bert_model_name)

## Load Data

### Poetry Database

In [5]:
data_file = '../data/with_epoque.csv'
data = pd.read_csv(data_file)
print(len(data))
print(data.head())

573
                                    author  \
0                      WILLIAM SHAKESPEARE   
1  DUCHESS OF NEWCASTLE MARGARET CAVENDISH   
2                           THOMAS BASTARD   
3                           EDMUND SPENSER   
4                        RICHARD BARNFIELD   

                                             content  \
0  Let the bird of loudest lay\nOn the sole Arabi...   
1  Sir Charles into my chamber coming in,\nWhen I...   
2  Our vice runs beyond all that old men saw,\nAn...   
3  Lo I the man, whose Muse whilome did maske,\nA...   
4  Long have I longd to see my love againe,\nStil...   

                                 poem name          age                  type  
0               The Phoenix and the Turtle  Renaissance  Mythology & Folklore  
1                 An Epilogue to the Above  Renaissance  Mythology & Folklore  
2                       Book 7, Epigram 42  Renaissance  Mythology & Folklore  
3  from The Faerie Queene: Book I, Canto I  Renaissance  Mytho

## Dataset Preparation

In [6]:
def make_data_training(df, char_max_line = 20):
    inputs = []
    context = []
    targets = []
    for i,rows in df.iterrows():
        splitted = rows['content'].split('\r\n')
        for line in splitted:
            if len(line.strip()) > 0 and len(line.split(' ')) <= char_max_line:
                inputs.append(line)
                targets.append(line)
                context.append(' '.join([str(rows['poem name'])]))
        
    return pd.DataFrame(list(zip(inputs, context, targets)),columns =['text', 'context','target'])


#Defining torch dataset class for poems
class PoemDataset(Dataset):
    def __init__(self, df):
        self.df = df

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        return self.df.iloc[idx]

In [7]:
df = make_data_training(data, char_max_line = 30)

In [8]:
pad_word = "<pad>"
bos_word = "<bos>"
eos_word = "<eos>"
unk_word = "<unk>"
sep_word = "sep"

pad_id = 0
bos_id = 1
eos_id = 2
unk_id = 3
sep_id = 4
    
def normalize_sentence(s):
    s = re.sub(r"([.!?])", r" \1", s)
    s = re.sub(r"[^a-zA-Z.!?]+", r" ", s)
    s = re.sub(r"\s+", r" ", s).strip()
    return s

class Vocabulary:
    def __init__(self):
        self.word_to_id = {pad_word: pad_id, bos_word: bos_id, eos_word:eos_id, unk_word: unk_id, sep_word: sep_id}
        self.word_count = {}
        self.id_to_word = {pad_id: pad_word, bos_id: bos_word, eos_id: eos_word, unk_id: unk_word, sep_id: sep_word}
        self.num_words = 5
    
    def get_ids_from_sentence(self, sentence):
        sentence = normalize_sentence(sentence)
        sent_ids = [bos_id] + [self.word_to_id[word.lower()] if word.lower() in self.word_to_id \
                               else unk_id for word in sentence.split()] + \
                               [eos_id]
        return sent_ids
    
    def tokenized_sentence(self, sentence):
        sent_ids = self.get_ids_from_sentence(sentence)
        return [self.id_to_word[word_id] for word_id in sent_ids]

    def decode_sentence_from_ids(self, sent_ids):
        words = list()
        for i, word_id in enumerate(sent_ids):
            if word_id in [bos_id, eos_id, pad_id]:
                # Skip these words
                continue
            else:
                words.append(self.id_to_word[word_id])
        return ' '.join(words)

    def add_words_from_sentence(self, sentence):
        sentence = normalize_sentence(sentence)
        for word in sentence.split():
            if word not in self.word_to_id:
                # add this word to the vocabulary
                self.word_to_id[word] = self.num_words
                self.id_to_word[self.num_words] = word
                self.word_count[word] = 1
                self.num_words += 1
            else:
                # update the word count
                self.word_count[word] += 1

vocab = Vocabulary()
for src in df['text']:
    vocab.add_words_from_sentence(src.lower())

print(f"Total words in the vocabulary = {vocab.num_words}")

Total words in the vocabulary = 319


In [9]:
class Poem_dataset(Dataset):
    """Single-Turn version of Cornell Movie Dialog Cropus dataset."""

    def __init__(self, poems, context,vocab, device):
        """
        Args:
            conversations: list of tuple (src_string, tgt_string) 
                         - src_string: String of the source sentence
                         - tgt_string: String of the target sentence
            vocab: Vocabulary object that contains the mapping of 
                    words to indices
            device: cpu or cuda
        """
        l = []
        
        for i in range(len(poems)):
            l.append( ( context[i] + ' sep ' + poems[i] , poems[i] ))
        
        self.conversations = l.copy()
        self.vocab = vocab
        self.device = device

        def encode(src, tgt):
            src_ids = self.vocab.get_ids_from_sentence(src)
            tgt_ids = self.vocab.get_ids_from_sentence(tgt)
            return (src_ids, tgt_ids)

        # We will pre-tokenize the conversations and save in id lists for later use
        self.tokenized_conversations = [encode(src, tgt) for src, tgt in self.conversations]
        
    def __len__(self):
        return len(self.conversations)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        return {"conv_ids":self.tokenized_conversations[idx], "conv":self.conversations[idx]}

def collate_fn(data):
    """Creates mini-batch tensors from the list of tuples (src_seq, tgt_seq).
    We should build a custom collate_fn rather than using default collate_fn,
    because merging sequences (including padding) is not supported in default.
    Seqeuences are padded to the maximum length of mini-batch sequences (dynamic padding).
    Args:
        data: list of dicts {"conv_ids":(src_ids, tgt_ids), "conv":(src_str, trg_str)}.
            - src_ids: list of src piece ids; variable length.
            - tgt_ids: list of tgt piece ids; variable length.
            - src_str: String of src
            - tgt_str: String of tgt
    Returns: dict { "conv_ids":     (src_ids, tgt_ids), 
                    "conv":         (src_str, tgt_str), 
                    "conv_tensors": (src_seqs, tgt_seqs)}
            src_seqs: torch tensor of shape (src_padded_length, batch_size).
            tgt_seqs: torch tensor of shape (tgt_padded_length, batch_size).
            src_padded_length = length of the longest src sequence from src_ids
            tgt_padded_length = length of the longest tgt sequence from tgt_ids
    """
    # Sort conv_ids based on decreasing order of the src_lengths.
    # This is required for efficient GPU computations.
    src_ids = [torch.LongTensor(e["conv_ids"][0]) for e in data]
    tgt_ids = [torch.LongTensor(e["conv_ids"][1]) for e in data]
    src_str = [e["conv"][0] for e in data]
    tgt_str = [e["conv"][1] for e in data]
    data = list(zip(src_ids, tgt_ids, src_str, tgt_str))
    data.sort(key=lambda x: len(x[0]), reverse=True)
    src_ids, tgt_ids, src_str, tgt_str = zip(*data)


    # Pad the src_ids and tgt_ids using token pad_id to create src_seqs and tgt_seqs
    
    # Implementation tip: You can use the nn.utils.rnn.pad_sequence utility
    # function to combine a list of variable-length sequences with padding.
    
    # YOUR CODE HERE
    src_seqs = nn.utils.rnn.pad_sequence(src_ids, padding_value = pad_id,
                                         batch_first = False)
    tgt_seqs = nn.utils.rnn.pad_sequence(tgt_ids, padding_value = pad_id, 
                                         batch_first = False)
    
    src_padded_length = len(src_seqs[0])
    tgt_padded_length = len(tgt_seqs[0])
    return {"conv_ids":(src_ids, tgt_ids), "conv":(src_str, tgt_str), "conv_tensors":(src_seqs.to(device), tgt_seqs.to(device))}

In [10]:
# Create the DataLoader for all_conversations

all_poems = df['text'].tolist()
context = df['context'].tolist()

dataset = Poem_dataset(all_poems, context, vocab, device)

batch_size = 5

data_loader = DataLoader(dataset=dataset, batch_size=batch_size, 
                               shuffle=True, collate_fn=collate_fn)

In [11]:

for src, tgt in dataset.conversations[:3]:
    sentence = src
    word_tokens = vocab.tokenized_sentence(sentence)
    # Automatically adds bos_id and eos_id before and after sentence ids respectively
    word_ids = vocab.get_ids_from_sentence(sentence)
    print(sentence)
    print(word_tokens)
    print(word_ids)
    print(vocab.decode_sentence_from_ids(word_ids))
    print()

word = "the"
word_id = vocab.word_to_id[word.lower()]
print(f"Word = {word}")
print(f"Word ID = {word_id}")
print(f"Word decoded from ID = {vocab.decode_sentence_from_ids([word_id])}")

Written in her French Psalter sep No crooked leg, no bleared eye,
No part deformed out of kind,
Nor yet so ugly half can be
As is the inward suspicious mind.
['<bos>', '<unk>', 'in', 'her', '<unk>', '<unk>', 'sep', 'no', 'crooked', 'leg', 'no', 'bleared', 'eye', 'no', 'part', 'deformed', 'out', 'of', 'kind', 'nor', 'yet', 'so', 'ugly', 'half', 'can', 'be', 'as', 'is', 'the', 'inward', 'suspicious', 'mind', '.', '<eos>']
[1, 3, 89, 232, 3, 3, 4, 5, 6, 7, 5, 8, 9, 5, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 2]
<unk> in her <unk> <unk> sep no crooked leg no bleared eye no part deformed out of kind nor yet so ugly half can be as is the inward suspicious mind .

Song of the Witches: Double, double toil and trouble sep Notes:
Macbeth: IV.i 10-19; 35-38
['<bos>', 'song', 'of', 'the', '<unk>', '<unk>', '<unk>', '<unk>', 'and', '<unk>', 'sep', 'notes', 'macbeth', 'iv', '.i', '<eos>']
[1, 281, 13, 24, 3, 3, 3, 3, 45, 3, 4, 29, 30, 31, 32, 2]
song of the <unk> <

In [12]:
# Test one batch of training data
first_batch = next(iter(data_loader))
print(f"Testing first training batch of size {len(first_batch['conv'][0])}")
print(f"List of source strings:")
print_list(first_batch["conv"][0])
print(f"Tokenized source ids:")
print_list(first_batch["conv_ids"][0])
print(f"Padded source ids as tensor (shape {first_batch['conv_tensors'][0].size()}):")
print(first_batch["conv_tensors"][0])

Testing first training batch of size 5
List of source strings:
The Poem that Took the Place of a Mountain sep Wallace Stevens, "The Poem that Took the Place of a Mountain" from The Collected Poems. Copyright  1954 by Wallace Stevens.  Reprinted by permission of Random House, Inc.
Ars Poetica sep Archibald MacLeish, Ars Poetica from Collected Poems 1917-1982. Copyright  1985 by The Estate of Archibald MacLeish. Reprinted with the permission of Houghton Mifflin Company. All rights reserved.
Written in her French Psalter sep No crooked leg, no bleared eye,
No part deformed out of kind,
Nor yet so ugly half can be
As is the inward suspicious mind.
The Eemis Stane sep Hugh MacDiarmid, The Eemis Stane from Selected Poetry. Copyright  1992 by Alan Riach and Michael Grieve. Reprinted with the permission of New Directions Publishing Corporation.
To a Dead Lover sep Originally published in Poetry, August 1922.

Tokenized source ids:
tensor([  1,  24, 109, 110, 111,  24, 112,  13,  90, 113,   4, 

In [15]:
def transformer_collate_fn(batch, tokenizer):
    bert_vocab = tokenizer.get_vocab()
    bert_pad_token = bert_vocab['[PAD]']
    bert_unk_token = bert_vocab['[UNK]']
    bert_cls_token = bert_vocab['[CLS]']
    inputs, masks_input, outputs, masks_output = [], [], [], []

    sentences, masks_sentences, targets, masks_targets = [], [], [], []
    for data in batch:

        tokenizer_output = tokenizer([data['text']])
        tokenized_sent = tokenizer_output['input_ids'][0]
        
        tokenizer_target = tokenizer([data['target']])
        tokenized_sent_target = tokenizer_target['input_ids'][0]
        
        mask_sentence = tokenizer_output['attention_mask'][0]
        mask_target = tokenizer_target['attention_mask'][0]
        sentences.append(torch.tensor(tokenized_sent))
        targets.append(torch.tensor(tokenized_sent_target))
        masks_targets.append(torch.tensor(mask_targets))
        masks_sentences.append(torch.tensor(mask_sentences))
    sentences = pad_sequence(sentences, batch_first=True, padding_value=bert_pad_token)
    targets = pad_sequence(targets, batch_first=True, padding_value=bert_pad_token)
    masks = pad_sequence(masks, batch_first=True, padding_value=0.0)
    return sentences, targets, masks

In [18]:
#create pytorch dataloaders from train_dataset, val_dataset, and test_datset
batch_size=5
train_dataloader = DataLoader(dataset,batch_size=batch_size,collate_fn=partial(transformer_collate_fn, tokenizer=tokenizer), shuffle = True)

In [19]:
#tokenizer.batch_decode(transformer_collate_fn(train_dataset,tokenizer)[0], skip_special_tokens=True)

## Polyencoder Model

In [20]:
#torch.cuda.empty_cache()
#bert1 = DistilBertModel.from_pretrained(bert_model_name)
#bert2 = DistilBertModel.from_pretrained(bert_model_name)

bert = DistilBertModel.from_pretrained(bert_model_name)

Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertModel: ['vocab_projector.weight', 'vocab_transform.bias', 'vocab_projector.bias', 'vocab_transform.weight', 'vocab_layer_norm.weight', 'vocab_layer_norm.bias']
- This IS expected if you are initializing DistilBertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DistilBertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [21]:
#Double Bert
class RetrieverPolyencoder(nn.Module):
    def __init__(self, contextBert, candidateBert, vocab, max_len = 300, hidden_dim = 768, out_dim = 64, num_layers = 2, dropout=0.1, device=device):
        super().__init__()

        self.device = device
        self.hidden_dim = hidden_dim
        self.max_len = max_len
        self.out_dim = out_dim
        
        # Context layers
        self.contextBert = contextBert
        self.contextDropout = nn.Dropout(dropout)
        self.contextFc = nn.Linear(self.hidden_dim, self.out_dim)
        
        # Candidates layers
        self.candidatesBert = candidateBert
        self.pos_emb = nn.Embedding(self.max_len, self.hidden_dim)
        self.candidatesDropout = nn.Dropout(dropout)
        self.candidatesFc = nn.Linear(self.hidden_dim, self.out_dim)
        
        self.att_dropout = nn.Dropout(dropout)


    def attention(self, q, k, v, vMask=None):
        w = torch.matmul(q, k.transpose(-1, -2))
        if vMask is not None:
            w *= vMask.unsqueeze(1)
            w = F.softmax(w, -1)
        w = self.att_dropout(w)
        score = torch.matmul(w, v)
        return score

    def score(self, context, context_mask, responses, responses_mask):
        """Run the model on the source and compute the loss on the target.

        Args:
            source: An integer tensor with shape (max_source_sequence_length,
                batch_size) containing subword indices for the source sentences.
            target: An integer tensor with shape (max_target_sequence_length,
                batch_size) containing subword indices for the target sentences.

        Returns:
            A scalar float tensor representing cross-entropy loss on the current batch
            divided by the number of target tokens in the batch.
            Many of the target tokens will be pad tokens. You should mask the loss 
            from these tokens using appropriate mask on the target tokens loss.
        """
        batch_size, nb_cand, seq_len = responses.shape
        # Context
        context_encoded = self.contextBert(context,context_mask)[-1]
        pos_emb = self.pos_emb(torch.arange(self.max_len).to(self.device))
        context_att = self.attention(pos_emb, context_encoded, context_encoded, context_mask)

        # Response
        responses_encoded = self.candidatesBert(responses.view(-1,responses.shape[2]), responses_mask.view(-1,responses.shape[2]))[-1][:,0,:]
        responses_encoded = responses_encoded.view(batch_size,nb_cand,-1)
        
        context_emb = self.attention(responses_encoded, context_att, context_att).squeeze() 
        dot_product = (context_emb*responses_encoded).sum(-1)
        
        return dot_product

    
    def compute_loss(self, context, context_mask, response, response_mask):
        """Run the model on the source and compute the loss on the target.

        Args:
            source: An integer tensor with shape (max_source_sequence_length,
                batch_size) containing subword indices for the source sentences.
            target: An integer tensor with shape (max_target_sequence_length,
                batch_size) containing subword indices for the target sentences.

        Returns:
            A scalar float tensor representing cross-entropy loss on the current batch
            divided by the number of target tokens in the batch.
            Many of the target tokens will be pad tokens. You should mask the loss 
            from these tokens using appropriate mask on the target tokens loss.
        """
        batch_size = context.shape[0]
        
        # Context
        context_encoded = self.contextBert(context,context_mask)[-1]
        pos_emb = self.pos_emb(torch.arange(self.max_len).to(self.device))
        context_att = self.attention(pos_emb, context_encoded, context_encoded, context_mask)

        # Response
        response_encoded = self.candidatesBert(response, response_mask)[-1][:,0,:]
        
        response_encoded = response_encoded.unsqueeze(0).expand(batch_size, batch_size, response_encoded.shape[1]) 
        context_emb = self.attention(response_encoded, context_att, context_att).squeeze() 
        dot_product = (context_emb*response_encoded).sum(-1)
        mask = torch.eye(batch_size).to(self.device)
        loss = F.log_softmax(dot_product, dim=-1) * mask
        loss = (-loss.sum(dim=1)).mean()
        return loss

In [22]:
#Single Bert
class RetrieverPolyencoder_single(nn.Module):
    def __init__(self, bert, max_len = 300, hidden_dim = 768, out_dim = 64, num_layers = 2, dropout=0.1, device=device):
        super().__init__()

        self.device = device
        self.hidden_dim = hidden_dim
        self.max_len = max_len
        self.out_dim = out_dim
        self.bert = bert
        
        # Context layers
        self.contextDropout = nn.Dropout(dropout)
        
        # Candidates layers
        self.pos_emb = nn.Embedding(self.max_len, self.hidden_dim)
        self.candidatesDropout = nn.Dropout(dropout)
        
        self.att_dropout = nn.Dropout(dropout)


    def attention(self, q, k, v, vMask=None):
        w = torch.matmul(q, k.transpose(-1, -2))
        if vMask is not None:
            w *= vMask.unsqueeze(1)
            w = F.softmax(w, -1)
        w = self.att_dropout(w)
        score = torch.matmul(w, v)
        return score

    def score(self, context, context_mask, responses, responses_mask):
        """Run the model on the source and compute the loss on the target.

        Args:
            source: An integer tensor with shape (max_source_sequence_length,
                batch_size) containing subword indices for the source sentences.
            target: An integer tensor with shape (max_target_sequence_length,
                batch_size) containing subword indices for the target sentences.

        Returns:
            A scalar float tensor representing cross-entropy loss on the current batch
            divided by the number of target tokens in the batch.
            Many of the target tokens will be pad tokens. You should mask the loss 
            from these tokens using appropriate mask on the target tokens loss.
        """
        batch_size, nb_cand, seq_len = responses.shape
        # Context
        context_encoded = self.bert(context,context_mask)[0][:,0,:]
        pos_emb = self.pos_emb(torch.arange(self.max_len).to(self.device))
        context_att = self.attention(pos_emb, context_encoded, context_encoded, context_mask)

        # Response
        responses_encoded = self.bert(responses.view(-1,responses.shape[2]), responses_mask.view(-1,responses.shape[2]))[0][:,0,:]
        responses_encoded = responses_encoded.view(batch_size,nb_cand,-1)
        response_encoded = self.candidatesFc(response_encoded)
        
        context_emb = self.attention(responses_encoded, context_att, context_att).squeeze() 
        dot_product = (context_emb*responses_encoded).sum(-1)
        
        return dot_product

    
    def compute_loss(self, context, context_mask, response, response_mask):
        """Run the model on the source and compute the loss on the target.

        Args:
            source: An integer tensor with shape (max_source_sequence_length,
                batch_size) containing subword indices for the source sentences.
            target: An integer tensor with shape (max_target_sequence_length,
                batch_size) containing subword indices for the target sentences.

        Returns:
            A scalar float tensor representing cross-entropy loss on the current batch
            divided by the number of target tokens in the batch.
            Many of the target tokens will be pad tokens. You should mask the loss 
            from these tokens using appropriate mask on the target tokens loss.
        """
        batch_size = context.shape[0]
        seq_len = response.shape[1]
        
        # Context
        context_encoded = self.bert(context,context_mask)[0][:,0,:]
        pos_emb = self.pos_emb(torch.arange(self.max_len).to(self.device))
        context_att = self.attention(pos_emb, context_encoded, context_encoded, context_mask)

        # Response
        print(response.shape)
        response_encoded = self.bert(response, response_mask)[0][:,0,:]
        print(response_encoded.shape)
        response_encoded = response_encoded.view(batch_size, -1)

        
        response_encoded = response_encoded.unsqueeze(0).expand(batch_size, batch_size, response_encoded.shape[1]) 
        context_emb = self.attention(response_encoded, context_att, context_att).squeeze() 
        dot_product = (context_emb*response_encoded).sum(-1)
        mask = torch.eye(batch_size).to(self.device)
        loss = F.log_softmax(dot_product, dim=-1) * mask
        loss = (-loss.sum(dim=1)).mean()
        return loss

In [23]:
#Bi-encoder
class RetrieverBiencoder(nn.Module):
    def __init__(self, bert):
        super().__init__()
        self.bert = bert
        
    def score(self, context, context_mask, responses, responses_mask):

        context_vec = self.bert(context, context_mask)[0][:,0,:]  # [bs,dim]

        batch_size, res_length = response.shape

        responses_vec = self.bert(responses_input_ids, responses_input_masks)[0][:,0,:]  # [bs,dim]
        responses_vec = responses_vec.view(batch_size, 1, -1)

        responses_vec = responses_vec.squeeze(1)        
        context_vec = context_vec.unsqueeze(1)
        dot_product = torch.matmul(context_vec, responses_vec.permute(0, 2, 1)).squeeze()
        return dot_product
    
    def compute_loss(self, context, context_mask, response, response_mask):

        context_vec = self.bert(context, context_mask)[0]  # [bs,dim]

        batch_size, res_length = response.shape

        responses_vec = self.bert(response, response_mask)[0][:,0,:]  # [bs,dim]
        #responses_vec = responses_vec.view(batch_size, 1, -1)
        
        print(context_vec.shape)
        print(responses_vec.shape)

        responses_vec = responses_vec.squeeze(1)
        dot_product = torch.matmul(context_vec, responses_vec.t())  # [bs, bs]
        mask = torch.eye(context.size(0)).to(context_mask.device)
        loss = F.log_softmax(dot_product, dim=-1) * mask
        loss = (-loss.sum(dim=1)).mean()
        return loss


In [24]:
def train(model, data_loader, num_epochs, model_file, learning_rate=0.0001):
    """Train the model for given µnumber of epochs and save the trained model in 
    the final model_file.
    """

    decoder_learning_ratio = 5.0
    #encoder_parameter_names = ['word_embedding', 'encoder']
    encoder_parameter_names = ['encode_emb', 'encode_gru', 'l1', 'l2']
                           
    encoder_named_params = list(filter(lambda kv: any(key in kv[0] for key in encoder_parameter_names), model.named_parameters()))
    decoder_named_params = list(filter(lambda kv: not any(key in kv[0] for key in encoder_parameter_names), model.named_parameters()))
    encoder_params = [e[1] for e in encoder_named_params]
    decoder_params = [e[1] for e in decoder_named_params]
    optimizer = torch.optim.AdamW([{'params': encoder_params},
                {'params': decoder_params, 'lr': learning_rate * decoder_learning_ratio}], lr=learning_rate)
    
    clip = 50.0
    for epoch in tqdm.notebook.trange(num_epochs, desc="training", unit="epoch"):
        # print(f"Total training instances = {len(train_dataset)}")
        # print(f"train_data_loader = {len(train_data_loader)} {1180 > len(train_data_loader)/20}")
        with tqdm.notebook.tqdm(
                data_loader,
                desc="epoch {}".format(epoch + 1),
                unit="batch",
                total=len(data_loader)) as batch_iterator:
            model.train()
            total_loss = 0.0
            for i, batch_data in enumerate(batch_iterator, start=1):
                source, mask_source, target, mask_target = batch_data["conv_tensors"]
                optimizer.zero_grad()
                loss = model.compute_loss(source, mask_source, target, mask_target)
                total_loss += loss.item()
                loss.backward()
                # Gradient clipping before taking the step
                _ = nn.utils.clip_grad_norm_(model.parameters(), clip)
                optimizer.step()

                batch_iterator.set_postfix(mean_loss=total_loss / i, current_loss=loss.item())
    # Save the model after training         
    torch.save(model.state_dict(), model_file)

In [25]:
# You are welcome to adjust these parameters based on your model implementation.
num_epochs = 10
batch_size = 32
learning_rate = 0.001
# Reloading the data_loader to increase batch_size

baseline_model = RetrieverBiencoder(bert).to(device)
train(baseline_model, train_dataloader, num_epochs, "baseline_model.pt",learning_rate=learning_rate)
# Download the trained model to local for future use
#files.download('baseline_model.pt')

training:   0%|          | 0/10 [00:00<?, ?epoch/s]

epoch 1:   0%|          | 0/15 [00:00<?, ?batch/s]

KeyError: 'text'

In [None]:
baseline_model = RetrieverPolyencoder(bert1,bert2,vocab).to(device)
baseline_model.load_state_dict(torch.load("baseline_model3.pt", map_location=device))

In [None]:
vals = transformer_collate_fn(all_conversations[0:100],tokenizer)

In [None]:
i=3

In [None]:
scores = baseline_model.score(vals[0][i].unsqueeze(0).cuda(),vals[1][i].unsqueeze(0).cuda(),vals[2].unsqueeze(0).cuda(),vals[3].unsqueeze(0).cuda()).detach().cpu().numpy()

In [None]:
all_conversations[i][0]

In [None]:
all_conversations[np.argmax(scores)][1]

In [None]:
max_v = 100
vals = transformer_collate_fn(all_conversations[0:max_v],tokenizer)
correct = 0
for i in range(max_v):
    scores = baseline_model.score(vals[0][i].unsqueeze(0).cuda(),vals[1][i].unsqueeze(0).cuda(),vals[2].unsqueeze(0).cuda(),vals[3].unsqueeze(0).cuda()).detach().cpu().numpy()
    if np.argmax(scores)==i:
        correct+=1
    print(all_conversations[i][0])
    print(all_conversations[np.argmax(scores)][1]+"\n")

In [None]:
print(correct/max_v)