# Utils


## Configuration

In [None]:
from easydict import EasyDict
import json
import os
import numpy as np


def config():
    args = EasyDict()
    args.train_set = "/home/hwiric/SagDRE-2024-07-05/data/DocRED/train_annotated.json"
    args.dev_set = "/home/hwiric/SagDRE-2024-07-05/data/DocRED/dev.json"
    args.test_set = "/home/hwiric/SagDRE-2024-07-05/data/DocRED/test.json"

    args.checkpoint_dir = "checkpoint"
    args.model_name = "SAGDRE_BERT_base"
    args.pretrain_model = ""

    args.relation_nums = 97
    args.entity_type_num = 7
    args.max_entity_num = 80

    args.word_pad = 0
    args.entity_type_pad = 0
    args.entity_id_pad = 0

    args.word_emb_size = 10
    args.use_entity_type = "store_true"
    args.entity_type_size = 20

    args.use_entity_id = "store_true"
    args.entity_id_size = 20

    args.nlayers = 1
    args.lstm_hidden_size = 32
    args.lstm_dropout = 0.4

    args.lr = 0.001
    args.batch_size = 2
    args.test_batch_size = 1
    args.epoch = 40
    args.test_epoch = 5
    args.weight_decay = 0.0001
    args.negativa_alpha = 4
    args.save_model_freq = 1

    args.gcn_layers = 2
    args.gcn_dim = 128
    args.dropout = 0.6
    args.activation = "relu"

    args.bert_hid_size = 768
    args.coslr = "store_true"

    args.use_model = "bert"

    args.input_theta = 1.0

    return args


class Object(object):
    pass


args = config()

data_opt = Object()
data_opt.data_dir = "/home/hwiric/SagDRE-2024-07-05/data/DocRED"
data_opt.rel2id = json.load(open(os.path.join(data_opt.data_dir, "rel2id.json"), "r"))
data_opt.id2rel = {v: k for k, v in data_opt.rel2id.items()}
data_opt.word2id = json.load(open(os.path.join(data_opt.data_dir, "word2id.json"), "r"))
data_opt.ner2id = json.load(open(os.path.join(data_opt.data_dir, "ner2id.json"), "r"))
data_opt.word2vec = np.load(os.path.join(data_opt.data_dir, "vec.npy"))

In [None]:
import json
import math
import random
from collections import defaultdict
import dgl
import numpy as np
import torch
from torch.utils.data import IterableDataset, DataLoader
from tqdm.std import tqdm

In [None]:
import os
from datetime import datetime
import numpy as np
import torch
import torch.nn.functional as F
import random
from torch import nn
import dgl.nn.pytorch as dglnn
from transformers import BertTokenizer


def get_cuda(tensor):
    if torch.cuda.is_available():
        return tensor.cuda()
    return tensor


class Bert:
    MASK = "[MASK]"
    CLS = "[CLS]"
    SEP = "[SEP]"

    def __init__(self, model_name):
        super().__init__()
        self.model_name = model_name
        if model_name == "SAGDRE_BERT_base":
            self.tokenizer = BertTokenizer.from_pretrained("bert-base-cased")
        else:
            self.tokenizer = BertTokenizer.from_pretrained("bert-large-cased")
        self.max_len = 512

    def tokenize(self, text, masked_idxs=None):
        tokenized_text = self.tokenizer.tokenize(text)
        if masked_idxs is not None:
            for idx in masked_idxs:
                tokenized_text[idx] = self.MASK
        tokenized = [self.CLS] + tokenized_text + [self.SEP]
        return tokenized

    def tokenize_to_ids(self, text, masked_idxs=None, pad=True):
        tokens = self.tokenize(text, masked_idxs)
        return tokens, self.convert_tokens_to_ids(tokens, pad=pad)

    def convert_tokens_to_ids(self, tokens, pad=True):
        token_ids = self.tokenizer.convert_tokens_to_ids(tokens)
        ids = torch.tensor([token_ids])
        ids = ids[:, : self.max_len]
        if pad:
            padded_ids = torch.zeros(1, self.max_len).to(ids)
            padded_ids[0, : ids.size(1)] = ids
            mask = torch.zeros(1, self.max_len).to(ids)
            mask[0, : ids.size(1)] = 1
            return padded_ids, mask
        else:
            return ids

    def flatten(self, list_of_lists):
        for list in list_of_lists:
            for item in list:
                yield item

    def subword_tokenize(self, tokens):
        """Segment each token into subwords while keeping track of
        token boundaries.
        Parameters
        ----------
        tokens: A sequence of strings, representing input tokens.
        Returns
        -------
        A tuple consisting of:
            - A list of subwords, flanked by the special symbols required
                by Bert (CLS and SEP).
            - An array of indices into the list of subwords, indicating
                that the corresponding subword is the start of a new
                token. For example, [1, 3, 4, 7] means that the subwords
                1, 3, 4, 7 are token starts, while all other subwords
                (0, 2, 5, 6, 8...) are in or at the end of tokens.
                This list allows selecting Bert hidden states that
                represent tokens, which is necessary in sequence
                labeling.
        """
        subwords = list(map(self.tokenizer.tokenize, tokens))
        subword_lengths = list(map(len, subwords))
        subwords = [self.CLS] + list(self.flatten(subwords))[:509] + [self.SEP]
        token_start_idxs = 1 + np.cumsum([0] + subword_lengths[:-1])
        token_start_idxs[token_start_idxs > 509] = 512
        return subwords, token_start_idxs

    def subword_tokenize_to_ids(self, tokens):
        """Segment each token into subwords while keeping track of
        token boundaries and convert subwords into IDs.
        Parameters
        ----------
        tokens: A sequence of strings, representing input tokens.
        Returns
        -------
        A tuple consisting of:
            - A list of subword IDs, including IDs of the special
                symbols (CLS and SEP) required by Bert.
            - A mask indicating padding tokens.
            - An array of indices into the list of subwords. See
                doc of subword_tokenize.
        """
        subwords, token_start_idxs = self.subword_tokenize(tokens)
        subword_ids, mask = self.convert_tokens_to_ids(subwords)
        return subword_ids.numpy(), token_start_idxs, subwords

    def segment_ids(self, segment1_len, segment2_len):
        ids = [0] * segment1_len + [1] * segment2_len
        return torch.tensor([ids])

