# Setting


In [1]:
from easydict import EasyDict
import json
import os
import numpy as np


def config():
    args = EasyDict()
    args.train_set = "/home/hwiric/SagDRE-2024-07-05/data/DocRED/train_annotated.json"
    args.dev_set = "/home/hwiric/SagDRE-2024-07-05/data/DocRED/dev.json"
    args.test_set = "/home/hwiric/SagDRE-2024-07-05/data/DocRED/test.json"

    args.checkpoint_dir = "checkpoint"
    args.model_name = "SAGDRE_BERT_base"
    args.pretrain_model = ""

    args.relation_nums = 97
    args.entity_type_num = 7
    args.max_entity_num = 80

    args.word_pad = 0
    args.entity_type_pad = 0
    args.entity_id_pad = 0

    args.word_emb_size = 10
    args.use_entity_type = "store_true"
    args.entity_type_size = 20

    args.use_entity_id = "store_true"
    args.entity_id_size = 20

    args.nlayers = 1
    args.lstm_hidden_size = 32
    args.lstm_dropout = 0.4

    args.lr = 0.001
    args.batch_size = 1
    args.test_batch_size = 1
    args.epoch = 40
    args.test_epoch = 5
    args.weight_decay = 0.0001
    args.negativa_alpha = 4
    args.save_model_freq = 1

    args.gcn_layers = 2
    args.gcn_dim = 128
    args.dropout = 0.6
    args.activation = "relu"

    args.bert_hid_size = 768
    args.coslr = "store_true"

    args.use_model = "bert"

    args.input_theta = 1.0

    return args


class Object(object):
    pass


args = config()

data_opt = Object()
data_opt.data_dir = "/home/hwiric/SagDRE-2024-07-05/data/DocRED"
data_opt.rel2id = json.load(open(os.path.join(data_opt.data_dir, "rel2id.json"), "r"))
data_opt.id2rel = {v: k for k, v in data_opt.rel2id.items()}
data_opt.word2id = json.load(open(os.path.join(data_opt.data_dir, "word2id.json"), "r"))
data_opt.ner2id = json.load(open(os.path.join(data_opt.data_dir, "ner2id.json"), "r"))
data_opt.word2vec = np.load(os.path.join(data_opt.data_dir, "vec.npy"))

In [2]:
import json
import math
import random
from collections import defaultdict
import dgl
import numpy as np
import torch
from torch.utils.data import IterableDataset, DataLoader
from tqdm.std import tqdm

  from .autonotebook import tqdm as notebook_tqdm
  return torch._C._cuda_getDeviceCount() > 0


In [3]:
import os
from datetime import datetime
import numpy as np
import torch
import torch.nn.functional as F
import random
from torch import nn
import dgl.nn.pytorch as dglnn
from transformers import BertTokenizer


def get_cuda(tensor):
    if torch.cuda.is_available():
        return tensor.cuda()
    return tensor


