#### Parse brat file to memory

In [None]:
#is cuda availabe
import torch
print(torch.cuda.is_available())

True


In [2]:
import os
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" 

In [3]:
from brat_parser import get_entities_relations_attributes_groups

def parseBratFile(file):
    entities, relations, attributes, groups = get_entities_relations_attributes_groups(file)
    #filter movelink relations
    relations = {id: rel for id, rel in relations.items() if "MOVELINK" not in rel.type}
    return entities, relations, attributes, groups
# The praser returns dataclasses.


def filter_annotations(annotations, type, match=False, exceptions=[]):
    if match:
        return {id: ann for id, ann in annotations.items() if type in ann.type and ann.type not in exceptions}
    else:
        return {id: ann for id, ann in annotations.items() if ann.type == type and ann.type not in exceptions}
    

In [5]:
import os
import re

path = "./lusa_news"


test_set_ids = ['lusa_97',
 'lusa_4',
 'lusa_67',
 'lusa_20',
 'lusa_83',
 'lusa_104',
 'lusa_80',
 'lusa_79',
 'lusa_34',
 'lusa_47',
 'lusa_30',
 'lusa_96',
 'lusa_11',
 'lusa_112',
 'lusa_100',
 'lusa_77',
 'lusa_38',
 'lusa_86',
 'lusa_60']


val_set_ids = ['lusa_12',
 'lusa_48',
 'lusa_45',
 'lusa_115',
 'lusa_44',
 'lusa_82',
 'lusa_111',
 'lusa_63',
 'lusa_8',
 'lusa_42',
 'lusa_50',
 'lusa_76',
 'lusa_114',
 'lusa_99',
 'lusa_18']

#test_set_ids = ['lusa_4']

files_train = []
files_test = []
files_val = []
#exceptions = ["lusa_52.txt", "lusa_88.txt", "lusa_22.txt"] # n têm SRLinks | lusa_113 flaggeed, no ha tlinks entre eventos e Times? 
for file in os.listdir(path):
    if file.endswith(".txt") and file.startswith("lusa") and file: #not in exceptions:
        file = re.match(r'(.*)\.txt',file)
        if file[1] in test_set_ids:
            files_test.append(os.path.join(path, file[1]))
        elif file[1] in val_set_ids:
            files_val.append(os.path.join(path, file[1]))
        else:
            files_train.append(os.path.join(path, file[1]))
        
len(files_train), len(files_test), len(files_val)


(83, 19, 15)

In [7]:
import spacy 
import torch
nlp = spacy.load('pt_core_news_lg')


In [9]:

def filter_relation_by_entity(relations, entities, link_type_prefix, entity_types={"Event", "Time"}):
    filtered = {}
    discarded = {}
    for rid, rel in relations.items():
        # Only TLINK relations
        if not rel.type.startswith(link_type_prefix):
            continue

        subj_ent = entities.get(rel.subj)
        obj_ent = entities.get(rel.obj)

        # Skip if entities are missing
        if subj_ent is None or obj_ent is None:
            continue

        # Event–Time or Time–Event
        types = {subj_ent.type, obj_ent.type}
        if types == entity_types:
            rel.type = "event-" + rel.type
            filtered[rid] = rel
            
        else:
            discarded[rid] = rel

    return filtered, discarded

def parse_brat(files):
    entities_list = []
    relations_list = []
    texts_list = []
    for file in files:
        #file = "../../Datasets/lusa_news_final/lusa_1"
        entities, relations, attributes, _ = parseBratFile(file + ".ann")
        #tlinks_event, tlinks_other = filter_relation_by_entity(relations, entities, "TLINK_", entity_types={"Event", "Time"})
        events = filter_annotations(entities, "Event")
        times = filter_annotations(entities, "Time")
        spatial_relations = filter_annotations(entities, "Spatial_Relation")
        participants = filter_annotations(entities, "Participant")
        #entities_ = events | participants | times | spatial_relations
        entities_ = events | participants | times | spatial_relations
        tlinks =  filter_annotations(relations, "TLINK", match=True)
        qslinks =  filter_annotations(relations, "QSLINK", match=True)
        olink =  filter_annotations(relations, "OLINK", match=True)#, exceptions=["OLINK_objIdentity"])
        mlink =  filter_annotations(relations, "MOVELINK", True)
        srl_links = filter_annotations(relations, "SRL", match=True)
        #relations_ = tlinks | qslinks | olink | mlink | srl_links
        relations_ = srl_links |  qslinks | olink | tlinks #tlinks_event | tlinks_other #tlinks #
        text = open(file + ".txt").read()
        entities_list.append(entities_)
        relations_list.append(relations_)
        texts_list.append(text)
    return entities_list, relations_list, texts_list




#### Generate Node Feature Vectors

In [None]:
from transformers import BertTokenizerFast
import tokenizations



entity_labels = ["O", "Event", "Participant", "Time", "Spatial_Relation"] #,"Time","Spatial_Relation"]
iob_labels = ["O", "B-Event", "I-Event", "B-Participant", "I-Participant", "B-Time", "I-Time", "B-Spatial_Relation", "I-Spatial_Relation"]
label2id = {label: i  for i, label in enumerate(entity_labels)}
iob2id = {label: i  for i, label in enumerate(iob_labels)}
id2label = {i: label for label, i in label2id.items()}
id2iob = {i: label for label, i in iob2id.items()}
universal_pos_tags = ["ADJ", "ADP", "ADV", "AUX", "CONJ", "CCONJ", "DET",
                      "INTJ", "NOUN", "NUM", "PART", "PRON", "PROPN",
                      "PUNCT", "SCONJ", "SYM", "VERB", "X", "SPACE"]


