In [1]:
import warnings
warnings.filterwarnings('ignore', category=DeprecationWarning)
warnings.filterwarnings('ignore', category=FutureWarning)

import re
import pandas as pd
import numpy as np
import json
import os
import string
from functools import partial
from collections import defaultdict, Counter
from tqdm.autonotebook import tqdm
import gc

In [2]:
# Offline installation of spacy v3
!mkdir -p /tmp/pip/cache/
!cp ../input/spacy-v3/spacy_v3_download/smart_open-3.0.0.xyz /tmp/pip/cache/smart_open-3.0.0.tar.gz
!cp ../input/spacy-v3/spacy_v3_download/spacy-3.0.6-cp37-cp37m-manylinux2014_x86_64.whl /tmp/pip/cache/
!cp ../input/spacy-v3/spacy_v3_download/pydantic-1.7.3-cp37-cp37m-manylinux2014_x86_64.whl /tmp/pip/cache/
!cp ../input/spacy-v3/spacy_v3_download/srsly-2.4.1-cp37-cp37m-manylinux2014_x86_64.whl /tmp/pip/cache/
!cp ../input/spacy-v3/spacy_v3_download/thinc-8.0.3-cp37-cp37m-manylinux2014_x86_64.whl  /tmp/pip/cache/
!cp ../input/spacy-v3/spacy_v3_download/catalogue-2.0.3-py3-none-any.whl  /tmp/pip/cache/
!cp ../input/spacy-v3/spacy_v3_download/pathy-0.5.0-py3-none-any.whl  /tmp/pip/cache/
!cp ../input/spacy-v3/spacy_v3_download/spacy_legacy-3.0.4-py2.py3-none-any.whl  /tmp/pip/cache/
!cp ../input/spacy-v3/spacy_v3_download/typer-0.3.2-py3-none-any.whl  /tmp/pip/cache/
!cp ../input/spacy-v3/spacy_v3_download/spacy_transformers-1.0.3-py2.py3-none-any.whl  /tmp/pip/cache/
!cp ../input/spacy-v3/spacy_v3_download/spacy_alignments-0.8.3-cp37-cp37m-manylinux2014_x86_64.whl  /tmp/pip/cache/

In [3]:
# Offline installation of spacy v3
!ls /tmp/pip/cache
!pip uninstall -y spacy
!pip install --no-index --find-links /tmp/pip/cache/ spacy[transformers,cuda110]

catalogue-2.0.3-py3-none-any.whl
pathy-0.5.0-py3-none-any.whl
pydantic-1.7.3-cp37-cp37m-manylinux2014_x86_64.whl
smart_open-3.0.0.tar.gz
spacy-3.0.6-cp37-cp37m-manylinux2014_x86_64.whl
spacy_alignments-0.8.3-cp37-cp37m-manylinux2014_x86_64.whl
spacy_legacy-3.0.4-py2.py3-none-any.whl
spacy_transformers-1.0.3-py2.py3-none-any.whl
srsly-2.4.1-cp37-cp37m-manylinux2014_x86_64.whl
thinc-8.0.3-cp37-cp37m-manylinux2014_x86_64.whl
typer-0.3.2-py3-none-any.whl
Found existing installation: spacy 2.3.5
Uninstalling spacy-2.3.5:
  Successfully uninstalled spacy-2.3.5
Looking in links: /tmp/pip/cache/
Processing /tmp/pip/cache/spacy-3.0.6-cp37-cp37m-manylinux2014_x86_64.whl
Processing /tmp/pip/cache/catalogue-2.0.3-py3-none-any.whl
Processing /tmp/pip/cache/thinc-8.0.3-cp37-cp37m-manylinux2014_x86_64.whl
Processing /tmp/pip/cache/pydantic-1.7.3-cp37-cp37m-manylinux2014_x86_64.whl
Processing /tmp/pip/cache/srsly-2.4.1-cp37-cp37m-manylinux2014_x86_64.whl
Processing /tmp/pip/cache/s

In [4]:
sample_sub = pd.read_csv('../input/coleridgeinitiative-show-us-the-data/sample_submission.csv')
#sample_sub = pd.read_csv('../input/coleridgeinitiative-show-us-the-data/train.csv')
train_df = pd.read_csv('../input/coleridgeinitiative-show-us-the-data/train.csv')
#test_files_path = '../input/coleridgeinitiative-show-us-the-data/train'
train_files_path = '../input/coleridgeinitiative-show-us-the-data/train'
test_files_path = '../input/coleridgeinitiative-show-us-the-data/test'

In [5]:
def clean_label(txt):
    return re.sub('[^A-Za-z0-9]+', ' ', str(txt).lower())

def read_append_return(filename, train_files_path=train_files_path, output='all'):
    """
    Function to read json file and then return the text data from them and append to the dataframe
    
    From: https://www.kaggle.com/prashansdixit/coleridge-initiative-eda-baseline-model
    """
    json_path = os.path.join(train_files_path, f"{filename}.json")
    headings = []
    contents = []
    combined = []
    with open(json_path, 'r') as f:
        json_decode = json.load(f)
        for data in json_decode:
            headings.append(data.get('section_title'))
            contents.append(data.get('text'))
            combined.append(data.get('section_title'))
            combined.append(data.get('text'))
            
    
    all_headings = ' '.join(headings)
    all_contents = ' '.join(contents)
    all_data = '. '.join(combined)
    
    if output == 'text':
        return all_contents
    elif output == 'head':
        return all_headings
    else:
        return all_data
    