class Bert:
    MASK = "[MASK]"
    CLS = "[CLS]"
    SEP = "[SEP]"

    def __init__(self, model_name):
        super().__init__()
        self.model_name = model_name
        if model_name == "SAGDRE_BERT_base":
            self.tokenizer = BertTokenizer.from_pretrained("bert-base-cased")
        else:
            self.tokenizer = BertTokenizer.from_pretrained("bert-large-cased")
        self.max_len = 512

    def tokenize(self, text, masked_idxs=None):
        tokenized_text = self.tokenizer.tokenize(text)
        if masked_idxs is not None:
            for idx in masked_idxs:
                tokenized_text[idx] = self.MASK
        tokenized = [self.CLS] + tokenized_text + [self.SEP]
        return tokenized

    def tokenize_to_ids(self, text, masked_idxs=None, pad=True):
        tokens = self.tokenize(text, masked_idxs)
        return tokens, self.convert_tokens_to_ids(tokens, pad=pad)

    def convert_tokens_to_ids(self, tokens, pad=True):
        token_ids = self.tokenizer.convert_tokens_to_ids(tokens)
        ids = torch.tensor([token_ids])
        ids = ids[:, : self.max_len]
        if pad:
            padded_ids = torch.zeros(1, self.max_len).to(ids)
            padded_ids[0, : ids.size(1)] = ids
            mask = torch.zeros(1, self.max_len).to(ids)
            mask[0, : ids.size(1)] = 1
            return padded_ids, mask
        else:
            return ids

    def flatten(self, list_of_lists):
        for list in list_of_lists:
            for item in list:
                yield item

    def subword_tokenize(self, tokens):
        """Segment each token into subwords while keeping track of
        token boundaries.
        Parameters
        ----------
        tokens: A sequence of strings, representing input tokens.
        Returns
        -------
        A tuple consisting of:
            - A list of subwords, flanked by the special symbols required
                by Bert (CLS and SEP).
            - An array of indices into the list of subwords, indicating
                that the corresponding subword is the start of a new
                token. For example, [1, 3, 4, 7] means that the subwords
                1, 3, 4, 7 are token starts, while all other subwords
                (0, 2, 5, 6, 8...) are in or at the end of tokens.
                This list allows selecting Bert hidden states that
                represent tokens, which is necessary in sequence
                labeling.
        """
        subwords = list(map(self.tokenizer.tokenize, tokens))
        subword_lengths = list(map(len, subwords))
        subwords = [self.CLS] + list(self.flatten(subwords))[:509] + [self.SEP]
        token_start_idxs = 1 + np.cumsum([0] + subword_lengths[:-1])
        token_start_idxs[token_start_idxs > 509] = 512
        return subwords, token_start_idxs

    def subword_tokenize_to_ids(self, tokens):
        """Segment each token into subwords while keeping track of
        token boundaries and convert subwords into IDs.
        Parameters
        ----------
        tokens: A sequence of strings, representing input tokens.
        Returns
        -------
        A tuple consisting of:
            - A list of subword IDs, including IDs of the special
                symbols (CLS and SEP) required by Bert.
            - A mask indicating padding tokens.
            - An array of indices into the list of subwords. See
                doc of subword_tokenize.
        """
        subwords, token_start_idxs = self.subword_tokenize(tokens)
        subword_ids, mask = self.convert_tokens_to_ids(subwords)
        return subword_ids.numpy(), token_start_idxs, subwords

    def segment_ids(self, segment1_len, segment2_len):
        ids = [0] * segment1_len + [1] * segment2_len
        return torch.tensor([ids])

In [4]:
import spacy
import dgl
from spacy.tokens import Doc
import networkx as nx
import matplotlib.pyplot as plt

nlp = spacy.load("en_core_web_lg")
# nlp.tokenizer = nlp.tokenizer.tokens_from_list


def build_g(sentences, pos_idx, max_id):
    # sentences should be a list of word lists
    # [[sent_1], [sent_2], ..., [sent_m]]
    # senti = [w_0, w_1, ..., w_n]
    pre_roots = []
    g = nx.DiGraph()
    docs = [Doc(nlp.vocab, words=ws) for ws in sentences]
    # tokens = parser(doc)
    for tokens in nlp.pipe(docs):
        g, pre_roots = parse_sent(tokens, g, pre_roots)
    # g = add_same_words_links(g, remove_stopwords=True)
    g, start = add_entity_node(g, pos_idx, max_id)
    paths = get_entity_paths(g, max_id, start)
    g = dgl.from_networkx(g)
    return g, paths