#One Hot Encoding para labels de entidades / dep_parser labels
def one_hot_encode(index, size):
    one_hot = torch.zeros(size, dtype=torch.float)
    one_hot[index] = 1.0
    return one_hot

def get_spacy_token_indices(span, doc, idx):
        start, end = span[0], span[1]
        #print(start, end)
        
        token_indices = []
        for token in doc:
            #print(token)
            if token.idx >= start and token.idx + len(token.text) <= end: # rever...
                token_indices.append(token.i)
            elif token.idx >= start and token.idx <= end <= token.idx + len(token.text):
                token_indices.append(token.i)
            elif token.idx <= start  <= token.idx + len(token.text) and token.idx + len(token.text) <= end:
                token_indices.append(token.i)
        if token_indices == []:
            
            print("Error when mapping spacy tokens to entities spans")
            print(span)
            print(doc.text[start:end])
            print(doc.text[start- 10:end+10])
            print(doc)
            print("Entity ID:", idx)
            for token in doc:
                print(token.idx, len(token.text))
                print(token.text)
        return token_indices

# Calcular a média dos embeddings dos tokens de uma entidade
def aggregate_embeddings(indices, embeddings):
    if indices:
        
        return torch.mean(embeddings[indices], dim=0)
    else:
        print("error calculating entity embeddings")
        #print(indices)
        #print(entity_indices)
        return torch.zeros(embeddings.size(1))



def entity_indices_to_iob(n_tokens, entity_indices):
    iob_tags = [iob2id['O']] * n_tokens  # default all outside

    for token_indices, ent_type_idx, _ in entity_indices:
        ent_type = id2label[ent_type_idx]
        if ent_type != "O":
            for i, token_idx in enumerate(token_indices):
                prefix = 'B' if i == 0 else 'I'
                iob = f"{prefix}-{ent_type}"
                iob_tags[token_idx] = iob2id[iob]
    return iob_tags





def generateEntityLabels(doc, entity_indices):
    n_tokens = len(doc)
    labels = entity_indices_to_iob(n_tokens, entity_indices)
    # generate one-hot encoding for entity labels
    entity_type_one_hot = torch.stack([one_hot_encode(label, len(iob_labels)) for label in labels])
    return labels, entity_type_one_hot


tokenizer = BertTokenizerFast.from_pretrained('neuralmind/bert-base-portuguese-cased') ## neuralmind/bert-base-portuguese-cased

def tokenize_and_align_labels(tokenized_inputs, labels):
    word_ids = tokenized_inputs.word_ids()
    previous_word_idx = None
    label_ids = []
    for word_idx in word_ids:
        # Special tokens have a word id that is None. We set the label to -100 so they are automatically
        # ignored in the loss function.
        if word_idx is None:
            label_ids.append(-100)
        # We set the label for the first token of each word.
        elif word_idx != previous_word_idx:
            label_ids.append(labels[word_idx])
        # For the other tokens in a word, we set the label to either the current label or -100, depending on
        # the label_all_tokens flag.
        else:
            label_ids.append(-100)
        previous_word_idx = word_idx
    tokenized_inputs["labels"] = torch.tensor([label_ids])
    return tokenized_inputs        

def generate_bert_encodings(doc, labels):
    tokens = [token.text for token in doc]
    encodings = tokenizer(tokens, is_split_into_words=True, return_tensors='pt', max_length=512, truncation=True, padding='max_length', add_special_tokens=True)
    encodings = tokenize_and_align_labels(encodings,labels)
    
    tokens_b = tokenizer.convert_ids_to_tokens(encodings["input_ids"][0])
    tokens_b = [tok for tok in tokens_b if tok != tokenizer.pad_token]
    a2b, _ = tokenizations.get_alignments(tokens, tokens_b)
    encodings['alignment'] = a2b
    encodings['pos_vectors'] = [one_hot_encode(universal_pos_tags.index(token.pos_), len(universal_pos_tags)) for token in doc]
    return encodings

   

  from .autonotebook import tqdm as notebook_tqdm


#### Generate input Edge_Index with the dependency parser tree (graph connections)

In [13]:
def accumulate_one_hot(edge_dict, edge, one_hot_attr):
    if edge not in edge_dict:
        edge_dict[edge] = one_hot_attr
    else:
        edge_dict[edge] = torch.logical_or(edge_dict[edge].bool(), one_hot_attr.bool()).float() # Accumulate one-hot vectors
    return edge_dict


def generateEdgeIndex(doc, input_rel_label2id, strategy=["dep", "seq"]):
    # Create the dependency graph between entities
    edge_index = []
    edge_attr = []
    unique_edges = []
    n_tokens = len(doc)
    edges_dict = {}
    #token_to_entity_idx = generateTokenToEntityIndex(entity_indices)
    if "dep" in strategy:  
        for token in doc:
            if token.head != token: #and token.i in token_to_entity_idx and token.head.i in token_to_entity_idx: # se o token e o head forem entidades
                edge = (token.i, token.head.i)
                #reverse = (token_to_entity_idx[token.head.i], token_to_entity_idx[token.i])
                one_hot_attr = one_hot_encode(input_rel_label2id[token.dep_], len(input_rel_label2id)) #
                if edge[0] != edge[1]:
                    edges_dict = accumulate_one_hot(edges_dict, edge, one_hot_attr)
            
    if "seq" in strategy:
        for i in range(n_tokens - 1):
            edge = (i, i + 1)
            one_hot_attr = one_hot_encode(input_rel_label2id["seq"], len(input_rel_label2id ))
            edges_dict = accumulate_one_hot(edges_dict, edge, one_hot_attr)

    if "full" in strategy:
        for i in range(n_tokens):
            for j in range(i + 1, n_tokens):
                edge = (i, j)
                one_hot_attr = one_hot_encode(input_rel_label2id["full"], len(input_rel_label2id))
                edges_dict = accumulate_one_hot(edges_dict, edge, one_hot_attr)
    
    edge_index = torch.tensor(list(edges_dict.keys()), dtype=torch.long).t().contiguous()
    edge_attr = torch.stack(list(edges_dict.values()))

    #print("Edge index:", edge_index.shape)
    return edge_index, edge_attr
    

