### The Purpose of This Notebook

The issue with this competition is that there are a lot of dataset labels that aren't provided on purpose. Because the missing labels cause **True Positive** samples to be regarded **True Negative**, training the model on this data creates uncertainty for the model. To address this problem, we extract candidate labels in two ways, as shown below, and use them in the train/validation process:

    1. Use AbbreviationDetector from scispacy to detect all capitalize string in the form "LONG-NAME (ACRONYM)" then only use the labels that contain keywords like (Dataset, Database, Study, Survey, ...).
    2. (Optional) We detect the keywords (Dataset, Database, Study, Survey, ...) position in the input string then look forward/backward of that keyword util meet two consecutive lowercase words.

For all found labels from the above steps, only labels that have **Jaccard Similarity** with any original train labels will be used for training, the rest will be passed to validation. 

**Note**: The strategy here is that, we try to found as many candidate labels as posible and let the model to learn if they are a dataset title or not rather than consider them as **True Negative** (caused by label missing).

## SCISPACY

In [None]:
from typing import Tuple, List, Optional, Set, Dict
from collections import defaultdict
from spacy.tokens import Span, Doc
from spacy.matcher import Matcher
from spacy.language import Language

import spacy
import pickle


def find_abbreviation(
    long_form_candidate: Span, short_form_candidate: Span
) -> Tuple[Span, Optional[Span]]:
    """
    Implements the abbreviation detection algorithm in "A simple algorithm
    for identifying abbreviation definitions in biomedical text.", (Schwartz & Hearst, 2003).
    The algorithm works by enumerating the characters in the short form of the abbreviation,
    checking that they can be matched against characters in a candidate text for the long form
    in order, as well as requiring that the first letter of the abbreviated form matches the
    _beginning_ letter of a word.
    Parameters
    ----------
    long_form_candidate: Span, required.
        The spaCy span for the long form candidate of the definition.
    short_form_candidate: Span, required.
        The spaCy span for the abbreviation candidate.
    Returns
    -------
    A Tuple[Span, Optional[Span]], representing the short form abbreviation and the
    span corresponding to the long form expansion, or None if a match is not found.
    """
    long_form = " ".join([x.text for x in long_form_candidate])
    short_form = " ".join([x.text for x in short_form_candidate])

    long_index = len(long_form) - 1
    short_index = len(short_form) - 1

    while short_index >= 0:
        current_char = short_form[short_index].lower()
        # We don't check non alpha-numeric characters.
        if not current_char.isalnum():
            short_index -= 1
            continue

            # Does the character match at this position? ...
        while (
            (long_index >= 0 and long_form[long_index].lower() != current_char)
            or
            # .... or if we are checking the first character of the abbreviation, we enforce
            # to be the _starting_ character of a span.
            (
                short_index == 0
                and long_index > 0
                and long_form[long_index - 1].isalnum()
            )
        ):
            long_index -= 1

        if long_index < 0:
            return short_form_candidate, None

        long_index -= 1
        short_index -= 1

    # The last subtraction will either take us on to a whitespace character, or
    # off the front of the string (i.e. long_index == -1). Either way, we want to add
    # one to get back to the start character of the long form
    long_index += 1

    # Now we know the character index of the start of the character span,
    # here we just translate that to the first token beginning after that
    # value, so we can return a spaCy span instead.
    word_lengths = 0
    starting_index = None
    for i, word in enumerate(long_form_candidate):
        # need to add 1 for the space characters
        word_lengths += len(word.text_with_ws)
        if word_lengths > long_index:
            starting_index = i
            break

    return short_form_candidate, long_form_candidate[starting_index:]