In [None]:
import spacy
import dgl
from spacy.tokens import Doc
import networkx as nx
import matplotlib.pyplot as plt

nlp = spacy.load("en_core_web_lg")
# nlp.tokenizer = nlp.tokenizer.tokens_from_list


def build_g(sentences, pos_idx, max_id):
    # sentences should be a list of word lists
    # [[sent_1], [sent_2], ..., [sent_m]]
    # senti = [w_0, w_1, ..., w_n]
    pre_roots = []
    g = nx.DiGraph()
    docs = [Doc(nlp.vocab, words=ws) for ws in sentences]
    # tokens = parser(doc)
    for tokens in nlp.pipe(docs):
        g, pre_roots = parse_sent(tokens, g, pre_roots)
    # g = add_same_words_links(g, remove_stopwords=True)
    g, start = add_entity_node(g, pos_idx, max_id)
    paths = get_entity_paths(g, max_id, start)
    g = dgl.from_networkx(g)
    return g, paths


def parse_sent(tokens, g, pre_roots):
    roots = [token for token in tokens if token.head == token]
    start = len(g.nodes())
    dic = {}
    idx = 0
    for token in tokens:
        is_root = token in roots
        g.add_node(
            start + idx,
            text=token.text,
            vector=token.vector,
            is_root=is_root,
            tag=token.tag_,
            pos=token.pos_,
            dep=token.dep_,
        )
        dic[token] = start + idx
        idx += 1

    for token in tokens:
        g.add_edge(dic[token], dic[token.head], dep=token.dep_)
        g.add_edge(dic[token.head], dic[token], dep=token.dep_)

    for idx, root in enumerate(roots[:-1]):
        g.add_edge(dic[root], dic[roots[idx + 1]], dep=token.dep_)
        g.add_edge(dic[roots[idx + 1]], dic[root], dep=token.dep_)

    if pre_roots:
        pre_root_idx = pre_roots[-1]
        for root in roots[:1]:
            g.add_edge(dic[root], pre_root_idx, dep="rootconn")
            g.add_edge(pre_root_idx, dic[root], dep="rootconn")
    for pre_root_idx in pre_roots:
        # pre_root_idx = pre_roots[-2]
        # g.add_edge(dic[root], pre_root_idx, dep='rootconn')
        for root in roots[:1]:
            g.add_edge(pre_root_idx, dic[root], dep="rootconn")
    for root in roots[:1]:
        pre_roots.append(dic[root])
    return g, pre_roots


def add_same_words_links(g, remove_stopwords=True):
    names = nx.get_node_attributes(g, "text")
    name_dic = {}
    stopwords = nlp.Defaults.stop_words
    for idx, name in names.items():
        name = name.lower()
        if remove_stopwords and name in stopwords:
            continue
        if len(name) < 5:
            continue
        if name not in name_dic:
            name_dic[name] = [idx]
        else:
            for pre_idx in name_dic[name]:
                g.add_edge(idx, pre_idx)
                g.add_edge(pre_idx, idx)
            name_dic[name].append(idx)
    return g


def add_entity_node(g, pos_idx, max_id):
    start = len(g.nodes())
    for idx in range(max_id):
        g.add_node(start + idx, text="entity_%s" % idx, is_root=False)
        if idx + 1 not in pos_idx:
            continue
        for idx2, node in enumerate(pos_idx[idx + 1]):
            g.add_edge(start + idx, node)
            g.add_edge(node, start + idx)
    return g, start


def get_entity_paths(g, max_id, start):
    ent_paths = {}
    pos_data = nx.get_node_attributes(g, "pos")
    for idx in range(max_id):
        for j in range(max_id):
            if idx != j:
                try:
                    # paths = list(nx.all_simple_paths(g, start + idx, start + j, 8))
                    # if len(paths) == 0:
                    paths = [nx.shortest_path(g, start + idx, start + j)]
                except nx.NetworkXNoPath:
                    paths = []
                new_paths = []
                # for path in paths:
                #     path = sorted(path)
                #     new_paths.append(path)

                for path in paths[:3]:
                    # add immediate neighbors for nodes on the path
                    neighbors = set()
                    # neighbors = list()
                    for n in path:
                        neighbors.add(n)
                        for cur_neigh in g.neighbors(n):
                            if cur_neigh in pos_data:
                                if pos_data[cur_neigh] in ["ADP"]:
                                    neighbors.add(cur_neigh)
                                # if dep_data[cur_neigh] in ['neg']:
                                #     neighbors.add(cur_neigh)
                    path = sorted(neighbors)
                    # path = sorted(path)
                    new_paths.append(path)
                # except nx.NetworkXNoPath:
                #     new_paths = [[start + idx] + roots_path + [start + j]]
                #     # path = [start+idx] + roots_path + [start+j]
                #     # path = [start+idx, start+j]

                ent_paths[(idx + 1, j + 1)] = new_paths[:3]
    return ent_paths

