In [96]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.nn.utils.rnn import pad_sequence

from transformers import BertModel, BertTokenizer
import numpy as np 
import math


from Data import PKLS_FILES


In [35]:
bert_tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
bert_model = BertModel.from_pretrained('bert-base-cased')


In [76]:

def get_relation_embedings(relations):
    tokenizer =bert_tokenizer
    model = bert_model
    model.eval()
    relation_embedings = []
    for label in relations:
        inputs = tokenizer(label, return_tensors="pt")
        with torch.no_grad():
            outputs = model(**inputs, output_hidden_states=True)
        
        hidden_states = outputs.hidden_states 
        
        last_layer = hidden_states[-1].squeeze(0) # (seq_len, hidden_size)
        before_last_layer = hidden_states[-2].squeeze(0)
        
        average_layers = (last_layer + before_last_layer) / 2.0
        r_j = average_layers.mean(dim=0)
        
        relation_embedings.append(r_j)
        
    return relation_embedings

In [98]:

from torch.utils.data import Dataset, DataLoader



class SentencesDS(Dataset):
    def __init__(self, descriptions, max_length=128):
        self.descriptions_dict = descriptions
        self.ids = list(descriptions.keys())
        self.sentences = list(descriptions.values())
        
        encoded = bert_tokenizer(self.sentences, padding=True, truncation=True, return_tensors="pt", max_length=max_length)
        input_ids = encoded['input_ids']
        attention_mask = encoded['attention_mask']
        
        with torch.no_grad():
            bert_output = bert_model(input_ids=input_ids, attention_mask=attention_mask)
            self.embeddings = bert_output.last_hidden_state # (batch_size, seq_len ,hidden_size)
        
        mask_expanded = attention_mask.unsqueeze(-1).expand(self.embeddings.size())

        masked_embeddings = self.embeddings * mask_expanded

        sum_embeddings = masked_embeddings.sum(dim=1)  # [batch, hidden_size]
        token_counts = mask_expanded.sum(dim=1)  # [batch, 1]
        token_counts = token_counts.clamp(min=1)

        self.hg  =  sum_embeddings / token_counts # [batch, hidden_size]
        
        
    
    def __getitem__(self,idx):
        return  {"string": self.sentences[idx], "hg": self.hg[idx], "embedding": self.embeddings[idx], "id": self.ids[idx] }

    def getitem_byid(self,desc_id):
        idx = self.ids.index(desc_id)
        return self.__getitem__(idx)
    

    def __len__(self):
        return len(self.ids) 


class RelationsDS(Dataset):
    def __init__(self, relations_dict, relations_transe_embs):
        self.relations_dict = relations_dict 
        self.ids = list(relations_dict.keys())
        self.relations_lst = list(relations_dict.values())
        embeddings_list = get_relation_embedings(self.relations_lst)
        self.embeddings = torch.stack(embeddings_list, dim=0)  #(num_relations, hidden_size)
        self.transe_embeddings = relations_transe_embs
        
        
    
    def __getitem__(self,idx):
        return {"string": self.relations_lst[idx],   "id":self.ids[idx], "embedding": self.embeddings[idx]}
     

    def getitem_byid(self,relation_id):
        idx = self.ids.index(relation_id)
        return self.__getitem__(idx)
    

    def __len__(self):
        return len(self.ids) 



Parameters:
- token_embeddings with shape (batch_size, seq_len, hidden_size)
- start_probs, end_probs are shape (batch_size, seq_len) 
- threshold (the probability that is above the threshold) -> consider it a subject start or end 

Function: 
- For each sentence in the batch, find indices having probabilities above the threshold 
- For each start index, find the nearest end index 
- Compute Sk which would be the average of the start token and the end token embeddings
- stack subject embeddings for each sentence 
- pad all sentences to the seq_len  and create corresponding mask of booleans indicating valid subjects embeddings
- apply linear transformation Ws to the padded tensor
- Return subject_embeddings containing sk and tuples containing indices (start, end) of the subjects