def sanitize_text(text):
    
    # Remove quotes
    text = text.replace('"', "")
        
    allowed_chars = " abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
    text = re.sub(r"\s", " ", text)
    text = re.sub("\s+", " ", text)
    text = ''.join([k for k in text if k in allowed_chars or k in string.punctuation])
    
    # Remove repeated sequences (char and spaces)
    repeated_seq = re.compile(r"([\w\.]\s*)\1{5,}")
    matches = [match.group(0) for match in repeated_seq.finditer(text)]
    matches = sorted(matches, reverse=True, key=lambda x: len(x))
    for match in matches:
        text = re.sub(re.escape(match), "", text)

    text = re.sub(" +", " ", text)
    matches = [match.group(0) for match in repeated_seq.finditer(text)]
    matches = sorted(matches, reverse=True, key=lambda x: len(x))
    for match in matches:
        text = re.sub(re.escape(match), "", text)
        
    return text

def clean_text(text):
    return re.sub('[^A-Za-z0-9]+', ' ', str(text).lower()).strip()

In [6]:
sample_sub['text'] = sample_sub['Id'].apply(partial(read_append_return, train_files_path=test_files_path))

In [7]:
%%time
tqdm.pandas()
sample_sub['text'] = sample_sub['text'].progress_apply(sanitize_text)

  0%|          | 0/4 [00:00<?, ?it/s]

CPU times: user 238 ms, sys: 7.28 ms, total: 246 ms
Wall time: 248 ms


In [8]:
import spacy
from spacy.pipeline import Sentencizer
from spacy.lang.en import English
from spacy.matcher import PhraseMatcher

In [9]:
nlp = English()
sentencizer = Sentencizer()
def extract_sentences(text):
    text = re.sub(r" al.", " al", text)
    nlp.max_length = max(1000000, len(text) + 1)
    doc = nlp(text)
    return [sent.text.strip() for sent in sentencizer(doc).sents if len(sent.text) < 1500]

In [10]:
sample_sub['sentences'] = sample_sub['text'].apply(extract_sentences)

In [11]:
id_sentences = sample_sub.groupby("Id")["sentences"].agg(list)
sample_sub = sample_sub.drop("sentences", axis=1)

In [12]:
import sys
sys.path.append("../input/dataset-grouping")
sys.path.append("../input/extract-abbreviations")

from dataset_grouping_v3 import group_candidates
from schwartz_hearst_v2 import extract_abbreviation_definition_pairs


all_sentences = []
all_ids = []
for Id in sample_sub["Id"].unique():
    all_sentences.extend(id_sentences[Id][0])
    all_ids.extend([Id] * len(id_sentences[Id][0]))
print(len(all_sentences))

2013


In [13]:
abbreviation_definition_pairs = extract_abbreviation_definition_pairs(doc_text=all_sentences, ids=all_ids)
abbreviations = abbreviation_definition_pairs["abbreviations"]
abbreviation_definition_pairs = extract_abbreviation_definition_pairs(
    doc_text=all_sentences,
    ids=all_ids,existing_abbreviations={abb.lower() for abb in abbreviations})
abbreviations = abbreviation_definition_pairs["abbreviations"]
sentences_candidates = abbreviation_definition_pairs["sentences_candidates"]
ids_abbreviations = abbreviation_definition_pairs["ids_abbreviations"]
print(len(abbreviations), len(ids_abbreviations), len(sentences_candidates))
del all_sentences, all_ids, abbreviation_definition_pairs
gc.collect()

0
0
76 4 2013


0

In [14]:
title_labels = train_df.groupby("dataset_title")["dataset_label"].agg(set).agg(list).to_dict()
for k, v in title_labels.items():
    title_labels[k] = sorted(v, reverse=True, key=lambda x: len(x))

In [15]:
result = group_candidates(abbreviations, sentences_candidates, title_labels)
groups = result["groups"]
print(len(groups))
del result, sentences_candidates
gc.collect()

100%|██████████| 74/74 [00:00<00:00, 16991.21it/s]

74
60
60
60
60
60
59
59
59
117
59
117
59
117
56
129
56
193
55
193
55
193
55
193
55
193
53
187
53
190
54
194
55
196
56
199
56
199
57
201
58
203
59
206
60
209
61
210
62
211
63
215
64
216
65
219
66
221
67
225
67
225
68
228
69
229
70
230
71
232
72
233
73
237
74
240
75
246
76
247
77
248
78
249
79
252
80
256
81
259
82
260
83
263
84
280
85
281
86
285
87
289
88
293
89
298
90
302
91
306
92
307
93
309
94
312
94





0

In [16]:
label_dataset_mapping = defaultdict(list)
for k, v in groups.items():
    for cand in v:
        if k not in label_dataset_mapping[cand]:
            label_dataset_mapping[cand].append(k)
print(len(label_dataset_mapping))

312


In [17]:
nlp = spacy.load("../input/en-core-web-lg/en_core_web_lg-3.0.0/en_core_web_lg/en_core_web_lg-3.0.0", disable=["ner", "tok2vec"])
terms = set(label_dataset_mapping.keys())
terms = {term for term in terms if len(term) > 2}
terms_long = sorted(list({term.lower() for term in terms
                     if len(term) > 6 and len(term.split()) > 2}), reverse=True, key=lambda x: len(x))
terms_short = sorted([term for term in terms if len(term) <= 6 or len(term.split()) <= 2], reverse=True, key=lambda x: len(x))
patterns_long = list(nlp.tokenizer.pipe(terms_long))
patterns_short = list(nlp.tokenizer.pipe(terms_short))
matcher_long = PhraseMatcher(nlp.vocab, attr="LOWER")
matcher_long.add(f"TerminologyListDataLong", patterns_long)
matcher_short = PhraseMatcher(nlp.vocab)
matcher_short.add(f"TerminologyListDataShort", patterns_short)