def parse_sent(tokens, g, pre_roots):
    roots = [token for token in tokens if token.head == token]
    start = len(g.nodes())
    dic = {}
    idx = 0
    for token in tokens:
        is_root = token in roots
        g.add_node(
            start + idx,
            text=token.text,
            vector=token.vector,
            is_root=is_root,
            tag=token.tag_,
            pos=token.pos_,
            dep=token.dep_,
        )
        dic[token] = start + idx
        idx += 1

    for token in tokens:
        g.add_edge(dic[token], dic[token.head], dep=token.dep_)
        g.add_edge(dic[token.head], dic[token], dep=token.dep_)

    for idx, root in enumerate(roots[:-1]):
        g.add_edge(dic[root], dic[roots[idx + 1]], dep=token.dep_)
        g.add_edge(dic[roots[idx + 1]], dic[root], dep=token.dep_)

    if pre_roots:
        pre_root_idx = pre_roots[-1]
        for root in roots[:1]:
            g.add_edge(dic[root], pre_root_idx, dep="rootconn")
            g.add_edge(pre_root_idx, dic[root], dep="rootconn")
    for pre_root_idx in pre_roots:
        # pre_root_idx = pre_roots[-2]
        # g.add_edge(dic[root], pre_root_idx, dep='rootconn')
        for root in roots[:1]:
            g.add_edge(pre_root_idx, dic[root], dep="rootconn")
    for root in roots[:1]:
        pre_roots.append(dic[root])
    return g, pre_roots


def add_same_words_links(g, remove_stopwords=True):
    names = nx.get_node_attributes(g, "text")
    name_dic = {}
    stopwords = nlp.Defaults.stop_words
    for idx, name in names.items():
        name = name.lower()
        if remove_stopwords and name in stopwords:
            continue
        if len(name) < 5:
            continue
        if name not in name_dic:
            name_dic[name] = [idx]
        else:
            for pre_idx in name_dic[name]:
                g.add_edge(idx, pre_idx)
                g.add_edge(pre_idx, idx)
            name_dic[name].append(idx)
    return g


def add_entity_node(g, pos_idx, max_id):
    start = len(g.nodes())
    for idx in range(max_id):
        g.add_node(start + idx, text="entity_%s" % idx, is_root=False)
        if idx + 1 not in pos_idx:
            continue
        for idx2, node in enumerate(pos_idx[idx + 1]):
            g.add_edge(start + idx, node)
            g.add_edge(node, start + idx)
    return g, start


def get_entity_paths(g, max_id, start):
    ent_paths = {}
    pos_data = nx.get_node_attributes(g, "pos")
    for idx in range(max_id):
        for j in range(max_id):
            if idx != j:
                try:
                    # paths = list(nx.all_simple_paths(g, start + idx, start + j, 8))
                    # if len(paths) == 0:
                    paths = [nx.shortest_path(g, start + idx, start + j)]
                except nx.NetworkXNoPath:
                    paths = []
                new_paths = []
                # for path in paths:
                #     path = sorted(path)
                #     new_paths.append(path)

                for path in paths[:3]:
                    # add immediate neighbors for nodes on the path
                    neighbors = set()
                    # neighbors = list()
                    for n in path:
                        neighbors.add(n)
                        for cur_neigh in g.neighbors(n):
                            if cur_neigh in pos_data:
                                if pos_data[cur_neigh] in ["ADP"]:
                                    neighbors.add(cur_neigh)
                                # if dep_data[cur_neigh] in ['neg']:
                                #     neighbors.add(cur_neigh)
                    path = sorted(neighbors)
                    # path = sorted(path)
                    new_paths.append(path)
                # except nx.NetworkXNoPath:
                #     new_paths = [[start + idx] + roots_path + [start + j]]
                #     # path = [start+idx] + roots_path + [start+j]
                #     # path = [start+idx, start+j]

                ent_paths[(idx + 1, j + 1)] = new_paths[:3]
    return ent_paths

In [5]:
IGNORE_INDEX = -100