In [94]:
def extract_subject_embeddings(token_embeddings, start_probs, end_probs, threshold=.5 ):
    batch_size, seq_len, hidden_size=  token_embeddings.shape 

    subjs_list = [] # list (len = batch_size) of tensors (n_subs, hidden_size)
    masks_list = []
    subjects_idxs = []

    for b in range(batch_size):
        start_idxs = (start_probs[b] > threshold).nonzero(as_tuple = True)[0]
        end_idxs = (end_probs[b] > threshold).nonzero(as_tuple = True)[0]
        subject_embeddings = []
        used_ends = set()
        sub_idxs_sentence = []
        for start in start_idxs:
            valid_ends = [e.item() for e in end_idxs if e.item() >= start.item() and e.item() not in used_ends]
            if valid_ends:
                end = valid_ends[0]
                used_ends.add(end)
                s_k = (token_embeddings[b, start] + token_embeddings[b, end]) / 2.0
                subject_embeddings.append(s_k)
                sub_idxs_sentence.append((start.item(),end))
                
        if len(subject_embeddings) > 0:
            subjs_tensor = torch.stack(subject_embeddings, dim=0)# (n_subs, hidden_size)
            subjs_list.append(subjs_tensor)
            masks_list.append(torch.ones(subjs_tensor.shape[0], dtype=torch.bool, device=token_embeddings.device))
        else:
            subjs_list.append(torch.empty(0, hidden_size, device=token_embeddings.device))
            masks_list.append(torch.empty(0, dtype=torch.bool, device=token_embeddings.device))
        subjects_idxs.append(sub_idxs_sentence)
    padded_subjects = pad_sequence(subjs_list, batch_first=True, padding_value=0)         # (batch_size, seq_len)
    padded_mask = pad_sequence(masks_list, batch_first=True, padding_value=False)
    return padded_subjects, padded_mask,  subjects_idxs


In [97]:
def extract_obj_spans_relation(obj_start_probs, obj_end_probs, threshold):
  
  seq_len = obj_start_probs.shape[0]
  obj_idxs = []
  used_ends = set()
  for tk_idx in range(seq_len):
    if obj_start_probs[tk_idx] > threshold:
      for j in range(tk_idx, seq_len):
        if obj_end_probs[j] > threshold and j not in used_ends:
          obj_idxs.append((tk_idx, j))
          used_ends.add(j)
          break
  return obj_idxs


def extract_triples(token_embs, subjects_idxs, obj_start_probs, obj_end_probs, threshold=.5):
  batch_size, seq_len, num_relations = obj_start_probs.shape
  batch_triples = []
  for sentence_idx in range(batch_size):
    sentence_triples = []
    subj_spans = subjects_idxs[sentence_idx] #[(s_start, s_end), ...]

    for rel_idx in range(num_relations):
      obj_start_sentence = obj_start_probs[sentence_idx, :, rel_idx] #probabilities of this sentence with this relation of all tokens 
      obj_end_sentence = obj_end_probs[sentence_idx, :, rel_idx] # vector with size (seq_len)
      
      obj_idxs = extract_obj_spans_relation(obj_start_sentence, obj_end_sentence, threshold) 
      for subj in subj_spans:
        for obj in obj_idxs:
          sentence_triples.append((subj, rel_idx , obj))
      batch_triples.append(sentence_triples)
  return batch_triples

