# SciBERT: Model Loading and Inferencing

This notebook makes an initial attempt at loading and implementing just the SciBERT model finetuned by Nomita and Daniel.

## 1. Imports

In [1]:
## Model definition related, adopted from train_cdr.py
import argparse
import os

import torch
from transformers import AutoConfig, AutoModel, AutoTokenizer

from model_bio import Model    ## Not using this yet as we are trying to just use the model definition here
from scibert_utils import set_seed  ## Haven't checked if this is needed for inferencing yet
from scibert_prepro_bio import read_bio
# from save_result import Logger
# from evaluation import train, evaluate  ## Not using this as we are not using the built-in train and evaluate modules

In [2]:
## Model definition related, adopted from model_bio.py
from opt_einsum import contract
from scibert_model_utils.long_seq import process_long_input
from scibert_model_utils.losses import *
from scibert_model_utils.attn_unet import AttentionUNet
from scibert_model_utils.graph_networks import GraphConvolution, GraphAttentionLayer

In [3]:
## Data preprocessing related
from tqdm import tqdm
import spacy
from scibert_prepro_bio import chunks, WhitespaceTokenizer, cdr_id2rel, biored_cd_id2rel, biored_id2rel
from torch.utils.data import DataLoader
from scibert_utils import collate_fn

## 2. Model Loading
- First cell defines the model class (from model_bio.py)
- Second cell instantiates the model
- Third cell loads the checkpoint fine-tuned on BC5CDR
- Fourth cell inspects the loaded model

In [4]:
## Code directly copied from the "model_bio.py".  Can also do an import if preferred.