## Dataset

In [None]:
IGNORE_INDEX = -100


class BERTDGLREDataset(IterableDataset):

    def __init__(
        self,
        src_file,
        ner2id,
        rel2id,
        dataset="train",
        instance_in_train=None,
        model_name="SAGDRE_BERT_base",
    ):

        super(BERTDGLREDataset, self).__init__()

        if instance_in_train is None:
            self.instance_in_train = set()
        else:
            self.instance_in_train = instance_in_train
        self.data = None
        self.document_max_length = 512
        self.bert = Bert(model_name)
        self.dataset = dataset
        self.rel2id = rel2id
        self.ner2id = ner2id
        print("Reading data from {}.".format(src_file))

        self.create_data(src_file)
        self.get_instance_in_train()

    def get_instance_in_train(self):
        for doc in self.data:
            entity_list = doc["vertexSet"]
            labels = doc.get("labels", [])
            for label in labels:
                head, tail, relation = label["h"], label["t"], label["r"]
                label["r"] = self.rel2id[relation]
                if self.dataset == "train":
                    for n1 in entity_list[head]:
                        for n2 in entity_list[tail]:
                            mention_triple = (n1["name"], n2["name"], relation)
                            self.instance_in_train.add(mention_triple)

    def process_doc(self, doc, dataset, ner2id, bert):
        title, entity_list = doc["title"], doc["vertexSet"]
        labels, sentences = doc.get("labels", []), doc["sents"]

        Ls = [0]
        L = 0
        for x in sentences:
            L += len(x)
            Ls.append(L)
        for j in range(len(entity_list)):
            for k in range(len(entity_list[j])):
                sent_id = int(entity_list[j][k]["sent_id"])
                entity_list[j][k]["sent_id"] = sent_id

                dl = Ls[sent_id]
                pos0, pos1 = entity_list[j][k]["pos"]
                entity_list[j][k]["global_pos"] = (pos0 + dl, pos1 + dl)

        # generate positive examples
        train_triple = []
        new_labels = []
        for label in labels:
            head, tail, relation = label["h"], label["t"], label["r"]
            # label['r'] = rel2id[relation]
            train_triple.append((head, tail))
            label["in_train"] = False

            # record training set mention triples and mark for dev and test
            for n1 in entity_list[head]:
                for n2 in entity_list[tail]:
                    mention_triple = (n1["name"], n2["name"], relation)
                    if dataset != "train":
                        if mention_triple in self.instance_in_train:
                            label["in_train"] = True
                            break

            new_labels.append(label)

        # generate negative examples
        na_triple = []
        for j in range(len(entity_list)):
            for k in range(len(entity_list)):
                if j != k and (j, k) not in train_triple:
                    na_triple.append((j, k))

        # generate document ids
        words = []
        for sentence in sentences:
            for word in sentence:
                words.append(word)

        bert_token, bert_starts, bert_subwords = bert.subword_tokenize_to_ids(words)

        word_id = np.zeros((self.document_max_length,), dtype=np.int32)
        pos_id = np.zeros((self.document_max_length,), dtype=np.int32)
        ner_id = np.zeros((self.document_max_length,), dtype=np.int32)
        mention_id = np.zeros((self.document_max_length,), dtype=np.int32)
        word_id[:] = bert_token[0]

        entity2mention = defaultdict(list)
        mention_idx = 1
        already_exist = set()
        pos_idx = {}
        ent_idx = {}
        for idx, vertex in enumerate(entity_list, 1):
            for v in vertex:

                sent_id, ner_type = v["sent_id"], v["type"]
                pos0_w, pos1_w = v["global_pos"]

                pos0 = bert_starts[pos0_w]
                if pos1_w < len(bert_starts):
                    pos1 = bert_starts[pos1_w]
                else:
                    pos1 = self.document_max_length

                if (pos0, pos1) in already_exist:
                    continue

                if pos0 >= len(pos_id):
                    continue

                if idx not in pos_idx:
                    pos_idx[idx] = []
                    ent_idx[idx] = []

                pos_idx[idx].extend(range(pos0_w, pos1_w))
                ent_idx[idx].extend(range(pos0, pos1))
                pos_id[pos0:pos1] = idx
                ner_id[pos0:pos1] = ner2id[ner_type]
                mention_id[pos0:pos1] = mention_idx
                entity2mention[idx].append(mention_idx)
                mention_idx += 1
                already_exist.add((pos0, pos1))

        # ======================================================
        # compute subword to word index
        sub2word = np.zeros(
            (len(bert_starts) + len(entity_list), self.document_max_length)
        )
        for idx in range(len(bert_starts) - 1):
            start, end = bert_starts[idx], bert_starts[idx + 1]
            if start == end:
                continue
            sub2word[idx, start:end] = 1 / (end - start)
        start, end = bert_starts[-1], len(bert_subwords)
        sub2word[len(bert_starts) - 1, start:end] = 1 / (end - start)
        # compute convertion matrix for entity
        for idx, poss in ent_idx.items():
            # print('------------>', idx, poss)
            sub2word[len(bert_starts) + idx - 1, poss] = 1 / len(poss)
        # ======================================================
        # compute words to sent index
        word2sent = np.zeros((len(Ls) - 1, Ls[-1]))
        for i in range(1, len(Ls)):
            word2sent[i - 1, Ls[i - 1] : Ls[i]] = 1 / (Ls[i] - Ls[i - 1])
        # ======================================================

        replace_i = 0
        idx = len(entity_list)
        if entity2mention[idx] == []:
            entity2mention[idx].append(mention_idx)
            while mention_id[replace_i] != 0:
                replace_i += 1
            mention_id[replace_i] = mention_idx
            pos_id[replace_i] = idx
            ner_id[replace_i] = ner2id[vertex[0]["type"]]
            mention_idx += 1

        new_Ls = [0]
        for ii in range(1, len(Ls)):
            if Ls[ii] < len(bert_starts):
                new_Ls.append(bert_starts[Ls[ii]])
            else:
                new_Ls.append(len(bert_subwords))

        Ls = new_Ls

        graph2, path2 = build_g(sentences, pos_idx, pos_id.max())

        return {
            "title": title,
            "num_sent": len(doc["sents"]),
            "entities": entity_list,
            "labels": new_labels,
            "na_triple": na_triple,
            "word_id": word_id,
            "pos_id": pos_id,
            "ner_id": ner_id,
            "sub2word": sub2word,
            "word2sent": word2sent,
            "graph2": graph2,
            "path2": path2,
        }

    def create_data(self, src_file):
        with open(file=src_file, mode="r", encoding="utf-8") as fr:
            ori_data = json.load(fr)
        self.data = ori_data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        doc = self.data[idx]
        cur_d = self.process_doc(
            doc, dataset=self.dataset, ner2id=self.ner2id, bert=self.bert
        )
        return cur_d

    def __iter__(self):
        return iter(self.data)