In [18]:
def is_intersecting(prev_spans, s, e):
    for span in prev_spans:
        span_s, span_e = span[0], span[1]
        span_range = set(range(span_s, span_e))
        new_span_range = set(range(s, e))
        if span_range.intersection(new_span_range):
            return True
    return False

def merge_spans(spans):
    new_spans = []
    start, end = spans[0]
    for span in spans:
        cur_start, cur_end = span
        if cur_start != start:
            if not is_intersecting(new_spans, start, end):
                new_spans.append((start, end))
            start, end = cur_start, cur_end
        else:
            end = max(end, cur_end)
    if not is_intersecting(new_spans, start, end):
        new_spans.append((start, end))
    return new_spans

ner_data = defaultdict(list)
for Id in tqdm(sample_sub["Id"].unique(), total=sample_sub["Id"].nunique()):
    try:
        sentences = [sent for sent in id_sentences[Id][0] if "data" in sent.lower()]
        for doc in nlp.pipe(sentences):
            matches_long = matcher_long(doc)
            if matches_long:
                merged_spans = merge_spans([(start, end) for _, start, end in matches_long])
                ner_data[Id].append(
                    (doc.text,
                     {
                         "entities": [],
                         "spans": []
                     })
                )
                for start, end in merged_spans:
                    ner_data[Id][-1][1]["spans"].append((start, end))
                    ner_data[Id][-1][1]["entities"].append((doc[start].idx, doc[end-1].idx+len(doc[end-1])))
                ner_data[Id][-1] = tuple(ner_data[Id][-1])
                matches_short = matcher_short(doc)
                if matches_short:
                    merged_spans = merge_spans([(start, end) for _, start, end in matches_short])

                    for start, end in merged_spans:
                        if not is_intersecting(ner_data[Id][-1][1]["spans"], start, end):
                            ner_data[Id][-1][1]["spans"].append((start, end))
                            ner_data[Id][-1][1]["entities"].append((doc[start].idx, doc[end-1].idx+len(doc[end-1])))
                    ner_data[Id][-1] = tuple(ner_data[Id][-1])
            else:
                matches_short = matcher_short(doc)
                if matches_short:
                    merged_spans = merge_spans([(start, end) for _, start, end in matches_short])
                    ner_data[Id].append(
                        (doc.text,
                         {
                             "entities": [],
                             "spans": []
                         })
                    )
                    for start, end in merged_spans:
                        ner_data[Id][-1][1]["spans"].append((start, end))
                        ner_data[Id][-1][1]["entities"].append((doc[start].idx, doc[end-1].idx+len(doc[end-1])))
                    ner_data[Id][-1] = tuple(ner_data[Id][-1])
    except:
        continue
                
# sort the entities and merge consecutive candidates
for Id, items in ner_data.items():
    try:
        for j, item in enumerate(items):
            indices_to_remove = set()
            sentence = item[0]
            item[1]["entities"] = sorted(item[1]["entities"], key=lambda x: x[0])
            for i in range(len(item[1]["entities"]) - 1):
                if item[1]["entities"][i][1] == item[1]["entities"][i+1][0] - 1:
                    ent1 = sentence[item[1]["entities"][i][0]:item[1]["entities"][i][1]]
                    ent2 = sentence[item[1]["entities"][i+1][0]:item[1]["entities"][i+1][1]]
                    if ent1 == ent1.upper() and ent2 == ent2.upper():
                        continue
                    if (clean_label(ent2) in label_dataset_mapping and
                        len(label_dataset_mapping[clean_label(ent2)]) == 1):
                        group = label_dataset_mapping[clean_label(ent2)][0]
                        item[1]["entities"][i+1] = (item[1]["entities"][i][0], item[1]["entities"][i+1][1])
                        new_candidate = sentence[item[1]["entities"][i+1][0]:item[1]["entities"][i+1][1]]
                        if new_candidate not in groups[group]:
                            groups[group].append(new_candidate)
                        indices_to_remove.add(i)
            if indices_to_remove:
                items[j][1]["entities"] = [ex for i, ex in enumerate(items[j][1]["entities"]) if i not in indices_to_remove]
    except:
        continue
        

  0%|          | 0/4 [00:00<?, ?it/s]

In [19]:
del terms, terms_long, terms_short, patterns_long, patterns_short, matcher_long, matcher_short
gc.collect()

240

In [20]:
full_data = []
for Id, items in ner_data.items():
    try:
        for item in items:
            item_data = []
            sentence = item[0]
            for ent in item[1]["entities"]:
                candidate = sentence[ent[0]:ent[1]]
                masked_sentence = sentence[:ent[0]] + "@CAND@" + sentence[ent[1]:]
                full_data.append((masked_sentence, {"Id": Id, "candidate": candidate}))
    except:
        continue
            
print(len(full_data))
del ner_data
gc.collect()

162


0

In [21]:
spacy.prefer_gpu()
nlp = spacy.load("../input/textcat-trf/model-best")

predictions = []
for doc, annts in tqdm(nlp.pipe(full_data, as_tuples=True, batch_size=150), total=len(full_data)):
    try:
        predictions.append((doc.text, {
            "Id": annts["Id"], "candidate": annts["candidate"], "score": doc.cats["positive"]
        }))
    except:
        continue

  0%|          | 0/162 [00:00<?, ?it/s]

In [22]:
del full_data
gc.collect()

6525

In [23]:
# First search for {candidate} dataset occurencies and constract initial dataset groups