### Generate output Output Edge_Index with LUSA relations

In [None]:
k = []
def generateTargetEdgeIndex(relations, node_id_mapping, relation_type_mapping):
    edge_indices = []
    edge_dict = {}
    edge_attr_text_dict = {} 
    global k
    edge_attributes = []
    edge_attr_text = []
    for rel in relations.values():
        #print(rel)
        if rel.subj in node_id_mapping and rel.obj in node_id_mapping:
            subj_idx = node_id_mapping[rel.subj]
            obj_idx = node_id_mapping[rel.obj]
            if subj_idx == obj_idx:
                print(" Self-loop detected")
                print(rel)
            edge_key = tuple(sorted([subj_idx, obj_idx]))
            #one_hot_attr = torch.tensor(one_hot_encode(relation_type_mapping[rel.type.split("_")[0]],len(relation_type_mapping))).clone().detach()
            one_hot_attr = torch.as_tensor(one_hot_encode(relation_type_mapping[rel.type.split("_")[0]], len(relation_type_mapping))).clone().detach() #rel.type.split("_")[0]

   
                # If the edge exists, sum the one-hot encoded attributes
            if edge_key in edge_dict:
                
                edge_dict[edge_key] = torch.logical_or(edge_dict[edge_key].bool(), one_hot_attr.bool()).float() # Accumulate one-hot vectors 
                edge_attr_text_dict[edge_key].add(rel.type.split("_")[0])  # Store unique text labels
            else:
                edge_dict[edge_key] = one_hot_attr
                edge_attr_text_dict[edge_key] = {rel.type.split("_")[0]}
        else:
            print(rel)
            print("ERROR: Relation refers to missing nodes")
    # Convert the dictionary into tensors
# Convert the dictionary into tensors
    if len(edge_dict) > 0:
        edge_indices_y = torch.tensor(list(edge_dict.keys()), dtype=torch.long).t().contiguous()
        edge_attr_y = torch.stack(list(edge_dict.values()))
    else:
        edge_indices_y = torch.empty((2, 0), dtype=torch.long)
        edge_attr_y = torch.empty((0, len(relation_type_mapping)), dtype=torch.float)

    edge_attr_text = ["|".join(sorted(types)) for types in edge_attr_text_dict.values()]
    
    return edge_indices_y, edge_attr_y, edge_attr_text

### Generate a input and output graph given a Lusa document

In [15]:
import json
def get_entity_preds(filename):
    entity_preds = json.load(open(filename))
    #one_hot_encode entity preds
    entity_preds_one_hot = []
    for batch in entity_preds:
        batch_preds = []
        for pred in batch:
            batch_preds.append(one_hot_encode(iob2id[pred], len(iob_labels)))
        entity_preds_one_hot.append(batch_preds)
    return entity_preds_one_hot

In [None]:
from torch_geometric.data import Data

# Create a mapping from token indices to entity indices
def generateTokenToEntityIndex(entity_indices):
    token_to_entity_idx = {}
    entity_sizes = {i: len(indices) for i, (indices, _, _) in enumerate(entity_indices)}
    discarded_entities = set()
    for entity_idx, (indices, type, idx) in enumerate(entity_indices):
        for token_idx in indices:
            #print(token_idx)
            # If token already assigned, keep the entity with more tokens
            if token_idx in token_to_entity_idx:
                current_entity = token_to_entity_idx[token_idx]
                if entity_sizes[entity_idx] > entity_sizes[current_entity]:
                    discarded_entities.add(entity_indices[current_entity][2])  # Track discarded entity
                    token_to_entity_idx[token_idx] = entity_idx
                else:
                    discarded_entities.add(idx)  # Track the smaller entity
            else:
                token_to_entity_idx[token_idx] = entity_idx
    return token_to_entity_idx, discarded_entities

def getGroupedTokens(doc, entity_indices):
    grouped_tokens = []
    visited_tokens = set()	
    indices_entity, discarded_entities = generateTokenToEntityIndex(entity_indices)
    for token in doc:
        if token.i not in indices_entity:
            grouped_tokens.append(([token.i], 0, "")) # out of entity index: 0
        else:
            entity_idx = indices_entity[token.i]
            if entity_idx not in visited_tokens:
                grouped_tokens.append(entity_indices[entity_idx])
                visited_tokens.add(entity_idx)
    return grouped_tokens, discarded_entities