def filter_matches(
    matcher_output: List[Tuple[int, int, int]], doc: Doc
) -> List[Tuple[Span, Span]]:
    # Filter into two cases:
    # 1. <Short Form> ( <Long Form> )
    # 2. <Long Form> (<Short Form>) [this case is most common].
    candidates = []
    for match in matcher_output:
        start = match[1]
        end = match[2]
        # Ignore spans with more than 8 words in them, and spans at the start of the doc
        if end - start > 8 or start == 1:
            continue
        if end - start > 3:
            # Long form is inside the parens.
            # Take one word before.
            short_form_candidate = doc[start - 2 : start - 1]
            long_form_candidate = doc[start:end]
        else:
            # Normal case.
            # Short form is inside the parens.
            short_form_candidate = doc[start:end]

            # Sum character lengths of contents of parens.
            abbreviation_length = sum([len(x) for x in short_form_candidate])
            max_words = min(abbreviation_length + 5, abbreviation_length * 2)
            # Look up to max_words backwards
            long_form_candidate = doc[max(start - max_words - 1, 0) : start - 1]

        # add candidate to candidates if candidates pass filters
        if short_form_filter(short_form_candidate):
            candidates.append((long_form_candidate, short_form_candidate))

    return candidates


def short_form_filter(span: Span) -> bool:
    # All words are between length 2 and 10
    if not all([2 <= len(x) < 10 for x in span]):
        return False

    # At least 50% of the short form should be alpha
    if (sum([c.isalpha() for c in span.text]) / len(span.text)) < 0.5:
        return False

    # The first character of the short form should be alpha
    if not span.text[0].isalpha():
        return False
    return True

@Language.factory("abbreviation_detector")
class AbbreviationDetector:
    """
    Detects abbreviations using the algorithm in "A simple algorithm for identifying
    abbreviation definitions in biomedical text.", (Schwartz & Hearst, 2003).
    This class sets the `._.abbreviations` attribute on spaCy Doc.
    The abbreviations attribute is a `List[Span]` where each Span has the `Span._.long_form`
    attribute set to the long form definition of the abbreviation.
    Note that this class does not replace the spans, or merge them.
    Parameters
    ----------
    nlp: `Language`, a required argument for spacy to use this as a factory
    name: `str`, a required argument for spacy to use this as a factory
    """

    def __init__(self, nlp: Language, name: str = "abbreviation_detector") -> None:
        Doc.set_extension("abbreviations", default=[], force=True)
        Span.set_extension("long_form", default=None, force=True)

        self.matcher = Matcher(nlp.vocab)
        self.matcher.add("parenthesis", [[{"ORTH": "("}, {"OP": "+"}, {"ORTH": ")"}]])
        self.global_matcher = Matcher(nlp.vocab)

    def find(self, span: Span, doc: Doc) -> Tuple[Span, Set[Span]]:
        """
        Functional version of calling the matcher for a single span.
        This method is helpful if you already have an abbreviation which
        you want to find a definition for.
        """
        dummy_matches = [(-1, int(span.start), int(span.end))]
        filtered = filter_matches(dummy_matches, doc)
        abbreviations = self.find_matches_for(filtered, doc)

        if not abbreviations:
            return span, set()
        else:
            return abbreviations[0]

    def __call__(self, doc: Doc) -> Doc:
        matches = self.matcher(doc)
        matches_no_brackets = [(x[0], x[1] + 1, x[2] - 1) for x in matches]
        filtered = filter_matches(matches_no_brackets, doc)
        occurences = self.find_matches_for(filtered, doc)

        for (long_form, short_forms) in occurences:
            for short in short_forms:
                short._.long_form = long_form
                doc._.abbreviations.append(short)
        return doc

    def find_matches_for(
        self, filtered: List[Tuple[Span, Span]], doc: Doc
    ) -> List[Tuple[Span, Set[Span]]]:
        rules = {}
        all_occurences: Dict[Span, Set[Span]] = defaultdict(set)
        already_seen_long: Set[str] = set()
        already_seen_short: Set[str] = set()
        for (long_candidate, short_candidate) in filtered:
            short, long = find_abbreviation(long_candidate, short_candidate)
            # We need the long and short form definitions to be unique, because we need
            # to store them so we can look them up later. This is a bit of a
            # pathalogical case also, as it would mean an abbreviation had been
            # defined twice in a document. There's not much we can do about this,
            # but at least the case which is discarded will be picked up below by
            # the global matcher. So it's likely that things will work out ok most of the time.
            new_long = long.text not in already_seen_long if long else False
            new_short = short.text not in already_seen_short
            if long is not None and new_long and new_short:
                already_seen_long.add(long.text)
                already_seen_short.add(short.text)
                all_occurences[long].add(short)
                rules[long.text] = long
                # Add a rule to a matcher to find exactly this substring.
                self.global_matcher.add(long.text, [[{"ORTH": x.text} for x in short]])
        to_remove = set()
        global_matches = self.global_matcher(doc)
        for match, start, end in global_matches:
            string_key = self.global_matcher.vocab.strings[match]
            to_remove.add(string_key)
            all_occurences[rules[string_key]].add(doc[start:end])
        for key in to_remove:
            # Clean up the global matcher.
            self.global_matcher.remove(key)

        return list((k, v) for k, v in all_occurences.items())
    
    