nlp = spacy.load("../input/en-core-web-lg/en_core_web_lg-3.0.0/en_core_web_lg/en_core_web_lg-3.0.0", disable=["ner", "tok2vec"])
terms = set([f"{cand} dataset" for cand in label_dataset_mapping.keys()])

terms = sorted(list(terms), reverse=True, key=lambda x: len(x))
patterns = list(nlp.tokenizer.pipe(terms))
matcher = PhraseMatcher(nlp.vocab, attr="LOWER")
matcher.add(f"TerminologyListDataset", patterns)

ner_data = defaultdict(list)
#for Id in tqdm(train_df["Id"].unique(), total=train_df["Id"].nunique()):
for Id in tqdm(sample_sub["Id"].unique(), total=sample_sub["Id"].nunique()):
    try:
        sentences = [sent for sent in id_sentences[Id][0] if "data" in sent.lower()]
        for doc in nlp.pipe(sentences):
            matches = matcher(doc)
            if matches:
                merged_spans = merge_spans([(start, end) for _, start, end in matches])
                ner_data[Id].append(
                    (doc.text,
                     {
                         "entities": [],
                         "spans": []
                     })
                )
                for start, end in merged_spans:
                    ner_data[Id][-1][1]["spans"].append((start, end))
                    ner_data[Id][-1][1]["entities"].append((doc[start].idx, doc[end-1].idx+len(doc[end-1])))
                ner_data[Id][-1] = tuple(ner_data[Id][-1])
    except:
        continue



  0%|          | 0/4 [00:00<?, ?it/s]

In [24]:
del terms, patterns, matcher
gc.collect()

1490

In [25]:
class Dataset():
    
    def __init__(self, groups):
        self.groups = groups
        self.label_group_mapping = self._label_group_mapping()
        self.scores = []
        self.data_groups = {}
        self.label_data_mapping = {}
        self.rejected = set()
        
    def _label_group_mapping(self):
        label_group_mapping = defaultdict(list)
        for k, v in self.groups.items():
            for label in v:
                cleaned_label = clean_label(label)
                if k not in label_group_mapping[cleaned_label]:
                    label_group_mapping[cleaned_label].append(k)
        return label_group_mapping
    
    def label_count(self, group):
        return len(self.label_group_mapping[group])
        
    def add_label(self, group, label, Id, score):
        cleaned_label = clean_label(label)
        if group not in self.data_groups:
            self.data_groups[group] = {"Ids": set(), "Id_scores": defaultdict(list), "labels": {}}
        self.data_groups[group]["Ids"].add(Id)
        self.data_groups[group]["Id_scores"][Id].append(score)
        if label not in self.data_groups[group]["labels"]:
            self.data_groups[group]["labels"][label] = {
                "Ids": [],
                "scores": []
            }
        self.data_groups[group]["labels"][label]["Ids"].append(Id)
        self.data_groups[group]["labels"][label]["scores"].append(score)
        if label not in self.label_data_mapping:
            self.label_data_mapping[label] = []
        if group not in self.label_data_mapping[label]:
            self.label_data_mapping[label].append(group)
            
    def get_is_dataset_freq(self, group, threshold=0.5):
        n_dataset = len([Id for Id, scores in self.data_groups[group]["Id_scores"].items()
                         if np.max(scores) > threshold])
        return n_dataset / len(self.data_groups[group]["Id_scores"])
            
    def mean_group_score(self, group):
        scores = []
        for label in self.data_groups[group]["labels"]:
            scores.extend(self.data_groups[group]["labels"][label]["scores"])
        if not scores:
            return 0.0
        return np.mean(scores)
    
    def max_group_score(self, group):
        scores = []
        for label in self.data_groups[group]["labels"]:
            scores.extend(self.data_groups[group]["labels"][label]["scores"])
        if not scores:
            return 0.0
        return np.max(scores)
    
    def mean_label_score(self, group, label):
        if label not in self.data_groups[group]["labels"] or not self.data_groups[group]["labels"][label]["scores"]:
            return 0
        return np.mean(self.data_groups[group]["labels"][label]["scores"])
    
    def max_label_score(self, group, label):
        if label not in self.data_groups[group]["labels"] or not self.data_groups[group]["labels"][label]["scores"]:
            return 0
        return np.max(self.data_groups[group]["labels"][label]["scores"])
    
    def remove_group(self, group):
        for label in self.data_groups[group]["labels"]:
            self.rejected.add(label)
            self.label_data_mapping[label].remove(group)
            if len(self.label_data_mapping[label]) == 0:
                del self.label_data_mapping[label]
        del self.data_groups[group]
            
    def reject(self, label):
        self.rejected.add(label)

In [26]:
# Remove the sentences that contributed to the first step
print(len(predictions))
indices_to_remove = set()
for i, item in enumerate(predictions):
    sentence = item[0]
    if "@cand@ dataset" in sentence.lower():
        indices_to_remove.add(i)

predictions = [item for i, item in enumerate(predictions) if i not in indices_to_remove]
print(len(predictions))

162
156