def generateDataset(entities, relations, text, relation_type_mapping, input_rel_label2id):
    text = text.replace("\n", " ")
    text = text.replace("\xa0", " ")
    text = text.replace("´", "'")
    doc = nlp(text)

    entity_indices = [(get_spacy_token_indices(entity.span[0], doc, idx), label2id[entity.type], idx) for idx, entity in entities.items()]
    #print(entity_indices)
    
    labels, label_one_hot = generateEntityLabels(doc, entity_indices)
    tokens = [token.text for token in doc]
    encodings = generate_bert_encodings(doc,labels)

    iob_labels = [id2iob[label] for label in labels]

    edge_index, edge_attr = generateEdgeIndex(doc, input_rel_label2id)
 
    node_id_mapping = {node_id: indices[0] for idx, (indices,type,node_id) in enumerate(entity_indices)}
    y_edge_index, y_edge_attr, edge_attr_text = generateTargetEdgeIndex(relations, node_id_mapping, relation_type_mapping)


    data = Data(
        x=torch.zeros(len(tokens)), 
        edge_index=edge_index, edge_attr=edge_attr, 
        y=torch.tensor(labels), 
        encodings = encodings, 
        target_edge_index=y_edge_index, 
        y_edge_attr=y_edge_attr, 
        tokens= tokens,
        labels = torch.tensor(labels), 
        iob_labels = iob_labels, 
        label_one_hot = torch.tensor(label_one_hot),
    )
    return data

#### Generate list of graphs for each document of the Lusa Dataset

In [None]:
from torch_geometric.loader import DataLoader

relation_type_mapping = {'TLINK': 0, 'SRLINK': 1, 'QSLINK': 2, "OLINK": 3}#, "event-TLINK":4}

def generateDatasets(files):
    data_list = []
    entities_list, relations_list, texts_list = parse_brat(files)
    print(len(entities_list), len(relations_list), len(texts_list))
    

    print(relation_type_mapping)
    input_rel_label2id = {label: idx for idx, label in enumerate(list(nlp.get_pipe("parser").labels) + ["seq", "full", "entity"])}

    for i, entities in enumerate(entities_list):

        data = generateDataset(entities, relations_list[i], texts_list[i], relation_type_mapping,input_rel_label2id)
        data['file_name'] = files[i]

        data_list.append(data)

    return data_list

data_list_train = generateDatasets(files_train)
data_list_test = generateDatasets(files_test)
data_list_val = generateDatasets(files_val)


In [22]:
entity_preds = get_entity_preds("lusa_ner_predictions_test.json")
for data, preds in zip(data_list_test, entity_preds):
    data.entity_preds = torch.stack(preds)
entity_preds = get_entity_preds("lusa_ner_predictions_val.json")
for data, preds in zip(data_list_val, entity_preds):
    data.entity_preds = torch.stack(preds)


In [33]:
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader

train_loader = DataLoader(data_list_train, batch_size=1, shuffle=False)
val_loader = DataLoader(data_list_val, batch_size=1, shuffle=False)
test_loader = DataLoader(data_list_test, batch_size=1, shuffle=False)

### Definning the Graph Auto-Encoder Model

In [None]:
import os
import torch
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import TransformerConv
from torch_geometric.nn import GATv2Conv
from torch.nn import BatchNorm1d
from transformers import BertModel, AutoConfig, BertTokenizer
from torch_geometric.nn import BatchNorm
import random
import torch
import tokenizations
import numpy as np

def set_all_seeds(seed=42):
    import os
    import random
    import numpy as np
    import torch

    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)

    # Deterministic CuDNN
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    torch.use_deterministic_algorithms(True)

    # Optional: environment variable for further determinism
    os.environ['PYTHONHASHSEED'] = str(seed)


set_all_seeds() 