def get_acronym(doc):
    doc = nlp(doc)
    short_form = []
    long_form = []
    for abrv in doc._.abbreviations:
        short_form.append(abrv)
        long_form.append(abrv._.long_form.text)
    return short_form, long_form

In [None]:
!pip install ../data/en_core_sci_lg-0.4.0.tar.gz

In [None]:
nlp = spacy.load("en_core_sci_lg")

In [None]:
nlp.add_pipe("abbreviation_detector")

In [None]:
import pandas as pd
import gc
import json
import numpy as np
import random
from tqdm import tqdm
import re
from collections import Counter
import glob
from functools import partial
from multiprocessing import Pool
import os

In [None]:
settings = json.load(open("../settings.json", "rb"))

In [None]:
settings

In [None]:
for k, v in settings.items():
    settings[k] = "." + v

In [None]:
def generate_s_e_window_sliding(sample_len, win_size, step_size):
    start = 0
    end = win_size
    s_e = []
    s_e.append([start, end])
    while end < sample_len:
        start += step_size
        end = start + win_size
        s_e.append([start, end])

    s_e[-1][0] -= s_e[-1][1] - sample_len
    s_e[-1][0] = max(s_e[-1][0], 0)
    s_e[-1][1] = sample_len
    return s_e

In [None]:
train_df = pd.read_csv(os.path.join(settings["RAW_DATA_DIR"], "train.csv"))

In [None]:
clean_label = train_df.cleaned_label.tolist()
dataset_label = train_df.cleaned_label.tolist()
dataset_title = train_df.dataset_title.tolist()

In [None]:
temp_1 = [x.lower().strip() for x in train_df['dataset_label'].unique()]
temp_2 = [x.lower().strip() for x in train_df['dataset_title'].unique()]
temp_3 = [x.lower().strip() for x in train_df['cleaned_label'].unique()]
all_train_labels = list(set(temp_1 + temp_2 + temp_3))

In [None]:
TRAIN_IDS = glob.glob(os.path.join(settings["RAW_DATA_DIR"], "train/*"))
TRAIN_IDS = [TRAIN_ID.split("/")[-1].split(".")[0] for TRAIN_ID in TRAIN_IDS]

In [None]:
len(TRAIN_IDS)

In [None]:
win_size = 30

def process(i):
    ids = []
    texts = []
    labels = []
    pub_titles = []
    cleaned_labels = []
    x = json.load(open(
        f"{settings['RAW_DATA_DIR']}/train/{TRAIN_IDS[i]}.json","rt"))
    label = "unknow"
    full_text = ""
    unique_id = []
    for section in x:
        raw_text = section["text"].replace("\n", " ")
        raw_text_encode = raw_text.split()
        s_e = generate_s_e_window_sliding(len(raw_text_encode), win_size, int(0.5 * win_size))
        for (s, e) in s_e:
            sent = " ".join(raw_text_encode[s:e]).strip()
            texts.append(sent)
            ids.append(TRAIN_IDS[i])
            labels.append(label)
        full_text += section["text"].replace("\n", " ") + " "
    
    unique_id = TRAIN_IDS[i]
    full_text = full_text.strip()

    results = {}
    results["id"] = ids
    results["text"] = texts
    results["label"] = labels
    results["unique_id"] = unique_id
    results["full_text"] = full_text
    return results
        
