# Training a Dual Encoder

In [1]:
#Warning control
import warnings
warnings.filterwarnings('ignore')

In [2]:
import torch
import numpy as np
import pandas as pd

from transformers import AutoTokenizer

## Contrastive Loss using CrossEntropyLoss

In [3]:
#Demonstrate Contrastive loss using CEL on similarity tensor
df = pd.DataFrame(
    [
        [4.3, 1.2, 0.05, 1.07],
        [0.18, 3.2, 0.09, 0.05],
        [0.85, 0.27, 2.2, 1.03],
        [0.23, 0.57, 0.12, 5.1]
    ]
)
data = torch.tensor(df.values, dtype=torch.float32)

In [9]:
#Define contrastive loss function
def contrastive_loss(data):
    target = torch.arange(data.size(0)) #Indexes corresponding to row label
    loss = torch.nn.CrossEntropyLoss()(data,target) #Compares index of highest logit to actual
    return loss

In [10]:
torch.nn.Softmax(dim = 1)(data)

tensor([[0.9100, 0.0410, 0.0130, 0.0360],
        [0.0429, 0.8801, 0.0393, 0.0377],
        [0.1512, 0.0846, 0.5832, 0.1810],
        [0.0075, 0.0105, 0.0067, 0.9753]])

In [12]:
# Sample optimization loop
N = data.size(0)
non_diag_mask = ~torch.eye(N, N, dtype=bool)
print('None diagonal mask:')
print(non_diag_mask)

for inx in range(3):
    data = torch.tensor(df.values, dtype=torch.float32)
    data[range(N), range(N)] += inx*0.5 #Targets diagonal elemenets using interval masking
    data[non_diag_mask] -= inx*0.02 #Targets non-diagonal elements using boolean masking
    print(data)
    print(f"Loss = {contrastive_loss(data)}")


None diagonal mask:
tensor([[False,  True,  True,  True],
        [ True, False,  True,  True],
        [ True,  True, False,  True],
        [ True,  True,  True, False]])
tensor([[4.3000, 1.2000, 0.0500, 1.0700],
        [0.1800, 3.2000, 0.0900, 0.0500],
        [0.8500, 0.2700, 2.2000, 1.0300],
        [0.2300, 0.5700, 0.1200, 5.1000]])
Loss = 0.19657586514949799
tensor([[4.8000, 1.1800, 0.0300, 1.0500],
        [0.1600, 3.7000, 0.0700, 0.0300],
        [0.8300, 0.2500, 2.7000, 1.0100],
        [0.2100, 0.5500, 0.1000, 5.6000]])
Loss = 0.12602083384990692
tensor([[5.3000, 1.1600, 0.0100, 1.0300],
        [0.1400, 4.2000, 0.0500, 0.0100],
        [0.8100, 0.2300, 3.2000, 0.9900],
        [0.1900, 0.5300, 0.0800, 6.1000]])
Loss = 0.07888662070035934


## Encoder Module

In [18]:
class Encoder(torch.nn.Module):
    def __init__(self,vocab_size, embed_dim,output_embed_dim):
        super().__init__()
        self.embedding_layer = torch.nn.Embedding(vocab_size, embed_dim) #Creates lookup of embeddings from tokens
        self.encoder = torch.nn.TransformerEncoder(
            torch.nn.TransformerEncoderLayer(embed_dim, nhead=8, batch_first=True),
            num_layers=3, # Number of Transformer Encoder layers
            norm=torch.nn.LayerNorm([embed_dim]), #Normalize along the embedding dimension
            enable_nested_tensor=False
        )
        self.projection = torch.nn.Linear(embed_dim, output_embed_dim)
        
    def forward(self, tokenizer_output):
        x = self.embedding_layer(tokenizer_output['input_ids']) # Returns a sentence matrix
        x = self.encoder(x,src_key_padding_mask=tokenizer_output['attention_mask'].logical_not())
        cls_embed = x[:,0,:]
        return self.projection(cls_embed)
        

## Training Loop