class BERTDGLREDataset(IterableDataset):

    def __init__(
        self,
        src_file,
        ner2id,
        rel2id,
        dataset="train",
        instance_in_train=None,
        model_name="SAGDRE_BERT_base",
    ):

        super(BERTDGLREDataset, self).__init__()

        if instance_in_train is None:
            self.instance_in_train = set()
        else:
            self.instance_in_train = instance_in_train
        self.data = None
        self.document_max_length = 512
        self.bert = Bert(model_name)
        self.dataset = dataset
        self.rel2id = rel2id
        self.ner2id = ner2id
        print("Reading data from {}.".format(src_file))

        self.create_data(src_file)
        self.get_instance_in_train()

    def get_instance_in_train(self):
        for doc in self.data:
            entity_list = doc["vertexSet"]
            labels = doc.get("labels", [])
            for label in labels:
                head, tail, relation = label["h"], label["t"], label["r"]
                label["r"] = self.rel2id[relation]
                if self.dataset == "train":
                    for n1 in entity_list[head]:
                        for n2 in entity_list[tail]:
                            mention_triple = (n1["name"], n2["name"], relation)
                            self.instance_in_train.add(mention_triple)

    def process_doc(self, doc, dataset, ner2id, bert):
        title, entity_list = doc["title"], doc["vertexSet"]
        labels, sentences = doc.get("labels", []), doc["sents"]

        Ls = [0]
        L = 0
        for x in sentences:
            L += len(x)
            Ls.append(L)
        for j in range(len(entity_list)):
            for k in range(len(entity_list[j])):
                sent_id = int(entity_list[j][k]["sent_id"])
                entity_list[j][k]["sent_id"] = sent_id

                dl = Ls[sent_id]
                pos0, pos1 = entity_list[j][k]["pos"]
                entity_list[j][k]["global_pos"] = (pos0 + dl, pos1 + dl)

        # generate positive examples
        train_triple = []
        new_labels = []
        for label in labels:
            head, tail, relation = label["h"], label["t"], label["r"]
            # label['r'] = rel2id[relation]
            train_triple.append((head, tail))
            label["in_train"] = False

            # record training set mention triples and mark for dev and test
            for n1 in entity_list[head]:
                for n2 in entity_list[tail]:
                    mention_triple = (n1["name"], n2["name"], relation)
                    if dataset != "train":
                        if mention_triple in self.instance_in_train:
                            label["in_train"] = True
                            break

            new_labels.append(label)

        # generate negative examples
        na_triple = []
        for j in range(len(entity_list)):
            for k in range(len(entity_list)):
                if j != k and (j, k) not in train_triple:
                    na_triple.append((j, k))

        # generate document ids
        words = []
        for sentence in sentences:
            for word in sentence:
                words.append(word)

        bert_token, bert_starts, bert_subwords = bert.subword_tokenize_to_ids(words)

        word_id = np.zeros((self.document_max_length,), dtype=np.int32)
        pos_id = np.zeros((self.document_max_length,), dtype=np.int32)
        ner_id = np.zeros((self.document_max_length,), dtype=np.int32)
        mention_id = np.zeros((self.document_max_length,), dtype=np.int32)
        word_id[:] = bert_token[0]

        entity2mention = defaultdict(list)
        mention_idx = 1
        already_exist = set()
        pos_idx = {}
        ent_idx = {}
        for idx, vertex in enumerate(entity_list, 1):
            for v in vertex:

                sent_id, ner_type = v["sent_id"], v["type"]
                pos0_w, pos1_w = v["global_pos"]

                pos0 = bert_starts[pos0_w]
                if pos1_w < len(bert_starts):
                    pos1 = bert_starts[pos1_w]
                else:
                    pos1 = self.document_max_length

                if (pos0, pos1) in already_exist:
                    continue

                if pos0 >= len(pos_id):
                    continue

                if idx not in pos_idx:
                    pos_idx[idx] = []
                    ent_idx[idx] = []

                pos_idx[idx].extend(range(pos0_w, pos1_w))
                ent_idx[idx].extend(range(pos0, pos1))
                pos_id[pos0:pos1] = idx
                ner_id[pos0:pos1] = ner2id[ner_type]
                mention_id[pos0:pos1] = mention_idx
                entity2mention[idx].append(mention_idx)
                mention_idx += 1
                already_exist.add((pos0, pos1))

        # ======================================================
        # compute subword to word index
        sub2word = np.zeros(
            (len(bert_starts) + len(entity_list), self.document_max_length)
        )
        for idx in range(len(bert_starts) - 1):
            start, end = bert_starts[idx], bert_starts[idx + 1]
            if start == end:
                continue
            sub2word[idx, start:end] = 1 / (end - start)
        start, end = bert_starts[-1], len(bert_subwords)
        sub2word[len(bert_starts) - 1, start:end] = 1 / (end - start)
        # compute convertion matrix for entity
        for idx, poss in ent_idx.items():
            # print('------------>', idx, poss)
            sub2word[len(bert_starts) + idx - 1, poss] = 1 / len(poss)
        # ======================================================
        # compute words to sent index
        word2sent = np.zeros((len(Ls) - 1, Ls[-1]))
        for i in range(1, len(Ls)):
            word2sent[i - 1, Ls[i - 1] : Ls[i]] = 1 / (Ls[i] - Ls[i - 1])
        # ======================================================

        replace_i = 0
        idx = len(entity_list)
        if entity2mention[idx] == []:
            entity2mention[idx].append(mention_idx)
            while mention_id[replace_i] != 0:
                replace_i += 1
            mention_id[replace_i] = mention_idx
            pos_id[replace_i] = idx
            ner_id[replace_i] = ner2id[vertex[0]["type"]]
            mention_idx += 1

        new_Ls = [0]
        for ii in range(1, len(Ls)):
            if Ls[ii] < len(bert_starts):
                new_Ls.append(bert_starts[Ls[ii]])
            else:
                new_Ls.append(len(bert_subwords))

        Ls = new_Ls

        graph2, path2 = build_g(sentences, pos_idx, pos_id.max())

        return {
            "title": title,
            "num_sent": len(doc["sents"]),
            "entities": entity_list,
            "labels": new_labels,
            "na_triple": na_triple,
            "word_id": word_id,
            "pos_id": pos_id,
            "ner_id": ner_id,
            "sub2word": sub2word,
            "word2sent": word2sent,
            "graph2": graph2,
            "path2": path2,
        }

    def create_data(self, src_file):
        with open(file=src_file, mode="r", encoding="utf-8") as fr:
            ori_data = json.load(fr)
        self.data = ori_data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        doc = self.data[idx]
        cur_d = self.process_doc(
            doc, dataset=self.dataset, ner2id=self.ner2id, bert=self.bert
        )
        return cur_d

    def __iter__(self):
        return iter(self.data)