# define map iterator
def iterator_data(items_list):
    for item in items_list:
        yield item

iterator_data = iterator_data(range(len(TRAIN_IDS)))
p = Pool(8)

partial_fn = partial(process)
train_map = p.imap(
    partial_fn,
    tqdm(iterator_data, total=len(TRAIN_IDS), desc="[Preprocessing TestSet]"),
    chunksize=10,
)

results = []
for result in tqdm(train_map):
    results.append(result)

ids = []
texts = []
labels = []
unique_ids = []
full_texts = []
for result in tqdm(results):
    ids.extend(result["id"])
    texts.extend(result["text"])
    labels.extend(result["label"])
    unique_ids.append(result["unique_id"])
    full_texts.append(result["full_text"])
    
test_df = pd.DataFrame()
test_df["id"] = ids
test_df["text"] = texts
test_df["label"] = labels
test_df["group"] = [-1] * len(ids)
test_df["title"] = [""] * len(ids)

p.close()

In [None]:
all_texts = test_df["text"].tolist()

In [None]:
def check_valid_text(string):
    """
    Check if the input string contains 
    below accepted_keywords or not
    """
    accepted_keywords = ["Study", "Studies", "Survey", 
                         "Surveys", "Dataset", "Datasets", 
                         "Database", "Databases", "Data Set", 
                         "Data System", "Data Systems"]
#                          "Program", "Programs", "Programme"]
    for k in accepted_keywords:
        if k in string:
            return True
    return False

In [None]:
valid_texts = []

for text in tqdm(all_texts):
    if check_valid_text(text):
        valid_texts.append(text)

In [None]:
len(valid_texts)

In [None]:
full_text = " ".join(valid_texts)  # concatenate all valid strings

In [None]:
# split full_text into chunks (charactor length is 100000)
chunk_texts = []
len_chunk = 100000

for start in range(0, len(full_text), len_chunk):
    chunk_texts.append(full_text[start:start + len_chunk])

In [None]:
accepted_preds = []

In [None]:
# Multiprocess finding all **LONG FORM (SHORT FORM)** 

def iterator_data(items_list):
    for item in items_list:
        yield item

def get_extra_label(string):
    short_form, long_form = get_acronym(string)
    return long_form


partial_fn = partial(get_extra_label)
extra_labels = []

for i in range(0, len(chunk_texts), 48):
    p = Pool(8)
    extra_label_results = p.imap(
        partial_fn,
        tqdm(iterator_data(chunk_texts[i:i+48]), total=48, desc="[Get extra labels]"),
        chunksize=10,
    )

    for result in tqdm(extra_label_results):
        extra_labels.extend(result)

    p.close()

extra_labels = list(set(extra_labels))

for extra_label in extra_labels:
    if extra_label.islower() is False:
        accepted_preds.append(extra_label)

In [None]:
len(accepted_preds)

## Forward/Backward Complete Dataset Finding

**Note**: This is an optional step, only the labels provided by scispacy is good enough to reproduce our results but there are still many candidate labels that not match the form "LONG FORM (SHORT FORM)", for example only LONG FORM, so this algorithm still yield more candidate labels that could improve the performance of the model (but not much in our experiments). Our latest submission used this method.

In [None]:
def jaccard_similarity(str1, str2): 
    a = set(str1.lower().split(" "))
    b = set(str2.lower().split(" "))
    c = a.intersection(b)
    return float(len(c)) / (len(a) + len(b) - len(c))

def clean_text(txt, lower=True):
    return re.sub('[^A-Za-z0-9]+', ' ', str(txt).lower())

