In [53]:
using_colab = False

In [5]:
if using_colab:
    from google.colab import drive

    drive.mount("/content/drive", force_remount=True)

In [51]:
import torch
import torch.nn as nn
import numpy as np
import importlib
import json
import random
import os
import re
from importlib import reload

In [54]:
if using_colab:
    dir_path = (
        "drive/Othercomputers/my_computer/dl-nlp_project_named-entity-recognition/"
    )
    # dir_path = "drive/MyDrive/dl-nlp_project_named-entity-recognition/"
    module_path = dir_path.replace("/", ".")
    # imports
    data_module = importlib.import_module(module_path + "data")
    load_data = data_module.load_data
    extract_sentences_and_labels = data_module.extract_sentences_and_labels
    generate_label_vocab = data_module.generate_label_vocab
    split_data = data_module.split_data

else:
    dir_path = "./"
    from data import (
        load_data,
        extract_sentences_and_labels,
        generate_label_vocab,
        split_data,
    )

In [55]:
train_file_path = dir_path + "data/train.json"
test_file_path = dir_path + "data/test.json"

In [56]:
train_data, test_data = load_data(train_file_path, test_file_path)
train_sentences, train_raw_labels = extract_sentences_and_labels(train_data)
test_sentences, test_raw_labels = extract_sentences_and_labels(test_data)

# Generate label vocabulary
label_vocab = generate_label_vocab(train_raw_labels + test_raw_labels)

In [57]:
SPECIAL_TOKEN = "<SPC>"


class Labels:
    def __init__(self, num_classes, names):
        super().__init__()
        self.names = names
        print(self.names)
        self.num_classes = num_classes

    def __getitem__(self, label_vector):
        return [self.names[idx] for idx, value in enumerate(label_vector) if value == 1]

    def decode(self, label_vector):
        return self.__getitem__(label_vector)

    def encode(self, names):
        indexes = []
        for name in names:
            index = self.names.index(name)
            indexes.append(index)
        tensor = torch.zeros(self.num_classes)
        for index in indexes:
            tensor[index] = 1
        return tensor

    def tensor2sentence(self, tensor):
        return [self.decode(vector) for vector in tensor]


ner_labels = Labels(
    num_classes=len(label_vocab) + 1, names=label_vocab + [SPECIAL_TOKEN]
)
id2label = ner_labels.decode
label2id = ner_labels.encode
ner_labels.num_classes

['PublicationYear', 'CTDesign', 'RelativeChangeValue', 'PMID', 'NumberPatientsCT', 'ObjectiveDescription', 'ConfIntervalDiff', 'PercentageAffected', 'NumberAffected', 'PvalueDiff', 'Title', 'PValueChangeValue', 'Country', 'Precondition', 'ResultMeasuredValue', 'MinAge', 'ConfIntervalChangeValue', 'SubGroupDescription', 'SdDevChangeValue', 'AvgAge', 'DiffGroupAbsValue', 'FinalNumPatientsArm', 'NumberPatientsArm', 'Author', 'ConclusionComment', 'DoseValue', 'Frequency', 'AllocationRatio', 'Journal', 'SdDevBL', 'Drug', 'SdDevResValue', 'DoseDescription', 'TimePoint', 'AggregationMethod', 'ObservedResult', '<SPC>']


37

In [58]:
def extract_sentences(json_file_path):
    with open(json_file_path, "r") as file:
        data = json.load(file)

    sentences = []

    for entry in data:
        for sentence in entry["sentences"]:
            tokens = sentence["words"]

            entities = sentence["entities"]
            labels_list = [torch.zeros(ner_labels.num_classes) for x in tokens]
            for label_entity in entities:
                start_pos = label_entity["start_pos"]
                end_pos = label_entity["end_pos"]
                label = label_entity["label"]
                label_id = label2id([label]).argmax().item()
                for label_index in range(start_pos, end_pos + 1):
                    labels_list[label_index][label_id] = 1
            sentence["tokens"] = tokens
            sentence["labels_list"] = labels_list
            sentences.append(sentence)

    return [x["tokens"] for x in sentences], [x["labels_list"] for x in sentences]

In [104]:
train_sentences, train_labels = extract_sentences(train_file_path)
test_sentences, test_labels = extract_sentences(test_file_path)
train_sentences, train_labels, val_sentences, val_labels = split_data(
    train_sentences, train_labels
)

print(len(train_sentences), len(train_labels))
print(len(val_sentences), len(val_labels))
print(len(test_sentences), len(test_labels))

1300 1300
145 145
385 385


In [150]:
label = "CTDesign"
data_file_name = f"{dir_path}data/labels/{label}.json"
if os.path.exists(data_file_name):
    with open(data_file_name, "r") as json_file:
        data = json.load(json_file)
else:
    data = {
        "sentences": [],
        "labels_lists": [],
    }

In [167]:
label_abbreviations = {
    "ObjectiveDescription": "OD",
    "Precondition": "PC",
}
label_unabbreviations = {v: k for k, v in label_abbreviations.items()}