In [19]:
def train(dataset, num_epochs=10):
    embed_size = 512
    output_embed_size = 128
    max_seq_len = 64
    batch_size = 32

    n_iters = len(dataset) // batch_size + 1
    
    # define the question/answer encoders
    tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
    question_encoder = Encoder(tokenizer.vocab_size, embed_size, output_embed_size)
    answer_encoder = Encoder(tokenizer.vocab_size, embed_size, output_embed_size)

    # define the dataloader, optimizer and loss function    
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)    
    optimizer = torch.optim.Adam(list(question_encoder.parameters()) + list(answer_encoder.parameters()), lr=1e-5)
    loss_fn = torch.nn.CrossEntropyLoss()

    for epoch in range(num_epochs):
        running_loss = []
        for idx, data_batch in enumerate(dataloader):

            # Tokenize the question/answer pairs (each is a batc of 32 questions and 32 answers)
            question, answer = data_batch
            question_tok = tokenizer(question, padding=True, truncation=True, return_tensors='pt', max_length=max_seq_len)
            answer_tok = tokenizer(answer, padding=True, truncation=True, return_tensors='pt', max_length=max_seq_len)
            if inx == 0 and epoch == 0:
                print(question_tok['input_ids'].shape, answer_tok['input_ids'].shape)
            
            # Compute the embeddings: the output is of dim = 32 x 128
            question_embed = question_encoder(question_tok)
            answer_embed = answer_encoder(answer_tok)
            if inx == 0 and epoch == 0:
                print(question_embed.shape, answer_embed.shape)
    
            # Compute similarity scores: a 32x32 matrix
            # row[N] reflects similarity between question[N] and answers[0...31]
            similarity_scores = question_embed @ answer_embed.T
            if inx == 0 and epoch == 0:
                print(similarity_scores.shape)
    
            # we want to maximize the values in the diagonal
            target = torch.arange(question_embed.shape[0], dtype=torch.long)
            loss = loss_fn(similarity_scores, target)
            running_loss += [loss.item()]
            if idx == n_iters-1:
                print(f"Epoch {epoch}, loss = ", np.mean(running_loss))
    
            # this is where the magic happens
            optimizer.zero_grad()    # reset optimizer so gradients are all-zero
            loss.backward()
            optimizer.step()

    return question_encoder, answer_encoder

## Dataset Module

In [20]:
class EncoderDataset(torch.utils.data.Dataset):
    def __init__(self, datapath):
        self.data = pd.read_csv(datapath, sep='\t', nrows=300)
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self,idx):
        return self.data.iloc[idx]['questions'], self.data.iloc[idx]['answers']

In [22]:
dataset = EncoderDataset('nq_sample.tsv')
dataset.data.head(5)

Unnamed: 0,questions,answers
0,who played bubba in the tv series in the heat ...,Carlos Alan Autry Jr. (also known for a period...
1,where did the 2017 tour de france start,"The 3,540 km (2,200 mi)-long race commenced wi..."
2,who is the chess champion of the world,Current world champion Magnus Carlsen won the ...
3,who scored the most hat tricks in football,Cristiano Ronaldo and Messi have scored three ...
4,what do you need to be an ontario scholar,Ontario Scholars are high school graduates in ...


## Training 

In [23]:
Q_encoder, A_encoder = train(dataset, num_epochs=5)

Epoch 0, loss =  3.810750651359558
Epoch 1, loss =  3.5239919662475585
Epoch 2, loss =  3.429384684562683
Epoch 3, loss =  3.361122727394104
Epoch 4, loss =  3.2964673757553102


In [50]:
question = 'What is the tallest mountain in the world?'
tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
question_tok = tokenizer(question, padding=True, truncation=True, return_tensors='pt', max_length=64)
question_emb = Q_encoder(question_tok)[0]
print(question_tok)
print('\n')
print(question_emb[:5])


{'input_ids': tensor([[  101,  2054,  2003,  1996, 13747,  3137,  1999,  1996,  2088,  1029,
           102]]), 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])}


tensor([-0.0242, -0.0792, -0.3418, -0.9653,  0.3829])


In [51]:
answers = [
    "What is the tallest mountain in the world?",
    "The tallest mountain in the world is Mount Everest.",
    "Who is donald duck?"
]
answer_tok = []
answer_emb = []
for answer in answers:
    tok = tokenizer(answer, padding=True, truncation=True, return_tensors='pt', max_length=64)
    answer_tok.append(tok['input_ids'])
    emb = A_encoder(tok)[0]
    answer_emb.append(emb)

print(answer_tok)
print(answer_emb[0][:5])
print(answer_emb[1][:5])
print(answer_emb[2][:5])

[tensor([[  101,  2054,  2003,  1996, 13747,  3137,  1999,  1996,  2088,  1029,
           102]]), tensor([[  101,  1996, 13747,  3137,  1999,  1996,  2088,  2003,  4057, 23914,
          1012,   102]]), tensor([[ 101, 2040, 2003, 6221, 9457, 1029,  102]])]
tensor([-0.3121,  0.4846, -0.0618, -0.9400, -0.6620])
tensor([-0.3959,  0.2701, -0.2616, -1.0928, -0.5686])
tensor([-0.8347, -0.0393, -0.3410, -0.4569, -0.3530])


In [52]:
# Similarity
question_emb @ torch.stack(answer_emb).T

tensor([-0.4880, -0.3826,  0.3147])
