In [1]:
!pip install datasets

Collecting datasets
  Downloading datasets-3.1.0-py3-none-any.whl.metadata (20 kB)
Collecting dill<0.3.9,>=0.3.0 (from datasets)
  Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)
Collecting xxhash (from datasets)
  Downloading xxhash-3.5.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Collecting multiprocess<0.70.17 (from datasets)
  Downloading multiprocess-0.70.16-py310-none-any.whl.metadata (7.2 kB)
Collecting fsspec<=2024.9.0,>=2023.1.0 (from fsspec[http]<=2024.9.0,>=2023.1.0->datasets)
  Downloading fsspec-2024.9.0-py3-none-any.whl.metadata (11 kB)
Downloading datasets-3.1.0-py3-none-any.whl (480 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m480.6/480.6 kB[0m [31m9.1 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading dill-0.3.8-py3-none-any.whl (116 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.3/116.3 kB[0m [31m6.4 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading fsspec-2024.9.0-py3-none-any.whl (1

In [2]:
import tensorflow as tf
from datasets import load_dataset
import random
from transformers import AutoTokenizer
from tensorflow.keras.preprocessing.sequence import pad_sequences
import numpy as np
from transformers import TFBertModel
from tensorflow.keras import layers


In [3]:
tag_dict = {
    'B-GENE-Y': 1, 'I-GENE-Y': 2,
    'B-GENE-N': 3, 'I-GENE-N': 4,
    'B-CHEMICAL': 5, 'I-CHEMICAL': 6
}

tag_dict_map = {
    (0,): 0, (0, 1): 1, (0, 2, 3, 5): 2,
    (0, 4): 3, (0, 1, 6): 4, (0, 2, 6): 5, (0, 2): 6, (0, 3, 6): 7,
    (0, 5): 8, (0, 4, 6): 9, (0, 2, 4, 6): 10, (0, 2, 4): 11,
    (0, 5, 6): 12, (0, 3): 13, (0, 6): 14, (0, 1, 5): 15, (0, 2, 5): 16, (0, 3, 5): 17, (0, 4, 5): 18
}

con = {2:[2,5,6,10,11,16], 4:[3,9,10,11,18], 6:[4,5,7,9,10,12,14]}

next_entity = {0:[[0],[]], 1:[con[2],[]], 2:[con[4],con[6]], 3:[[3],[]], 4:[con[2],[]], 5:[[5],[]], 6:[[6],[]], 7:[con[4],[]], 8:[con[6],[]], 9:[[9],[]] ,10:[[10],[]], 11:[[11],[]], 12:[con[6],[]], 13:[con[4],[]], 14:[[14],[]], 15:[con[2],con[6]], 16:[con[6],[]], 17:[con[4],con[6]], 18:[con[6],[]]}

r_dict = {
    'CPR:0': 0, 'CPR:1': 1, 'CPR:2': 2, 'CPR:3': 3,
    'CPR:4': 4, 'CPR:5': 5, 'CPR:6': 6, 'CPR:7': 7,
    'CPR:8': 8, 'CPR:9': 9, 'CPR:10': 10
}

batch_size = 30
num_batches = 131

num_tags = 19
hidden_dim = 128
seq_len = 128
num_relations = 12


In [37]:
#tf.config.run_functions_eagerly(True)


class CustomCRF(tf.keras.layers.Layer):
    def __init__(self, num_tags, seq_len, **kwargs):
        super(CustomCRF, self).__init__(**kwargs)
        self.num_tags = num_tags
        self.seq_len = seq_len

    def build(self):
        self.start_transitions = self.add_weight(
            shape=(self.num_tags,)
        )
        self.transition_matrix = self.add_weight(
            shape=(self.num_tags, self.num_tags)
        )
        self.end_transitions = self.add_weight(
            shape=(self.num_tags,)
        )
    @tf.function
    def call(self, inputs, labels=None, training=None):
        emissions, attention_mask = inputs
        seq_len = tf.shape(emissions)[1]

        if training:
            return self._crf_loss(emissions, labels, attention_mask)
        else:
            return self.viterbi(emissions, attention_mask)

    @tf.function
    def viterbi(self, emissions, attention_mask):
        batch_size = tf.shape(emissions)[0]
        seq_len = tf.shape(emissions)[1]
        num_tags = tf.shape(emissions)[2]

        dp = tf.TensorArray(dtype=tf.float32, size=seq_len, clear_after_read=False)
        backpointer = tf.TensorArray(dtype=tf.int32, size=seq_len, clear_after_read=False)

        first_step = self.start_transitions + emissions[:, 0, :]
        dp = dp.write(0, first_step)

        t = tf.constant(1)

        def loop_body(t, dp, backpointer):
            prev_scores = dp.read(t - 1)
            scores = tf.expand_dims(prev_scores, axis=2) + self.transition_matrix
            best_scores = tf.reduce_max(scores, axis=1)
            best_paths = tf.argmax(scores, axis=1, output_type=tf.int32)
            current_scores = emissions[:, t, :] + best_scores
            dp = dp.write(t, current_scores)
            backpointer = backpointer.write(t, best_paths)
            return t + 1, dp, backpointer

        _, dp, backpointer = tf.while_loop(
            cond=lambda t, *_: t < seq_len,
            body=loop_body,
            loop_vars=(t, dp, backpointer)
        )

        last_step_scores = dp.read(seq_len - 1) + self.end_transitions
        last_tag = tf.argmax(last_step_scores, axis=1, output_type=tf.int32)

        def backtrack_fn(i):
            best_path = tf.TensorArray(dtype=tf.int32, size=seq_len, clear_after_read=False)
            best_path = best_path.write(seq_len - 1, last_tag[i])
            t = seq_len - 2

            def backtrack_body(t, best_path):
                next_tag = best_path.read(t + 1)
                best_tag = backpointer.read(t + 1)[i, next_tag]
                best_path = best_path.write(t, best_tag)
                return t - 1, best_path

            _, best_path = tf.while_loop(
                cond=lambda t, *_: t >= 0,
                body=backtrack_body,
                loop_vars=(t, best_path)
            )

            return best_path.stack()

        best_paths = tf.map_fn(backtrack_fn, tf.range(batch_size), fn_output_signature=tf.int32)
        return best_paths


class NERModel(tf.keras.Model):
    def __init__(self, num_tags, hidden_dim, seq_len, **kwargs):
        super(NERModel, self).__init__(**kwargs)
        self.bert = TFBertModel.from_pretrained('bert-base-uncased')
        self.lstm = tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(hidden_dim, return_sequences=True))
        self.lstm.build((None, None, self.bert.config.hidden_size))

        self.dense = tf.keras.layers.Dense(num_tags)
        self.dense.build((None, 256))

        self.num_tags = num_tags
        self.metrics_dict = {i: tf.Variable([0, 0, 0, 0], dtype=tf.int32) for i in range(num_tags)}
        self.crf = CustomCRF(num_tags, seq_len)
        self.crf.build()

    @tf.function
    def call(self, inputs, labels=None, training=None):
        bert_output = self.bert(inputs['input_ids'], attention_mask=inputs['attention_mask'])['last_hidden_state']
        lstm_output = self.lstm(bert_output)
        logits = self.dense(lstm_output)
        if training:
            return logits
        else:
            return self.crf((logits, inputs['attention_mask']), training=False)

    def reset_metrics(self):
        for key in self.metrics_dict:
            self.metrics_dict[key].assign([0, 0, 0, 0])

    def test_step(self, data, validation=False, inference = False):
        if inference:
            predictions = self(data, training=False)
            return predictions
        inputs, labels = data
        emissions = self(inputs, training=validation)
        if validation:
            loss = self.crf._crf_loss(emissions, labels, inputs['attention_mask'], tf.shape(emissions)[1])
            return {'loss':loss}
        predictions = tf.reshape(emissions, [-1])  # Flatten to 1D
        true_labels = tf.reshape(labels, [-1])     # Flatten to 1D

        for i in range(self.num_tags):
            true_positives = tf.reduce_sum(tf.cast((predictions == i) & (true_labels == i), tf.int32))
            false_positives = tf.reduce_sum(tf.cast((predictions == i) & (true_labels != i), tf.int32))
            false_negatives = tf.reduce_sum(tf.cast((predictions != i) & (true_labels == i), tf.int32))
            true_negatives = tf.reduce_sum(tf.cast((predictions != i) & (true_labels != i), tf.int32))

            self.metrics_dict[i].assign_add([true_positives, true_negatives, false_positives, false_negatives])


class RelationExtractionModel(tf.keras.Model):
    def __init__(self, ner_model, num_relations, hidden_dim, dropout_rate=0.2):
        super(RelationExtractionModel, self).__init__()
        self.bert = TFBertModel.from_pretrained('bert-base-uncased')
        self.bilstm = tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(hidden_dim, return_sequences=True))
        self.bilstm.build((None, None, ner_model.bert.config.hidden_size))
        self.dense = tf.keras.layers.Dense(num_relations, activation='softmax')
        self.dense.build((None, 512))
        self.num_relations = num_relations
        self.metrics_dict = {i: tf.Variable([0, 0, 0, 0], dtype=tf.int32) for i in range(num_relations)}
    @tf.function
    def call(self, inputs, training=False):
        input_ids = inputs['input_ids']
        attention_mask = inputs['attention_mask']
        ner_tags = inputs['ner_tags']
        re_mask = inputs['re_mask']

        bert_output = self.bert(input_ids, attention_mask=attention_mask)
        token_embeddings = bert_output.last_hidden_state
        lstm_output = self.bilstm(token_embeddings, training=training)
        combined_embeddings = self.extract_entity_pairs(lstm_output, ner_tags, re_mask)
        combined_embeddings_stacked = tf.stack(combined_embeddings)
        return self.dense(combined_embeddings_stacked)

        '''def return_zeros():
            return tf.zeros((1, self.dense.units))

        def proceed_with_embeddings():
            combined_embeddings_stacked = tf.stack(combined_embeddings)
            return self.dense(self.dense1(combined_embeddings_stacked))

        return tf.cond(
            tf.equal(tf.shape(combined_embeddings)[0], 0),
            return_zeros,
            proceed_with_embeddings
        )'''

    @tf.function
    def extract_entity_pairs(self, lstm_output, ner_tags, re_mask):
        combined_embeddings = tf.TensorArray(dtype=tf.float32, size=0, dynamic_size=True)
        batch_size = tf.shape(ner_tags)[0]

        def process_batch(i, embeddings_list):
            re_pairs = re_mask[i]
            valid_pairs = tf.boolean_mask(re_pairs, tf.not_equal(re_pairs[:, 0], -1))

            def process_pair(j, embeddings_list):
                e1_sidx, e1_eidx, relation_type, e2_sidx, e2_eidx = tf.unstack(valid_pairs[j])
                e1_emb = self.pool_entity(lstm_output[i], ner_tags[i], e1_sidx, e1_eidx)
                e2_emb = self.pool_entity(lstm_output[i], ner_tags[i], e2_sidx, e2_eidx)
                combined = tf.concat([e1_emb, e2_emb], axis=-1)
                return j + 1, embeddings_list.write(embeddings_list.size(), combined)

            _, embeddings_list = tf.while_loop(
                lambda j, _: j < tf.shape(valid_pairs)[0],
                process_pair,
                loop_vars=[0, embeddings_list]
            )
            return i + 1, embeddings_list

        _, final_embeddings = tf.while_loop(
            lambda i, _: i < batch_size,
            process_batch,
            loop_vars=[0, combined_embeddings]
        )

        return final_embeddings.stack()
    @tf.function
    def pool_entity(self, lstm_output, ner_tags, start_idx, end_idx):
        entity_span = tf.TensorArray(dtype=tf.float32, size=0, dynamic_size=True)

        def condition(idx, _):
            return idx <= end_idx

        def body(idx, entity_span):
            entity_span = entity_span.write(entity_span.size(), lstm_output[idx])
            return idx + 1, entity_span

        _, collected_span = tf.while_loop(
            condition, body, loop_vars=[start_idx, entity_span]
        )

        return tf.reduce_mean(collected_span.stack(), axis=0)

    @tf.function
    def extract_relation_labels(self, re_mask):
        reduced_re_mask = re_mask[:, :, 2]
        valid_mask = tf.not_equal(reduced_re_mask, -1)
        valid_relation_labels = tf.boolean_mask(reduced_re_mask, valid_mask)
        return valid_relation_labels


    def reset_metrics(self):
        for key in self.metrics_dict:
            self.metrics_dict[key].assign([0, 0, 0, 0])

    def test_step(self, inputs, validation = False, inference = False):
        if tf.size(inputs['re_mask']) == 0 or all(tf.size(mask) == 0 for mask in inputs['re_mask']):
            return tf.constant([])
        logits = self(inputs, training=False)
        predictions = tf.argmax(logits, axis=-1)
        if inference:
            return predictions
        true_labels = self.extract_relation_labels(inputs['re_mask'])
        if validation :
            loss = tf.reduce_mean(tf.keras.losses.sparse_categorical_crossentropy(
                true_labels, logits
            ))
            return {'loss':loss}

        for i in range(self.num_relations):
            tp = tf.reduce_sum(tf.cast((predictions == i) & (true_labels == i), tf.int32))
            fp = tf.reduce_sum(tf.cast((predictions == i) & (true_labels != i), tf.int32))
            fn = tf.reduce_sum(tf.cast((predictions != i) & (true_labels == i), tf.int32))
            tn = tf.reduce_sum(tf.cast((predictions != i) & (true_labels != i), tf.int32))

            self.metrics_dict[i].assign_add([tp, tn, fp, fn])






In [5]:
!pip install datasets



In [7]:
ner_model = NERModel(num_tags=num_tags, hidden_dim=hidden_dim, seq_len=seq_len)

ner_model.load_weights('/content/ner_model.keras')

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/440M [00:00<?, ?B/s]

Some weights of the PyTorch model were not used when initializing the TF 2.0 model TFBertModel: ['cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing TFBertModel from a PyTorch model trained on another task or with another architecture (e.g. initializing a TFBertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing TFBertModel from a PyTorch model that you expect to be exactly identical (e.g. initializing a TFBertForSequenceClassification model from a BertForSequenceClassification model).
All the weights of TFBertModel were initialized from the PyTorch model.
If your task is similar to the task the model of the checkpoint was trained on, you can already use TFBertModel for predictions w

In [38]:
re_model = RelationExtractionModel(ner_model=ner_model, hidden_dim=hidden_dim, num_relations=num_relations)

re_model.load_weights('/content/re_model.keras')

Some weights of the PyTorch model were not used when initializing the TF 2.0 model TFBertModel: ['cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing TFBertModel from a PyTorch model trained on another task or with another architecture (e.g. initializing a TFBertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing TFBertModel from a PyTorch model that you expect to be exactly identical (e.g. initializing a TFBertForSequenceClassification model from a BertForSequenceClassification model).
All the weights of TFBertModel were initialized from the PyTorch model.
If your task is similar to the task the model of the checkpoint was trained on, you can already use TFBertModel for predictions w

In [9]:
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
ds = load_dataset("bigbio/chemprot", "chemprot_full_source")

def gen(i, j, ner_tags_segment, next_tags):
    l, m = i + 1, j + 1
    tem = []
    while ner_tags_segment[l] in next_tags[ner_tags_segment[i]][0]:
        l += 1
    while ner_tags_segment[m] in next_tags[ner_tags_segment[j]][0]:
        m += 1
    if m - 1 != j and l - 1 != i:
        tem.append((i, l - 1, j, m - 1))
    l, m = i + 1, j + 1
    while ner_tags_segment[l] in next_tags[ner_tags_segment[i]][0]:
        l += 1
    while ner_tags_segment[m] in next_tags[ner_tags_segment[j]][1]:
        m += 1
    if m - 1 != j and l - 1 != i:
        tem.append((i, l - 1, j, m - 1))
    l, m = i + 1, j + 1
    while ner_tags_segment[l] in next_tags[ner_tags_segment[i]][1]:
        l += 1
    while ner_tags_segment[m] in next_tags[ner_tags_segment[j]][0]:
        m += 1
    if m - 1 != j and l - 1 != i:
        tem.append((i, l - 1, j, m - 1))
    l, m = i + 1, j + 1
    while ner_tags_segment[l] in next_tags[ner_tags_segment[i]][1]:
        l += 1
    while ner_tags_segment[m] in next_tags[ner_tags_segment[j]][1]:
        m += 1
    if m - 1 != j and l - 1 != i:
        tem.append((i, l - 1, j, m - 1))
    return tem

def prepare_data(row, tag_dict, r_dict, tag_dict_map, next_tags, max_length=128, stride=10, padding=True):
    text = row['text']
    entities = row['entities']
    relations = row['relations']

    tokenized_inputs = tokenizer(
        text,
        return_offsets_mapping=True,
        return_overflowing_tokens=True,
        max_length=max_length,
        truncation=True,
        stride=stride,
        padding='max_length' if padding else False
    )

    input_ids = []
    attention_mask = []
    ner_tags = []
    relation_mask = []

    for i, offset_mapping in enumerate(tokenized_inputs['offset_mapping']):
        input_ids_segment = tokenized_inputs['input_ids'][i]
        attention_mask_segment = tokenized_inputs['attention_mask'][i]

        ner_tags_segment = [(0,) for i in range(max_length)]
        rel_mask_segment = []
        entity_map = {}
        b_tag_indices = []

        for entity_id, entity_type, (start_char, end_char) in zip(entities['id'], entities['type'], entities['offsets']):
            entity_start_found = False
            for token_idx, (token_start, token_end) in enumerate(offset_mapping):
                if token_idx == 0 or token_idx == max_length - 1:
                    continue
                if token_start == start_char and offset_mapping[-2][1] >= end_char:
                    ner_tags_segment[token_idx] += (tag_dict[f'B-{entity_type}'],)
                    entity_map[entity_id] = [token_idx, token_idx]
                    entity_start_found = True
                elif entity_start_found and token_start < end_char:
                    ner_tags_segment[token_idx] += (tag_dict[f'I-{entity_type}'],)
                    entity_map[entity_id][1] = token_idx

        rel_pairs = set()
        for rel_type, arg1_id, arg2_id in zip(relations['type'], relations['arg1'], relations['arg2']):
            if arg1_id in entity_map and arg2_id in entity_map:
                rel_mask_segment.append([entity_map[arg1_id][0], entity_map[arg1_id][1], r_dict[rel_type], entity_map[arg2_id][0], entity_map[arg2_id][1]])
                rel_pairs.add((entity_map[arg1_id][0], entity_map[arg1_id][1], entity_map[arg2_id][0], entity_map[arg2_id][1]))

        for j in range(max_length):
            ner_tags_segment[j] = tag_dict_map[tuple(sorted(list(set(ner_tags_segment[j]))))]
            if ner_tags_segment[j] in {1, 2, 4, 7, 8, 12, 13, 15, 16, 17, 18}:
                b_tag_indices.append(j)

        relpairs = set(rel_pairs)
        neg_samples = []
        for i in rel_pairs:
            for j in b_tag_indices:
                tem = gen(i[0], j, ner_tags_segment, next_tags)
                for ke in tem:
                    if ke not in relpairs:
                        relpairs.add(ke)
                        neg_samples.append([ke[0], ke[1], 11, ke[2], ke[3]])
                tem = gen(j, i[2], ner_tags_segment, next_tags)
                for ke in tem:
                    if ke not in relpairs:
                        relpairs.add(ke)
                        neg_samples.append([ke[0], ke[1], 11, ke[2], ke[3]])

        neg_samples = random.sample(neg_samples, len(rel_pairs)) if len(rel_pairs) != 0 and len(rel_pairs) <= len(neg_samples) else random.sample(neg_samples, 15) if len(neg_samples) >= 15 else neg_samples

        neg_samples_r = []
        for i in b_tag_indices:
            for j in b_tag_indices:
                tem = gen(i, j, ner_tags_segment, next_tags)
                for ke in tem:
                    if ke not in relpairs:
                        relpairs.add(ke)
                        neg_samples_r.append([ke[0], ke[1], 11, ke[2], ke[3]])

        neg_samples_r = random.sample(neg_samples_r, len(rel_pairs)) if len(rel_pairs) != 0 and len(rel_pairs) <= len(neg_samples_r) else random.sample(neg_samples_r, 15) if len(neg_samples_r) >= 15 else neg_samples_r
        input_ids.append(input_ids_segment)
        attention_mask.append(attention_mask_segment)
        ner_tags.append(ner_tags_segment)
        relation_mask.append(rel_mask_segment + neg_samples + neg_samples_r)

    return input_ids, attention_mask, ner_tags, relation_mask



tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]



README.md:   0%|          | 0.00/1.27k [00:00<?, ?B/s]

chemprot.py:   0%|          | 0.00/15.8k [00:00<?, ?B/s]

chemprot_full_source/sample/0000.parquet:   0%|          | 0.00/80.7k [00:00<?, ?B/s]

0000.parquet:   0%|          | 0.00/1.20M [00:00<?, ?B/s]

chemprot_full_source/test/0000.parquet:   0%|          | 0.00/950k [00:00<?, ?B/s]

(…)prot_full_source/validation/0000.parquet:   0%|          | 0.00/727k [00:00<?, ?B/s]

Generating sample split:   0%|          | 0/50 [00:00<?, ? examples/s]

Generating train split:   0%|          | 0/1020 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/800 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/612 [00:00<?, ? examples/s]

In [10]:

def calculate_final_metrics(model, num_classes, threshold):
    avg_precision, avg_recall, avg_f1_score = [], [], []
    total_tp, total_fp, total_fn, total_tn = 0, 0, 0, 0
    tpa,fpa,tna,fna,k=[],[],[],[],0
    for i in range(num_classes):
        tp, tn, fp, fn = model.metrics_dict[i].numpy()
        if tp+fn < 10:
            continue

        k += 1
        tpa.append(tp)
        fpa.append(fp)
        tna.append(tn)
        fna.append(fn)

        total_tp += tp
        total_fp += fp
        total_fn += fn
        total_tn += tn

        precision = tp / (tp + fp + tf.keras.backend.epsilon())
        recall = tp / (tp + fn + tf.keras.backend.epsilon())
        f1 = 2 * (precision * recall) / (precision + recall + tf.keras.backend.epsilon())

        avg_precision.append(precision)
        avg_recall.append(recall)
        avg_f1_score.append(f1)

    accuracy = (total_tp ) / (total_tp + total_fp + tf.keras.backend.epsilon())
    return {
        "average_precision": tf.reduce_mean(avg_precision),
        "average_recall": tf.reduce_mean(avg_recall),
        "average_f1_score": tf.reduce_mean(avg_f1_score),
        "accuracy": accuracy
    }


In [None]:

all_input_ids, all_attention_masks, all_ner_tags, all_relation_masks = [], [], [], []
for row in ds['test']:
    input_ids, attention_mask, ner_tags, relation_mask = prepare_data(row, tag_dict, r_dict, tag_dict_map, next_entity)
    all_input_ids.extend(input_ids)
    all_attention_masks.extend(attention_mask)
    all_ner_tags.extend(ner_tags)
    all_relation_masks.extend(relation_mask)

prepared_data = {
    'input_ids': np.array(all_input_ids),
    'attention_mask': np.array(all_attention_masks),
    'ner_tags': np.array(all_ner_tags),
    're_mask': all_relation_masks
}

padded_relation_mask = pad_sequences(
    [np.array(item, dtype=int) for item in prepared_data['re_mask']],
    padding='post',
    value=-1,
    dtype='object'
)

padded_relation_mask = np.array(padded_relation_mask, dtype=int)

inputs = {
    'input_ids': prepared_data['input_ids'],
    'attention_mask': prepared_data['attention_mask'],
    're_mask': padded_relation_mask,
    'ner_tags': prepared_data['ner_tags']
}

In [None]:
ner_model.reset_metrics()
dataset = tf.data.Dataset.from_tensor_slices((
    {
        'input_ids': inputs['input_ids'],
        'attention_mask': inputs['attention_mask']
    },
    inputs['ner_tags']
))
dataset = dataset.batch(batch_size)

for batch_num, batch_data in enumerate(dataset.take(num_batches)):
    ner_model.test_step(batch_data)


In [None]:
re_model.reset_metrics()
dataset = tf.data.Dataset.from_tensor_slices((
    {
        'input_ids': inputs['input_ids'],
        'attention_mask': inputs['attention_mask'],
        're_mask': padded_relation_mask,
        'ner_tags': inputs['ner_tags']
    }
))
dataset = dataset.batch(batch_size)

for batch_num, batch_data in enumerate(dataset.take(num_batches)):
    re_model.test_step(batch_data)





In [None]:
print('In test data rare labels for NER and RE are ignored')
print('For NER, count of labels in test data if less than 10, that label will be ignored')
print('NER task metrics')
ner_metrics = calculate_final_metrics(ner_model, 19, 10)
for key in ner_metrics.keys():
    print(f"{key}: {ner_metrics[key]}")

In test data rare labels for NER and RE are ignored
For NER, count of labels in test data if less than 10, that label will be ignored
NER task metrics
average_precision: 0.7057518311082903
average_recall: 0.5608184309103595
average_f1_score: 0.613044709264921
accuracy: 0.9334979838707325


In [None]:
print('In test data rare labels for NER and RE are ignored')
print('For RE, count of labels in test data if less than 30, that label will be ignored')
print('RE task metrics')
re_metrics = calculate_final_metrics(re_model, 12, 30)
for key in re_metrics.keys():
    print(f"{key}: {re_metrics[key]}")

In test data rare labels for NER and RE are ignored
For RE, count of labels in test data if less than 30, that label will be ignored
RE task metrics
average_precision: 0.44164211880959725
average_recall: 0.25567830960201055
average_f1_score: 0.3021802531153275
accuracy: 0.8676540850422005


In [13]:
!pip install neo4j

Collecting neo4j
  Downloading neo4j-5.26.0-py3-none-any.whl.metadata (5.9 kB)
Downloading neo4j-5.26.0-py3-none-any.whl (302 kB)
[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/302.0 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[90m╺[0m [32m297.0/302.0 kB[0m [31m10.5 MB/s[0m eta [36m0:00:01[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m302.0/302.0 kB[0m [31m7.6 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: neo4j
Successfully installed neo4j-5.26.0


In [None]:
rel_dict = {
    0: "Undefined",
    1: "Part of",
    2: "Regulator or direct regulator or indirect regulator of",
    3: "Activator or upregulator or indirect upregulator of",
    4: "Inhibitor or downregulator or indirect downregulator of",
    5: "Agonist or agonist-activator or agonist-inhibitor of",
    6: "Antagonist of",
    7: "Modulator or modulator-activator or modulator-inhibitor of",
    8: "Cofactor of",
    9: "Substrate or product or substrate product of",
    10: "No relation"
}


In [40]:
from neo4j import GraphDatabase

uri = "neo4j+s://fcdacc86.databases.neo4j.io"
username = "neo4j"
password = "wZcPOK4_j5tFD8_bxTWm-jrGGb7IdDNlkqLm-r9l2Q0"

driver = GraphDatabase.driver(uri, auth=(username, password))

def delete_all_nodes_and_relationships(tx):
    tx.run("MATCH (n) DETACH DELETE n")

with driver.session() as session:
    session.execute_write(delete_all_nodes_and_relationships)

def add_relations(text, offset_mapping, ner_segments, rel_segment, relations, tag_dict):
    with driver.session() as session:
        for i, rel in enumerate(rel_segment):
            start1, end1, _, start2, end2 = rel
            if relations[i] in {0, 10, 11}:
                continue
            relation_type = rel_dict[relations[i]]
            ent1_offsets = offset_mapping[start1:end1 + 1]
            ent2_offsets = offset_mapping[start2:end2 + 1]
            ent1_text = text[ent1_offsets[0][0]:ent1_offsets[-1][1]]
            ent2_text = text[ent2_offsets[0][0]:ent2_offsets[-1][1]]
            ent1_type = tag_dict[ner_segments[i][0]]
            ent2_type = tag_dict[ner_segments[i][1]]
            session.execute_write(
                create_or_update_relation,
                ent1_text, ent1_type,
                ent2_text, ent2_type,
                relation_type
            )

def create_or_update_relation(tx, ent1_text, ent1_type, ent2_text, ent2_type, relation_type):
    query = """
    MERGE (e1:%s {name: $ent1_text})
    MERGE (e2:%s {name: $ent2_text})
    MERGE (e1)-[:`%s`]->(e2)
    """ % (ent1_type, ent2_type, relation_type)
    tx.run(query, ent1_text=ent1_text, ent2_text=ent2_text)


In [28]:

b_i = {}
for i in con.keys():
    for j in con[i]:
        b_i[j] = i-1


def gen(i, j, ner_tags_segment, next_tags):
    l, m = i + 1, j + 1
    tem = []
    types = []
    while ner_tags_segment[l] in next_tags[ner_tags_segment[i]][0]:
        l += 1
    while ner_tags_segment[m] in next_tags[ner_tags_segment[j]][0]:
        m += 1
    if m - 1 != j and l - 1 != i:
        tem.append((i, l - 1, j, m - 1))
        types.append([b_i[next_tags[ner_tags_segment[i]][0][0]],b_i[next_tags[ner_tags_segment[j]][0][0]]])
    l, m = i + 1, j + 1
    while ner_tags_segment[l] in next_tags[ner_tags_segment[i]][0]:
        l += 1
    while ner_tags_segment[m] in next_tags[ner_tags_segment[j]][1]:
        m += 1
    if m - 1 != j and l - 1 != i:
        tem.append((i, l - 1, j, m - 1))
        types.append([b_i[next_tags[ner_tags_segment[i]][0][0]],b_i[next_tags[ner_tags_segment[j]][1][0]]])
    l, m = i + 1, j + 1
    while ner_tags_segment[l] in next_tags[ner_tags_segment[i]][1]:
        l += 1
    while ner_tags_segment[m] in next_tags[ner_tags_segment[j]][0]:
        m += 1
    if m - 1 != j and l - 1 != i:
        tem.append((i, l - 1, j, m - 1))
        types.append([b_i[next_tags[ner_tags_segment[i]][1][0]],b_i[next_tags[ner_tags_segment[j]][0][0]]])
    l, m = i + 1, j + 1
    while ner_tags_segment[l] in next_tags[ner_tags_segment[i]][1]:
        l += 1
    while ner_tags_segment[m] in next_tags[ner_tags_segment[j]][1]:
        m += 1
    if m - 1 != j and l - 1 != i:
        tem.append((i, l - 1, j, m - 1))
        types.append([b_i[next_tags[ner_tags_segment[i]][1][0]],b_i[next_tags[ner_tags_segment[j]][1][0]]])
    return tem, types

def infer(text, tag_dict, r_dict, tag_dict_map, next_tags, ner_model, re_model, max_length=128, stride=10, padding=True):


    tokenized_inputs = tokenizer(
        text,
        return_offsets_mapping=True,
        return_overflowing_tokens=True,
        max_length=max_length,
        truncation=True,
        stride=stride,
        padding='max_length' if padding else False
    )

    k=0
    for (input_ids, attention_mask) in zip(tokenized_inputs['input_ids'], tokenized_inputs['attention_mask']):
        ner_tags_segment = ner_model.test_step({'input_ids':np.array([input_ids]), 'attention_mask':np.array([attention_mask])}, inference = True)[0].numpy()
        b_tag_indices = []
        for i in range(max_length):
            if attention_mask[i] == 0:
                ner_tags_segment[i] = 0
            if ner_tags_segment[i] in {1, 2, 4, 7, 8, 12, 13, 15, 16, 17, 18}:
                b_tag_indices.append(i)
        rel_segment = []
        entity_types = []
        for i in b_tag_indices:
            for j in b_tag_indices:
                tem, types = gen(i, j, ner_tags_segment, next_tags)
                for (ke, t) in zip(tem, types):
                    rel_segment.append([ke[0], ke[1], -1, ke[2], ke[3]])
                    entity_types.append(t)
        relations = re_model.test_step({'input_ids':np.array([input_ids]), 'attention_mask':np.array([attention_mask]), 're_mask':np.array([rel_segment]), 'ner_tags':np.array([ner_tags_segment])}, inference = True).numpy()
        add_relations(text, tokenized_inputs['offset_mapping'][k], entity_types, rel_segment, relations, { 1:'GENE-Y', 3:'GENE-N', 5:'CHEMICAL'})
        k += 1



In [39]:
for row in ds['test']:
    infer(row['text'], tag_dict, r_dict, tag_dict_map, next_entity, ner_model, re_model)



[11 11 11 11 11 11 11 11 11 11 11 11 11 11 11 11 11 11 11 11 11 11 11 11
 11 11 11 11 11 11 11 11 11 11 11 11 11 11 11 11 11 11 11 11 11 11 11 11
 11 11 11 11 11 11 11 11 11 11 11 11 11 11 11 11 11 11 11 11 11 11 11 11
 11 11 11 11 11 11 11 11 11 11 11 11 11 11 11 11 11 11 11 11 11 11 11 11
 11 11 11 11] 100 100
[11 11 11 11 11 11 11 11 11 11 11 11 11 11 11 11 11 11 11 11 11 11 11 11
 11 11 11 11 11 11 11 11 11 11 11 11 11 11 11 11 11 11 11 11 11 11 11 11
 11 11 11 11 11 11 11 11 11 11 11 11 11 11 11 11 11 11 11 11 11 11 11 11
 11 11 11 11 11 11 11 11 11 11 11 11 11 11 11 11 11 11 11 11 11 11 11 11
 11 11 11 11] 100 100




[11 11 11 11 11 11 11 11  2 11 11 11 11 11 11 11 11 11 11 11 11 11 11 11
 11 11 11 11 11 11 11 11 11 11 11 11 11 11 11 11 11 11 11 11 11 11 11 11
 11 11 11 11 11 11 11 11 11 11 11 11 11 11 11 11 11 11 11 11 11 11 11 11
 11 11 11 11 11 11 11  2  2 11 11 11 11  2  2 11 11 11 11 11 11 11 11 11
 11 11 11 11 11 11 11 11 11 11 11 11 11 11 11 11 11 11 11 11 11 11 11 11
 11] 121 121
[11 11 11 11 11 11 11 11 11 11 11 11 11 11 11 11] 16 16
[11 11 11 11 11  2 11 11 11 11 11 11 11  2 11 11 11 11 11 11 11 11 11 11
 11 11 11 11 11  2  2 11 11 11 11 11 11 11 11 11 11 11 11 11 11 11 11 11
 11 11 11 11 11 11 11 11 11 11 11 11 11 11 11 11] 64 64
[11 11 11 11 11 11 11 11 11 11 11 11 11 11 11 11] 16 16
[] 0 0
[11 11 11 11] 4 4
[] 0 0
[11 11 11 11] 4 4
[11 11 11 11] 4 4
[] 0 0
[11 11 11 11 11 11 11 11 11 11 11 11 11 11 11 11] 16 16
[11] 1 1
[] 0 0
[11 11 11 11 11 11 11 11 11 11 11 11 11 11 11 11  9 11 11 11 11 11  9 11
 11 11 11 11 11 11 11 11 11 11 11 11] 36 36
[11 11 11 11 11 11 11 11 11 11 11 11 11 11 1

KeyboardInterrupt: 

In [None]:
driver.close()