In [None]:
train_set = BERTDGLREDataset(
    args.train_set,
    data_opt.ner2id,
    data_opt.rel2id,
    dataset="train",
    model_name=args.model_name,
)

dev_set = BERTDGLREDataset(
    args.dev_set,
    data_opt.ner2id,
    data_opt.rel2id,
    dataset="dev",
    instance_in_train=train_set.instance_in_train,
    model_name=args.model_name,
)

## DataLoader


In [None]:
import pickle


class DGLREDataloader(DataLoader):

    def __init__(
        self,
        dataset,
        batch_size,
        shuffle=False,
        h_t_limit=1722,
        relation_num=97,
        max_length=512,
        negativa_alpha=0.0,
        dataset_type="train",
    ):
        super(DGLREDataloader, self).__init__(
            dataset, batch_size=batch_size, num_workers=4
        )
        self.shuffle = shuffle
        self.length = len(self.dataset)
        self.max_length = max_length
        self.negativa_alpha = negativa_alpha
        self.dataset_type = dataset_type
        self.h_t_limit = h_t_limit
        self.relation_num = relation_num
        self.order = list(range(self.length))
        self.data = []
        self.boosted = 0
        # for idx in tqdm(range(self.length)):
        #     self.data.append(self.dataset[idx])
        #     self.data[idx]["idx"] = idx
        with open(
            f"/home/hwiric/SagDRE-2024-07-05/data/DocRED/{self.dataset_type}_data.pkl",
            "rb",
        ) as handle:
            self.data = pickle.load(handle)

    def __iter__(self):
        if self.shuffle:
            random.shuffle(self.order)
        batch_num = math.ceil(self.length / self.batch_size)
        self.batches = [
            (idx * self.batch_size, min(self.length, (idx + 1) * self.batch_size))
            for idx in range(0, batch_num)
        ]
        self.batches_order = [
            self.order[
                idx * self.batch_size : min(self.length, (idx + 1) * self.batch_size)
            ]
            for idx in range(0, batch_num)
        ]

        # begin
        context_word_ids = torch.LongTensor(self.batch_size, self.max_length).cpu()
        context_pos_ids = torch.LongTensor(self.batch_size, self.max_length).cpu()
        context_ner_ids = torch.LongTensor(self.batch_size, self.max_length).cpu()
        context_word_mask = torch.LongTensor(self.batch_size, self.max_length).cpu()
        context_word_length = torch.LongTensor(self.batch_size).cpu()
        ht_pairs = torch.LongTensor(self.batch_size, self.h_t_limit, 2).cpu()
        relation_multi_label = torch.Tensor(
            self.batch_size, self.h_t_limit, self.relation_num
        ).cpu()
        relation_mask = torch.Tensor(self.batch_size, self.h_t_limit).cpu()
        relation_example_idx = torch.LongTensor(self.batch_size, self.h_t_limit).cpu()
        relation_label = torch.LongTensor(self.batch_size, self.h_t_limit).cpu()

        for idx, (batch_s, batch_e) in enumerate(self.batches):
            minibatch = [self.data[idx] for idx in self.order[batch_s:batch_e]]
            cur_bsz = len(minibatch)

            for mapping in [
                context_word_ids,
                context_pos_ids,
                context_ner_ids,
                context_word_mask,
                context_word_length,
                ht_pairs,
                relation_multi_label,
                relation_mask,
                relation_label,
                relation_example_idx,
            ]:
                if mapping is not None:
                    mapping.zero_()

            relation_label.fill_(IGNORE_INDEX)

            max_h_t_cnt = 0

            label_list = []
            L_vertex = []
            titles = []
            indexes = []
            graphs = []
            path2_table = []
            sub2word_list = []
            word2sent_list = []

            for i, example in enumerate(minibatch):
                entities, labels, na_triple, word_id, pos_id, ner_id = (
                    example["entities"],
                    example["labels"],
                    example["na_triple"],
                    example["word_id"],
                    example["pos_id"],
                    example["ner_id"],
                )
                # graphs.append(dgl.add_self_loop(example["graph2"]).to("cuda:0"))
                graphs.append(dgl.add_self_loop(example["graph2"]))
                path2_table.append(example["path2"])

                prewrong = example.get("wrong_predits", [])

                sub2word_list.append(torch.Tensor(example["sub2word"]))
                word2sent_list.append(torch.Tensor(example["word2sent"]))
                L = len(entities)
                word_num = word_id.shape[0]

                context_word_ids[i, :word_num].copy_(torch.from_numpy(word_id))
                context_pos_ids[i, :word_num].copy_(torch.from_numpy(pos_id))
                context_ner_ids[i, :word_num].copy_(torch.from_numpy(ner_id))

                idx2label = defaultdict(list)
                evid2label = defaultdict(list)
                label_set = {}
                for label in labels:
                    head, tail, relation, intrain = (
                        label["h"],
                        label["t"],
                        label["r"],
                        label["in_train"],
                    )
                    idx2label[(head, tail)].append(relation)
                    evid2label[(head, tail)].extend(label["evidence"])
                    label_set[(head, tail, relation)] = intrain

                label_list.append(label_set)

                if self.dataset_type == "train":
                    train_tripe = list(idx2label.keys())
                    na_train_triple = set()
                    for j, (h_idx, t_idx) in enumerate(train_tripe):
                        ht_pairs[i, j, :] = torch.Tensor([h_idx + 1, t_idx + 1])
                        label = idx2label[(h_idx, t_idx)]
                        for r in label:
                            relation_multi_label[i, j, r] = 1

                        relation_mask[i, j] = 1
                        relation_label[i, j] = 1
                        relation_example_idx[i, j] = example["idx"]

                    # =========================================
                    # This is for forcing selecting challenging negative pairs
                    #     if (t_idx, h_idx) not in idx2label:
                    #         na_train_triple.add((t_idx, h_idx))

                    to_sample = min(len(train_tripe), int(len(prewrong) * 0.1))
                    na_train_triple = random.sample(prewrong, to_sample)
                    self.boosted += to_sample
                    for j, (h_idx, t_idx) in enumerate(
                        na_train_triple, len(train_tripe)
                    ):
                        ht_pairs[i, j, :] = torch.Tensor([h_idx + 1, t_idx + 1])
                        relation_multi_label[i, j, 0] = 1
                        relation_label[i, j] = 0
                        relation_mask[i, j] = 1
                        relation_example_idx[i, j] = example["idx"]
                    # =========================================

                    lower_bound = len(na_triple)
                    if self.negativa_alpha > 0.0:
                        random.shuffle(na_triple)
                        lower_bound = int(
                            max(20, len(train_tripe) * self.negativa_alpha)
                        )
                    lower_bound -= len(na_train_triple)

                    for j, (h_idx, t_idx) in enumerate(
                        na_triple[:lower_bound], len(train_tripe) + len(na_train_triple)
                    ):
                        ht_pairs[i, j, :] = torch.Tensor([h_idx + 1, t_idx + 1])
                        relation_multi_label[i, j, 0] = 1
                        relation_label[i, j] = 0
                        relation_mask[i, j] = 1
                        relation_example_idx[i, j] = example["idx"]

                    max_h_t_cnt = max(
                        max_h_t_cnt,
                        len(train_tripe) + lower_bound + len(na_train_triple),
                    )
                else:
                    j = 0
                    for h_idx in range(L):
                        for t_idx in range(L):
                            if h_idx != t_idx:
                                ht_pairs[i, j, :] = torch.Tensor([h_idx + 1, t_idx + 1])
                                relation_mask[i, j] = 1
                                relation_example_idx[i, j] = example["idx"]
                                j += 1

                    max_h_t_cnt = max(max_h_t_cnt, j)
                    L_vertex.append(L)
                    # titles.append(example["title"])
                    indexes.append(self.batches_order[idx][i])
                titles.append(example["title"])

            context_word_mask = context_word_ids > 0
            context_word_length = context_word_mask.sum(1)
            batch_max_length = context_word_length.max()
            sub2word_list = [sw[:, :batch_max_length] for sw in sub2word_list]
            word2sent_list = [ws for ws in word2sent_list]

            # yield {
            #     "context_idxs": get_cuda(
            #         context_word_ids[:cur_bsz, :batch_max_length].contiguous()
            #     ),
            #     "context_pos": get_cuda(
            #         context_pos_ids[:cur_bsz, :batch_max_length].contiguous()
            #     ),
            #     "context_ner": get_cuda(
            #         context_ner_ids[:cur_bsz, :batch_max_length].contiguous()
            #     ),
            #     "context_word_mask": get_cuda(
            #         context_word_mask[:cur_bsz, :batch_max_length].contiguous()
            #     ),
            #     "context_word_length": get_cuda(
            #         context_word_length[:cur_bsz].contiguous()
            #     ),
            #     "h_t_pairs": get_cuda(ht_pairs[:cur_bsz, :max_h_t_cnt, :2]),
            #     "relation_label": get_cuda(
            #         relation_label[:cur_bsz, :max_h_t_cnt]
            #     ).contiguous(),
            #     "relation_multi_label": get_cuda(
            #         relation_multi_label[:cur_bsz, :max_h_t_cnt]
            #     ),
            #     "relation_mask": get_cuda(relation_mask[:cur_bsz, :max_h_t_cnt]),
            #     "relation_example_idx": relation_example_idx[:cur_bsz, :max_h_t_cnt],
            #     "labels": label_list,
            #     "graph2s": graphs,
            #     "sub2words": sub2word_list,
            #     "word2sents": word2sent_list,
            #     "path2_table": path2_table,
            #     "L_vertex": L_vertex,
            #     "titles": titles,
            #     "indexes": indexes,
            # }

            yield {
                "context_idxs": context_word_ids[
                    :cur_bsz, :batch_max_length
                ].contiguous(),
                "context_pos": context_pos_ids[
                    :cur_bsz, :batch_max_length
                ].contiguous(),
                "context_ner": context_ner_ids[
                    :cur_bsz, :batch_max_length
                ].contiguous(),
                "context_word_mask": context_word_mask[
                    :cur_bsz, :batch_max_length
                ].contiguous(),
                "context_word_length": context_word_length[:cur_bsz].contiguous(),
                "h_t_pairs": ht_pairs[:cur_bsz, :max_h_t_cnt, :2],
                "relation_label": relation_label[:cur_bsz, :max_h_t_cnt].contiguous(),
                "relation_multi_label": relation_multi_label[:cur_bsz, :max_h_t_cnt],
                "relation_mask": relation_mask[:cur_bsz, :max_h_t_cnt],
                "relation_example_idx": relation_example_idx[:cur_bsz, :max_h_t_cnt],
                "labels": label_list,
                "graph2s": graphs,
                "sub2words": sub2word_list,
                "word2sents": word2sent_list,
                "path2_table": path2_table,
                "L_vertex": L_vertex,
                "titles": titles,
                "indexes": indexes,
            }

    def feedback(self, m_preds, m_label, r_mask, h_t_pairs, relation_example_idx):
        output_m = torch.argmax(m_preds, dim=-1)
        output_m = output_m.data.cpu().numpy()
        m_label = m_label.data.cpu().numpy()
        r_mask = r_mask.data.cpu().numpy()
        h_t_pairs = h_t_pairs.data.cpu().numpy()
        wrong_predits = {}
        for i in range(len(r_mask)):
            for j in range(len(r_mask[0])):
                idx = int(relation_example_idx[i, j])
                ent_0, ent_1 = h_t_pairs[i, j, 0], h_t_pairs[i, j, 1]
                if r_mask[i, j] == 1 and m_label[i, j, output_m[i, j]] == 0:
                    if idx not in wrong_predits:
                        wrong_predits[idx] = []
                    wrong_predits[idx].append((ent_0 - 1, ent_1 - 1))
        for idx in wrong_predits:
            self.data[idx]["wrong_predits"] = wrong_predits.get(idx, [])

