In [None]:
%%capture
!pip install spacy
!pip install scispacy
!pip install https://s3-us-west-2.amazonaws.com/ai2-s2-scispacy/releases/v0.5.1/en_core_sci_lg-0.5.1.tar.gz

In [None]:
import nltk
import pandas as pd
import spacy
import scispacy
from scispacy.abbreviation import AbbreviationDetector
from scispacy.linking import EntityLinker

nltk.download('punkt')

[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Unzipping tokenizers/punkt.zip.


True

In [None]:
class Augmentor:
    """
    A class for augmenting tagged text data.

    ...

    Attributes
    ----------
    nlp : spacy.lang.en.English
        A spacy language model loaded with the 'en_core_sci_lg' pipeline.
    linker : scispacy.linking.EntityLinker
        A linker object used for entity resolution.

    Methods
    -------
    augment(df, tag="ADE", score_threshold=0.95)
        Augments tagged text data with new entities and returns the augmented DataFrame.

    """
    def __init__(self):
        """
        Constructs necessary attributes for the Augmentor object.

        Attributes
        ----------
        nlp : spacy.lang.en.English
            A spacy language model loaded with the 'en_core_sci_lg' pipeline.
        linker : scispacy.linking.EntityLinker
            A linker object used for entity resolution.
        """
        self.nlp = spacy.load("en_core_sci_lg")
        self.nlp.add_pipe("abbreviation_detector")
        self.nlp.add_pipe(
            "scispacy_linker",
            config={"resolve_abbreviations": True, "linker_name": "umls"}
        )
        self.linker = self.nlp.get_pipe("scispacy_linker")

    def augment(self, df, tag="ADE", score_threshold=0.95):
        """
        Augments tagged text data with new entities and returns the augmented DataFrame.

        Parameters
        ----------
        df : pandas.DataFrame
            The DataFrame containing the tagged text data to be augmented.
        tag : str, optional
            The tag to be used for the new entities. Default is 'ADE'.
        score_threshold : float, optional
            The threshold score for selecting candidate entities for augmentation. Default is 0.95.

        Returns
        -------
        pandas.DataFrame
            A new DataFrame with the augmented data.

        """
        df = df.copy()
        df["is_augmented"] = 0
        
        scu_list = list(df[df["tag"].str.contains(tag)][["sid", "contains_rel", "uid"]].drop_duplicates().itertuples(index=False, name=None))

        augmented_data = {
            "token": [],
            "tag": [],
            "sid": [],
            "contains_rel": [],
            "uid": [],
            "is_augmented": []
        }

        for sid, contains_rel, uid in scu_list:
            tokens = list(df[(df["sid"] == sid) & (df["contains_rel"] == contains_rel) & (df["uid"] == uid)]["token"])
            tags = list(df[(df["sid"] == sid) & (df["contains_rel"] == contains_rel) & (df["uid"] == uid)]["tag"])

            token_idx_list = []
            for i in range(len(tokens)):
                if tags[i] == "B-" + tag or tags[i] == "S-" + tag:
                    token_idx_list.append([tokens[i], i, i])
                elif tags[i] == "I-" + tag or tags[i] == "E-" + tag:
                    token_idx_list[-1][0] += " " + tokens[i]
                    token_idx_list[-1][2] = i

            for i, [token, s_idx, e_idx] in enumerate(token_idx_list):
                doc = self.nlp(token)

                if len(doc.ents) == 1 and doc.ents[0].text == token:
                    ent = doc.ents[0]
                    flag = False

                    for cui, score in ent._.kb_ents:
                        if score < score_threshold:
                            continue

                        for alias in set(self.linker.kb.cui_to_entity[cui].aliases):
                            # TODO: Add conditions on alias
                            if alias != ent.text and len(alias) < 100:
                                flag = True
                                alias_tokens = nltk.word_tokenize(alias)
                                alias_tags = [None for _ in range(len(alias_tokens))]

                                if len(alias_tags) == 1:
                                    alias_tags[0] = "S-" + tag
                                else:
                                    alias_tags[0] = "B-" + tag
                                    for j in range(1, len(alias_tags) - 1):
                                        alias_tags[j] = "I-" + tag
                                    alias_tags[-1] = "E-" + tag

                                aug_tokens = tokens[:s_idx] + alias_tokens + tokens[e_idx + 1:]
                                aug_tags = tags[:s_idx] + alias_tags + tags[e_idx + 1:]
                                augmented_data["token"] += aug_tokens
                                augmented_data["tag"] += aug_tags
                                augmented_data["sid"] += [f"{sid}.{i}" for _ in range(len(aug_tokens))]
                                augmented_data["contains_rel"] += [contains_rel for _ in range(len(aug_tokens))]
                                augmented_data["uid"] += [uid for _ in range(len(aug_tokens))]
                                augmented_data["is_augmented"] += [1 for _ in range(len(aug_tokens))]
                            
                            if flag:
                                break # TEMP: Only consider one alias
                        if flag:                     
                            break # Only consider the entity with the highest CUI
        
        df_augmented = pd.DataFrame.from_dict(augmented_data)

        df["sid_contains_rel_uid"] = df.apply(lambda row: str(row["sid"]) + "_" + str(row["contains_rel"]) + "_" + str(row["uid"]), axis=1)
        scu_set = set(str(sid) + "_" + str(contains_rel) + "_" + str(uid) for (sid, contains_rel, uid) in scu_list)
        df_original = df.copy()
        df_original = df_original[df_original["sid_contains_rel_uid"].isin(scu_set)]
        df_original = df_original.reset_index(drop=True)
        df_original = df_original.drop(columns=["sid_contains_rel_uid"])
        df_original["sid"] = df_original["sid"].astype("string")

        df_original_and_augmented = pd.concat([df_original, df_augmented], ignore_index=True)

        return df_original_and_augmented

In [None]:
df = pd.read_parquet("/content/ner_train.parquet")

In [None]:
augmentor = Augmentor()

https://s3-us-west-2.amazonaws.com/ai2-s2-scispacy/data/linkers/2020-10-09/umls/tfidf_vectors_sparse.npz not found in cache, downloading to /tmp/tmp5kjq1g6c
Finished download, copying /tmp/tmp5kjq1g6c to cache at /root/.scispacy/datasets/e9f7327283e43f0482f7c0c71b71dec278a58ccb3ffdd03c2c2350159e7ef146.f2a350ad19015b2591545f7feeed6a6d6d2fffcd635d868a5d7fc0dfc3cadfd8.tfidf_vectors_sparse.npz
https://s3-us-west-2.amazonaws.com/ai2-s2-scispacy/data/linkers/2020-10-09/umls/nmslib_index.bin not found in cache, downloading to /tmp/tmpl0ce3xof
Finished download, copying /tmp/tmpl0ce3xof to cache at /root/.scispacy/datasets/f48455d6c79262057cce66b4619123c2b558b21092d42fac97f47bb99a5b8f9f.dd70d3dffe7d90d7ac8914460e16a48375dab32485fb6313a34e6fbcaf53218b.nmslib_index.bin
https://s3-us-west-2.amazonaws.com/ai2-s2-scispacy/data/linkers/2020-10-09/umls/tfidf_vectorizer.joblib not found in cache, downloading to /tmp/tmpuz9o330w
Finished download, copying /tmp/tmpuz9o330w to cache at /root/.scispacy/da

https://scikit-learn.org/stable/model_persistence.html#security-maintainability-limitations
https://scikit-learn.org/stable/model_persistence.html#security-maintainability-limitations


https://s3-us-west-2.amazonaws.com/ai2-s2-scispacy/data/linkers/2020-10-09/umls/concept_aliases.json not found in cache, downloading to /tmp/tmpf4z2y_a6
Finished download, copying /tmp/tmpf4z2y_a6 to cache at /root/.scispacy/datasets/1428ec15d3b1061731ea273c03699130b3d6b90948993e74bda66af605ff8e2a.aeb7a686c654df6bccb6c2c23d3eda3eb381daaefda4592b58158d0bee53b352.concept_aliases.json
https://s3-us-west-2.amazonaws.com/ai2-s2-scispacy/data/kbs/2020-10-09/umls_2020_aa_cat0129.jsonl not found in cache, downloading to /tmp/tmp71l60hs9
Finished download, copying /tmp/tmp71l60hs9 to cache at /root/.scispacy/datasets/4d7fb8fcae1035d1e0a47d9072b43d5a628057d35497fbfb2499b4b7b2dd4dd7.05ec7eef12f336d4666da85b7fa69b9401883a7dd4244473f7b88b413ccbba03.umls_2020_aa_cat0129.jsonl
https://s3-us-west-2.amazonaws.com/ai2-s2-scispacy/data/umls_semantic_type_tree.tsv not found in cache, downloading to /tmp/tmp1aevc7kq
Finished download, copying /tmp/tmp1aevc7kq to cache at /root/.scispacy/datasets/21a1012c53

In [None]:
df_original_and_augmented = augmentor.augment(df)

  global_matches = self.global_matcher(doc)


In [None]:
with open("/content/flair_ner_train_augmented.txt", mode="w") as f:
    prev_sid = df_original_and_augmented["sid"][0]
    for i, row in df_original_and_augmented.iterrows():
        if prev_sid != row["sid"]:
            prev_sid = row["sid"]
            f.write("\n")
        f.write(row["token"] + " " + row["tag"] + "\n")