In [27]:
dataset = Dataset(groups)
for Id, items in ner_data.items():
    try:
        for item in items:
            sentence = item[0]
            for ent in item[1]["entities"]:
                candidate = sentence[ent[0]:ent[1]]
                candidate_base = " ".join(candidate.split()[:-1])
                cleaned_candidate = clean_label(candidate)
                cleaned_candidate_base = clean_label(candidate_base)
                if len(candidate_base.split()) > 1:
                    if cleaned_candidate in dataset.label_group_mapping:
                        cand_groups = sorted(dataset.label_group_mapping[cleaned_candidate],
                                     reverse=True, key=lambda x: len(dataset.label_group_mapping[x]))
                        group = cand_groups[0]
                        dataset.add_label(group, candidate, Id, 2)
                    elif cleaned_candidate_base in dataset.label_group_mapping:
                        cand_groups = sorted(dataset.label_group_mapping[cleaned_candidate_base],
                                     reverse=True, key=lambda x: len(dataset.label_group_mapping[x]))
                        group = cand_groups[0]
                        dataset.add_label(group, candidate_base, Id, 2)
                else:
                    if Id in ids_abbreviations and candidate_base in ids_abbreviations[Id]:
                        for definition in ids_abbreviations[Id][candidate_base]:
                            if clean_label(definition) in dataset.label_group_mapping:    
                                cand_groups = sorted(dataset.label_group_mapping[clean_label(definition)],
                                             reverse=True, key=lambda x: len(dataset.label_group_mapping[x]))
                                group = cand_groups[0]
                                dataset.add_label(group, candidate_base, Id, 2)
                                break
                    elif (len(candidate_base) > 3 and candidate_base == candidate_base.upper() and
                          cleaned_candidate_base in dataset.label_group_mapping and
                          len(dataset.label_group_mapping[cleaned_candidate_base]) == 1):
                        group = dataset.label_group_mapping[cleaned_candidate_base][0]
                        dataset.add_label(group, candidate_base, Id, 1)
    except:
        continue
print(len(dataset.data_groups))
del ner_data
gc.collect()

3


0

In [28]:
# Remove titles containing lower case words
connecting_words = ["of", "in", "on", "and", "for", "from", "the", "in", "a", "to", "after", "with", "at", "by", "&"]
for group in dataset.data_groups.copy():
    if any([w == w.lower() and w not in connecting_words  and w != "dataset" and not re.search(r"[0-9]", w) for w in group.split()]):
        dataset.remove_group(group)
print(len(dataset.data_groups))

3


In [29]:
candidate_scores = defaultdict(list)
databases = ["database", "data base", "dataset", "data set", "data", "data system"]

def get_thresholds(group, dataset):
    if not group:
        return [0.95, 0.4]
    if "(" in group:
        title = group[:group.rfind("(")].strip().lower()
    else:
        title = group.lower()
    if (any([title.endswith(database) for database in databases]) or
        any([f"{database} for" in title for database in databases]) or
        any([f"{database} of" in title for database in databases])):
        return [0.05, 0.05]
    if group in dataset.data_groups:
        return [0.85, 0.25]
    return [0.95, 0.4]
    

for i, item in enumerate(predictions):
    try:
        sentence = item[0]
        candidate = item[1]["candidate"]
        clean_candidate = clean_label(candidate)
        candidate_scores[clean_candidate].append(item[1]["score"])
    except:
        continue
    
cands_to_remove = set()
for candidate, v in candidate_scores.items():
    try:
        clean_candidate = clean_label(candidate)
        if len(candidate.split()) > 1:
            cand_groups = sorted(dataset.label_group_mapping[clean_candidate],
                             reverse=True, key=lambda x: len(dataset.label_group_mapping[x]))
            if cand_groups:
                group = cand_groups[0]
            else:
                group = None
        else:
            group = None
        max_threshold, mean_threshold = get_thresholds(group, dataset)
        if np.max(v) < max_threshold or np.mean(v) < mean_threshold:
            cands_to_remove.add(candidate)
    except:
        continue
    
indices_to_remove = set()
for i, item in enumerate(predictions):
    try:
        candidate = item[1]["candidate"]
        clean_candidate = clean_label(candidate)
        if clean_candidate in cands_to_remove:
            indices_to_remove.add(i)
    except:
        continue

predictions = [item for i, item in enumerate(predictions) if i not in indices_to_remove]
print(len(predictions))

78


In [30]:
indices_to_remove = set()

for i, item in enumerate(predictions):
    try:
        Id = item[1]["Id"]
        score = item[1]["score"]
        candidate = item[1]["candidate"]
        cleaned_candidate = clean_label(candidate)
        if len(candidate.split()) == 1:
            continue
        cand_groups = sorted(dataset.label_group_mapping[cleaned_candidate],
                             reverse=True, key=lambda x: len(dataset.label_group_mapping[x]))
        group = cand_groups[0]
        dataset.add_label(group, candidate, Id, score)
        indices_to_remove.add(i)
    except:
        continue
predictions = [item for i, item in enumerate(predictions) if i not in indices_to_remove]
print(len(predictions))
print(len(dataset.data_groups))

67
6


In [31]:
def find_matching_label(label, abbreviation, dataset):
    """Finds a matching label that the label startwith
    Example: label="National Education Longitudinal Study of 1988"
             matching_label = "National Education Longitudinal Study"
    """
    cleaned_label = clean_label(label)
    candidates = set()
    if abbreviation not in dataset.label_data_mapping:
        return None
    else:
        for group in dataset.label_data_mapping[abbreviation]:     
            for cand in dataset.data_groups[group]["labels"]:
                if len(cand.split()) == 1:
                    continue
                if cleaned_label.startswith(clean_label(cand)) and len(cand.split()) >= len(abbreviation):
                    candidates.add(cand)
    if candidates:
        return sorted(list(candidates), reverse=True, key=lambda x: len(x))[0]
    return None

indices_to_remove = set()

for i, item in enumerate(predictions):
    try:
        Id = item[1]["Id"]
        score = item[1]["score"]
        candidate = item[1]["candidate"]
        cleaned_candidate = clean_label(candidate)
        if len(candidate.split()) != 1:
            continue
        if candidate not in ids_abbreviations[Id]:
            continue
        for definition in ids_abbreviations[Id][candidate]:
            if clean_label(definition) in dataset.label_group_mapping:    
                cand_groups = sorted(dataset.label_group_mapping[clean_label(definition)],
                                     reverse=True, key=lambda x: len(dataset.label_group_mapping[x]))
                group = cand_groups[0]
                if group in dataset.data_groups:
                    dataset.add_label(group, candidate, Id, score)
                indices_to_remove.add(i)
                break
    except:
        continue