class GNNForEdgePrediction(nn.Module):
    def __init__(self, input_dim, hidden_dim, config):
        super(GNNForEdgePrediction, self).__init__()
        decoder_size = 512
        dep_size = len(nlp.get_pipe("parser").labels) + 3
  
        self.num_entity_labels = len(iob_labels)
        num_relation_labels = len(relation_type_mapping)

        # Encoder Layers
        self.bert = BertModel(config, add_pooling_layer=False)
        self.conv1 = GATv2Conv(input_dim, hidden_dim, heads=4, concat=False, edge_dim=dep_size)
        self.conv2 = GATv2Conv(hidden_dim, hidden_dim, heads=4, concat=False,  edge_dim=dep_size)
        self.conv3 = GATv2Conv(hidden_dim, hidden_dim, heads=4, concat=False,  edge_dim=dep_size)

        # Latent Layers (GVAE)
        #self.conv_mu = TransformerConv(hidden_dim, hidden_dim, heads=4, concat=False,  edge_dim=dep_size)
        #self.conv_logstd = TransformerConv(hidden_dim, hidden_dim, heads=4, concat=False,  edge_dim=dep_size)


        # Decoder layers
        self.decoder_dense_1 = nn.Linear(hidden_dim*2 + self.num_entity_labels*2, decoder_size)
        self.decoder_dense_2 = nn.Linear(decoder_size, decoder_size)
        self.decoder_dense_3 = nn.Linear(decoder_size, decoder_size)
        self.label_pred_layer = nn.Linear(decoder_size, len(relation_type_mapping)) #srlinks and tlinks | qslinks
        self.entity_classifier = nn.Linear(hidden_dim, self.num_entity_labels)
        classifier_dropout = (
            config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
        )
        self.dropout = nn.Dropout(classifier_dropout)
  # Relation classifier dropout
        self.relation_dropout = nn.Dropout(classifier_dropout)

        # Fusion + final relation decoders
        fusion_input_dim = hidden_dim * 2 + 2 * self.num_entity_labels  # relation_feat + source_ent + target_ent
        self.fusion_layer = nn.Linear(fusion_input_dim, hidden_dim)
        self.relation_classifier_final = nn.Linear(hidden_dim, num_relation_labels)

        # Final GNN to refine entities using relations
        self.relation_to_entity_gnn = GATv2Conv(hidden_dim, hidden_dim, heads=4, concat=False,  edge_dim=hidden_dim * 2)
        self.entity_classifier_final = nn.Linear(hidden_dim, self.num_entity_labels)
        
    def encode(self, x, edge_index, edge_attr):

        h1 = self.conv1(x, edge_index, edge_attr)
        #h2 = self.conv2(h1, edge_index, edge_attr) + h1
        #h3 = self.conv3(h2, edge_index, edge_attr) + h2

        return h1

        return z_mu, z_logstd
  
    def preprocess_bert(self, data):
        tokens = data.tokens
        encodings = data.encodings

        bert_output = self.bert(encodings["input_ids"], attention_mask=encodings["attention_mask"], token_type_ids=encodings["token_type_ids"])
        bert_embeddings = bert_output.last_hidden_state#.squeeze()
        all_embeddings = []
        #print("BERT embeddings shape:", bert_embeddings.shape)
        all_embeddings = []
        for i, indices_list in enumerate(encodings['alignment']):
            z_embeddings = []
            for j, indices in enumerate(indices_list):
                if indices:
                    token_embedding = bert_embeddings[i, indices[0], :]
                else:
                    token_embedding = torch.zeros(bert_embeddings.shape[2]).to(bert_embeddings.device)

                # Concatenate POS one-hot vector
                pos_vec = encodings['pos_vectors'][j].to(token_embedding.device)
                token_embedding = torch.cat([token_embedding, pos_vec], dim=-1)

                z_embeddings.append(token_embedding)

            all_embeddings.append(torch.stack(z_embeddings))
        
        return torch.cat(all_embeddings, dim=0)


    def decode_joint(self, z, entity_logits_init):

        graph_z = z
        #print("Graph Z shape:", graph_z.shape)
        edge_indicesX = torch.triu_indices(graph_z.shape[0], graph_z.shape[0], offset=1) # calculate upper part of the matrix indexes
        #print("Edge indicesX shape:", edge_indicesX.shape)
        source_indices = edge_indicesX[0]
        #print(source_indices.shape)
        target_indices = edge_indicesX[1]

        source_features = graph_z[source_indices]
        target_features = graph_z[target_indices]

        entity_source = entity_logits_init[source_indices]
        entity_target = entity_logits_init[target_indices]

        source_features_joint = torch.cat([source_features, entity_source], dim=1)
        target_features_joint = torch.cat([target_features, entity_target], dim=1)
        #print(source_features.shape)
        edge_features_joint = torch.cat([source_features_joint, target_features_joint], axis=1)
        graph_inputs = torch.cat([source_features, target_features], axis=1)

        k = self.fusion_layer(edge_features_joint).relu()
        edge_logits_joint = self.relation_classifier_final(k)
        return  edge_logits_joint #edge_indicesX #graph_inputs
    
    
    
    
    def decode(self, z):

        graph_z = z

        edge_indicesX = torch.triu_indices(graph_z.shape[0], graph_z.shape[0], offset=1) # calculate upper part of the matrix indexes

        source_indices = edge_indicesX[0]

        target_indices = edge_indicesX[1]

        source_features = graph_z[source_indices]
        target_features = graph_z[target_indices]

        graph_inputs = torch.cat([source_features, target_features], axis=1)
        # Get predictions
        x = self.decoder_dense_1(graph_inputs).relu()
        x = self.decoder_dense_2(x).relu()

        x = self.relation_dropout(x)
        edge_logits = self.label_pred_layer(x)
        return  edge_logits #edge_indicesX #graph_inputs


        
    def reparameterize(self, mu, logstd):
        if self.training:
            #getting the std deviation
            std = torch.exp(logstd)
            #generate a random number from a normal distribution
            eps = torch.randn_like(std)
            #return the z value
            return eps.mul(std).add_(mu)
        else:
            #return the mean
            return mu

    def forward(self, data, pipeline=False):
        data = data.to(device)
        x , edge_index, edge_attr = data.x, data.edge_index, data.edge_attr

        x = self.preprocess_bert(data)
 

        z = self.encode(x, edge_index, edge_attr) #mu, logstd

        sequence_output = self.dropout(z)
        entity_logits_init = self.entity_classifier(sequence_output)
        #print("Entity logits init shape:", entity_logits_init.shape)
        #enitity logits to one hot

        if self.training:

            entity_features_detached = entity_logits_init.detach()
        else:

            if pipeline:
                entity_features_detached = data.entity_preds.to(device)
                print("using predicted entity labels")
            else:
                entity_features_detached = entity_logits_init.detach()

        
        z = torch.cat([z, entity_features_detached], dim=1) #data.label_one_hot # entity_one_hot
        

        edge_logits = self.decode(z)
        return entity_logits_init, edge_logits
        


In [None]:
model_save_path = "./models/__joint_best_pos55.pt"
input_dim = 768 + len(universal_pos_tags) #768 + 5 BERT // 300 + 5 spacy
hidden_dim = 768 + len(universal_pos_tags)#  512 spacy // 1024 BERT
model_checkpoint = "neuralmind/bert-base-portuguese-cased"  # or any other model checkpoint
config = AutoConfig.from_pretrained(model_checkpoint) # lfcc/lusa_events neuralmind/bert-base-portuguese-cased
config.num_labels = len(iob_labels)  # Set the number of labels for token classification
config.id2label = id2iob
config.label2id = iob2id
bert = BertModel.from_pretrained(
    model_checkpoint,
    config=config,
    add_pooling_layer=True  # Required to match pretrained checkpoint
)
model = GNNForEdgePrediction(input_dim, hidden_dim, config)
missing_keys, unexpected_keys = model.bert.load_state_dict(
    bert.state_dict(), strict=False
)
model.to(device)
model.load_state_dict(torch.load(model_save_path))