In [None]:
train_loader = DGLREDataloader(
    train_set,
    batch_size=args.batch_size,
    shuffle=True,
    negativa_alpha=args.negativa_alpha,
)

dev_loader = DGLREDataloader(
    dev_set, batch_size=args.test_batch_size, dataset_type="dev"
)

## Model

In [None]:
import os
from datetime import datetime
import numpy as np
import torch
import torch.nn.functional as F
import random
from torch import nn
import dgl.nn.pytorch as dglnn


class GCNLayer(nn.Module):

    def __init__(self, in_dim, out_dim, activation=None, dropout=0.0):
        super(GCNLayer, self).__init__()
        self.proj = nn.Linear(in_dim, out_dim)
        self.act = activation
        self.drop = nn.Dropout(p=dropout) if dropout > 0.0 else nn.Identity()

    def forward(self, A, X):
        X = self.drop(X)
        X = torch.matmul(A, X)
        X = self.proj(X)
        X = self.act(X) if self.act else X
        return X


class AttnLayer(nn.Module):
    def __init__(self, in_feats, activation=None, dropout=0.0):
        super(AttnLayer, self).__init__()
        self.attn = nn.MultiheadAttention(in_feats, 8, dropout=dropout)
        self.activation = activation
        self.v_proj = nn.Linear(in_feats, in_feats)

    def forward(self, query, key, value):
        query = query.unsqueeze(1)
        key = key.unsqueeze(1)
        value = self.v_proj(value)
        value = value.unsqueeze(1)
        out_fea = self.attn(query, key, value, need_weights=False)[0]
        out_fea = out_fea.squeeze(1)
        if self.activation:
            return self.activation(out_fea)
        return out_fea