predictions = [item for i, item in enumerate(predictions) if i not in indices_to_remove]
print(len(predictions))
print(len(dataset.data_groups))

indices_to_remove = set()
for i, item in enumerate(predictions):
    try:
        Id = item[1]["Id"]
        score = item[1]["score"]
        candidate = item[1]["candidate"]
        cleaned_candidate = clean_label(candidate)
        if len(candidate.split()) != 1:
            continue
        if candidate not in ids_abbreviations[Id]:
            continue
        for definition in ids_abbreviations[Id][candidate]:
            if clean_label(definition) not in dataset.label_group_mapping:    
                matched_definition = find_matching_label(definition, candidate, dataset)
                if matched_definition and clean_label(matched_definition) in dataset.label_group_mapping:
                    cand_groups = sorted(dataset.label_group_mapping[clean_label(matched_definition)],
                                     reverse=True, key=lambda x: len(dataset.label_group_mapping[x]))
                    group = cand_groups[0]
                    if group in dataset.data_groups:
                        dataset.add_label(group, candidate, Id, score)
                    indices_to_remove.add(i)
                    break
    except:
        continue
        
predictions = [item for i, item in enumerate(predictions) if i not in indices_to_remove]
print(len(predictions))
print(len(dataset.data_groups))

0
6
0
6


In [32]:
indices_to_remove = set()

for i, item in enumerate(predictions):
    try:
        Id = item[1]["Id"]
        score = item[1]["score"]
        candidate = item[1]["candidate"]
        cleaned_candidate = clean_label(candidate)
        if cleaned_candidate in dataset.label_data_mapping:
            group = dataset.label_data_mapping[cleaned_candidate][0]
            dataset.add_label(group, candidate, Id, score)
            indices_to_remove.add(i)
        else:
            indices_to_remove.add(i)
    except:
        continue
        
predictions = [item for i, item in enumerate(predictions) if i not in indices_to_remove]
print(len(predictions))
print(len(dataset.data_groups))

0
6


In [33]:
def is_group_unclear(title, cands):
    if len(title.split()) < 3:
        return False
    keywords = ["survey", "study", "initiative", "program", "programme",
                "assessment", "database", "data base", "data set", "dataset", "data"]
    for key in keywords:
        if title.lower().endswith(key):
            if ("institute of" in title.lower() or "association of" in title.lower() or
                "institute for" in title.lower() or "association for" in title.lower() or
                "institute on" in title.lower() or "association on" in title.lower()):
                return False
            return True
        if f"{key} of " in title.lower() or f"{key} on " in title.lower() or f"{key} for " in title.lower() :
            if ("institute of" in title.lower() or "association of" in title.lower() or
                "institute for" in title.lower() or "association for" in title.lower() or
                "institute on" in title.lower() or "association on" in title.lower()):
                return False
            return True
    for key in keywords:
        if any([cand.lower().endswith(f" key") for cand in cands]):
            return True
    return False

def is_org(candidate):
    if len(candidate.split()) == 1:
        return False
    if "data" in candidate and not candidate.endswith("center"):
        return False
    keywords = ["institute", "institute", "center", "foundation", "organisation", "administration",
                "organizations", "alliance", "clinics", "institut", "institutes", "society", "centers",
                "unit", "collaboration",
                "bureau", "university", "service", "department", "divisiion", "agency",
                "office", "library", "organization", "board", "council", "union", "college",
                "committee", "consortium", "association", "clinic", "hospital", "laboratory",
                "centre", "ministry", "panel", "school", "schools", "facility", "commission", "league",
                "taskforce", "register", "insurance"]
    
    if any([candidate.endswith(keyword) for keyword in keywords]):
        if not is_group_unclear(candidate.lower(), [candidate.lower()]):
            return True
        else:
            return False
    if any([f"{keyword} of" in candidate for keyword in keywords]):
        if not is_group_unclear(candidate.lower(), [candidate.lower()]):
            return True
        else:
            return False
    if any([f"{keyword} on" in candidate for keyword in keywords]):
        if not is_group_unclear(candidate.lower(), [candidate.lower()]):
            return True
    if any([f"{keyword} for" in candidate for keyword in keywords]):
        if not is_group_unclear(candidate.lower(), [candidate.lower()]):
            return True
    return False

In [34]:
for group in dataset.data_groups.copy():
    try:
        if "(" in group:
            title = group[:group.rfind("(")].strip().lower()
        else:
            title = group.lower()
        if is_org(title):
            dataset.remove_group(group)
    except:
        continue
print(len(dataset.data_groups))

5


In [35]:
data_keywords = ["survey", "study", "initiative", "program", "programme",
                 "assessment", "database", "data base", "data set", "dataset", "data"]
for group in dataset.data_groups.copy():
    try:
        if "(" in group:
            title = group[:group.rfind("(")].strip()
            abb = group[group.rfind("(")+1:-1].strip()
        else:
            title = group
            abb = None
        if len(title.split()) < 3:
            if any([w == w.lower() for w in title.split()]):
                dataset.remove_group(group)
            elif len([w for w in title.lower().split() if w in data_keywords]) > 1:
                dataset.remove_group(group)
            elif title == title.upper() or (abb and len(abb) < 3):
                dataset.remove_group(group)
    except:
        continue