In [59]:
def mask_predicted_edges(edge_preds, edge_src_types, edge_tgt_types, ignore_O=True, ignore_I=True):

    masked_preds = edge_preds.copy()
    x = 0
    for i, (src, tgt) in enumerate(zip(edge_src_types, edge_tgt_types)):
        #print(src, tgt)
        if ignore_O and (id2iob[src] == 'O' or id2iob[tgt] == 'O'):
            x = x + 1
            masked_preds[i, :] = 0
            #print("masked O")
        if ignore_I and (id2iob[src].startswith('I') or id2iob[tgt].startswith('I')):
            masked_preds[i, :] = 0
            #print("masked I")
    print(f"Masked {x} edges")
    return masked_preds


import torch
import torch.nn.functional as F

def refine_entities_with_relations(entity_logits, edge_index, edge_logits, topk=1):

    num_nodes, num_entity_types = entity_logits.shape
    refined_entity_logits = entity_logits.clone()

    # 1. Compute edge probabilities
    edge_probs = torch.sigmoid(edge_logits)  # [num_edges, num_relation_types]
    # Keep only topk relations per edge
    topk_vals, topk_idx = torch.topk(edge_probs, k=topk, dim=-1)

    # 2. Message passing: propagate relation info to nodes
    for i, (src, tgt) in enumerate(edge_index.t()):
        for k in range(topk):
            rel_prob = topk_vals[i, k]
            # Heuristic: if edge exists with high probability, boost non-O predictions
            if rel_prob > 0.5:
                # Apply a small boost to entity logits of connected nodes
                refined_entity_logits[src] += 0.1 * rel_prob
                refined_entity_logits[tgt] += 0.1 * rel_prob

    return refined_entity_logits





def refine_entity_logits_with_relations(
    entity_logits,
    edge_probs,
    edge_indices,
    rel_threshold=0.5,
    penalty=5.0,
):
    """
    Suppress O-label for tokens involved in confident relations.
    """
    refined_logits = entity_logits.clone()

    # Identify edges with at least one confident relation
    confident_edges = (edge_probs > rel_threshold).any(dim=1)

    if confident_edges.sum() == 0:
        return refined_logits

    src = edge_indices[0][confident_edges]
    tgt = edge_indices[1][confident_edges]
    involved_nodes = torch.unique(torch.cat([src, tgt]))

    # Penalize O logit
    refined_logits[involved_nodes, label2id['O']] -= penalty

    return refined_logits

def refine_entity_logits_with_relations_(
    entity_logits: torch.Tensor,
    edge_probs: torch.Tensor,
    edge_indices: torch.Tensor,
    o_label_id: int,
    alpha: float = 0.5,
    rel_threshold: float = 0.5,
):
    refined_logits = entity_logits.clone()
    device = entity_logits.device
    edge_indices = edge_indices.to(device)
    edge_probs = edge_probs.to(device)

    # Identify confident edges
    confident_edges = (edge_probs > rel_threshold).any(dim=1)
    if confident_edges.sum() == 0:
        return refined_logits

    src = edge_indices[0][confident_edges]
    tgt = edge_indices[1][confident_edges]

    # Nodes involved in confident relations
    involved_nodes = torch.unique(torch.cat([src, tgt]))

    # Penalize O
    refined_logits[involved_nodes, o_label_id] -= alpha

    return refined_logits


def boost_label1(entity_logits: torch.Tensor, label1_id: int = 1, alpha: float = 0.5):
    """
    Simply increases the logit of label1 for all nodes.
    """
    refined_logits = entity_logits.clone()
    refined_logits[:, label1_id] += alpha
    return refined_logits


def refine_entity_logits_with_relations(
    entity_logits,
    edge_probs_joint,
    num_nodes,
    alpha=0.1
):
    refined_logits = entity_logits.clone()

    # Upper-triangular edge indexing (must match your edge construction)
    edge_indices = torch.triu_indices(num_nodes, num_nodes, offset=1)
    src_idx = edge_indices[0]
    tgt_idx = edge_indices[1]

    # Use max relation confidence per edge
    rel_strength = edge_probs_joint.max(dim=1).values  # [E]
    #print(edge_probs_joint.shape, rel_strength.shape)

    for e, strength in enumerate(rel_strength):
        if strength <= 0:
            continue

        i = src_idx[e]
        j = tgt_idx[e]
        refined_logits[i] += alpha #* strength
        refined_logits[j] += alpha #* strength

    return refined_logits


from collections import defaultdict
import numpy as np
from sklearn.metrics import f1_score, precision_score, recall_score

def edge_type_statistics(edge_all_y, edge_all_preds, edge_src_gt, edge_tgt_gt, relation_type_mapping):

    inverted_relation_mapping = {v: k for k, v in relation_type_mapping.items()}
    
    # Group edges by (src_entity_type, tgt_entity_type, relation_type)
    metrics_by_group = defaultdict(list)
    
    # Determine the relation type for each edge (argmax for simplicity)
    edge_true_type = np.argmax(edge_all_y, axis=1)
    
    for i, (rel, src, tgt) in enumerate(zip(edge_true_type, edge_src_gt, edge_tgt_gt)):
        metrics_by_group[(src, tgt, rel)].append(i)
    
    # Compute F1, precision, recall per group
    stats = {}
    for (src, tgt, rel), indices in metrics_by_group.items():
        y_true = edge_all_y[indices]
        y_pred = edge_all_preds[indices]

        f1 = f1_score(y_true, y_pred, average='micro', zero_division=0)
        precision = precision_score(y_true, y_pred, average='micro', zero_division=0)
        recall = recall_score(y_true, y_pred, average='micro', zero_division=0)

        stats[(src, tgt, inverted_relation_mapping[rel])] = {
            'precision': precision,
            'recall': recall,
            'f1': f1,
            'support': y_true.sum(),
            'predicted': y_pred.sum()
        }
    
    return stats