In [91]:
class MyBRASKModel(nn.Module):
    def __init__(self,  hidden_size=768):
        super(MyBRASKModel, self).__init__()
        self.bert = bert_model 
        
        self.start_subject_fc = nn.Linear(hidden_size, 1)
        self.end_subject_fc = nn.Linear(hidden_size, 1)
        
        self.start_object_fc = nn.Linear(hidden_size, 1)
        self.end_object_fc = nn.Linear(hidden_size, 1)
        
        
        #I will put after starT_object_fc and end_object_fc. 
        
        
        self.W_r = nn.Linear(hidden_size, hidden_size)
        self.W_g = nn.Linear(hidden_size, hidden_size)
        self.W_x = nn.Linear(hidden_size, hidden_size)
        self.V   = nn.Linear(hidden_size, 1)  
        
        
        self.W_s = nn.Linear(hidden_size, hidden_size)
        self.W_x2 = nn.Linear(hidden_size, hidden_size)
        
        
        
        self.sigmoid = nn.Sigmoid()
        
    def forward(self, descriptions_dataset, relation_dataset):

        batch_size = descriptions_dataset.embeddings.shape[0]
        token_embs = descriptions_dataset.embeddings # (batch_size, seq_len, hidden_size)
        
        
        
        print(f"batch_size: {batch_size}")
        #apply linear + sigmoid to each token 
        sub_start_probs = self.sigmoid(self.start_subject_fc(token_embs )).squeeze(-1) # (batch_size, seq_len)
        sub_end_probs = self.sigmoid(self.end_subject_fc(token_embs)).squeeze(-1) 
        
        padded_subjects, padded_mask,subjects_idxs   = extract_subject_embeddings(
            token_embs, sub_start_probs, sub_end_probs, threshold=0.5
        ) # padded_subjects (batch_size, seq_len, hidden_size)
        s_k_w = self.W_s(padded_subjects)
        s_k_w = s_k_w * padded_mask.unsqueeze(-1).float() #(batch_size, seq_len, hidden_size)


        h_g = descriptions_dataset.hg # (batch_size, hidden_size)
        relations_embeddings = relation_dataset.embeddings #(num_relations, hidden_size)
        #backward relation embeddings
        b_relations_embeddings = relation_dataset.transe_embeddings #(num_relations, 80)

                
        
        
        

        token_embs_exp = token_embs.unsqueeze(2) #(batch_size, seq_len, 1, hidden_size) 
        h_g_exp = h_g.unsqueeze(1).unsqueeze(2) #(batch_size, 1,1 , hidden_size)
        relation_exp = relations_embeddings.unsqueeze(0).unsqueeze(1)  #(1, 1, num_relations, hidden_size)
        relation_exp = relation_exp.expand(batch_size, -1, -1, -1) #(batch_size, 1, num_relations, hidden_size) #repeat through the batch

        
        #attention scores
        #tanh to add nonlinearity
        e = torch.tanh(
            self.W_r(relation_exp) + self.W_g(h_g_exp) + self.W_x(token_embs_exp)
        ) # (batch_size, seq_len, num_relations, hidden_size)
        #self.V is linear (hidden, 1) because we want a scalar value for each token (to calculate the relevance between the token and the relation)
        v_e = self.V(e).squeeze(-1)  # shape: (batch_size, seq_len, num_relations)
        
        #normalize attention score
        #on dim = 1 because on the sentences dimension because we want to distribute attention on tokens, like we want to get probability distribution for all tokens if they are relevant to the relation
        A = F.softmax(v_e, dim=1)  # (batch_size, seq_len, num_relations) 
        A_exp = A.unsqueeze(-1) # (batch_size, seq_len, num_relations, 1)
        C = torch.sum(A_exp * token_embs_exp, dim=1)  # (batch_size ,num_relations, hidden_size)
        
        W_x_xi = self.W_x2(token_embs) #(batch_size, seq_len,  hidden_size)
        h_i_k = s_k_w + W_x_xi #(batch_size, seq_len,  hidden_size)
        h_i_k_exp = h_i_k.unsqueeze(2) #(batch_size, seq_len, 1,  hidden_size)
        
        
        C_exp = C.unsqueeze(1) # (batch_size, 1 ,num_relations, hidden_size)
        # token_embs_exp (batch_size, seq_len, 1, hidden_size) 
        h_i_j = C_exp + token_embs_exp  # (batch_size, seq_len, num_relations, hidden_size) 
        
        
        H_i_j_k = h_i_j + h_i_k_exp
        
        
        #For each sentence, for each token, for each relation, what is the probability that this token is the start (or end) of the object?
        obj_start_probs = self.sigmoid(self.start_object_fc(H_i_j_k )).squeeze(-1) # (batch_size, seq_len, num_relations)
        obj_end_probs = self.sigmoid(self.end_object_fc(H_i_j_k)).squeeze(-1)  # (batch_size, seq_len, num_relations)
        
        forward_triples  = extract_triples(token_embs, subjects_idxs, obj_start_probs, obj_end_probs)
        
        
        
        

In [92]:
from Data import PKLS_FILES


k  = 10_000

descriptions = PKLS_FILES["descriptions_normalized"][k]

descriptions_dataset = SentencesDS(descriptions_f)
relations = get_relations_dict_from_descriptions(descriptions_dataset.descriptions_dict, relations_f)
relations_dataset = RelationsDS(relations)
print(relations_dataset.embeddings.shape)


model = SubjectPredictor()
start_probs = model(descriptions_dataset, relations_dataset)



torch.Size([75, 768])
batch_size: 10
h_g_exp shape: torch.Size([10, 1, 1, 768])
token_exp shape: torch.Size([10, 128, 1, 768])
relation_exp shape: torch.Size([10, 1, 75, 768])
e shape: torch.Size([10, 128, 75])
A shape: torch.Size([10, 128, 75])
C shape: torch.Size([10, 75, 768])