print(len(dataset.data_groups))

5


In [36]:
nondata_keywords = ["method", "example", "resonance", "tool", "agreement", "procedure", "builder"]
reject_keywords  = ["test", "sample", "cohort", "supplement",
                    "act", "file", "index", "trial", "protocol", "instrument", "form",
                    "conference", "infrastructure", "trials"]
for group in dataset.data_groups.copy():
    try:
        if "(" in group:
            title = group[:group.rfind("(")].strip().lower()
            abb = group[group.rfind("(")+1:-1].strip()
        else:
            title = group.lower()
            abb = None
        if any([keyword in title for keyword in nondata_keywords]):
            dataset.remove_group(group)
        elif any([title.endswith(keyword) for keyword in reject_keywords]):
            dataset.remove_group(group)
    except:
        continue
print(len(dataset.data_groups))

5


In [37]:
for group in dataset.data_groups.copy():
    try:
        if all(len(cand) < 10 for cand in dataset.data_groups[group]["labels"]):
            dataset.remove_group(group)
    except:
        continue
print(len(dataset.data_groups))

3


In [38]:
groups_to_remove = set()
all_group_titles = sorted([group for group in dataset.data_groups],
                          reverse=True, key=lambda x: len(dataset.data_groups[x]["labels"]))
for i in range(len(all_group_titles)-1):
    try:
        group_1 = all_group_titles[i]
        if "(" in group_1:
            abb1 = group_1[group_1.rfind("(")+1:-1].strip()
        else:
            abb1 = None
        for j in range(i+1, len(all_group_titles)):
            group_2 = all_group_titles[j]
            for label in dataset.data_groups[group_2]["labels"]:
                if any([label.startswith(cand) and len(label.replace(cand, "")) > 2 and
                        len(label.replace(cand, "")) < 10 for cand in dataset.data_groups[group_1]["labels"]]):
                    if "(" in group_2:
                        abb2 = group_2[group_2.rfind("(")+1:-1].strip()
                        if len(abb2.split()) > 1 or re.search(r"[^A-Za-z]", abb2):
                            groups_to_remove.add(group_2)
                        elif abb1 and not abb2.startswith(abb1) and re.search(r"[^A-Za-z]", abb2):
                            groups_to_remove.add(title2)
                elif any([label.startswith(cand) and len(label.replace(cand, "")) > 2
                          for cand in dataset.data_groups[group_1]["labels"]]):
                    if "(" in group_2:
                        abb2 = group_2[group_2.rfind("(")+1:-1].strip()
                        if abb1 and abb2.startswith(abb1) and re.search(r"[^A-Za-z]", abb1):
                            groups_to_remove.add(group_2)
    except:
        continue
                        

for group in groups_to_remove:
    try:
        dataset.remove_group(group)
    except:
        continue
    
print(len(dataset.data_groups))

3


In [39]:
def is_data_like(title, cands):
    keywords = ["survey", "study", "initiative", "program", "programme", "inventory",
                "assessment", "model", "network", "sequence", "practice", "project", "datum",
                "database", "data base", "data set", "list", "archive", "interpolation", "atlas",
                "surveys", "studies", "dataset", "data", "model", "registry", "census", "encyclopedia"]
    if any([key in title.lower() for key in keywords]):
        return True
    for key in keywords:
        if any([cand.lower().endswith(key) for cand in cands]):
            return True
    return False

In [40]:
for group in dataset.data_groups.copy():
    try:
        if "(" in group:
            title = group[:group.rfind("(")].strip().lower()
            abb = group[group.rfind("(")+1:-1].strip()
        else:
            title = group.lower()
        if not is_data_like(title, list(dataset.data_groups[group]["labels"].keys())):
            if abb and re.search("[^A-Z&a-z0-9]", abb):
                dataset.remove_group(group)
            elif any([w == w.lower() and w not in connecting_words and not re.search(r"[0-9]", w) for w in group.split()]):
                dataset.remove_group(group)
            elif len(dataset.data_groups[group]["Ids"]) < 5:
                dataset.remove_group(group)
            elif title.endswith("system"):
                dataset.remove_group(group)
    except:
        continue
print(len(dataset.data_groups))

3


In [41]:
data_groups = {}
for group in dataset.data_groups.copy():
    try:
        data_groups[group] = list(set(dataset.data_groups[group]["labels"].keys()))
    except:
        continue

In [42]:
group_ids = defaultdict(set)
for group in dataset.data_groups:
    try:
        for label in dataset.data_groups[group]["labels"]:
            group_ids[group].update(dataset.data_groups[group]["labels"][label]["Ids"])
    except:
        continue
        
print(len(group_ids))

3


In [43]:
%%time
tqdm.pandas()
sample_sub['clean_text'] = sample_sub['text'].progress_apply(clean_text)

  0%|          | 0/4 [00:00<?, ?it/s]

CPU times: user 56 ms, sys: 1.27 ms, total: 57.2 ms
Wall time: 55.9 ms


In [44]:
id_clean_text = sample_sub.groupby("Id")["clean_text"].agg(list)
group_label_counter_found = {}
group_label_counter_all = {}
for group in tqdm(dataset.data_groups, total=len(dataset.data_groups)):
    try:
        group_label_counter_found[group] = Counter()
        group_label_counter_all[group] = Counter()
        for row_id in sample_sub["Id"].unique():
        #for row_id in train_df["Id"].unique():
            cleaned_text = id_clean_text[row_id][0]
            for label in {clean_label(l) for l in dataset.data_groups[group]["labels"]}:
                if f" {label.strip()} " in cleaned_text:
                    if row_id in group_ids[group]:
                        group_label_counter_found[group][label] += 1
                    group_label_counter_all[group][label] += 1
    except:
        continue

  0%|          | 0/3 [00:00<?, ?it/s]