from collections import defaultdict

def summarize_edge_errors(stats, inverted_relation_mapping, id2iob):
    """
    Summarize which source-target label types produce errors.
    stats: dict with keys (relation_id, src_node, tgt_node)
           and values {'tp', 'fp', 'fn', 'support', 'predicted'}
    """
    error_summary = defaultdict(lambda: {"fp":0, "fn":0, "support":0, "predicted":0})

    for (r, src, tgt), v in stats.items():
        rel_name = inverted_relation_mapping[r]
        src_prefix = id2iob[src].split('-')[0] if '-' in id2iob[src] else id2iob[src]
        tgt_prefix = id2iob[tgt].split('-')[0] if '-' in id2iob[tgt] else id2iob[tgt]

        key = (rel_name, src_prefix, tgt_prefix)

        error_summary[key]["fp"] += v["fp"]
        error_summary[key]["fn"] += v["fn"]
        error_summary[key]["support"] += v["support"]
        error_summary[key]["predicted"] += v["predicted"]

    # Print sorted by relation
    for rel in sorted(set(k[0] for k in error_summary.keys())):
        print(f"\n=== {rel} ===")
        for (r, src_p, tgt_p), v in sorted(error_summary.items()):
            if r != rel:
                continue
            print(f"{src_p} → {tgt_p} | FP={v['fp']} FN={v['fn']} Support={v['support']} Predicted={v['predicted']}")


In [None]:
import evaluate
metric = evaluate.load("seqeval")
from tqdm import tqdm
import numpy as np
from seqeval.metrics import classification_report
from sklearn.metrics import f1_score as sklearn_f1, precision_score as sklearn_p, recall_score as sklearn_r
from collections import defaultdict

def edge_index_2_adj_matrix_label(data):
    #print(data)
    matrix = torch.squeeze(to_dense_adj(data.target_edge_index, edge_attr=data.y_edge_attr, max_num_nodes=data.x.shape[0]))
    triu_indices = torch.triu_indices(matrix.shape[0], matrix.shape[0], offset=1)
    triu_mask = torch.squeeze(to_dense_adj(triu_indices)).bool()
    #print(edge_logits.squeeze().shape, batch_targets[triu_mask].shape)
    
    return matrix[triu_mask]