def check_splcharacter(test):
    string_check= re.compile('[@_!#$%^&*()<>?/\|}{~:.]')
 
    if(string_check.search(test) == None):
        return True
    else: 
        return False

In [None]:
def custom_find_dataset0(string):
    keywords = ["Database", "Dataset", "Databases", "Datasets"]
    prepositions = ["on", "for", "of", "in"]
    start_end_list = []
    word_list = []
    se2idx = {}
    words = string.split(" ")
    old_s = 0
    for i, word in enumerate(words):
        new_s = old_s
        start_end_list.append([new_s, new_s + len(word)])
        se2idx[f"{new_s}-{new_s + len(word)}"] = i
        word_list.append(word)
        old_s += len(word) + 1
        
    candidates = []
        
    def complete_dataset(dataset_idx, word_list, mode="backward"):
        # go backwards until meet 2 consecutive lowercase words
        candidate_word = []
        n_consecutive_lower = 0
        if mode == "backward":
            for i in range(dataset_idx, 0, -1):
                word_i = word_list[i]
                if word_i.islower() is False and word_i.replace("\'", "").isalpha():
                    candidate_word.append(word_i)
                    n_consecutive_lower = 0 # reset it to 0
                else:
                    # temporal add this word, will remove
                    # if n_consecutive_lower == 2
                    n_consecutive_lower += 1
                    if n_consecutive_lower == 2 or word_i.lower() in ["a", "an", "the"]:
                        if n_consecutive_lower == 2:
                            candidate_word = candidate_word[:-1]
                        break
                    else:
                        candidate_word.append(word_i)
                        
            # remove a, an, the
            if candidate_word[-1] in ["A", "An", "The"]:
                candidate_word = candidate_word[:-1]
            return " ".join(candidate_word[::-1])
        else:
            for i in range(dataset_idx, len(word_list), 1):
                word_i = word_list[i]
                if word_i.islower() is False and word_i.replace("\'", "").isalpha():
                    candidate_word.append(word_i)
                    n_consecutive_lower = 0 # reset it to 0
                else:
                    # temporal add this word, will remove
                    # if n_consecutive_lower == 2
                    n_consecutive_lower += 1
                    if n_consecutive_lower == 2 or word_i.lower() in ["a", "an", "the"]:
                        if n_consecutive_lower == 2:
                            candidate_word = candidate_word[:-1]
                        break
                    else:
                        candidate_word.append(word_i)

            return " ".join(candidate_word)
            
    # forward, backward complete (keyword + prepositions)
    for k in keywords:
        for prepos in prepositions:
            new_k = " " + k + f" {prepos} "
            matchs = re.finditer(new_k, string)
            for match in matchs:
                start_index = match.start() + 1
                end_index = match.end() - 1
                forward_candidate = complete_dataset(se2idx[f"{start_index}-{start_index + len(k)}"] ,word_list, "forward")
                backward_candidate = complete_dataset(se2idx[f"{start_index}-{start_index + len(k)}"] ,word_list, "backward")
                candidate = " ".join(backward_candidate.split(" ") + [prepos] + forward_candidate.split(" ")[2:])
                words = candidate.split()
                words = [w for i,w in enumerate(words) if w not in prepositions or i == len(words) - 1]
                acronym = "".join([w[0] for w in words])
                if len(candidate.split(" ")) >= 3 and len(candidate.split(" ")) <= 10 and acronym.isupper() and len(candidate) <= 60 and check_splcharacter(candidate):
                    candidates.append(candidate)
    
    # backward complete
    for k in keywords:
        matchs = re.finditer(" "+ k + " ", string)
        for match in matchs:
            start_index = match.start() + 1
            end_index = match.end() - 1
            candidate = complete_dataset(se2idx[f"{start_index}-{end_index}"], word_list, "backward")
            words = candidate.split()
            words = [w for i,w in enumerate(words) if w not in prepositions or i == len(words) - 1]
            acronym = "".join([w[0] for w in words])
            if len(candidate.split(" ")) >= 3 and len(candidate.split(" ")) <= 10 and acronym.isupper() and len(candidate) <= 60 and check_splcharacter(candidate):
                candidates.append(candidate)
            
    return candidates