In [45]:
for group in dataset.data_groups.copy():
    try:
        for label, count_all in group_label_counter_all[group].items():
            if group in group_label_counter_found and label in group_label_counter_found[group]:
                count_found = group_label_counter_found[group][label]
                if len(label) < 8 and count_found < count_all / 4:
                    for cand in dataset.data_groups[group]["labels"].copy():
                        if clean_label(cand) == label:
                            try:
                                del dataset.data_groups[group]["labels"][cand]
                            except KeyError:
                                continue
                    if len(dataset.data_groups[group]["labels"]) == 0:
                        dataset.remove_group(group)
                    elif all([len(cand) < 6 for cand in dataset.data_groups[group]["labels"]]):
                        dataset.remove_group(group)
                elif len(label) >= 8 and count_found < count_all / 20:
                    for cand in dataset.data_groups[group]["labels"].copy():
                        if clean_label(cand) == label:
                            try:
                                del dataset.data_groups[group]["labels"][cand]
                            except KeyError:
                                continue
                    if len(dataset.data_groups[group]["labels"]) == 0:
                        dataset.remove_group(group)
                    elif all([len(cand) < 6 for cand in dataset.data_groups[group]["labels"]]):
                        dataset.remove_group(group)
    except:
        continue

In [46]:
id_clean_text = sample_sub.groupby("Id")["clean_text"].agg(list)
group_label_single_counter = {}
for group in tqdm(dataset.data_groups, total=len(dataset.data_groups)):
    try:
        group_label_single_counter[group] = Counter()
        for row_id in sample_sub["Id"].unique():
            cleaned_text = id_clean_text[row_id][0]
            for label in sorted(list({clean_label(l) for l in dataset.data_groups[group]["labels"]}),
                                reverse=True, key=lambda x: len(x)):
                if f" {label.strip()} " in cleaned_text:
                    group_label_single_counter[group][label] += 1
                    break
    except:
        continue

  0%|          | 0/3 [00:00<?, ?it/s]

In [47]:
included_abbreviations = set()
group_label_frequencies = {}
for group in group_label_single_counter:
    try:
        group_label_frequencies[group] = {}
        count = sum([v for v in group_label_single_counter[group].values()])
        if count == 0:
            continue
        for label in group_label_single_counter[group]:
            group_label_frequencies[group][label] = group_label_single_counter[group][label] / count
            if len(label) < 8 and len(label) > 3:
                abb_freq = group_label_frequencies[group][label]
                if len(label) < 5:
                    if abb_freq > 0.15 and abb_freq < 0.25:
                        included_abbreviations.add(label)
                else:
                    if abb_freq > 0.01 and abb_freq < 0.35:
                        included_abbreviations.add(label)
    except:
        continue

In [48]:
final_group_labels = {}
for group in dataset.data_groups:
    try:
        if group not in group_label_counter_all:
            continue
        final_group_labels[group] = {"abbreviation": None, "most_common": None, "remaining": []}
        all_labels = [clean_label(label) for label in dataset.data_groups[group]["labels"]
                      if (clean_label(label) in group_label_counter_all[group] and
                          clean_label(label) in group_label_single_counter[group] and
                          group_label_frequencies[group][clean_label(label)] > 0.005 and
                          group_label_single_counter[group][clean_label(label)] > 1)]
        for label in all_labels.copy():
            if len(label) < 8 and label in included_abbreviations:
                final_group_labels[group]["abbreviation"] = label
                all_labels.remove(label)
            elif len(label) < 8:
                all_labels.remove(label)

        try:
            all_labels = sorted(set(all_labels), reverse=True, key=lambda x: group_label_counter_all[group][x])
            final_group_labels[group]["most_common"] = all_labels[0]
            final_group_labels[group]["remaining"] = all_labels[1:]
            if len(final_group_labels[group]["remaining"]) != 1:
                for label in final_group_labels[group]["remaining"].copy():
                    if (label.startswith(final_group_labels[group]["most_common"]) and
                        len(label.split()) == len(final_group_labels[group]["most_common"].split()) + 1):
                        final_group_labels[group]["remaining"].remove(label)
        except:
            continue
    except:
        continue

In [49]:
id_list = []
lables_list = []
for index, row in tqdm(sample_sub.iterrows()):
    sample_clean_text = row['clean_text']
    row_id = row['Id']
    cleaned_labels = []
    for group in final_group_labels:
        try:
            abbreviation = final_group_labels[group]["abbreviation"]
            most_common = final_group_labels[group]["most_common"]
            remaining_labels = final_group_labels[group]["remaining"]
            if abbreviation:
                if f" {clean_text(abbreviation)} " in sample_clean_text:
                    cleaned_labels.append(abbreviation)
            if most_common:
                if f" {clean_text(most_common)} " in sample_clean_text:
                    cleaned_labels.append(most_common)
            for remaining_label in remaining_labels:
                if f" {clean_text(remaining_label)} " in sample_clean_text:
                    cleaned_labels.append(remaining_label)
                    break
        except:
            continue
    cleaned_labels = [clean_label(x) for x in cleaned_labels]
    cleaned_labels = set(cleaned_labels)
    lables_list.append('|'.join(cleaned_labels))
    id_list.append(row_id)

0it [00:00, ?it/s]

In [50]:
submission = pd.DataFrame()
submission['Id'] = id_list
submission['PredictionString'] = lables_list

In [51]:
submission.to_csv("submission.csv", index=None)

In [52]:
lables_list[:10]

['', '', '', '']