def norm_g(g):
    degrees = torch.sum(g, 1)
    g = g / degrees
    return g


def filter_g(g, features):
    cos = nn.CosineSimilarity(dim=1, eps=1e-6)
    sim_A = cos(features.unsqueeze(2), features.t().unsqueeze(0))
    # adj = g.adj().to_dense().cuda()
    adj = g.adj().to_dense()
    unorder = ((adj + adj.t()) == 2).float()
    # print('================>', (adj - unorder).sum())
    ordered = (adj - unorder) * (sim_A > 0.6)
    # print('================>', ordered.sum())
    A = ordered + unorder
    return norm_g(A)

In [None]:
from collections import defaultdict
import torch
import torch.nn as nn
from transformers import BertModel


class SAGDRE_BERT(nn.Module):
    def __init__(self, config):
        super(SAGDRE_BERT, self).__init__()
        self.config = config
        self.activation = nn.ReLU()
        self.entity_type_emb = nn.Embedding(
            config.entity_type_num,
            config.entity_type_size,
            padding_idx=config.entity_type_pad,
        )
        self.entity_id_emb = nn.Embedding(
            config.max_entity_num + 1,
            config.entity_id_size,
            padding_idx=config.entity_id_pad,
        )

        if config.model_name == "SAGDRE_BERT_base":
            self.bert = BertModel.from_pretrained("bert-base-cased", return_dict=False)
        else:
            self.bert = BertModel.from_pretrained("bert-large-cased", return_dict=False)

        self.start_dim = config.bert_hid_size

        # if config.use_entity_type:
        self.start_dim += config.entity_type_size + config.entity_id_size

        self.gcn_dim = config.gcn_dim

        self.start_gcn = GCNLayer(self.start_dim, self.gcn_dim)

        self.GCNs = nn.ModuleList(
            [
                GCNLayer(self.gcn_dim, self.gcn_dim, activation=self.activation)
                for i in range(config.gcn_layers)
            ]
        )

        self.Attns = nn.ModuleList(
            [
                AttnLayer(self.gcn_dim, activation=self.activation)
                for _ in range(config.gcn_layers)
            ]
        )

        self.bank_size = self.start_dim + self.gcn_dim * (self.config.gcn_layers + 1)
        self.dropout = nn.Dropout(self.config.dropout)

        self.rnn = nn.LSTM(
            self.bank_size, self.bank_size, 2, bidirectional=False, batch_first=True
        )

        self.path_attn = nn.MultiheadAttention(self.bank_size, 4)

        self.predict2 = nn.Sequential(
            nn.Linear(self.bank_size * 5, self.bank_size * 5),
            self.activation,
            self.dropout,
        )

        self.out_linear = nn.Linear(self.bank_size * 5, config.relation_nums)
        self.out_linear_binary = nn.Linear(self.bank_size * 5, 2)

    def forward(self, **params):
        # words = params["words"].cuda()
        # mask = params["mask"].cuda()
        words = params["words"]
        mask = params["mask"]
        bsz = words.size(0)

        encoder_outputs, sent_cls = self.bert(input_ids=words, attention_mask=mask)
        encoder_outputs = torch.cat(
            [
                encoder_outputs,
                self.entity_type_emb(params["entity_type"]),
                self.entity_id_emb(params["entity_id"]),
            ],
            dim=-1,
        )

        graphs = params["graph2s"]
        sub2words = params["sub2words"]
        features = []

        for i, graph in enumerate(graphs):
            encoder_output = encoder_outputs[i]
            sub2word = sub2words[i]
            x = torch.mm(sub2word, encoder_output)
            graph = filter_g(graph, x)
            xs = [x]
            x = self.start_gcn(graph, x)
            xs.append(x)
            for GCN, Attn in zip(self.GCNs, self.Attns):
                x1 = GCN(graph, x)
                x2 = Attn(x, x1, x1)
                x = x1 + x2
                xs.append(x)
            out_feas = torch.cat(xs, dim=-1)
            features.append(out_feas)

        h_t_pairs = params["h_t_pairs"]
        h_t_pairs = h_t_pairs + (h_t_pairs == 0).long() - 1
        h_t_limit = h_t_pairs.size(1)
        # path_info = torch.zeros((bsz, h_t_limit, self.bank_size)).cuda()
        path_info = torch.zeros((bsz, h_t_limit, self.bank_size))
        rel_mask = params["relation_mask"]
        path_table = params["path2_table"]

        path_len_dict = defaultdict(list)

        entity_num = torch.max(params["entity_id"])
        # entity_bank = torch.Tensor(bsz, entity_num, self.bank_size).cuda()
        entity_bank = torch.Tensor(bsz, entity_num, self.bank_size)

        for i in range(len(graphs)):
            max_id = torch.max(params["entity_id"][i])
            entity_feas = features[i][-max_id:]
            entity_bank[i, : entity_feas.size(0)] = entity_feas
            path_t = path_table[i]
            for j in range(h_t_limit):
                h_ent = h_t_pairs[i, j, 0].item()
                t_ent = h_t_pairs[i, j, 1].item()

                if rel_mask is not None and rel_mask[i, j].item() == 0:
                    break

                if rel_mask is None and h_ent == 0 and t_ent == 0:
                    continue

                # path = path_t[(h_ent+1, t_ent+1)]
                paths = path_t[(h_ent + 1, t_ent + 1)]
                for path in paths:
                    # path = torch.LongTensor(path).cuda()
                    path = torch.LongTensor(path)
                    cur_h = torch.index_select(features[i], 0, path)
                    path_len_dict[len(path)].append((i, j, cur_h))

        h_ent_idx = h_t_pairs[:, :, 0].unsqueeze(-1).expand(-1, -1, self.bank_size)
        t_ent_idx = h_t_pairs[:, :, 1].unsqueeze(-1).expand(-1, -1, self.bank_size)
        h_ent_feas = torch.gather(input=entity_bank, dim=1, index=h_ent_idx)
        t_ent_feas = torch.gather(input=entity_bank, dim=1, index=t_ent_idx)

        path_embedding = {}

        for items in path_len_dict.values():
            cur_hs = torch.stack([h for _, _, h in items], 0)
            cur_hs2, _ = self.rnn(cur_hs)
            cur_hs = cur_hs2.max(1)[0]
            for idx, (i, j, _) in enumerate(items):
                if (i, j) not in path_embedding:
                    path_embedding[(i, j)] = []
                path_embedding[(i, j)].append(cur_hs[idx])

        querys = h_ent_feas - t_ent_feas

        for (i, j), emb in path_embedding.items():
            query = querys[i : i + 1, j : j + 1]
            keys = torch.stack(emb).unsqueeze(1)
            output, attn_weights = self.path_attn(query, keys, keys)
            path_info[i, j] = output.squeeze(0).squeeze(0)

        out_feas = torch.cat(
            [
                h_ent_feas,
                t_ent_feas,
                torch.abs(h_ent_feas - t_ent_feas),
                torch.mul(h_ent_feas, t_ent_feas),
                path_info,
            ],
            dim=-1,
        )
        out_feas = self.predict2(out_feas)
        m_preds = self.out_linear(out_feas)
        b_preds = self.out_linear_binary(out_feas)
        return m_preds, b_preds, None

