In [98]:
import torch
import numpy as np
import torch.utils
import requests
from transformers import BertTokenizer, BertModel
from  data import create_dataset
torch.manual_seed(42)
np.random.seed(42)
MAX_LEN = 13

In [123]:
def encode(tokenizer, sentences):
    input_ids = []
    attention_masks = []

    for sentence in sentences:

        encoded_dict = tokenizer.encode_plus(
                            sentence,                  
                            add_special_tokens=True,   # Add '[CLS]' and '[SEP]'
                            max_length=MAX_LEN+1,             # Adjust sentence length
                            padding='max_length',    # Pad/truncate sentences
                            return_attention_mask=True,# Generate attention masks
                            return_tensors='pt',       # Return PyTorch tensors
                            truncation = False
                    )
        input_ids.append(encoded_dict['input_ids'])
        attention_masks.append(encoded_dict['attention_mask'])
        # Construct an attention mask (identifying padding/non-padding).
    input_ids = torch.cat(input_ids,dim=0)
    attention_masks = torch.cat(attention_masks,dim=0)
    
    return   input_ids,attention_masks 

def token(data, tokenizer):
    input_a = []
    input_b = []
    input_c = []
    input_d = []
    for sentence in data:
    
        a,b,c,d = sentence.split()
        input_a.append(a)
        input_b.append(b)
        input_c.append(c)
        input_d.append(d)
    input_idA, maskA = encode(tokenizer, input_a)
    input_idB, maskB = encode(tokenizer, input_b)
    input_idC, maskC = encode(tokenizer, input_c)
    input_idD, maskD = encode(tokenizer, input_d)
    
    return input_idA, maskA, input_idB, maskB, input_idC, maskC, input_idD, maskD

def embedding(model,tokens_tensor, attention_masks):
    with torch.no_grad():
        outputs = model(tokens_tensor, attention_mask=attention_masks)
    hidden_states = outputs.last_hidden_state
    #CLS contains all the information of the sentence so we use CLS token for embedding
    CLS_embedding = hidden_states[:,0,:]
    
    return CLS_embedding


def l2_dis(ab,c,d_prime, k =20):
    candidates = []
    for d in d_prime:
        temp = ab +c-d

        cos = torch.dot(temp,ab)/(torch.norm(temp,p=2)*torch.norm(ab,p=2))
        candidates.append(cos)
    
    return  torch.topk(torch.Tensor(candidates), k, sorted=True, largest=False)

def cos_sim(ab, c, d_prime, k=20):
    candidates = []
    for d in d_prime:
        temp = ab +c-d
        cos = torch.dot(temp,ab)/(torch.norm(temp,p=2)*torch.norm(ab,p=2))
        candidates.append(cos)

    return  torch.topk(torch.Tensor(candidates), k, sorted=True)
def accuracy(predictions, words,k):
    num = 0
    for row in predictions:
        if any(x == words[i] for i,x in enumerate(row[:k])):
            num+=1
    return num/len(words)


def predicition(data,model,tokenizer, k_groups):
    #Tokenize a,b,c,d and their coresponding masks
    input_idA, maskA, input_idB, maskB, input_idC, maskC, input_idD, maskD = token(data,tokenizer)
    #Add B and D together
    
    cd = torch.concat((input_idB,input_idD),0)
    #maskUnique = torch.concat((maskB,maskD), 0)

    #Find unique words as those are our candidates
    unique_words = {tuple(tensor) for tensor in cd.tolist()}
    unique_lst = [list(lst) for lst in unique_words]
    unique_tensors = torch.stack([torch.tensor(list(lst)) for lst in unique_words])
    unique_embedding = embedding(model,unique_tensors, maskD[:len(unique_lst)])
    d = torch.Tensor(input_idD).tolist()
    words = [ unique_lst.index(d[i]) for i in range(len(d))]
    #sentence_embedding = embeddings(model, input_ids[:100], torch.tensor(attention_masks[:100]))
    #find the embedding of a,b,c
    embedding_A =  embedding(model, input_idA, maskA)
    embedding_B =  embedding(model, input_idB, maskB)
    embedding_C =  embedding(model, input_idC, maskC)
    #embedding_D =  embedding(model, input_idD, maskD)

    # For each sentence find the top k most likely words
    pred_cos = []
    pred_l2 = []
    for i in range(embedding_A.shape[0]):
        AB = embedding_A-embedding_B
        top_k_cos, top_k_indicesC = cos_sim(AB[i], embedding_C[i],unique_embedding )
        top_k_l2, top_k_indicesL = l2_dis(AB[i], embedding_C[i],unique_embedding )
        #print(top_k_simil[0])
        # for i in top_k_indices:
        #     print(tokenizer.decode(unique_tensors[i], skip_special_tokens=True, clean_up_tokenization_spaces=True))
        # print(torch.dot(AB[0],embedding_C[i] - embedding_D[i])/(torch.norm(AB[0],p=2)*torch.norm(embedding_C[i] - embedding_D[i],p=2)))
        pred_cos.append(top_k_indicesC)
        pred_l2.append(top_k_indicesL)
    
    for k in k_groups:
        print(f'\nk = {k}')
        print(f'Cos Simmiliairty Accuracy : {100*accuracy(pred_cos,words,k):.4f} ')
        print(f'L2 Accuracy :  {100*accuracy(pred_l2,words,k):.4f}')

In [100]:
url = 'https://www.cs.fsu.edu/~liux/courses/deepRL/assignments/word-test.v1.txt'
url_voca = 'https://www.cs.fsu.edu/~liux/courses/deepRL/assignments/bert_vocab.txt'
batch_size = 16
num_epochs = 3
learning_rate = 5e-5

In [101]:
# Load pre-trained BERT tokenizer and model
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', url_vocab= url_voca)
model = BertModel.from_pretrained('bert-base-uncased' , output_hidden_states = True)

In [102]:
#Load Dataset
dataset = create_dataset(url)
data = dataset['city-in-state']

In [103]:
max_len = 0
i = 0
index = -1
for sen in data:
    a,b,c,d = sen.split()
    la = len(a)
    lb = len(b)
    lc = len(c)
    ld = len(d)
    size = max(la,lb,lc,ld)
    if size> max_len:
        index = i
        max_len = size
    i+=1
print(data[index])
print(max_len)

Chicago Illinois Boston Massachusetts
13


In [124]:
groups = ('capital-common-countries')
#groups = ('capital-common-countries', 'currency', 'city-in-state')

k_group= (1,2,5,10,20)
data = dataset['family']
predicition(data,model,tokenizer, k_group)


k = 1
Cos Simmiliairty Accuracy : 7.7075 
L2 Accuracy :  0.0000

k = 2
Cos Simmiliairty Accuracy : 8.3004 
L2 Accuracy :  0.0000

k = 5
Cos Simmiliairty Accuracy : 28.6561 
L2 Accuracy :  0.9881

k = 10
Cos Simmiliairty Accuracy : 45.8498 
L2 Accuracy :  20.3557

k = 20
Cos Simmiliairty Accuracy : 65.0198 
L2 Accuracy :  45.8498