class REModel(nn.Module):  #Renamed to REModel
    def __init__(self, args, config, model, emb_size=768, block_size=64, num_labels=-1):
        super().__init__()
        self.device = args.device
        self.sizeA = 256
        self.gnn = args.gnn
        if self.gnn == 'GCN':
            self.gc1 = GraphConvolution(config.hidden_size, self.sizeA)
            self.gc2 = GraphConvolution(config.hidden_size, self.sizeA // 2)
        elif self.gnn == 'GAT':
            self.gc1 = GraphAttentionLayer(config.hidden_size, self.sizeA)
            self.gc2 = GraphAttentionLayer(config.hidden_size, self.sizeA // 2)
        else:
            raise ValueError('This is a GNN Error')

        if args.dropout > 0.0:
            self.dropout = nn.Dropout(args.dropout)
        else:
            self.dropout = None
        self.args = args
        self.config = config
        self.model = model
        self.hidden_size = config.hidden_size
        if args.loss == 'BSCELoss':
            self.loss_fn = BSCELoss(args.s0)
        elif args.loss == 'BalancedLoss':
            self.loss_fn = BalancedLoss()
        elif args.loss == 'ATLoss':
            self.loss_fn = ATLoss()
        elif args.loss == 'AsymmetricLoss':
            self.loss_fn = AsymmetricLoss()
        elif args.loss == 'APLLoss':
            self.loss_fn = APLLoss()
        else:
            print('error loss')
            return
        self.rels = args.num_class - 1
        # 768 * 2, 768
        # self.head_extractor = nn.Linear(2 * config.hidden_size, emb_size)
        # self.tail_extractor = nn.Linear(2 * config.hidden_size, emb_size)
        # self.head_extractor = nn.Linear(config.hidden_size + self.sizeA + args.unet_out_dim, emb_size)
        # self.tail_extractor = nn.Linear(config.hidden_size + self.sizeA + args.unet_out_dim, emb_size)
        self.head_extractor = nn.Linear(config.hidden_size + 2 * args.unet_out_dim, emb_size)
        self.tail_extractor = nn.Linear(config.hidden_size + 2 * args.unet_out_dim, emb_size)
        self.bilinear = nn.Linear(emb_size * block_size, config.num_labels)

        self.emb_size = emb_size
        self.block_size = block_size
        self.num_labels = num_labels

        self.bertdrop = nn.Dropout(0.6)
        self.unet_in_dim = args.unet_in_dim
        self.unet_out_dim = args.unet_in_dim
        self.liner = nn.Linear(config.hidden_size + self.sizeA, args.unet_in_dim)
        self.min_height = args.max_height

        self.segmentation_net = AttentionUNet(input_channels=args.unet_in_dim,
                                              class_number=args.unet_out_dim,
                                              down_channel=args.down_dim)
        self.use_gcn = args.use_gcn
        self.adj_linear = nn.Linear(self.sizeA * 2, self.sizeA)

    def encode(self, input_ids, attention_mask):
        config = self.config
        if config.transformer_type == "bert":
            start_tokens = [config.cls_token_id]
            end_tokens = [config.sep_token_id]
        elif config.transformer_type == "roberta":
            start_tokens = [config.cls_token_id]
            end_tokens = [config.sep_token_id, config.sep_token_id]
        sequence_output, attention = process_long_input(self.model, input_ids, attention_mask, start_tokens, end_tokens)
        return sequence_output, attention

    def get_hrt(self, sequence_output, attention, entity_pos, hts):
        offset = 1 if self.config.transformer_type in ["bert", "roberta"] else 0
        n, h, _, c = attention.size()
        hss, tss, rss = [], [], []
        entity_es = []
        entity_as = []
        for i in range(len(entity_pos)):
            entity_embs, entity_atts = [], []
            # for e in entity_pos[i]:
            for entity_num, e in enumerate(entity_pos[i]):
                if len(e) > 1:
                    e_emb, e_att = [], []
                    for start, end in e:
                        if start + offset < c:
                            # In case the entity mention is truncated due to limited max seq length.
                            e_emb.append(sequence_output[i, start + offset])
                            e_att.append(attention[i, :, start + offset])
                    if len(e_emb) > 0:
                        e_emb = torch.logsumexp(torch.stack(e_emb, dim=0), dim=0)
                        e_att = torch.stack(e_att, dim=0).mean(0)
                    else:
                        e_emb = torch.zeros(c).to(sequence_output)
                        e_att = torch.zeros(h, c).to(attention)
                else:
                    start, end = e[0]
                    if start + offset < c:
                        e_emb = sequence_output[i, start + offset]
                        e_att = attention[i, :, start + offset]
                    else:
                        e_emb = torch.zeros(c).to(sequence_output)
                        e_att = torch.zeros(h, c).to(attention)
                entity_embs.append(e_emb)
                entity_atts.append(e_att)
            # 这句话是新增的，这句话很离谱
            for _ in range(self.min_height - entity_num - 1):
                entity_atts.append(e_att)

            entity_embs = torch.stack(entity_embs, dim=0)  # [n_e, d]
            entity_atts = torch.stack(entity_atts, dim=0)  # [n_e, h, seq_len]

            # 这句话是新增的
            entity_es.append(entity_embs)
            entity_as.append(entity_atts)

            # if len(hts[i]) == 0:
            #     hss.append(torch.FloatTensor([]).to(sequence_output.device))
            #     tss.append(torch.FloatTensor([]).to(sequence_output.device))
            #     rss.append(torch.FloatTensor([]).to(sequence_output.device))
            #     continue
            ht_i = torch.LongTensor(hts[i]).to(sequence_output.device)
            hs = torch.index_select(entity_embs, 0, ht_i[:, 0])
            ts = torch.index_select(entity_embs, 0, ht_i[:, 1])
            h_att = torch.index_select(entity_atts, 0, ht_i[:, 0])
            t_att = torch.index_select(entity_atts, 0, ht_i[:, 1])
            # 将 ht 降维
            ht_att = (h_att * t_att).mean(1)
            ht_att = ht_att / (ht_att.sum(1, keepdim=True) + 1e-5)
            rs = contract("ld,rl->rd", sequence_output[i], ht_att)
            # rrs = contract("rl,ld->rd", ht_att, sequence_output[i])
            hss.append(hs)
            tss.append(ts)
            rss.append(rs)

        hss = torch.cat(hss, dim=0)
        tss = torch.cat(tss, dim=0)
        rss = torch.cat(rss, dim=0)
        return hss, rss, tss, entity_es, entity_as

    def get_channel_map(self, sequence_output, entity_as):
        # sequence_output = sequence_output.to('cpu')
        # attention = attention.to('cpu')
        bs, _, d = sequence_output.size()
        # ne = max([len(x) for x in entity_as])  # 本次bs中的最大实体数
        ne = self.min_height
        index_pair = []
        for i in range(ne):
            tmp = torch.cat((torch.ones((ne, 1), dtype=int) * i, torch.arange(0, ne).unsqueeze(1)), dim=-1)
            index_pair.append(tmp)
        index_pair = torch.stack(index_pair, dim=0).reshape(-1, 2).to(sequence_output.device)
        map_rss = []
        for b in range(bs):
            entity_atts = entity_as[b]
            h_att = torch.index_select(entity_atts, 0, index_pair[:, 0])
            t_att = torch.index_select(entity_atts, 0, index_pair[:, 1])
            ht_att = (h_att * t_att).mean(1)
            ht_att = ht_att / (ht_att.sum(1, keepdim=True) + 1e-5)
            rs = contract("ld,rl->rd", sequence_output[b], ht_att)
            map_rss.append(rs)
        map_rss = torch.cat(map_rss, dim=0).reshape(bs, ne, ne, d)
        return map_rss

    def get_ht(self, rel_enco, hts):
        htss = []
        for i in range(len(hts)):
            ht_index = hts[i]
            for (h_index, t_index) in ht_index:
                htss.append(rel_enco[i, h_index, t_index])
        htss = torch.stack(htss, dim=0)
        return htss

    def forward(self, input_ids=None, attention_mask=None, labels=None, entity_pos=None, hts=None, list_feature_id=None,
                adj_mention=None, adj_syntactic_dependency_tree=None):
        sequence_output, attention = self.encode(input_ids, attention_mask)
        if self.dropout is not None:
            sequence_output = self.dropout(sequence_output)
        # GCN
        if self.use_gcn == 'both':
            a = F.normalize(adj_mention)
            b = F.normalize(adj_syntactic_dependency_tree)
            sequence_output_A = torch.relu(self.gc2(sequence_output, a))
            sequence_output_B = torch.relu(self.gc2(sequence_output, b))
            sequence_output = torch.cat([sequence_output, sequence_output_A, sequence_output_B], dim=2)
        elif self.use_gcn == 'mentions':
            a = F.normalize(adj_mention)
            sequence_output_A = torch.relu(self.gc1(sequence_output, a))
            sequence_output = torch.cat([sequence_output, sequence_output_A], dim=2)
        elif self.use_gcn == 'tree':
            a = F.normalize(adj_syntactic_dependency_tree)
            sequence_output_A = torch.relu(self.gc1(sequence_output, a))
            sequence_output = torch.cat([sequence_output, sequence_output_A], dim=2)
        else:
            # 这时的adj_mention和adj_syntactic_dependency_tree都是空矩阵
            a1, a2, _ = adj_mention.size()
            sequence_output_A = adj_mention.clone()
            sequence_output_A = sequence_output_A.resize_(a1, a2, self.sizeA)
            sequence_output_A = sequence_output_A.zero_()
            sequence_output = torch.cat([sequence_output, sequence_output_A], dim=2)

        hs, rs, ts, entity_embs, entity_as = self.get_hrt(sequence_output, attention, entity_pos, hts)
        feature_map = self.get_channel_map(sequence_output, entity_as)
        # print('feature_map:', feature_map.shape)
        attn_input = self.liner(feature_map).permute(0, 3, 1, 2).contiguous()
        attn_map = self.segmentation_net(attn_input)
        # attn_map = self.segmentation_net_acc_unet(attn_input)
        # attn_map = attn_map.permute(0, 2, 3, 1).contiguous()
        rs = self.get_ht(attn_map, hts)

        # Binary Classifier
        # a = torch.cat([hs, rs], dim=1)
        # b = self.head_extractor(a)
        # hr进行了拼接变成了(1376, 1536)，然后再变回之前的维度(1376, 768)
        hs = torch.tanh(self.head_extractor(torch.cat([hs, rs], dim=1)))  # zs
        ts = torch.tanh(self.tail_extractor(torch.cat([ts, rs], dim=1)))  # zo
        # (1376, 768) -> (1376, 12, 64)
        b1 = hs.view(-1, self.emb_size // self.block_size, self.block_size)
        b2 = ts.view(-1, self.emb_size // self.block_size, self.block_size)
        bl = (b1.unsqueeze(3) * b2.unsqueeze(2)).view(-1, self.emb_size * self.block_size)
        logits = self.bilinear(bl)

        output = (get_label(logits, num_labels=self.num_labels))
        if labels is not None:
            labels = [torch.tensor(label) for label in labels]
            labels = torch.cat(labels, dim=0).to(logits)
            loss = self.loss_fn(logits.float(), labels.float())
            output = [loss.to(sequence_output), output]
        return output

In [7]:
## Configuring the parser for the args variable
parser = argparse.ArgumentParser()
parser.add_argument('--gnn', type=str, default='GCN', help="GCN/GAT")
parser.add_argument("--model_name_or_path", default="./RE_base", type=str)
parser.add_argument("--transformer_type", default="bert", type=str)
parser.add_argument("--config_name", default="", type=str,
                    help="Pretrained config name or path if not the same as model_name")
parser.add_argument("--tokenizer_name", default="", type=str,
                    help="Pretrained tokenizer name or path if not the same as model_name")
parser.add_argument('--use_gcn', type=str, default='tree', help="use gcn, both/mentions/tree/false")
parser.add_argument('--dropout', type=float, default=0.5, help="0.0/0.2/0.5")
parser.add_argument('--loss', type=str, default='BSCELoss',
                    help="use BSCELoss/BalancedLoss/ATLoss/AsymmetricLoss/APLLoss")
parser.add_argument('--s0', type=float, default=0.3)
parser.add_argument("--unet_in_dim", type=int, default=3, help="unet_in_dim.")
parser.add_argument("--unet_out_dim", type=int, default=256, help="unet_out_dim.")
parser.add_argument("--down_dim", type=int, default=256, help="down_dim.")
parser.add_argument("--bert_lr", default=3e-5, type=float, help="The initial learning rate for Adam.")
parser.add_argument("--max_height", type=int, default=64, help="max_height.")
parser.add_argument("--max_seq_length", default=1024, type=int,   ## This may require changing in future
                    help="The maximum total input sequence length after tokenization. Sequences longer "
                         "than this will be truncated, sequences shorter will be padded.")

args, _ = parser.parse_known_args()
args.n_gpu = torch.cuda.device_count()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
args.device = device
args.num_class = 2

tokenizer = AutoTokenizer.from_pretrained(
    args.tokenizer_name if args.tokenizer_name else args.model_name_or_path, )

config = AutoConfig.from_pretrained(
    args.config_name if args.config_name else args.model_name_or_path, num_labels=args.num_class, )
config.cls_token_id = tokenizer.cls_token_id
config.sep_token_id = tokenizer.sep_token_id
config.transformer_type = args.transformer_type

# Being a bit lazy here: not downloading the model checkpoint as we are loading our own state dict anyway?
model = AutoModel.from_pretrained(
    args.model_name_or_path, from_tf=bool(".ckpt" in args.model_name_or_path), config=config, )

## Initializing model, state dict loading in next cell
RE_model  = REModel(args, config, model, num_labels=1)


In [8]:
## Loading model_checkpoint
model_checkpoint = "train_filter_bert_cdr_seed_BSCELoss_tree_03_05_66_best"
RE_model.load_state_dict(torch.load(model_checkpoint, weights_only=True))

<All keys matched successfully>

In [9]:
## Showing model structure
RE_model

REModel(
  (gc1): GraphConvolution (768 -> 256)
  (gc2): GraphConvolution (768 -> 128)
  (dropout): Dropout(p=0.5, inplace=False)
  (model): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(31116, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_featur

## 3. Loading Data
- We now try to look at the data structure used by the RE model

In [10]:
## Loading the test data, with the version being the one used by the SSGU-CD-all repo
args.data_dir = './data_scibert_version'   ## filepath name altered, remember to update this.
args.train_file = 'train_filter.data'
args.dev_file = 'dev_filter.data'
args.test_file = 'test_filter.data'

train_file = os.path.join(args.data_dir, args.train_file)
dev_file = os.path.join(args.data_dir, args.dev_file)
test_file = os.path.join(args.data_dir, args.test_file)

In [11]:
file_in = dev_file

with open(file_in, 'r') as infile:
        lines = infile.readlines()

In [12]:
lines

['6794356\tTricuspid valve regurgitation and lithium carbonate toxicity in a newborn infant .|A newborn with massive tricuspid regurgitation , atrial flutter , congestive heart failure , and a high serum lithium level is described .|This is the first patient to initially manifest tricuspid regurgitation and atrial flutter , and the 11th described patient with cardiac disease among infants exposed to lithium compounds in the first trimester of pregnancy .|Sixty - three percent of these infants had tricuspid valve involvement .|Lithium carbonate may be a factor in the increasing incidence of congenital heart disease when taken during early pregnancy .|It also causes neurologic depression , cyanosis , and cardiac arrhythmia when consumed prior to delivery .\t1:CID:2\tL2R\tCROSS\t82-84\t105-107\tD016651\tlithium carbonate|Lithium carbonate\tChemical\t4:82\t6:84\t0:4\tD003866\tneurologic depression\tDisease\t105\t107\t5\t1:CID:2\tL2R\tCROSS\t82-84\t108-109\tD016651\tlithium carbonate|Lithiu

In [13]:
args.rel2 = 0
rel2id = {'1:NR:2': 0, '1:CID:2': 1}

file_in = dev_file
max_seq_length=1024

features = []
pos_samples = 0
neg_samples = 0
pmids = set()
nlp = spacy.load('en_core_web_sm')

with open(file_in, 'r') as infile:
        lines = infile.readlines()

        for i_l, line in enumerate(tqdm(lines)):
            line = line.rstrip().split('\t')
            pmid = line[0]
            entities = {}
            ent2ent_type = {}
            if pmid not in pmids:
                pmids.add(pmid)
                text = line[1]
                sents = [t.split(' ') for t in text.split('|')]
                sent_len = [len(i) for i in sents]
                sents_len = []
                for i in range(len(sent_len)):
                    if i == 0:
                        sents_len.append(sent_len[0])
                    else:
                        sents_len.append(sent_len[i] + sents_len[i - 1])
                prs = chunks(line[2:], 17)
                ent2idx = {}
                train_triples = {}
                entity_pos = set()
                for p in prs:
                    if p[0] == "not_include":
                        continue
                    es = list(map(int, p[8].split(':')))
                    ed = list(map(int, p[9].split(':')))
                    tpy = p[7]
                    entity_str = list(map(str, p[6].split('|')))
                    entity_id = p[5]
                    if entity_id not in entities:
                        entities[entity_id] = []
                    for start, end, string in zip(es, ed, entity_str):
                        entity_pos.add((start, end, tpy))
                        # entity in which sent
                        sent_in_id = -1
                        for i in range(len(sents_len)):
                            if sents_len[i] > end:
                                sent_in_id = i
                                break
                        if [start, end, tpy, string, sent_in_id] not in entities[entity_id]:
                            entities[entity_id].append([start, end, tpy, string, sent_in_id])

                    es = list(map(int, p[14].split(':')))
                    ed = list(map(int, p[15].split(':')))
                    tpy = p[13]
                    entity_str = list(map(str, p[12].split('|')))
                    entity_id = p[11]
                    if entity_id not in entities:
                        entities[entity_id] = []
                    for start, end, string in zip(es, ed, entity_str):
                        entity_pos.add((start, end, tpy))
                        # entity in which sent
                        sent_in_id = -1
                        for i in range(len(sents_len)):
                            if sents_len[i] > end:
                                sent_in_id = i
                                break
                        if [start, end, tpy, string, sent_in_id] not in entities[entity_id]:
                            entities[entity_id].append([start, end, tpy, string, sent_in_id])
                    if p[5] not in ent2ent_type:
                        ent2ent_type[p[5]] = p[7]
                    if p[11] not in ent2ent_type:
                        ent2ent_type[p[11]] = p[13]
                if len(entity_pos) == 0:
                    # print(pmid, 'rel is none')
                    continue

                # spacy 分析
                nlp.tokenizer = WhitespaceTokenizer(nlp.vocab)
                doc = nlp(text.replace('|', ' '))
                spacy_tokens = []
                spacy_offset = nlp("*")[0]
                for token in doc:
                    spacy_tokens.append(token)
                # 依据spacy的分词解析结果，存放开始的index
                # id对应的单词
                index2word = {}
                # 一个单词对应的所有分词片段id
                word2piecesid = {}
                # spacy中token的当前id
                spacy_token_id = 0

                # entitys 中缺少来源于哪个句子的id，可以考虑添加
                sents = [t.split(' ') for t in text.split('|')]

                new_sents = []
                token_map = []
                lengthofPice = 0
                sent_map = {}
                entity_pos = list(entity_pos)
                entity_pos.sort()
                i_t = 0
                for sent in sents:
                    for token in sent:
                        tokens_wordpiece = tokenizer.tokenize(token)
                        # 每个实体的起始位置都加进去
                        oneToken = []
                        eid = 0
                        for i, ep in enumerate(entity_pos):
                            start, end, tpy = ep
                            if i_t == start or i_t == end:
                                eid = i
                                break

                        if i_t == entity_pos[eid][0] or i_t == entity_pos[eid][1]:
                            # 标记实体的'*'应该算在是实体的一部分，但不是一个单词的一部分，故不算其中。
                            index2word[len(index2word)] = spacy_offset
                            for token_wordpiece in tokens_wordpiece:
                                index2word[len(index2word)] = spacy_tokens[spacy_token_id]
                                if spacy_tokens[spacy_token_id] not in word2piecesid:
                                    word2piecesid[spacy_tokens[spacy_token_id]] = []
                                word2piecesid[spacy_tokens[spacy_token_id]].append(len(index2word) - 1)

                            oneToken.append(lengthofPice + 1)
                            if 'Chemical' in entity_pos[eid][2]:
                                special_token = '<<Chemical>>'
                            elif 'Disease' in entity_pos[eid][2]:
                                special_token = '<<Disease>>'
                            elif 'Gene' in entity_pos[eid][2]:
                                special_token = '<<Gene>>'
                            elif 'Variant' in entity_pos[eid][2]:
                                special_token = '<<Variant>>'
                            else:
                                raise KeyError('not Chemical or Disease or Gene or Variant')
                            tokens_wordpiece = [special_token] + tokens_wordpiece
                            lengthofPice += len(tokens_wordpiece)
                            oneToken.append(lengthofPice)
                        else:
                            for token_wordpiece in tokens_wordpiece:
                                index2word[len(index2word)] = spacy_tokens[spacy_token_id]
                                if spacy_tokens[spacy_token_id] not in word2piecesid:
                                    word2piecesid[spacy_tokens[spacy_token_id]] = []
                                word2piecesid[spacy_tokens[spacy_token_id]].append(len(index2word) - 1)

                            oneToken.append(lengthofPice)
                            lengthofPice += len(tokens_wordpiece)
                            oneToken.append(lengthofPice)
                        # 相当于docred中的new_map，分词后每个词对应的位置
                        sent_map[i_t] = len(new_sents)
                        new_sents.extend(tokens_wordpiece)
                        token_map.append(oneToken)
                        i_t += 1
                        spacy_token_id += 1
                    # sent_map[i_t] = len(new_sents)
                sents = new_sents

                entity_pos = []
                for p in prs:
                    if p[0] == "not_include":
                        continue
                    if p[1] == "L2R":
                        h_id, t_id = p[5], p[11]
                        h_start, t_start = p[8], p[14]
                        h_end, t_end = p[9], p[15]
                    else:
                        t_id, h_id = p[5], p[11]
                        t_start, h_start = p[8], p[14]
                        t_end, h_end = p[9], p[15]
                    h_start = map(int, h_start.split(':'))
                    h_end = map(int, h_end.split(':'))
                    t_start = map(int, t_start.split(':'))
                    t_end = map(int, t_end.split(':'))
                    h_start = [sent_map[idx] for idx in h_start]
                    h_end = [sent_map[idx] for idx in h_end]
                    t_start = [sent_map[idx] for idx in t_start]
                    t_end = [sent_map[idx] for idx in t_end]
                    if h_id not in ent2idx:
                        ent2idx[h_id] = len(ent2idx)
                        entity_pos.append(list(zip(h_start, h_end)))
                    if t_id not in ent2idx:
                        ent2idx[t_id] = len(ent2idx)
                        entity_pos.append(list(zip(t_start, t_end)))
                    h_id, t_id = ent2idx[h_id], ent2idx[t_id]

                    if args.rel2:
                        r = 0 if p[0] == '1:NR:2' else 1
                    else:
                        r = rel2id[p[0]]
                    if (h_id, t_id) not in train_triples:
                        train_triples[(h_id, t_id)] = [{'relation': r}]
                    else:
                        train_triples[(h_id, t_id)].append({'relation': r})

                relations, hts = [], []
                for h, t in train_triples.keys():
                    relation = [0] * len(rel2id)
                    for mention in train_triples[h, t]:
                        relation[mention["relation"]] = 1
                        if mention["relation"] != 0:
                            pos_samples += 1
                        else:
                            neg_samples += 1
                    relations.append(relation)
                    hts.append([h, t])
            sents = sents[:max_seq_length - 2]
            input_ids = tokenizer.convert_tokens_to_ids(sents)
            input_ids_new = tokenizer.build_inputs_with_special_tokens(input_ids)

            max_len = len(input_ids_new)
            # 结构计算
            a_mentions = np.eye(len(input_ids))
            a_mentions_new = np.eye(max_len)
            adj_syntactic_dependency_tree = np.eye(len(input_ids))
            adj_syntactic_dependency_tree_new = np.eye(max_len)
            offset = 1
            edges = 0
            for token_s in token_map:
                start = token_s[0]
                end = token_s[1]
                for i in range(start, end):
                    for j in range(start, end):
                        if i < (len(input_ids) - 1) and j < (len(input_ids) - 1):
                            if a_mentions[i][j] == 0:
                                a_mentions[i][j] = 1
                                a_mentions_new[i + 1][j + 1] = 1
                                edges += 1
            # 所有实体在 tokens 中的跨度
            mentionsofPice = []
            for eid in entities:
                for i in entities[eid]:
                    ment = [i[0], i[1]]
                    mentionsofPice.append([token_map[ment[0]][0], token_map[ment[1] - 1][1]])
            for ment in mentionsofPice:
                start = ment[0]
                end = ment[1]
                for i in range(start, end):
                    for j in range(start, end):
                        if i < (len(input_ids) - 1) and j < (len(input_ids) - 1):
                            if a_mentions[i][j] == 0:
                                a_mentions[i][j] = 1
                                a_mentions_new[i + 1][j + 1] = 1
                                edges += 1
            # 各类实体的实体跨度
            entityofPice = []
            for ent in entity_pos:
                # 每个单词（属于实体）的起始位置，可能是字母或者索引
                oneEntityP = []
                for ment in ent:
                    if (ment[0] + offset) == ment[1]:
                        oneEntityP.append(ment[0] + offset)
                    for i in range(ment[0] + offset, ment[1]):
                        oneEntityP.append(i)
                entityofPice.append(oneEntityP)
            predicted_Doc2 = []
            for h in range(0, len(entities)):
                item = [0, h]
                predicted_Doc2.append(item)

            predictedEntityPairPiece = []
            for item in predicted_Doc2:
                one_predicted = entityofPice[item[0]] + entityofPice[item[1]]
                predictedEntityPairPiece.append(one_predicted)
            for line in predictedEntityPairPiece:
                for i in line:
                    for j in line:
                        if a_mentions[i][j] == 0:
                            a_mentions[i][j] = 1
                            a_mentions_new[i + 1][j + 1] = 1
                            edges += 1

            # 句法树
            count = 0
            i = 0
            while i < len(input_ids):
                if index2word[i] == spacy_offset:
                    i += 1
                    continue
                word = spacy_tokens[count]
                word_sp = tokenizer.tokenize(word.text)
                for child in word.children:
                    adj_word_list = word2piecesid[child]
                    word_list = word2piecesid[word]
                    # obtain the start index of child
                    child_key = next(key for key, val in index2word.items() if val == child)
                    # obtain the start index of spacy_word
                    word_key = next(key for key, val in index2word.items() if val == word)
                    # print("child:{}, word:{}".format(child, word))
                    for m in range(child_key, len(adj_word_list) + child_key):
                        for n in range(word_key, len(word_list) + word_key):
                            # print("m:{}, n:{}".format(m, n))
                            adj_syntactic_dependency_tree[m][n] = 1  # 无向图
                            adj_syntactic_dependency_tree[n][m] = 1
                            adj_syntactic_dependency_tree_new[m + 1][n + 1] = 1
                            adj_syntactic_dependency_tree_new[n + 1][m + 1] = 1

                i += len(word_sp)
                count += 1
                
            adj_syntactic_dependency_tree_new[0][0] = 0
            adj_syntactic_dependency_tree_new[-1][-1] = 0
            a_mentions_new[0][0] = 0
            a_mentions_new[-1][-1] = 0
            assert len(ent2idx) == len(ent2ent_type)
            if len(hts) > 0:
                feature = {'input_ids': input_ids_new,
                           'entity_pos': entity_pos,
                           'labels': relations,
                           'hts': hts,
                           'title': pmid,
                           'ent2idx': ent2idx,
                           'ent2ent_type': ent2ent_type,
                           'adj_mention': a_mentions_new.tolist(),
                           'adj_syntactic_dependency_tree': adj_syntactic_dependency_tree_new.tolist()
                           }
                features.append(feature)

100%|██████████| 500/500 [01:25<00:00,  5.84it/s]


In [14]:
## These are the features used by the model
type(features)

list

In [15]:
len(features)

500

In [16]:
sample0 = features[0]
sample0.keys()

dict_keys(['input_ids', 'entity_pos', 'labels', 'hts', 'title', 'ent2idx', 'ent2ent_type', 'adj_mention', 'adj_syntactic_dependency_tree'])

In [17]:
print(f"length of input_ids: {len(sample0['input_ids'])}")
sample0['input_ids']

length of input_ids: 173


[101,
 100,
 19288,
 8842,
 172,
 10050,
 27913,
 2247,
 100,
 136,
 100,
 18536,
 20735,
 100,
 6623,
 100,
 124,
 105,
 19226,
 11402,
 211,
 105,
 19226,
 188,
 11425,
 100,
 19288,
 8842,
 172,
 27913,
 2247,
 100,
 430,
 100,
 11866,
 1558,
 14217,
 30112,
 100,
 430,
 100,
 22985,
 339,
 872,
 3562,
 3225,
 100,
 430,
 136,
 105,
 620,
 2913,
 100,
 18536,
 100,
 627,
 163,
 1340,
 211,
 306,
 163,
 111,
 803,
 1607,
 146,
 6335,
 8573,
 100,
 19288,
 8842,
 172,
 27913,
 2247,
 100,
 136,
 100,
 11866,
 1558,
 14217,
 30112,
 100,
 430,
 136,
 111,
 1019,
 279,
 1340,
 1607,
 188,
 100,
 4572,
 1354,
 100,
 1410,
 7229,
 4934,
 146,
 100,
 18536,
 100,
 3901,
 124,
 111,
 803,
 24857,
 125,
 6118,
 211,
 3272,
 1125,
 578,
 939,
 2264,
 125,
 530,
 7229,
 907,
 19288,
 8842,
 172,
 10050,
 6165,
 211,
 100,
 18536,
 20735,
 100,
 579,
 203,
 105,
 1525,
 124,
 111,
 2037,
 4225,
 125,
 100,
 14156,
 3562,
 1354,
 100,
 716,
 2644,
 855,
 2093,
 6118,
 211,
 305,
 498,
 4290,
 10

In [18]:
## Coreference data included!!
print(f"length of input_ids: {len(sample0['entity_pos'])}")
sample0['entity_pos']

length of input_ids: 10


[[(9, 12), (122, 125)],
 [(149, 152)],
 [(154, 157)],
 [(160, 164)],
 [(0, 7), (24, 30), (66, 72)],
 [(51, 53), (96, 98)],
 [(12, 14)],
 [(32, 37), (74, 79)],
 [(39, 45)],
 [(88, 91), (135, 139)]]

In [19]:
print(f"length of input_ids: {len(sample0['labels'])}")
sample0['labels']

length of input_ids: 15


[[0, 1],
 [0, 1],
 [0, 1],
 [1, 0],
 [1, 0],
 [1, 0],
 [1, 0],
 [1, 0],
 [1, 0],
 [1, 0],
 [1, 0],
 [1, 0],
 [1, 0],
 [1, 0],
 [1, 0]]

In [20]:
## This is the last sample of the dev set

# Original text:
'''Water intoxication associated with oxytocin administration during saline-induced abortion. 
Four cases of water intoxication in connection with oxytocin administration during saline-induced abortions are described.
The mechanism of water intoxication is discussed in regard to these cases. Oxytocin administration during 
midtrimester-induced abortions is advocated only if it can be carried out under careful observations of an alert nursing 
staff, aware of the symptoms of water intoxication and instructed to watch the diuresis and report such early signs of 
the syndrome as asthenia, muscular irritability, or headaches. The oxytocin should be given only in Ringers lactate or, 
alternately, in Ringers lactate and a 5 per cent dextrose and water solutions. The urinary output should be monitored and 
the oxytocin administration discontinued and the serum electrolytes checked if the urinary output decreases. 
The oxytocin should not be administered in excess of 36 hours. If the patient has not aborted by then the oxytocin 
should be discontinued for 10 to 12 hours in order to perform electrolyte determinations and correct any electrolyte 
imbalance.'''

sents

['<<Disease>>',
 'water',
 'intox',
 '##ication',
 '<<Disease>>',
 'associated',
 'with',
 '<<Chemical>>',
 'ox',
 '##yt',
 '##ocin',
 '<<Chemical>>',
 'administration',
 'during',
 'saline',
 '-',
 'induced',
 '<<Disease>>',
 'abortion',
 '<<Disease>>',
 '.',
 'four',
 'cases',
 'of',
 '<<Disease>>',
 'water',
 'intox',
 '##ication',
 '<<Disease>>',
 'in',
 'connection',
 'with',
 '<<Chemical>>',
 'ox',
 '##yt',
 '##ocin',
 '<<Chemical>>',
 'administration',
 'during',
 'saline',
 '-',
 'induced',
 '<<Disease>>',
 'abortion',
 '##s',
 '<<Disease>>',
 'are',
 'described',
 '.',
 'the',
 'mechanism',
 'of',
 '<<Disease>>',
 'water',
 'intox',
 '##ication',
 '<<Disease>>',
 'is',
 'discussed',
 'in',
 'regard',
 'to',
 'these',
 'cases',
 '.',
 '<<Chemical>>',
 'ox',
 '##yt',
 '##ocin',
 '<<Chemical>>',
 'administration',
 'during',
 'mid',
 '##trim',
 '##ester',
 '-',
 'induced',
 '<<Disease>>',
 'abortion',
 '##s',
 '<<Disease>>',
 'is',
 'advocated',
 'only',
 'if',
 'it',
 'can',
 'b

In [21]:
## The function "build_inputs_with_special_tokens" just adds bos and eos tokens?
input_ids_new[1:-1] == input_ids

True

In [22]:
print(f"length of entity_pos: {len(sample0['entity_pos'])}")
sample0['entity_pos']

length of entity_pos: 10


[[(9, 12), (122, 125)],
 [(149, 152)],
 [(154, 157)],
 [(160, 164)],
 [(0, 7), (24, 30), (66, 72)],
 [(51, 53), (96, 98)],
 [(12, 14)],
 [(32, 37), (74, 79)],
 [(39, 45)],
 [(88, 91), (135, 139)]]

In [24]:
args.train_batch_size = 4

RE_model.to(args.device)

dataloader = DataLoader(features, batch_size=8, shuffle=False, collate_fn=collate_fn, drop_last=False)
preds, golds = [], []

for i, batch in enumerate(dataloader):
        RE_model.eval()
        list_feature_id = torch.tensor([i for i in range(args.train_batch_size)])
        inputs = {'input_ids': batch[0].to(args.device),
                  'attention_mask': batch[1].to(args.device),
                  'labels': batch[2],
                  'entity_pos': batch[3],
                  'hts': batch[4],
                  'adj_mention': batch[5].to(args.device),
                  'adj_syntactic_dependency_tree': batch[6].to(args.device),
                  'list_feature_id': list_feature_id.to(args.device)
                  }
        with torch.no_grad():
            output = RE_model(**inputs)
            loss = output[0]
            pred = output[1].cpu().numpy()
            pred[np.isnan(pred)] = 0
            preds.append(pred)
            golds.append(np.concatenate([np.array(label, np.float32) for label in batch[2]], axis=0))

preds = np.concatenate(preds, axis=0).astype(np.float32)
golds = np.concatenate(golds, axis=0).astype(np.float32)

In [25]:
preds.shape

(5087, 2)

In [26]:
preds[0:15]

array([[0., 1.],
       [0., 1.],
       [0., 1.],
       [1., 0.],
       [1., 0.],
       [1., 0.],
       [1., 0.],
       [1., 0.],
       [1., 0.],
       [1., 0.],
       [1., 0.],
       [1., 0.],
       [1., 0.],
       [1., 0.],
       [1., 0.]], dtype=float32)

In [27]:
preds[1]

array([0., 1.], dtype=float32)

In [28]:
golds.shape

(5087, 2)

## 4. Looking at how a single data point is processed
- This part examines how the input data gets processed
- It follows the preprocessing steps shown above and see how sample 0 of the dev set gets processed
- The ultimate aim is to see the data format that we need to do to the NER model outputs
- As well as detect any missing information that we need.

- Line-by-line examination of the read_bio function from prepare_bio.py

In [29]:
## Reading in all the lines
with open(file_in, 'r') as infile:
        lines = infile.readlines()

In [30]:
## Each line corresponds to 1 single data point in the dev set (i.e. 500 in total)
lines[0]

'6794356\tTricuspid valve regurgitation and lithium carbonate toxicity in a newborn infant .|A newborn with massive tricuspid regurgitation , atrial flutter , congestive heart failure , and a high serum lithium level is described .|This is the first patient to initially manifest tricuspid regurgitation and atrial flutter , and the 11th described patient with cardiac disease among infants exposed to lithium compounds in the first trimester of pregnancy .|Sixty - three percent of these infants had tricuspid valve involvement .|Lithium carbonate may be a factor in the increasing incidence of congenital heart disease when taken during early pregnancy .|It also causes neurologic depression , cyanosis , and cardiac arrhythmia when consumed prior to delivery .\t1:CID:2\tL2R\tCROSS\t82-84\t105-107\tD016651\tlithium carbonate|Lithium carbonate\tChemical\t4:82\t6:84\t0:4\tD003866\tneurologic depression\tDisease\t105\t107\t5\t1:CID:2\tL2R\tCROSS\t82-84\t108-109\tD016651\tlithium carbonate|Lithium

In [31]:
## 1. Frist step is to record the "pmid", which should be the PubMed ID
sample_line = lines[0]
sample_line = sample_line.rstrip().split('\t')
sample_line[0]

'6794356'

In [32]:
## 2. The code then processes the main text, which is shown below after the strip and split
sample_line

['6794356',
 'Tricuspid valve regurgitation and lithium carbonate toxicity in a newborn infant .|A newborn with massive tricuspid regurgitation , atrial flutter , congestive heart failure , and a high serum lithium level is described .|This is the first patient to initially manifest tricuspid regurgitation and atrial flutter , and the 11th described patient with cardiac disease among infants exposed to lithium compounds in the first trimester of pregnancy .|Sixty - three percent of these infants had tricuspid valve involvement .|Lithium carbonate may be a factor in the increasing incidence of congenital heart disease when taken during early pregnancy .|It also causes neurologic depression , cyanosis , and cardiac arrhythmia when consumed prior to delivery .',
 '1:CID:2',
 'L2R',
 'CROSS',
 '82-84',
 '105-107',
 'D016651',
 'lithium carbonate|Lithium carbonate',
 'Chemical',
 '4:82',
 '6:84',
 '0:4',
 'D003866',
 'neurologic depression',
 'Disease',
 '105',
 '107',
 '5',
 '1:CID:2',
 'L

In [33]:
## 3. Then the text (including title) gets split into sentences, and then each sentence becomes a list of words.
sample_text = sample_line[1]
sample_sents = [t.split(' ') for t in sample_text.split('|')]
sample_sents

[['Tricuspid',
  'valve',
  'regurgitation',
  'and',
  'lithium',
  'carbonate',
  'toxicity',
  'in',
  'a',
  'newborn',
  'infant',
  '.'],
 ['A',
  'newborn',
  'with',
  'massive',
  'tricuspid',
  'regurgitation',
  ',',
  'atrial',
  'flutter',
  ',',
  'congestive',
  'heart',
  'failure',
  ',',
  'and',
  'a',
  'high',
  'serum',
  'lithium',
  'level',
  'is',
  'described',
  '.'],
 ['This',
  'is',
  'the',
  'first',
  'patient',
  'to',
  'initially',
  'manifest',
  'tricuspid',
  'regurgitation',
  'and',
  'atrial',
  'flutter',
  ',',
  'and',
  'the',
  '11th',
  'described',
  'patient',
  'with',
  'cardiac',
  'disease',
  'among',
  'infants',
  'exposed',
  'to',
  'lithium',
  'compounds',
  'in',
  'the',
  'first',
  'trimester',
  'of',
  'pregnancy',
  '.'],
 ['Sixty',
  '-',
  'three',
  'percent',
  'of',
  'these',
  'infants',
  'had',
  'tricuspid',
  'valve',
  'involvement',
  '.'],
 ['Lithium',
  'carbonate',
  'may',
  'be',
  'a',
  'factor',
 

In [34]:
## 4. The number of words (including punctuations) in each sentence is then recorded,
## with word position number of the start/end (depending on whether you count 0) set out in the variable "sample_sents_len"
sample_sent_len = [len(i) for i in sample_sents]
sample_sents_len = []
for i in range(len(sample_sent_len)):
    if i == 0:
        sample_sents_len.append(sample_sent_len[0])
    else:
        sample_sents_len.append(sample_sent_len[i] + sample_sents_len[i - 1])

# Note that we have 5 sentences and 6 points here, so it covers all end points (start + finish)
sample_sents_len

[12, 35, 70, 82, 102, 119]

In [35]:
## 5. Extracting the "prs" information, essentially the info on the entities and possible relationship info in the text
## One thing we note is that each chunk of information is supposed to have 17 fields of info, e.g:
'''
- Data from the sample 0 of the dev set has 16 (=len(sample_line[2:])) chunks of 17 entries each.  
- Each chunk contains one possible relationship.  
- Since the text has 2 unique chemicals and 8 unique diseases, we have 16 possible relationships, with 17 fields of info
  expected by the model (or its preprocessor to be exact) for each of the possible relationships

The first of these 16 chunks is shown below:
(1) '1:CID:2',  ## 1:CID:2 or 1:NR:2, signifies CID-relationship exists or no relationship.  This is the target/label.
(2) 'L2R',      ## left-to-right or right-to-left, signifies whether entity 1 came before or after entity 2
(3) 'CROSS',    ## CROSS or NON-CROSS, no idea what this means, but not used in the extraction of features or labels
(4) '82-84',    ## word location of entity 1, in this case lithium carbonate. No idea why the second mention is chosen
(5) '105-107',  ## word location of entity 2, in this case neurologic depression.  Only appeared once in the sample text.
(6) 'D016651',  ## MeSH ID of lithium carbonate
(7) 'lithium carbonate|Lithium carbonate',  ## exact wording used for each mention.  Note that case appears to matter.
(8) 'Chemical', ## entity type information
(9) '4:82',     ## locations where each mention appeared, e.g. lithium carbonate appeared twice at words 4 and 82.
(10) '6:84',    ## locations where each mention stopped, e.g. lithium carbonate stopped at words 5 and 82.
(11) '0:4',     ## sentence number where the entity was mentioned, e.g. lithium carbaonate appeared in sentences 0 and 4.
(12) 'D003866', ## MeSH ID of neurologic depression
(13) 'neurologic depression',               ## exact wording used for each mention.
(14) 'Disease', ## entity type information
(15) '105',     ## locations where each mention appeared, e.g. neurologic depression appeared once at words 105-106.
(16) '107',     ## locations where each mention appeared, e.g. neurologic depression appeared once at words 105-106.
(17) '5',       ## sentence number where the entity was mentioned, e.g. neurologic depression mentioned in sentence 5.
'''
'''
And this is the chunks function, so it just separate the parts of the sample after the text into chunks,
assuming each chunk has exactly 17 entries (see above for an example) each:

def chunks(l, n):
    res = []
    for i in range(0, len(l), n):
        assert len(l[i:i + n]) == n
        res += [l[i:i + n]]
    return res
'''
sample_entities = {}
sample_ent2ent_type = {}
sample_ent2idx = {}
sample_train_triples = {}
sample_entity_pos = set()

sample_prs = chunks(sample_line[2:], 17)
sample_prs

[['1:CID:2',
  'L2R',
  'CROSS',
  '82-84',
  '105-107',
  'D016651',
  'lithium carbonate|Lithium carbonate',
  'Chemical',
  '4:82',
  '6:84',
  '0:4',
  'D003866',
  'neurologic depression',
  'Disease',
  '105',
  '107',
  '5'],
 ['1:CID:2',
  'L2R',
  'CROSS',
  '82-84',
  '108-109',
  'D016651',
  'lithium carbonate|Lithium carbonate',
  'Chemical',
  '4:82',
  '6:84',
  '0:4',
  'D003490',
  'cyanosis',
  'Disease',
  '108',
  '109',
  '5'],
 ['1:CID:2',
  'L2R',
  'CROSS',
  '82-84',
  '111-113',
  'D016651',
  'lithium carbonate|Lithium carbonate',
  'Chemical',
  '4:82',
  '6:84',
  '0:4',
  'D001145',
  'cardiac arrhythmia',
  'Disease',
  '111',
  '113',
  '5'],
 ['1:NR:2',
  'R2L',
  'NON-CROSS',
  '4-6',
  '0-3',
  'D016651',
  'lithium carbonate|Lithium carbonate',
  'Chemical',
  '4:82',
  '6:84',
  '0:4',
  'D014262',
  'Tricuspid valve regurgitation|tricuspid regurgitation|tricuspid regurgitation',
  'Disease',
  '0:16:43',
  '3:18:45',
  '0:1:2'],
 ['1:NR:2',
  'R2L'

In [36]:
## 6. The next step involve processing each "p" in the "prs" list.  
## First the start, end, entity type, exact wording and MeSH ID of the first entity is extracted
sample_p = sample_prs[0]

'''part skipped, but this is in the originally loop iterating through the p's
if sample_p[0] == "not_include":  ## So why is something not included??
    continue
'''
sample_es = list(map(int, sample_p[8].split(':')))
sample_ed = list(map(int, sample_p[9].split(':')))
sample_tpy = sample_p[7]
sample_entity_str = list(map(str, sample_p[6].split('|')))
sample_entity_id = p[5]

'''
Here are the 5 sentences of sample 0, along with the word indices:
0-11:    Tricuspid valve regurgitation and lithium carbonate toxicity in a newborn infant .
12-34:   A newborn with massive tricuspid regurgitation , atrial flutter , congestive heart failure , 
         and a high serum lithium level is described .
35-69:   This is the first patient to initially manifest tricuspid regurgitation and atrial flutter , 
         and the 11th described patient with cardiac disease among infants exposed to lithium compounds in the first 
         trimester of pregnancy .
70-81:   Sixty - three percent of these infants had tricuspid valve involvement .
82-101:  Lithium carbonate may be a factor in the increasing incidence of congenital heart disease when taken during 
         early pregnancy .
102-118: It also causes neurologic depression , cyanosis , and cardiac arrhythmia when consumed prior to delivery .
'''
# We can now speculate what the numbers mean
# Note that "lithium carbonate" appeared 2 times in the text, at words [4-5] and [82-83], 
# so es and ed probably marks the beginning and end word locations.
(sample_es, sample_ed, sample_tpy, sample_entity_str, sample_entity_id)

([4, 82],
 [6, 84],
 'Chemical',
 ['lithium carbonate', 'Lithium carbonate'],
 'D005947')

In [37]:
## 7. After extracting the details of the first mentioned entity, the following processing is done:

if sample_entity_id not in sample_entities:
    sample_entities[sample_entity_id] = []

for sample_start, sample_end, sample_string in zip(sample_es, sample_ed, sample_entity_str):
    sample_entity_pos.add((sample_start, sample_end, sample_tpy))
    # entity in which sent
    sample_sent_in_id = -1
    for i in range(len(sample_sents_len)):
        if sample_sents_len[i] > sample_end:
            sample_sent_in_id = i
            break
    if [sample_start, sample_end, sample_tpy, sample_string, sample_sent_in_id] not in sample_entities[sample_entity_id]:
        sample_entities[sample_entity_id].append([sample_start, sample_end, sample_tpy, sample_string, sample_sent_in_id])
        
print(f"Here is the compiled sample_entity_pos:{sample_entity_pos}")
print("\nAnd here is the compiled sample_entities:")
sample_entities

Here is the compiled sample_entity_pos:{(4, 6, 'Chemical'), (82, 84, 'Chemical')}

And here is the compiled sample_entities:


{'D005947': [[4, 6, 'Chemical', 'lithium carbonate', 0],
  [82, 84, 'Chemical', 'Lithium carbonate', 4]]}

In [38]:
## 8. Then the code handles the 2nd entity set out in the "p" of each value of "prs"
sample_es2 = list(map(int, sample_p[14].split(':')))
sample_ed2 = list(map(int, sample_p[15].split(':')))
sample_tpy2 = sample_p[13]
sample_entity_str2 = list(map(str, sample_p[12].split('|')))
sample_entity_id2 = sample_p[11]

if sample_entity_id2 not in sample_entities:
    sample_entities[sample_entity_id2] = []
    
for sample_start2, sample_end2, sample_string2 in zip(sample_es2, sample_ed2, sample_entity_str2):
    sample_entity_pos.add((sample_start2, sample_end2, sample_tpy2))
    
    # entity in which sent
    sample_sent_in_id2 = -1
    
    for i in range(len(sample_sents_len)):
        if sample_sents_len[i] > sample_end2:
            sample_sent_in_id2 = i
            break
    if [sample_start2, sample_end2, sample_tpy2, sample_string2, sample_sent_in_id2] not in sample_entities[sample_entity_id2]:
        sample_entities[sample_entity_id2].append([sample_start2, sample_end2, sample_tpy2, sample_string2, sample_sent_in_id2])
        
print(f"Here is the compiled sample_entity_pos:{sample_entity_pos}")
print("\nAnd here is the compiled sample_entities:")
sample_entities

Here is the compiled sample_entity_pos:{(4, 6, 'Chemical'), (82, 84, 'Chemical'), (105, 107, 'Disease')}

And here is the compiled sample_entities:


{'D005947': [[4, 6, 'Chemical', 'lithium carbonate', 0],
  [82, 84, 'Chemical', 'Lithium carbonate', 4]],
 'D003866': [[105, 107, 'Disease', 'neurologic depression', 5]]}

In [40]:
## 9. Preparing a dictionary of the MeSH IDs against entity typing
if sample_p[5] not in sample_ent2ent_type:
    sample_ent2ent_type[sample_p[5]] = sample_p[7]
if sample_p[11] not in sample_ent2ent_type:
    sample_ent2ent_type[sample_p[11]] = sample_p[13]

'''part skipped, but this is in the originally loop iterating through the samples
if len(sample_entity_pos) == 0:
    # print(pmid, 'rel is none')
    continue
'''
    
sample_ent2ent_type

{'D016651': 'Chemical', 'D003866': 'Disease'}

In [42]:
## 10. This part uses spacy for analysis, based on the annotations (in Chinese) from the authors
## Let's first look at what the nlp.vocab is

'''
Here is the definition of the WhitespaceTokenizer:

class WhitespaceTokenizer:
    def __init__(self, vocab):
        self.vocab = vocab

    def __call__(self, text):
        words = text.split(" ")
        spaces = [True] * len(words)
        # Avoid zero-length tokens
        for i, word in enumerate(words):
            if word == "":
                words[i] = " "
                spaces[i] = False
        # Remove the final trailing space
        if words[-1] == " ":
            words = words[0:-1]
            spaces = spaces[0:-1]
        else:
            spaces[-1] = False

        return Doc(self.vocab, words=words, spaces=spaces)
'''

# spacy 分析
nlp.tokenizer = WhitespaceTokenizer(nlp.vocab)
sample_doc = nlp(sample_text.replace('|', ' '))

# so it reverts back to the state without the '|' and becomes a paragraph again.
# Note that it is not a simple string, but a spacy object
print(f"Type of the doc varuiable: {type(doc)}\n")
sample_doc

Type of the doc varuiable: <class 'spacy.tokens.doc.Doc'>



Tricuspid valve regurgitation and lithium carbonate toxicity in a newborn infant . A newborn with massive tricuspid regurgitation , atrial flutter , congestive heart failure , and a high serum lithium level is described . This is the first patient to initially manifest tricuspid regurgitation and atrial flutter , and the 11th described patient with cardiac disease among infants exposed to lithium compounds in the first trimester of pregnancy . Sixty - three percent of these infants had tricuspid valve involvement . Lithium carbonate may be a factor in the increasing incidence of congenital heart disease when taken during early pregnancy . It also causes neurologic depression , cyanosis , and cardiac arrhythmia when consumed prior to delivery .

In [43]:
## 11. Uses spacy to break up the whole paragraph in the doc object back into a list of spacy tokens
## At least for sample 0, this is equivalent to just splitting the words using .split(), but the list contains spacy tokens
## rather than word in strings.
sample_spacy_tokens = []
sample_spacy_offset = nlp("*")[0]

for sample_token in sample_doc:
    sample_spacy_tokens.append(sample_token)

print(f"Type of the each token in the list: {type(sample_spacy_tokens[0])}\n")
sample_spacy_tokens

Type of the each token in the list: <class 'spacy.tokens.token.Token'>



[Tricuspid,
 valve,
 regurgitation,
 and,
 lithium,
 carbonate,
 toxicity,
 in,
 a,
 newborn,
 infant,
 .,
 A,
 newborn,
 with,
 massive,
 tricuspid,
 regurgitation,
 ,,
 atrial,
 flutter,
 ,,
 congestive,
 heart,
 failure,
 ,,
 and,
 a,
 high,
 serum,
 lithium,
 level,
 is,
 described,
 .,
 This,
 is,
 the,
 first,
 patient,
 to,
 initially,
 manifest,
 tricuspid,
 regurgitation,
 and,
 atrial,
 flutter,
 ,,
 and,
 the,
 11th,
 described,
 patient,
 with,
 cardiac,
 disease,
 among,
 infants,
 exposed,
 to,
 lithium,
 compounds,
 in,
 the,
 first,
 trimester,
 of,
 pregnancy,
 .,
 Sixty,
 -,
 three,
 percent,
 of,
 these,
 infants,
 had,
 tricuspid,
 valve,
 involvement,
 .,
 Lithium,
 carbonate,
 may,
 be,
 a,
 factor,
 in,
 the,
 increasing,
 incidence,
 of,
 congenital,
 heart,
 disease,
 when,
 taken,
 during,
 early,
 pregnancy,
 .,
 It,
 also,
 causes,
 neurologic,
 depression,
 ,,
 cyanosis,
 ,,
 and,
 cardiac,
 arrhythmia,
 when,
 consumed,
 prior,
 to,
 delivery,
 .]

In [45]:
sample_entity_pos

[(4, 6, 'Chemical'), (82, 84, 'Chemical'), (105, 107, 'Disease')]

In [44]:
## 12. This step basically tries to insert special tokens, e.g. <<Chemical>> and <<Disease>> before the entities
## while at the same time checks how each word is tokenized (since 1 word can be made of multiple tokens) and finally
## creates a mapping that bridges the word location indices with the token indices.

# 依据spacy的分词解析结果，存放开始的index
# id对应的单词
sample_index2word = {}
# 一个单词对应的所有分词片段id
sample_word2piecesid = {}
# spacy中token的当前id
sample_spacy_token_id = 0

# entitys 中缺少来源于哪个句子的id，可以考虑添加
sample_sents = [t.split(' ') for t in sample_text.split('|')]

sample_new_sents = []
sample_token_map = []
sample_lengthofPice = 0
sample_sent_map = {}
sample_entity_pos = list(sample_entity_pos)
sample_entity_pos.sort()
sample_i_t = 0

for sample_sent in sample_sents:
    for sample_token in sample_sent:
        sample_tokens_wordpiece = tokenizer.tokenize(sample_token)
        # 每个实体的起始位置都加进去
        sample_oneToken = []
        sample_eid = 0
        for sample_i, sample_ep in enumerate(sample_entity_pos):
            sample_start, sample_end, sample_tpy = sample_ep
            if sample_i_t == sample_start or sample_i_t == sample_end:
                sample_eid = sample_i
                break
                
        if sample_i_t == sample_entity_pos[sample_eid][0] or sample_i_t == sample_entity_pos[sample_eid][1]:
            # 标记实体的'*'应该算在是实体的一部分，但不是一个单词的一部分，故不算其中。
            sample_index2word[len(sample_index2word)] = sample_spacy_offset
            for sample_token_wordpiece in sample_tokens_wordpiece:
                sample_index2word[len(sample_index2word)] = sample_spacy_tokens[sample_spacy_token_id]
                if sample_spacy_tokens[sample_spacy_token_id] not in sample_word2piecesid:
                    sample_word2piecesid[sample_spacy_tokens[sample_spacy_token_id]] = []
                sample_word2piecesid[sample_spacy_tokens[sample_spacy_token_id]].append(len(sample_index2word) - 1)
                
            sample_oneToken.append(sample_lengthofPice + 1)
            if 'Chemical' in sample_entity_pos[sample_eid][2]:
                sample_special_token = '<<Chemical>>'
            elif 'Disease' in sample_entity_pos[sample_eid][2]:
                sample_special_token = '<<Disease>>'
            elif 'Gene' in sample_entity_pos[sample_eid][2]:
                sample_special_token = '<<Gene>>'
            elif 'sample_Variant' in sample_entity_pos[sample_eid][2]:
                sample_special_token = '<<Variant>>'
            else:
                raise KeyError('not Chemical or Disease or Gene or Variant')
            sample_tokens_wordpiece = [sample_special_token] + sample_tokens_wordpiece
            sample_lengthofPice += len(sample_tokens_wordpiece)
            sample_oneToken.append(sample_lengthofPice)
        else:
            for sample_token_wordpiece in sample_tokens_wordpiece:
                sample_index2word[len(sample_index2word)] = sample_spacy_tokens[sample_spacy_token_id]
                if sample_spacy_tokens[sample_spacy_token_id] not in sample_word2piecesid:
                    sample_word2piecesid[sample_spacy_tokens[sample_spacy_token_id]] = []
                sample_word2piecesid[sample_spacy_tokens[sample_spacy_token_id]].append(len(sample_index2word) - 1)
            
            sample_oneToken.append(sample_lengthofPice)
            sample_lengthofPice += len(sample_tokens_wordpiece)
            sample_oneToken.append(sample_lengthofPice)
            
        # 相当于docred中的new_map，分词后每个词对应的位置
        sample_sent_map[sample_i_t] = len(sample_new_sents)
        sample_new_sents.extend(sample_tokens_wordpiece)
        sample_token_map.append(sample_oneToken)
        sample_i_t += 1
        sample_spacy_token_id += 1
    # sent_map[i_t] = len(new_sents)
sample_sents2 = sample_new_sents

sample_sents2

['tric',
 '##usp',
 '##id',
 'valve',
 'regurg',
 '##itation',
 'and',
 '<<Chemical>>',
 'lithium',
 'carbonate',
 '<<Chemical>>',
 'toxicity',
 'in',
 'a',
 'newborn',
 'infant',
 '.',
 'a',
 'newborn',
 'with',
 'massive',
 'tric',
 '##usp',
 '##id',
 'regurg',
 '##itation',
 ',',
 'atrial',
 'flu',
 '##tte',
 '##r',
 ',',
 'cong',
 '##est',
 '##ive',
 'heart',
 'failure',
 ',',
 'and',
 'a',
 'high',
 'serum',
 'lithium',
 'level',
 'is',
 'described',
 '.',
 'this',
 'is',
 'the',
 'first',
 'patient',
 'to',
 'initially',
 'manifest',
 'tric',
 '##usp',
 '##id',
 'regurg',
 '##itation',
 'and',
 'atrial',
 'flu',
 '##tte',
 '##r',
 ',',
 'and',
 'the',
 '11',
 '##th',
 'described',
 'patient',
 'with',
 'cardiac',
 'disease',
 'among',
 'infants',
 'exposed',
 'to',
 'lithium',
 'compounds',
 'in',
 'the',
 'first',
 'trimester',
 'of',
 'pregnancy',
 '.',
 'six',
 '##ty',
 '-',
 'three',
 'percent',
 'of',
 'these',
 'infants',
 'had',
 'tric',
 '##usp',
 '##id',
 'valve',
 'invo

In [46]:
## sample_sent_map and sample_sent_map tries to map the original word-by-word token indices with 
## the indices after the words go through the bert tokenizer
len(sample_token_map) == sum([len(sample_sent) for sample_sent in sample_sents])

True

In [47]:
## E.g sample_token_map is a list of 119 sub-lists, corresponding to the 119 words.  
## Each sublist signifies where in the tokenized indices the words can be found.

## Note that there are gaps, e.g. [8,9] follows [6,7] because the special token "<<Chemical>>" is added as [7,8]
sample_token_map

[[0, 3],
 [3, 4],
 [4, 6],
 [6, 7],
 [8, 9],
 [9, 10],
 [11, 12],
 [12, 13],
 [13, 14],
 [14, 15],
 [15, 16],
 [16, 17],
 [17, 18],
 [18, 19],
 [19, 20],
 [20, 21],
 [21, 24],
 [24, 26],
 [26, 27],
 [27, 28],
 [28, 31],
 [31, 32],
 [32, 35],
 [35, 36],
 [36, 37],
 [37, 38],
 [38, 39],
 [39, 40],
 [40, 41],
 [41, 42],
 [42, 43],
 [43, 44],
 [44, 45],
 [45, 46],
 [46, 47],
 [47, 48],
 [48, 49],
 [49, 50],
 [50, 51],
 [51, 52],
 [52, 53],
 [53, 54],
 [54, 55],
 [55, 58],
 [58, 60],
 [60, 61],
 [61, 62],
 [62, 65],
 [65, 66],
 [66, 67],
 [67, 68],
 [68, 70],
 [70, 71],
 [71, 72],
 [72, 73],
 [73, 74],
 [74, 75],
 [75, 76],
 [76, 77],
 [77, 78],
 [78, 79],
 [79, 80],
 [80, 81],
 [81, 82],
 [82, 83],
 [83, 84],
 [84, 85],
 [85, 86],
 [86, 87],
 [87, 88],
 [88, 90],
 [90, 91],
 [91, 92],
 [92, 93],
 [93, 94],
 [94, 95],
 [95, 96],
 [96, 97],
 [97, 100],
 [100, 101],
 [101, 102],
 [102, 103],
 [104, 105],
 [105, 106],
 [107, 108],
 [108, 109],
 [109, 110],
 [110, 111],
 [111, 112],
 [112, 113]

In [62]:
## The next chunk looks populates the object "train_triples" (called sample_train_triples here)
## The first part involves look at the head and tail entities, hence (h_id, t_id), relation,
## while the second part is to prepare the labels.

sample_entity_pos2 = []
for sample_p in sample_prs:
    if sample_p[0] == "not_include":
        continue
    if sample_p[1] == "L2R":
        sample_h_id, sample_t_id = sample_p[5], sample_p[11]
        sample_h_start, sample_t_start = sample_p[8], sample_p[14]
        sample_h_end, sample_t_end = sample_p[9], sample_p[15]
                    
    else:
        sample_t_id, sample_h_id = sample_p[5], sample_p[11]
        sample_t_start, sample_h_start = sample_p[8], sample_p[14]
        sample_t_end, sample_h_end = sample_p[9], sample_p[15]
    
print(f"These are the head entities {(sample_h_id, sample_h_start, sample_h_end)}")
print(f"And the tail ones {(sample_t_id, sample_t_start, sample_t_end)}")

print("Remember this is only the last p in the prs.")

These are the head entities ('D008094', '30:61', '31:62')
And the tail ones ('D001145', '111', '113')
Remember this is only the last p in the prs.


In [77]:
## From the cell above
sample_ent2idx = {}
sample_train_triples = {}
sample_entity_pos2 = []
for sample_p in sample_prs:
    if sample_p[0] == "not_include":
        continue
    if sample_p[1] == "L2R":
        sample_h_id, sample_t_id = sample_p[5], sample_p[11]
        sample_h_start, sample_t_start = sample_p[8], sample_p[14]
        sample_h_end, sample_t_end = sample_p[9], sample_p[15]
                    
    else:
        sample_t_id, sample_h_id = sample_p[5], sample_p[11]
        sample_t_start, sample_h_start = sample_p[8], sample_p[14]
        sample_t_end, sample_h_end = sample_p[9], sample_p[15]

## Next steps
    sample_h_start = map(int, sample_h_start.split(':'))
    sample_h_end = map(int, sample_h_end.split(':'))
    sample_t_start = map(int, sample_t_start.split(':'))
    sample_t_end = map(int, sample_t_end.split(':'))
    sample_h_start = [sample_sent_map[idx] for idx in sample_h_start]
    sample_h_end = [sample_sent_map[idx] for idx in sample_h_end]
    sample_t_start = [sample_sent_map[idx] for idx in sample_t_start]
    sample_t_end = [sample_sent_map[idx] for idx in sample_t_end]
    
    
    if sample_h_id not in sample_ent2idx:
        sample_ent2idx[sample_h_id] = len(sample_ent2idx)
        sample_entity_pos2.append(list(zip(sample_h_start, sample_h_end)))
    if sample_t_id not in sample_ent2idx:
        sample_ent2idx[sample_t_id] = len(sample_ent2idx)
        sample_entity_pos2.append(list(zip(sample_t_start, sample_t_end)))
    sample_h_id, sample_t_id = sample_ent2idx[sample_h_id], sample_ent2idx[sample_t_id]
    
    if args.rel2:
        sample_r = 0 if sample_p[0] == '1:NR:2' else 1
    else:
        sample_r = rel2id[sample_p[0]]
    if (sample_h_id, sample_t_id) not in sample_train_triples:
        sample_train_triples[(sample_h_id, sample_t_id)] = [{'relation': sample_r}]
    else:
        sample_train_triples[(sample_h_id, sample_t_id)].append({'relation': sample_r})
        
        
print(f"- These are the final list of entities (note that 1 and 5 are chemicals) {sample_ent2idx}\n")
print(f"- The number of relationships in the train_triples: {len(sample_train_triples)}\n")
print(f"- And the train_triples {sample_train_triples}")

## Note that there are 15 relationships because there is 1 "not include", i.e. 0 vs 9

- These are the final list of entities (note that 1 and 5 are chemicals) {'D016651': 0, 'D003866': 1, 'D003490': 2, 'D001145': 3, 'D014262': 4, 'D008094': 5, 'D064420': 6, 'D001282': 7, 'D006333': 8, 'D006331': 9}

- The number of relationships in the train_triples: 15

- And the train_triples {(0, 1): [{'relation': 1}], (0, 2): [{'relation': 1}], (0, 3): [{'relation': 1}], (4, 0): [{'relation': 0}], (4, 5): [{'relation': 0}], (0, 6): [{'relation': 0}], (0, 7): [{'relation': 0}], (0, 8): [{'relation': 0}], (6, 5): [{'relation': 0}], (7, 5): [{'relation': 0}], (8, 5): [{'relation': 0}], (5, 9): [{'relation': 0}], (5, 1): [{'relation': 0}], (5, 2): [{'relation': 0}], (5, 3): [{'relation': 0}]}


In [81]:
## This next part uses the train_triples to prepare 'relations' (i.e. labels field of the features dict)
## and hts (i.e. hts field of the features dict)
sample_pos_samples = 0
sample_neg_samples = 0
sample_relations, sample_hts = [], []
for sample_h, sample_t in sample_train_triples.keys():
    sample_relation = [0] * len(rel2id)
    for sample_mention in sample_train_triples[sample_h, sample_t]:
        sample_relation[sample_mention["relation"]] = 1
        if sample_mention["relation"] != 0:
            sample_pos_samples += 1
        else:
            sample_neg_samples += 1
    sample_relations.append(sample_relation)
    sample_hts.append([sample_h, sample_t])

In [82]:
sample_relations

[[0, 1],
 [0, 1],
 [0, 1],
 [1, 0],
 [1, 0],
 [1, 0],
 [1, 0],
 [1, 0],
 [1, 0],
 [1, 0],
 [1, 0],
 [1, 0],
 [1, 0],
 [1, 0],
 [1, 0]]

In [83]:
## So this is kind of only a matching table showing which entity pair each row is representing
sample_hts

[[0, 1],
 [0, 2],
 [0, 3],
 [4, 0],
 [4, 5],
 [0, 6],
 [0, 7],
 [0, 8],
 [6, 5],
 [7, 5],
 [8, 5],
 [5, 9],
 [5, 1],
 [5, 2],
 [5, 3]]

In [94]:
## This part just adjusts the length of the inputs.  The -2 presumably deals with the CLS and SEP tokens
sample_sents = sample_new_sents[:max_seq_length - 2]
sample_input_ids = tokenizer.convert_tokens_to_ids(sample_sents)
sample_input_ids_new = tokenizer.build_inputs_with_special_tokens(sample_input_ids)

sample_max_len = len(sample_input_ids_new)

len(sample_input_ids_new)

148

In [97]:
## Sample_input_ids_new is just sample_input_ids plus CLS and SEP at the beginning and end
sample_input_ids == sample_input_ids_new[1:-1]

True

In [104]:
## The next part is the adjacency mentions part (a_mentions)
## The explanation in Chinese suggests it's some kind of "structure calculations"

# 结构计算
sample_a_mentions = np.eye(len(sample_input_ids))   ## np.eye is an identity matrix
sample_a_mentions_new = np.eye(sample_max_len)

sample_adj_syntactic_dependency_tree = np.eye(len(sample_input_ids))
sample_adj_syntactic_dependency_tree_new = np.eye(sample_max_len)
offset = 1
sample_edges = 0

for token_s in sample_token_map:    ## Remember, the special tokens such as "<<Chemical>>" are not included in the mapping
    start = token_s[0]
    end = token_s[1]
    for i in range(start, end):
        for j in range(start, end):
            if i < (len(sample_input_ids) - 1) and j < (len(sample_input_ids) - 1):
                if sample_a_mentions[i][j] == 0:
                    sample_a_mentions[i][j] = 1
                    sample_a_mentions_new[i + 1][j + 1] = 1    ## account for the offset caused by the CLS token?
                    sample_edges += 1

## So this is kind of a map showing where the tokens deriving from the same words can be found.
print(f"Number of edges added in the 1st dev sample: {sample_edges}")
sample_a_mentions[:16, :16]

Number of edges added in the 1st dev sample: 56


array([[1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,

In [103]:
## tokens 21-23 were from the same word, so are words 24-25
sample_a_mentions[17:32, 17:32]

array([[1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 1., 1., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 1., 1., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 0.],
       [0., 0., 0., 0., 0., 0., 0.

In [108]:
## We compile a full list of entities here
sample_entities_full = {}
sample_ent2ent_type_full = {}
sample_entity_pos_full = set()

for p in sample_prs:
    if p[0] == "not_include":
        continue
    es = list(map(int, p[8].split(':')))
    ed = list(map(int, p[9].split(':')))
    tpy = p[7]
    entity_str = list(map(str, p[6].split('|')))
    entity_id = p[5]
    if entity_id not in sample_entities_full:
        sample_entities_full[entity_id] = []
    for start, end, string in zip(es, ed, entity_str):
        sample_entity_pos_full.add((start, end, tpy))
        # entity in which sent
        sent_in_id = -1
        for i in range(len(sents_len)):
            if sents_len[i] > end:
                sent_in_id = i
                break
        if [start, end, tpy, string, sent_in_id] not in sample_entities_full[entity_id]:
            sample_entities_full[entity_id].append([start, end, tpy, string, sent_in_id])

    es = list(map(int, p[14].split(':')))
    ed = list(map(int, p[15].split(':')))
    tpy = p[13]
    entity_str = list(map(str, p[12].split('|')))
    entity_id = p[11]
    if entity_id not in sample_entities_full:
        sample_entities_full[entity_id] = []
    for start, end, string in zip(es, ed, entity_str):
        sample_entity_pos_full.add((start, end, tpy))
        # entity in which sent
        sent_in_id = -1
        for i in range(len(sents_len)):
            if sents_len[i] > end:
                sent_in_id = i
                break
        if [start, end, tpy, string, sent_in_id] not in sample_entities_full[entity_id]:
            sample_entities_full[entity_id].append([start, end, tpy, string, sent_in_id])
    if p[5] not in ent2ent_type:
        sample_ent2ent_type_full[p[5]] = p[7]
    if p[11] not in ent2ent_type:
        sample_ent2ent_type_full[p[11]] = p[13]


sample_entities_full

{'D016651': [[4, 6, 'Chemical', 'lithium carbonate', 0],
  [82, 84, 'Chemical', 'Lithium carbonate', 3]],
 'D003866': [[105, 107, 'Disease', 'neurologic depression', 4]],
 'D003490': [[108, 109, 'Disease', 'cyanosis', 4]],
 'D001145': [[111, 113, 'Disease', 'cardiac arrhythmia', 4]],
 'D014262': [[0, 3, 'Disease', 'Tricuspid valve regurgitation', 0],
  [16, 18, 'Disease', 'tricuspid regurgitation', 1],
  [43, 45, 'Disease', 'tricuspid regurgitation', 3]],
 'D008094': [[30, 31, 'Chemical', 'lithium', 2],
  [61, 62, 'Chemical', 'lithium', 3]],
 'D064420': [[6, 7, 'Disease', 'toxicity', 0]],
 'D001282': [[19, 21, 'Disease', 'atrial flutter', 1],
  [46, 48, 'Disease', 'atrial flutter', 3]],
 'D006333': [[22, 25, 'Disease', 'congestive heart failure', 1]],
 'D006331': [[55, 57, 'Disease', 'cardiac disease', 3],
  [93, 96, 'Disease', 'congenital heart disease', 3]]}

In [136]:
## From previous cells
sample_a_mentions = np.eye(len(sample_input_ids))   ## np.eye is an identity matrix
sample_a_mentions_new = np.eye(sample_max_len)

sample_adj_syntactic_dependency_tree = np.eye(len(sample_input_ids))
sample_adj_syntactic_dependency_tree_new = np.eye(sample_max_len)
offset = 1
sample_edges = 0

for token_s in sample_token_map:    ## Remember, the special tokens such as "<<Chemical>>" are not included in the mapping
    start = token_s[0]
    end = token_s[1]
    for i in range(start, end):
        for j in range(start, end):
            if i < (len(sample_input_ids) - 1) and j < (len(sample_input_ids) - 1):
                if sample_a_mentions[i][j] == 0:
                    sample_a_mentions[i][j] = 1
                    sample_a_mentions_new[i + 1][j + 1] = 1    ## account for the offset caused by the CLS token?
                    sample_edges += 1

## New part examined here
# 所有实体在 tokens 中的跨度
sample_mentionsofPice = []
for eid in sample_entities_full:    # eid is the MeSH ID of the entities
    for i in sample_entities_full[eid]:
        sample_ment = [i[0], i[1]]  # this is the word index of the mentioning of the entity
        # Using the token_map to locate the token index (beginning and end)
        sample_mentionsofPice.append([sample_token_map[sample_ment[0]][0], sample_token_map[sample_ment[1] - 1][1]])
for sample_ment in sample_mentionsofPice:
    start = sample_ment[0]
    end = sample_ment[1]
    for i in range(start, end):
        for j in range(start, end):
            if i < (len(sample_input_ids) - 1) and j < (len(sample_input_ids) - 1):
                if sample_a_mentions[i][j] == 0:
                    sample_a_mentions[i][j] = 1
                    sample_a_mentions_new[i + 1][j + 1] = 1
                    sample_edges += 1
                    
print(f"Number of edges added in the 1st dev sample: {sample_edges}")
## So tokens from the same entity but from different words are also toggled to 1 (from 0)
sample_a_mentions[:16, :16]

Number of edges added in the 1st dev sample: 146


array([[1., 1., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [1., 1., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [1., 1., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [1., 1., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [1., 1., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [1., 1., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,

In [119]:
sample_a_mentions[17:32, 17:32]

array([[1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 1., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 1., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 1., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 1., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 1., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 0.],
       [0., 0., 0., 0., 0., 0., 0.

In [121]:
sample_entity_pos_full

{(0, 3, 'Disease'),
 (4, 6, 'Chemical'),
 (6, 7, 'Disease'),
 (16, 18, 'Disease'),
 (19, 21, 'Disease'),
 (22, 25, 'Disease'),
 (30, 31, 'Chemical'),
 (43, 45, 'Disease'),
 (46, 48, 'Disease'),
 (55, 57, 'Disease'),
 (61, 62, 'Chemical'),
 (82, 84, 'Chemical'),
 (93, 96, 'Disease'),
 (105, 107, 'Disease'),
 (108, 109, 'Disease'),
 (111, 113, 'Disease')}

In [125]:
for ent in sample_entity_pos_full:
    print(ent[0])

4
82
0
19
108
105
111
16
55
93
43
6
61
22
30
46


In [139]:
sample_entity_pos_full2 = []
sample_ent2idx_full = {}
for p in sample_prs:
    if p[0] == "not_include":
        continue
    if p[1] == "L2R":
        h_id, t_id = p[5], p[11]
        h_start, t_start = p[8], p[14]
        h_end, t_end = p[9], p[15]
    else:
        t_id, h_id = p[5], p[11]
        t_start, h_start = p[8], p[14]
        t_end, h_end = p[9], p[15]
    h_start = map(int, h_start.split(':'))
    h_end = map(int, h_end.split(':'))
    t_start = map(int, t_start.split(':'))
    t_end = map(int, t_end.split(':'))
    h_start = [sample_sent_map[idx] for idx in h_start]
    h_end = [sample_sent_map[idx] for idx in h_end]
    t_start = [sample_sent_map[idx] for idx in t_start]
    t_end = [sample_sent_map[idx] for idx in t_end]
    if h_id not in sample_ent2idx_full:
        sample_ent2idx_full[h_id] = len(sample_ent2idx_full)
        sample_entity_pos_full2.append(list(zip(h_start, h_end)))
    if t_id not in sample_ent2idx_full:
        sample_ent2idx_full[t_id] = len(sample_ent2idx_full)
        sample_entity_pos_full2.append(list(zip(t_start, t_end)))
        
sample_entity_pos_full2

[[(7, 10), (103, 106)],
 [(128, 131)],
 [(133, 135)],
 [(137, 140)],
 [(0, 6), (21, 26), (55, 60)],
 [(42, 43), (79, 80)],
 [(10, 12)],
 [(27, 31), (61, 65)],
 [(32, 37)],
 [(73, 75), (116, 119)]]

In [140]:
## The last part deals with this entityofPice.  
## This basically identifies all token locations that belong to different mentioning of the same entity.

# 各类实体的实体跨度
sample_entityofPice = []
for ent in sample_entity_pos_full2:
    # 每个单词（属于实体）的起始位置，可能是字母或者索引
    oneEntityP = []
    for ment in ent:
        if (ment[0] + offset) == ment[1]:
            oneEntityP.append(ment[0] + offset)
        for i in range(ment[0] + offset, ment[1]):
            oneEntityP.append(i)
        sample_entityofPice.append(oneEntityP)
predicted_Doc2 = []
for h in range(0, len(sample_entities_full)):
    item = [0, h]
    predicted_Doc2.append(item)

predictedEntityPairPiece = []
for item in predicted_Doc2:
    one_predicted = sample_entityofPice[item[0]] + sample_entityofPice[item[1]]
    predictedEntityPairPiece.append(one_predicted)

sample_entityofPice

[[8, 9, 104, 105],
 [8, 9, 104, 105],
 [129, 130],
 [134],
 [138, 139],
 [1, 2, 3, 4, 5, 22, 23, 24, 25, 56, 57, 58, 59],
 [1, 2, 3, 4, 5, 22, 23, 24, 25, 56, 57, 58, 59],
 [1, 2, 3, 4, 5, 22, 23, 24, 25, 56, 57, 58, 59],
 [43, 80],
 [43, 80],
 [11],
 [28, 29, 30, 62, 63, 64],
 [28, 29, 30, 62, 63, 64],
 [33, 34, 35, 36],
 [74, 117, 118],
 [74, 117, 118]]

In [141]:
predictedEntityPairPiece

[[8, 9, 104, 105, 8, 9, 104, 105],
 [8, 9, 104, 105, 8, 9, 104, 105],
 [8, 9, 104, 105, 129, 130],
 [8, 9, 104, 105, 134],
 [8, 9, 104, 105, 138, 139],
 [8, 9, 104, 105, 1, 2, 3, 4, 5, 22, 23, 24, 25, 56, 57, 58, 59],
 [8, 9, 104, 105, 1, 2, 3, 4, 5, 22, 23, 24, 25, 56, 57, 58, 59],
 [8, 9, 104, 105, 1, 2, 3, 4, 5, 22, 23, 24, 25, 56, 57, 58, 59],
 [8, 9, 104, 105, 43, 80],
 [8, 9, 104, 105, 43, 80]]

In [142]:
for line in predictedEntityPairPiece:
    for i in line:
        for j in line:
            if sample_a_mentions[i][j] == 0:
                sample_a_mentions[i][j] = 1
                sample_a_mentions_new[i + 1][j + 1] = 1
                sample_edges += 1
                
print(f"Number of edges added in the 1st dev sample: {sample_edges}")
## So tokens from the same entity but from different mentionings are also toggled to 1 (from 0)
sample_a_mentions[7:10, 103:105]

Number of edges added in the 1st dev sample: 428


array([[0., 0.],
       [0., 1.],
       [0., 1.]])

In [143]:
sample_index2word

{0: Tricuspid,
 1: Tricuspid,
 2: Tricuspid,
 3: valve,
 4: regurgitation,
 5: regurgitation,
 6: and,
 7: *,
 8: lithium,
 9: carbonate,
 10: *,
 11: toxicity,
 12: in,
 13: a,
 14: newborn,
 15: infant,
 16: .,
 17: A,
 18: newborn,
 19: with,
 20: massive,
 21: tricuspid,
 22: tricuspid,
 23: tricuspid,
 24: regurgitation,
 25: regurgitation,
 26: ,,
 27: atrial,
 28: flutter,
 29: flutter,
 30: flutter,
 31: ,,
 32: congestive,
 33: congestive,
 34: congestive,
 35: heart,
 36: failure,
 37: ,,
 38: and,
 39: a,
 40: high,
 41: serum,
 42: lithium,
 43: level,
 44: is,
 45: described,
 46: .,
 47: This,
 48: is,
 49: the,
 50: first,
 51: patient,
 52: to,
 53: initially,
 54: manifest,
 55: tricuspid,
 56: tricuspid,
 57: tricuspid,
 58: regurgitation,
 59: regurgitation,
 60: and,
 61: atrial,
 62: flutter,
 63: flutter,
 64: flutter,
 65: ,,
 66: and,
 67: the,
 68: 11th,
 69: 11th,
 70: described,
 71: patient,
 72: with,
 73: cardiac,
 74: disease,
 75: among,
 76: infants,
 7

In [145]:
sample_spacy_offset

*

In [176]:
## The last part on adjaceny and syntax tree (句法樹) 
## We now examine what it does

# 句法树
count = 0
i = 0
while i < len(sample_input_ids):
    ## sample_spacy_offset (='*') is found for the special symbols, e.g. "<<Chemical>>"
    if sample_index2word[i] == sample_spacy_offset:    
        i += 1
        continue
    sample_word = sample_spacy_tokens[count]
    sample_word_sp = tokenizer.tokenize(sample_word.text)
    for child in sample_word.children:   
        ## According to stackoverflow, token.children uses the dependency parse to get all tokens that directly 
        ## depend on the token in question. In a visualization (try displacy), 
        ## this will be all the tokens with arrows pointing away from a token; 
        ## if the word is a verb this could be the subject and any objects, 
        ## if the word is a noun it could be any adjectives modifying it, for example.
        ## source: https://stackoverflow.com/questions/74794944/what-does-this-children-attribute-do
        sample_adj_word_list = sample_word2piecesid[child]
        sample_word_list = sample_word2piecesid[sample_word]
        
        # obtain the start index of child
        child_key = next(key for key, val in sample_index2word.items() if val == child)
        # obtain the start index of spacy_word
        word_key = next(key for key, val in sample_index2word.items() if val == sample_word)
        # print("child:{}, word:{}".format(child, word))
        for m in range(child_key, len(sample_adj_word_list) + child_key):
            for n in range(word_key, len(sample_word_list) + word_key):
                # print("m:{}, n:{}".format(m, n))
                sample_adj_syntactic_dependency_tree[m][n] = 1  # 无向图
                sample_adj_syntactic_dependency_tree[n][m] = 1
                sample_adj_syntactic_dependency_tree_new[m + 1][n + 1] = 1
                sample_adj_syntactic_dependency_tree_new[n + 1][m + 1] = 1

    i += len(sample_word_sp)
    count += 1
        
    sample_adj_syntactic_dependency_tree_new[0][0] = 0
    sample_adj_syntactic_dependency_tree_new[-1][-1] = 0

In [178]:
sample_adj_syntactic_dependency_tree[:16, :16]

array([[1., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 1., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 1., 1., 0., 0., 0.],
       [0., 0., 0., 0., 1., 0., 1., 0., 0., 1., 0., 1., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 1., 1., 0., 0., 1., 0., 1., 0., 0., 0., 0.],
       [0., 0., 0., 0., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 1., 1., 0., 0., 1., 1., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],
       [0., 0., 0., 1., 1., 1., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0.],
       [0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 1.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,

In [165]:
x = sample_spacy_tokens[1].children
x

<generator at 0x2151fb99d60>

In [169]:
next(x)

.