In [None]:
model = SAGDRE_BERT(args)

# MyUtils


In [None]:
from collections import defaultdict
from transformers import BertModel
import torch
from operator import itemgetter
from random import shuffle
import pickle


def search_sample_batch_with_title(title, data_loader):

    sample = None

    for batch in tqdm(train_loader):
        t1, t2 = batch["titles"]
        if "Mighty" in t1 or "Mighty" in t2:
            sample = batch
            break

    return sample


def save_dataset_with_index(train_set, dev_set):

    train_data = []
    dev_data = []

    for idx in tqdm(range(len(train_set))):
        train_data.append(train_set[idx])
        train_data[idx]["idx"] = idx

    for idx in tqdm(range(len(dev_set))):
        dev_data.append(dev_set[idx])
        dev_data[idx]["idx"] = idx

    with open(
        "/home/hwiric/SagDRE-2024-07-05/data/DocRED/train_data.pkl", "wb"
    ) as handle:
        pickle.dump(train_data, handle, protocol=pickle.HIGHEST_PROTOCOL)

    with open(
        "/home/hwiric/SagDRE-2024-07-05/data/DocRED/dev_data.pkl", "wb"
    ) as handle:
        pickle.dump(dev_data, handle, protocol=pickle.HIGHEST_PROTOCOL)