In [None]:
def custom_find_dataset1(string):
    keywords = ["Data Set", "Data System", "Data Systems", "Data Sets", "Dataset System", "Dataset Systems"]
    start_end_list = []
    word_list = []
    se2idx = {}
    words = string.split(" ")
    old_s = 0
    for i, word in enumerate(words):
        new_s = old_s
        start_end_list.append([new_s, new_s + len(word)])
        se2idx[f"{new_s}-{new_s + len(word)}"] = i
        word_list.append(word)
        old_s += len(word) + 1
        
    candidates = []
        
    def complete_dataset(dataset_idx, word_list):
        # go backwards until meet 2 consecutive lowercase words
        candidate_word = []
        n_consecutive_lower = 0
        for i in range(dataset_idx, 0, -1):
            word_i = word_list[i]
            if word_i.islower() is False and word_i.replace("\'", "").isalpha():
                candidate_word.append(word_i)
                n_consecutive_lower = 0 # reset it to 0
            else:
                # temporal add this word, will remove
                # if n_consecutive_lower == 2
                n_consecutive_lower += 1
                if n_consecutive_lower == 2 or word_i.lower() in ["a", "an", "the"]:
                    if n_consecutive_lower == 2:
                        candidate_word = candidate_word[:-1]
                    break
                else:
                    candidate_word.append(word_i)

        # remove a, an, the
        if candidate_word[-1] in ["A", "An", "The"]:
            candidate_word = candidate_word[:-1]
        return " ".join(candidate_word[::-1])
            
    
    # backward complete
    for k in keywords:
        matchs = re.finditer(" "+ k + " ", string)
        for match in matchs:
            start_index = match.start() + 1
            end_index = match.end() - 1
            candidate = complete_dataset(se2idx[f"{start_index}-{start_index + len(k.split(' ')[0])}"], word_list)
            candidate += " " + k.split(" ")[1]
            words = candidate.split()
            acronym = "".join([w[0] for w in words])
            if len(candidate.split(" ")) >= 4 and len(candidate.split(" ")) <= 10 and acronym.isupper() and len(candidate) <= 60 and check_splcharacter(candidate):
                candidates.append(candidate)

    return candidates