In [168]:
examples = []
examples_with_labels = []
for sentence, labels_list in zip(train_sentences, train_labels):
    new_sentence = sentence
    found = False
    for i, (token, labels) in enumerate(zip(sentence, labels_list)):
        if label in id2label(labels):
            found = True
            if not new_sentence[i].startswith("!!"):
                new_sentence[i] = f"!!{token}!!"
    if found:
        # print(" ".join(new_sentence))
        words = []
        for word, labels in zip(sentence, labels_list):
            # print(id2label(labels), word)
            abbreviated_labels = [
                label_abbreviations[label] if label in label_abbreviations else label
                for label in id2label(labels)
            ]
            words.append(f"{word} {abbreviated_labels}")
        # print(words)
        # print()
        examples.append(new_sentence)
        examples_with_labels.append(words)

In [169]:
example_count = 5

for i in range(min(example_count, len(examples))):
    index = random.randint(0, len(examples) - 1)
    print(" ".join(examples[index]))
    print(examples_with_labels[index])
    print()

A 24 - week , randomized , !!treat!! !!-!! !!to!! !!-!! !!target!! trial comparing initiation of insulin glargine once - daily with insulin detemir twice - daily in patients with type 2 diabetes inadequately controlled on oral glucose - lowering drugs .
["A ['Title']", "24 ['Title']", "- ['Title']", "week ['Title']", ", ['Title']", "randomized ['Title']", ", ['Title']", "!!treat!! ['CTDesign', 'Title']", "!!-!! ['CTDesign', 'Title']", "!!to!! ['CTDesign', 'Title']", "!!-!! ['CTDesign', 'Title']", "!!target!! ['CTDesign', 'Title']", "trial ['Title']", "comparing ['Title']", "initiation ['Title']", "of ['Title']", "insulin ['Title']", "glargine ['Title']", "once ['Title']", "- ['Title']", "daily ['Title']", "with ['Title']", "insulin ['Title']", "detemir ['Title']", "twice ['Title']", "- ['Title']", "daily ['Title']", "in ['Title']", "patients ['Title']", "with ['Title']", "type ['Title']", "2 ['Title']", "diabetes ['Title']", "inadequately ['Title']", "controlled ['Title']", "on ['Title

In [170]:
import ast

input_text = input("Input: ")
sentences = input_text.strip().split("  ")
for sentence in sentences:
    tokens = []
    labels_list = []
    sentence = sentence.replace("'", '"')
    matches = re.findall(r"\[\"\w*\", \"\w*\"]", sentence)  # find double labels
    for match in matches:
        sentence = sentence.replace(match, match.replace('"', "'"))
        break
    token_pairs = sentence[sentence.index("[") + 1 : -1].split('", "')
    for token_pair in token_pairs:
        token, labels_str = token_pair.split(" ", 1)
        token = token.strip('"')
        token = token.strip("!!")
        # TODO split up hyphonenated tokens
        labels_str = labels_str.strip('"')

        labels = ast.literal_eval(labels_str)
        labels = [
            label_unabbreviations[label] if label in label_unabbreviations else label
            for label in labels
        ]

        tokens.append(token)
        labels_list.append(labels)
    print(tokens)
    print(labels_list)
    data["sentences"].append(tokens)
    data["labels_lists"].append(labels_list)

['In', 'a', 'randomized', ',', 'controlled', 'clinical', 'trial', 'involving', 'patients', 'with', 'type', '2', 'diabetes', ',', 'insulin', 'detemir', 'was', 'compared', 'to', 'insulin', 'glargine', 'using', 'a', 'basal', '-', 'bolus', 'regimen', '.']
[['Title'], ['Title'], ['Title'], ['Title'], ['CTDesign', 'Title'], ['Title'], ['Title'], ['Title'], ['Title'], ['Title'], ['Title'], ['Title'], ['Title'], ['Title'], ['Title'], ['Title'], ['Title'], ['Title'], ['Title'], ['Title'], ['Title'], ['Title'], ['Title'], ['Title'], ['Title'], ['Title'], ['Title'], ['Title']]
['A', 'randomized', ',', 'controlled', 'clinical', 'investigation', 'was', 'conducted', 'to', 'evaluate', 'the', 'effectiveness', 'of', 'insulin', 'detemir', 'versus', 'insulin', 'glargine', 'in', 'a', 'basal', '-', 'bolus', 'regimen', 'among', 'type', '2', 'diabetes', 'patients', '.']
[['Title'], ['Title'], ['Title'], ['CTDesign', 'Title'], ['Title'], ['Title'], ['Title'], ['Title'], ['Title'], ['Title'], ['Title'], ['Titl

In [171]:
data_file_name = f"{dir_path}data/labels/{label}.json"
with open(data_file_name, "w") as json_file:
    json.dump(data, json_file, indent=4)
    print(f"{data_file_name} was updated.")

./data/labels/CTDesign.json was updated.