def shuffle_indicies(data_set, last_index):

    shuffled_indices = list(range(len(data_set)))
    shuffle(shuffled_indices)

    return shuffled_indices[:last_index]


def print_children_nodes(edges, id_to_word, id):

    children_ids = edges[id]
    children_words = itemgetter(*edges[id])(id_to_word)
    children = list(zip(children_ids, children_words))

    print(f"({id_to_word[id]})의 자식 노드: {children}")


# entities added
def id_to_word(raw_sample):

    id_to_word_dict = defaultdict()
    id = 0

    for sentence in raw_sample["sents"]:
        for word in sentence:
            id_to_word_dict[id] = word
            id += 1

    for entity in raw_sample["vertexSet"]:
        entity_name = entity[0]["name"]
        id_to_word_dict[id] = entity_name
        id += 1

    return id_to_word_dict


class SampleEncoder:
    def __init__(self, words_id) -> None:

        self.bert = BertModel.from_pretrained("bert-base-cased", return_dict=False)
        self.words_id = torch.unsqueeze(torch.from_numpy(words_id), 0)
        self.mask = torch.unsqueeze(torch.from_numpy(words_id > 0), 0)

    def encode(self):

        output, x = self.bert(input_ids=self.words_id, attention_mask=self.mask)

        return output, x


def convert_edges_format_dgl2dict(src, dst):

    src = src.detach().cpu().numpy()
    dst = dst.detach().cpu().numpy()

    edges = defaultdict(list)
    for element in list(zip(src, dst)):
        edges[element[0]].append(element[1])

    return edges


def raw_data(src_file):

    with open(file=src_file, mode="r", encoding="utf-8") as fr:
        ori_data = json.load(fr)

    return ori_data


def subwords_cnt_dist(raw_data):

    my_bert = Bert("SAGDRE_BERT_base")
    subwords_cnt = []

    for sample in raw_data:
        words = []
        for sentence in sample["sents"]:
            for word in sentence:
                words.append(word)

        bert_token, bert_starts, bert_subwords = my_bert.subword_tokenize_to_ids(words)
        subwords_cnt.append(len(bert_subwords))

    return subwords_cnt

# Temp