In [None]:
def custom_find_dataset2(string):
    keywords = ["Survey", "Surveys", "Study", "Studies"]
    prepositions = ["on", "for", "of", "in"]
    start_end_list = []
    word_list = []
    se2idx = {}
    words = string.split(" ")
    old_s = 0
    for i, word in enumerate(words):
        new_s = old_s
        start_end_list.append([new_s, new_s + len(word)])
        se2idx[f"{new_s}-{new_s + len(word)}"] = i
        word_list.append(word)
        old_s += len(word) + 1
        
    candidates = []
        
    def complete_dataset(dataset_idx, word_list, mode="backward"):
        # go backwards until meet 2 consecutive lowercase words
        candidate_word = []
        n_consecutive_lower = 0
        if mode == "backward":
            for i in range(dataset_idx, 0, -1):
                word_i = word_list[i]
                if word_i.islower() is False and word_i.replace("\'", "").isalpha():
                    candidate_word.append(word_i)
                    n_consecutive_lower = 0 # reset it to 0
                else:
                    # temporal add this word, will remove
                    # if n_consecutive_lower == 2
                    n_consecutive_lower += 1
                    if n_consecutive_lower == 2 or word_i.lower() in ["a", "an", "the"]:
                        if n_consecutive_lower == 2:
                            candidate_word = candidate_word[:-1]
                        break
                    else:
                        candidate_word.append(word_i)
                        
            # remove a, an, the
            if candidate_word[-1] in ["A", "An", "The"]:
                candidate_word = candidate_word[:-1]
            return " ".join(candidate_word[::-1])
        else:
            for i in range(dataset_idx, len(word_list), 1):
                word_i = word_list[i]
                if word_i.islower() is False and word_i.replace("\'", "").isalpha():
                    candidate_word.append(word_i)
                    n_consecutive_lower = 0 # reset it to 0
                else:
                    # temporal add this word, will remove
                    # if n_consecutive_lower == 2
                    n_consecutive_lower += 1
                    if n_consecutive_lower == 2 or word_i.lower() in ["a", "an", "the"]:
                        if n_consecutive_lower == 2:
                            candidate_word = candidate_word[:-1]
                        break
                    else:
                        candidate_word.append(word_i)

            return " ".join(candidate_word)

    # forward, backward complete (keyword + prepositions)
    for k in keywords:
        for prepos in prepositions:
            new_k = " " + k + f" {prepos} "
            matchs = re.finditer(new_k, string)
            for match in matchs:
                start_index = match.start() + 1
                end_index = match.end() - 1
                forward_candidate = complete_dataset(se2idx[f"{start_index}-{start_index + len(k)}"] ,word_list, "forward")
                backward_candidate = complete_dataset(se2idx[f"{start_index}-{start_index + len(k)}"] ,word_list, "backward")
                candidate = " ".join(backward_candidate.split(" ") + [prepos] + forward_candidate.split(" ")[2:])
                words = candidate.split()
                words = [w for i,w in enumerate(words) if w not in prepositions or i == len(words) - 1]
                acronym = "".join([w[0] for w in words])
                if len(candidate.split(" ")) >= 3 and len(candidate.split(" ")) <= 10 and acronym.isupper() and len(candidate) <= 60 and check_splcharacter(candidate):
                    candidates.append(candidate)
    
    # backward complete
    for k in keywords:
        matchs = re.finditer(" "+ k + " ", string)
        for match in matchs:
            start_index = match.start() + 1
            end_index = match.end() - 1
            candidate = complete_dataset(se2idx[f"{start_index}-{end_index}"], word_list, "backward")
            words = candidate.split()
            words = [w for i,w in enumerate(words) if w not in prepositions or i == len(words) - 1]
            acronym = "".join([w[0] for w in words])
            if len(candidate.split(" ")) >= 3 and len(candidate.split(" ")) <= 10 and acronym.isupper() and len(candidate) <= 60 and check_splcharacter(candidate):
                candidates.append(candidate)
                
    # forward complete
    for k in keywords:
        matchs = re.finditer(" "+ k + " ", string)
        for match in matchs:
            start_index = match.start() + 1
            end_index = match.end() - 1
            candidate = complete_dataset(se2idx[f"{start_index}-{end_index}"], word_list, "forward")
            words = candidate.split()
            words = [w for i,w in enumerate(words) if w not in prepositions or i == len(words) - 1]
            acronym = "".join([w[0] for w in words])
            if len(candidate.split(" ")) >= 3 and len(candidate.split(" ")) <= 10 and acronym.isupper() and len(candidate) <= 60 and check_splcharacter(candidate):
                candidates.append(candidate)
                
    return candidates

In [None]:
custom_accepted_preds = []

In [None]:
def iterator_data(items_list):
    for item in items_list:
        yield item

def get_extra_label(string):
    candidates0 = custom_find_dataset0(string)
    candidates1 = custom_find_dataset1(string)
    candidates2 = custom_find_dataset2(string)
    return candidates0 + candidates1 + candidates2

iterator_data = iterator_data(full_texts)
partial_fn = partial(get_extra_label)
p = Pool(16)
extra_label_results = p.imap(
    partial_fn,
    tqdm(iterator_data, total=len(full_texts), desc="[Get extra labels]"),
    chunksize=10,
)

for result in tqdm(extra_label_results):
    custom_accepted_preds.extend(result)
    
custom_accepted_preds = list(set(custom_accepted_preds))

p.close()