def evaluate_bert_4(model, eval_dataloader, device, pipeline=False,
                    node_type_mapping=None, relation_type_mapping=None):
    
    node_type_mapping = id2iob
    if relation_type_mapping is None:
        raise ValueError("relation_type_mapping must be provided")
    inverted_relation_mapping = {v: k for k, v in relation_type_mapping.items()}
    
    model.eval()
    all_preds = []
    all_labels = []

    edge_all_y = []
    edge_all_preds_joint = []
    edge_src_types = []
    edge_tgt_types = []
    print(relation_type_mapping)
    with torch.no_grad():
        for batch in tqdm(eval_dataloader):
            # Forward pass
            entity_logits, edge_logits_joint = model(batch, pipeline=pipeline)

        

            # Edge processing
            edge_y = edge_index_2_adj_matrix_label(batch)  # [num_edges, num_relation_types]
            edge_probs_joint = torch.sigmoid(edge_logits_joint)
            edge_preds_joint = (edge_probs_joint > 0.5).int()

            edge_all_y.extend(edge_y.cpu())
            edge_all_preds_joint.extend(edge_preds_joint.cpu())
            
            # Determine node types
            num_nodes = batch.x.shape[0]
            edge_indices = torch.triu_indices(num_nodes, num_nodes, offset=1)
            src_idx = edge_indices[0]
            tgt_idx = edge_indices[1]
            
                        # Entity predictions
                        
            #entity_logits_refined = refine_entity_logits_with_relations_(
            #    entity_logits,
            #    edge_probs_joint,
            #    edge_indices,
            #    o_label_id=iob2id['O'],
            #    alpha=1
            #)
            #entity_logits_refined = boost_label1(entity_logits, label1_id=1, alpha=10)

            predictions = torch.argmax(entity_logits, dim=-1)
            entity_labels = batch.labels  # ground-truth node labels



            # Use predicted or ground-truth entity types as node types
            if pipeline:
                node_type_ids = predictions#batch.labels
            else:
                node_type_ids = predictions#batch.labels
            #print(src_idx.shape, tgt_idx.shape, node_type_ids.shape)
            src_types = node_type_ids[src_idx].cpu().numpy()
            tgt_types = node_type_ids[tgt_idx].cpu().numpy()
            edge_src_types.extend(src_types)
            edge_tgt_types.extend(tgt_types)

            # Convert entity labels to IOB format
            entity_labels_iob = [id2iob[label.item()] for label in entity_labels]
            predictions_iob = [id2iob[pred.item()] for pred in predictions]
            all_preds.append(predictions_iob)
            all_labels.append(entity_labels_iob)

    # Convert edge lists to numpy
    edge_all_y = np.array(edge_all_y)
    edge_all_preds_joint = np.array(edge_all_preds_joint)
    edge_src_types = np.array(edge_src_types)
    edge_tgt_types = np.array(edge_tgt_types)
    print(edge_all_y.shape, edge_src_types.shape, edge_tgt_types.shape)

    # ---- Overall edge metrics ----
    #edge_all_preds_joint = mask_predicted_edges(edge_all_preds_joint, edge_src_types, edge_tgt_types, ignore_I=False, ignore_O=True)

    edge_joint_f1 = sklearn_f1(edge_all_y, edge_all_preds_joint, average="micro", zero_division=0)
    print(f"\nOverall Edge Joint F1: {edge_joint_f1:.4f}")


    # ---- Edge metrics per relation type and node-type pair with counts ----
    x = {}
    y = {}
    
    print("\n=== Edge metrics per (relation, src_type, tgt_type) ===")

    stats = defaultdict(lambda: {
        "tp": 0,
        "fp": 0,
        "fn": 0,
        "support": 0,
        "predicted": 0
    })

    num_edges = edge_all_y.shape[0]
    num_relations = edge_all_y.shape[1]


    m = {} # count edges between I or O
    for i in range(num_edges):
        src = edge_src_types[i]
        tgt = edge_tgt_types[i]
        src, tgt = tuple(sorted((src, tgt)))
        for r in range(num_relations):
            key = (r, src, tgt)

            y_t = edge_all_y[i, r]
            y_p = edge_all_preds_joint[i, r]

            if y_t == 1:
                stats[key]["support"] += 1
            if y_p == 1:
                stats[key]["predicted"] += 1

            
            if y_t == 1 and y_p == 1:
                stats[key]["tp"] += 1
            elif y_t == 0 and y_p == 1:
                stats[key]["fp"] += 1
            elif y_t == 1 and y_p == 0:
                stats[key]["fn"] += 1
                
                
    sorted_stats = sorted(
        stats.items(),
        key=lambda item: inverted_relation_mapping[item[0][0]]  # sort by rel_name
    )
                
    for (r, src, tgt), v in sorted_stats:
        tp, fp, fn = v["tp"], v["fp"], v["fn"]

        if v["support"] == 0 and v["predicted"] == 0:
            continue

        precision = tp / (tp + fp) if (tp + fp) > 0 else 0.0
        recall = tp / (tp + fn) if (tp + fn) > 0 else 0.0
        f1 = (
            2 * precision * recall / (precision + recall)
            if (precision + recall) > 0 else 0.0
        )

        rel_name = inverted_relation_mapping[r]
        src_name = node_type_mapping[src]
        tgt_name = node_type_mapping[tgt]

        print(
            f"{rel_name:15s}: "
            f"{src_name} → {tgt_name} | "
            f"P={precision:.4f} R={recall:.4f} F1={f1:.4f} | "
            f"Support={v['support']} Predicted={v['predicted']}"
        ) 
        x[rel_name] = x.get(rel_name, 0) + v['predicted']
        y[rel_name] = y.get(rel_name, 0) + v['support']
        if "I-" in src_name or "I-" in tgt_name or src_name == "O" or tgt_name == "O":
            m["I"] = m.get("I", 0) + v['predicted']     
    #---- Summarize edge errors ----
    print("\n=== Edge error summary ===")
    summarize_edge_errors(stats, inverted_relation_mapping, node_type_mapping)
    # ---- Entity metrics ----
    f1_entity = metric.compute(predictions=all_preds, references=all_labels)["overall_f1"]
    print("\nEntity metrics by label:")
    print(classification_report(all_labels, all_preds, digits=4))
    print(f"Entity F1 score: {f1_entity:.4f}")
    print(x)
    print(y)
    print(m)
    return f1_entity, edge_joint_f1


In [None]:
test_entity_f1, test_edge_f1 = evaluate_bert_4(model, test_loader, device, relation_type_mapping=relation_type_mapping)
print(f"Final Test Edge F1: {test_edge_f1:.4f}, Entity F1: {test_entity_f1:.4f}")


{'TLINK': 0, 'SRLINK': 1, 'QSLINK': 2, 'OLINK': 3}


100%|██████████| 19/19 [00:03<00:00,  5.57it/s]


(342474, 4) (342474,) (342474,)

Overall Edge Joint F1: 0.5576

=== Edge metrics per (relation, src_type, tgt_type) ===
OLINK          : O → O | P=0.0000 R=0.0000 F1=0.0000 | Support=3 Predicted=0
OLINK          : O → B-Participant | P=0.4444 R=0.0952 F1=0.1569 | Support=42 Predicted=9
OLINK          : B-Participant → B-Participant | P=0.4927 R=0.5488 F1=0.5192 | Support=246 Predicted=274
OLINK          : B-Participant → B-Spatial_Relation | P=0.4444 R=0.5714 F1=0.5000 | Support=14 Predicted=18
OLINK          : B-Participant → I-Participant | P=0.5455 R=0.1500 F1=0.2353 | Support=40 Predicted=11
OLINK          : I-Participant → I-Participant | P=0.0000 R=0.0000 F1=0.0000 | Support=10 Predicted=0
OLINK          : I-Participant → B-Spatial_Relation | P=0.0000 R=0.0000 F1=0.0000 | Support=2 Predicted=0
OLINK          : I-Event → B-Participant | P=0.0000 R=0.0000 F1=0.0000 | Support=1 Predicted=0
OLINK          : B-Participant → I-Spatial_Relation | P=0.3333 R=0.5000 F1=0.4000 | Support=2 