In [6]:
train_set = BERTDGLREDataset(
    args.train_set,
    data_opt.ner2id,
    data_opt.rel2id,
    dataset="train",
    model_name=args.model_name,
)

dev_set = BERTDGLREDataset(
    args.dev_set,
    data_opt.ner2id,
    data_opt.rel2id,
    dataset="dev",
    instance_in_train=train_set.instance_in_train,
    model_name=args.model_name,
)



Reading data from /home/hwiric/SagDRE-2024-07-05/data/DocRED/train_annotated.json.
Reading data from /home/hwiric/SagDRE-2024-07-05/data/DocRED/dev.json.


# Utils


In [34]:
from collections import defaultdict
from transformers import BertModel
import torch


class SampleEncoder:
    def __init__(self, words_id) -> None:
        self.bert = BertModel.from_pretrained("bert-base-cased", return_dict=False)
        self.words_id = torch.unsqueeze(torch.from_numpy(words_id), 0)
        self.mask = torch.unsqueeze(torch.from_numpy(words_id > 0), 0)

    def encode(self):
        output, x = self.bert(input_ids=self.words_id, attention_mask=self.mask)

        return output


def convert_edges_format_dgl2dict(src, dst):
    src = src.detach().cpu().numpy()
    dst = dst.detach().cpu().numpy()

    edges = defaultdict(list)
    for element in list(zip(src, dst)):
        edges[element[0]].append(element[1])

    return edges


def raw_data(src_file):
    with open(file=src_file, mode="r", encoding="utf-8") as fr:
        ori_data = json.load(fr)

    return ori_data


def subwords_cnt_dist(raw_data):
    my_bert = Bert("SAGDRE_BERT_base")
    subwords_cnt = []

    for sample in raw_data:
        words = []
        for sentence in sample["sents"]:
            for word in sentence:
                words.append(word)

        bert_token, bert_starts, bert_subwords = my_bert.subword_tokenize_to_ids(words)
        subwords_cnt.append(len(bert_subwords))

    return subwords_cnt

# Temp