In [None]:
len(custom_accepted_preds)

## Filter Bad Labels from Scispacy and A Custom Algorithm

In [None]:
accepted_preds_counter = Counter(accepted_preds)
custom_accepted_preds_counter = Counter(custom_accepted_preds)

In [None]:
def num_there(s):
    return any(i.isdigit() for i in s)

In [None]:
accepted_preds = []

# atleast 3 words for scispacy
for k, v in accepted_preds_counter.items():
    if v >= 1 and num_there(k) is False and len(clean_text(k).strip().split(" ")) >= 3 and len(clean_text(k).strip().split(" ")) <= 10 and check_splcharacter(k):
        accepted_preds.append(k.strip())

# atleast 4 words for our custom algorithm (optional, you can ingore those labels)
for k, v in custom_accepted_preds_counter.items():
    if v >= 1 and num_there(k) is False and len(clean_text(k).strip().split(" ")) >= 4 and len(clean_text(k).strip().split(" ")) <= 10 and check_splcharacter(k):
        accepted_preds.append(k.strip())

In [None]:
accepted_preds = list(set(accepted_preds))

In [None]:
len(accepted_preds)

## Train/Valid Extra Labels

In [None]:
preds_not_in_train = []
for pred in tqdm(accepted_preds):
    in_train = False
    for train_label in all_train_labels:
        clean_pred = clean_text(pred)
        clean_train = clean_text(train_label)
        jaccard = jaccard_similarity(clean_pred, clean_train)
        if jaccard >= 0.5 or clean_pred in clean_train or clean_train in clean_pred:
            in_train = True
            break
    if in_train is False:
        preds_not_in_train.append(pred)

In [None]:
preds_in_train = list(set(accepted_preds) - set(preds_not_in_train))

In [None]:
def check_valid_extra_label(extra_label):
    """
    This function check if the extra label statisfy
    all below condition or not.
    """
    accepted_keywords = ["Study", "Studies", "Survey", "Surveys"]
    accepted_keywords2 = ["Dataset", "Database", "Datasets", "Databases"] # "Program", "Programs", "Programe"
    accepted_keywords3 = ["Data Set", "Data System", "Data Systems", "Data Sets", "Dataset System", "Dataset Systems"]
    words = extra_label.split(" ")
    first_word = words[0]
    last_word = words[-1]
    for k in accepted_keywords:
        if (
            (
                k in first_word
                or k in last_word
                or ((k + " on") in extra_label)
                or ((k + " for") in extra_label)
                or ((k + " of") in extra_label)
                or ((k + " in") in extra_label)
            )
            and first_word[0].isupper()
            and last_word[0].isupper()
        ):
            return True
    for k in accepted_keywords2:
        if (
            (
                k in last_word
                or ((k + " on") in extra_label)
                or ((k + " for") in extra_label)
                or ((k + " of") in extra_label)
                or ((k + " in") in extra_label)
            )
            and first_word[0].isupper()
        ):
            return True
    for k in accepted_keywords3:
        if k in " ".join(words[-2:]) and first_word[0].isupper():
            return True
    return False

In [None]:
os.makedirs(settings["RAW_DATA_DIR"], exist_ok=True)

In [None]:
# create extra_train_labels and save it into ./pretrained
pseudo_labels = []
for label in preds_in_train:
    if check_valid_extra_label(label):
        pseudo_labels.append(label)
pseudo = pd.DataFrame()
pseudo["label"] = pseudo_labels
pseudo.to_csv(f"{settings['RAW_DATA_DIR']}/extra_train_labels.csv", index=False)

In [None]:
# create extra_valid_labels and save it into ./pretrained
pseudo_labels = []
for label in preds_not_in_train:
    if check_valid_extra_label(label):
        pseudo_labels.append(label)
pseudo = pd.DataFrame()
pseudo["label"] = pseudo_labels
pseudo.to_csv(f"{settings['RAW_DATA_DIR']}/extra_valid_labels.csv", index=False)