# Brief Solution Description

The solution is based on a simple heuristic: a capitalized sequence of words that includes a keyword and followed by parenthesis usually refer to a dataset. So, any sequence like 

``` Xxx Xxx Keyword Xxx (XXX)```

is a good candidate to be a dataset.

All mentions of a given form are extracted to form a list of dataset names to look for. Each text in the test is checked for inclusion of the dataset name from the list. Every match is added to the prediction. Substring predictions are removed.

Keywords list:

- Study
- Survey
- Assessment
- Initiative
- Data
- Dataset
- Database

Also, many data mentions refer to some organizations or systems. These mentions seem to be non-valid dataset names. To remove them the following list of stopwords is used:

``` 
' lab', 'centre', 'center', 'consortium', 'office', 'agency', 'administration', 'clearinghouse',
'corps', 'organization', 'organisation', 'association', 'university', 'department',
'institute', 'foundation', 'service', 'bureau', 'company', 'test', 'tool', 'board', 'scale',
'framework', 'committee', 'system', 'group', 'rating', 'manual', 'division', 'supplement',
'variables', 'documentation', 'format' 
```

 To exclude mentions not related to data a simple count statistic is used: 
 
 $$ F_d = \frac{N_{data}(str)}{N_{total}(str)}$$
 
 where $N_{data}(str)$ is the number of times the `str` occures with `data` word (parenthesis are dropped) and $N_{total}(str)$ is the total number of times `str` present in texts. All mentions with $F_d < 0.1$ are dropped.

In [1]:
import re
from collections import defaultdict, Counter
from pathlib import Path
from functools import partial
import json
from itertools import chain, combinations
from typing import Callable, List, Union, Optional, Set, Dict

import pandas as pd
from tqdm import tqdm


In [2]:
TOKENIZE_PAT = re.compile("[\w']+|[^\w ]")
CAMEL_PAT = re.compile(r'(\b[A-Z]+[a-z]+[A-Z]\w+)')
BR_PAT = re.compile('\s?\((.*)\)')
PREPS = {'from', 'for', 'of', 'the', 'in', 'with', 'to', 'on', 'and'}

In [3]:
def tokenize(text):
    return TOKENIZE_PAT.findall(text)


def clean_text(txt):
    return re.sub('[^A-Za-z0-9]+', ' ', str(txt).lower()).strip()


## Filters

In [4]:
def remove_substring_predictions(preds):
    preds = set(preds)
    to_filter = set()
    for d1, d2 in combinations(preds, 2):
        if d1 in d2:
            to_filter.add(d1)
        if d2 in d1:
            to_filter.add(d2)
    return list(preds - to_filter)


def filter_stopwords(datasets, stopwords, do_lower=True):
    # Remove all instances that contain any stopword as a substring
    filtered_datasets = []
    if do_lower:
        stopwords = [sw.lower() for sw in stopwords]
    for ds in datasets:
        ds_to_analyze = ds.lower() if do_lower else ds
        if any(sw in ds_to_analyze for sw in stopwords):
            continue
        filtered_datasets.append(ds)
    return filtered_datasets


def extend_parentehis(datasets):
    # Return each instance of dataset from datasets + 
    # the same instance without parenthesis (if there are some)
    pat = re.compile('\(.*\)')
    extended_datasets = []
    for ds in datasets:
        ds_no_parenth = pat.sub('', ds).strip()
        if ds != ds_no_parenth:
            extended_datasets.append(ds_no_parenth)
        extended_datasets.append(ds)
    return extended_datasets


def filler_intro_words(datasets):
    miss_intro_pat = re.compile('^[A-Z][a-z\']+ (?:the|to the) ')
    return [miss_intro_pat.sub('', ds) for ds in datasets]


def filter_parial_match_datasets(datasets):
    # Some matches are truncated due to parsing errors 
    # or other factors. To remove those, we look for 
    # the most common form of the dataset and remove
    # mentions, that are substrings of this form.
    # Obviously, some true mentions might be dropped 
    # at this stage
    counter = Counter(datasets)

    abbrs_used = set()
    golden_ds_with_br = []

    for ds, count in counter.most_common():
        abbr = BR_PAT.findall(ds)[0]
        no_br_ds = BR_PAT.sub('', ds)

        if abbr not in abbrs_used:
            abbrs_used.add(abbr)
            golden_ds_with_br.append(ds)

    filtered_datasets = []
    for ds in datasets:
        if not any((ds in ds_) and (ds != ds_) for ds_ in golden_ds_with_br):
            filtered_datasets.append(ds)
    return filtered_datasets


def filter_br_less_than_two_words(datasets):
    filtered_datasets = []
    for ds in datasets:
        no_br_ds = BR_PAT.sub('', ds)
        if len(tokenize(no_br_ds)) > 2:
            filtered_datasets.append(ds)
    return filtered_datasets


def filter_intro_ssai(datasets):
    # Filtering introductory words marked as a part of the mention by mistake
    connection_words = {'of', 'the', 'with', 'for', 'in', 'to', 'on', 'and', 'up'}
    keywords = {'Program', 'Study', 'Survey', 'Assessment'}
    filtered_datasets = []
    for ds in datasets:
        toks_spans = list(TOKENIZE_PAT.finditer(ds))
        toks = [t.group() for t in toks_spans]
        start = 0
        if len(toks) > 3:
            if toks[1] == 'the':
                start = toks_spans[2].span()[0]
            elif toks[0] not in keywords and  toks[1] in connection_words and len(toks) > 2 and toks[2] in connection_words:
                start = toks_spans[3].span()[0]
            elif toks[0].endswith('ing') and toks[1] in connection_words:
                if toks[2] not in connection_words:
                    start_tok = 2
                else:
                    start_tok = 3
                start = toks_spans[start_tok].span()[0]
            filtered_datasets.append(ds[start:])
        else:
            filtered_datasets.append(ds)
    return filtered_datasets


def get_index(texts: List[str], words: List[str]) -> Dict[str, Set[int]]:
    # Returns a dictionary where words are keys and values are indices 
    # of documents (sentences) in texts, in which the word present
    index = defaultdict(set)
    words = set(words)
    words = {w for w in words if w.lower() not in PREPS and re.sub('\'', '', w).isalnum()}
    for n, text in tqdm(enumerate(texts), total=len(texts)):
        tokens = tokenize(text)
        for tok in tokens:
            if tok in words:
                index[tok].add(n)
    return index


def get_train_predictions_counts_data(datasets, index, kw):
    # Returns N_data and N_total counts dictionary 
    # (check the formulas in the first cell)
    pred_count = Counter()
    data_count = Counter()
    if isinstance(kw, str):
        kw = [kw]
    
    for ds in tqdm(datasets):
        first_tok, *toks = tokenize(ds)
        to_search = None
        for tok in [first_tok] + toks:
            if index.get(tok):
                if to_search is None:
                    to_search = set(index[tok])
                else:
                    to_search &= index[tok]
        for doc_idx in to_search:
            text = texts[doc_idx]
            if ds in text:
                pred_count[ds] += 1
                data_count[ds] += int(any(w in text.lower() for w in kw))
    return pred_count, data_count


def filter_by_train_counts(datasets, index, kw, min_train_count, rel_freq_threshold):
    # Filter by relative frequency (no parenthesis)
    # (check the formula in the first cell)
    tr_counts, data_counts = get_train_predictions_counts_data(extend_parentehis(set(datasets)), index, kw)
    stats = []

    for ds, count in Counter(datasets).most_common():
        stats.append([ds, count, tr_counts[ds], tr_counts[re.sub('[\s]?\(.*\)', '', ds)],
                      data_counts[ds], data_counts[re.sub('[\s]?\(.*\)', '', ds)]])
    
    filtered_datasets = []
    for ds, count, tr_count, tr_count_no_br, dcount, dcount_nobr in stats:
        if (tr_count_no_br > min_train_count) and (dcount_nobr / tr_count_no_br > rel_freq_threshold):
            filtered_datasets.append(ds)
    return filtered_datasets


def filter_and_the(datasets):
    pat = re.compile(' and [Tt]he ')
    return [pat.split(ds)[-1] for ds in datasets]


In [5]:
DATA_DIR = Path('/kaggle/input/coleridgeinitiative-show-us-the-data/')
TRAIN_MARKUP_FILE = DATA_DIR / 'train.csv'

## Data Utils

In [6]:
class Sentencizer:
    def __init__(self,
                 sentencize_fun: Callable,
                 split_by_newline: bool = True) -> None:
        self.sentencize = sentencize_fun
        self.split_by_newline = split_by_newline

    def __call__(self, text: str) -> List[str]:
        if self.split_by_newline:
            texts = text.split('\n')
        else:
            texts = [text]
        sents = []
        for text in texts:
            sents.extend(self.sentencize(text))
        return sents


class DotSplitSentencizer(Sentencizer):
    def __init__(self,
                 split_by_newline: bool) -> None:
        def _sent_fun(text: str) -> List[str]:
            return [sent.strip() for sent in text.split('.') if sent]
        super().__init__(_sent_fun, split_by_newline)


def get_coleridge_data(data_path: Union[str, Path],
                       sentencizer: Optional[Sentencizer] = None) -> None:
    data_path = Path(data_path)

    df = pd.read_csv(data_path / 'train.csv')

    samples = {}
    for _, (idx, pub_title, dataset_title, dataset_label, cleaned_label) in tqdm(df.iterrows()):
        if idx not in samples:
            with open(data_path / 'train' / (idx + '.json')) as fp:
                data = json.load(fp)
            samples[idx] = {'texts': [sec['text'] for sec in data], 
                            'dataset_titles': [],
                            'dataset_labels': [],
                            'cleaned_labels': [],
                            'pub_title': pub_title,
                            'idx': idx
                            }
        samples[idx]['dataset_titles'].append(dataset_title)
        samples[idx]['dataset_labels'].append(dataset_label)
        samples[idx]['cleaned_labels'].append(cleaned_label)

    train_ids = []
    train_texts = []
    train_labels = []
    for sample_dict in samples.values():
        train_ids.append(sample_dict['idx'])
        texts = sample_dict['texts']
        if sentencizer is not None:
            texts = list(chain(*[sentencizer(text) for text in texts]))
        train_texts.append(texts)
        train_labels.append(sample_dict['dataset_labels'])
    
    test_texts = []
    test_ids = []
    for test_file in (data_path / 'test').glob('*.json'):
        idx = test_file.name.split('.')[0]
        with open(test_file) as fp:
            data = json.load(fp)
        texts = [sec['text'] for sec in data]
        if sentencizer is not None:
            texts = list(chain(*[sentencizer(text) for text in texts]))

        test_texts.append(texts)
        test_ids.append(idx)
        
    return train_texts, train_ids, train_labels, test_texts, test_ids


In [7]:
train_texts, train_ids, train_labels, test_texts, test_ids = get_coleridge_data(DATA_DIR, DotSplitSentencizer(True))
train_labels_set = set(chain(*train_labels))

# all sentences from train and test as a single list
texts = list(chain(*(train_texts + test_texts)))

19661it [01:36, 204.38it/s]


# Pattern Extractor

In [8]:
def tokenzed_extract(texts, keywords):
    # Exracts all mentions of the form
    # Xxx Xxx Keyword Xxx (XXX)
    
    connection_words = {'of', 'the', 'with', 'for', 'in', 'to', 'on', 'and', 'up'}
    datasets = []
    for text in tqdm(texts):
        try:
            # Skip texts without parenthesis orXxx Xxx Keyword Xxx (XXX) keywords
            if '(' not in text or all(not kw in text for kw in keywords):
                continue

            toks = list(TOKENIZE_PAT.finditer(text))
            toksg = [tok.group() for tok in toks]

            found = False
            current_dss = set()
            for n in range(1, len(toks) - 2):
                is_camel = bool(CAMEL_PAT.findall(toksg[n + 1]))
                is_caps = toksg[n + 1].isupper()
                
                if toksg[n] == '(' and (is_caps or is_camel) and toksg[n + 2] == ')':
                    end = toks[n + 2].span()[1]
                    n_capi = 0
                    has_kw = False
                    for tok, tokg in zip(toks[n - 1:: -1], toksg[n - 1:: -1]):
                        if tokg in keywords:
                            has_kw = True
                        if tokg[0].isupper() and tokg.lower() not in connection_words:
                            n_capi += 1
                            start = tok.span()[0]
                        elif tokg in connection_words or tokg == '-':
                            continue
                        else:
                            break
                    if n_capi > 1 and has_kw:
                        ds = text[start: end]
                        datasets.append(ds)
                        found = True
                        current_dss.add(ds)
        except:
            print(text)

    return datasets


def get_parenthesis(t, ds):
    # Get abbreviations in the brackets if there are any 
    cur_abbrs = re.findall(re.escape(ds) + '\s?(\([^\)]+\)|\[[^\]]+\])', t)
    cur_abbrs = [abbr.strip('()[]').strip() for abbr in cur_abbrs]
    cur_abbrs = [re.split('[\(\[]', abbr)[0].strip() for abbr in cur_abbrs]
    cur_abbrs = [re.split('[;,]', abbr)[0].strip() for abbr in cur_abbrs]
    cur_abbrs = [a for a in cur_abbrs if not any(ch in a for ch in '[]()')]
    cur_abbrs = [a for a in cur_abbrs if re.findall('[A-Z][A-Z]', a)]
    cur_abbrs = [a for a in cur_abbrs if len(a) > 2]
    cur_abbrs = [a for a in cur_abbrs if not any(tok.islower() for tok in tokenize(a))]
    fabbrs = []
    for abbr in cur_abbrs:
        if not (sum(bool(re.findall('[A-Z][a-z]+', tok)) for tok in tokenize(abbr)) > 2):
            fabbrs.append(abbr)
    return fabbrs


# Evaluation

In [9]:
def get_datasets():
    STOPWORDS_PAR = [' lab', 'centre', 'center', 'consortium', 'office', 'agency', 'administration', 'clearinghouse',
                     'corps', 'organization', 'organisation', 'association', 'university', 'department',
                     'institute', 'foundation', 'service', 'bureau', 'company', 'test', 'tool', 'board', 'scale',
                     'framework', 'committee', 'system', 'group', 'rating', 'manual', 'division', 'supplement',
                     'variables', 'documentation', 'format']

    filter_stopwords_par_data = partial(filter_stopwords, stopwords=STOPWORDS_PAR)

    keywords = {'Study', 'Survey', 'Assessment', 'Initiative', 'Data', 'Dataset', 'Database'}
    
    # Datasets 
    ssai_par_datasets = tokenzed_extract(texts, keywords)
    
    words = list(chain(*[tokenize(ds) for ds in ssai_par_datasets]))
    texts_index = get_index(texts, words)
    filter_by_train_counts_filled = partial(filter_by_train_counts, index=texts_index,
                                            kw='data', min_train_count=2, rel_freq_threshold=0.1)

    filters = [filter_and_the, filter_stopwords_par_data, filter_intro_ssai, filler_intro_words, 
               filter_br_less_than_two_words, filter_parial_match_datasets, filter_by_train_counts_filled] 

    for filt in filters:
        ssai_par_datasets = filt(ssai_par_datasets)
    
    ssai_par_datasets = [BR_PAT.sub('', ds) for ds in ssai_par_datasets]

    return ssai_par_datasets
    

In [10]:
def solution():
    predictions = defaultdict(set)
    datasets = get_datasets()
    train_datasets = [ds for ds in train_labels_set if sum(ch.islower() for ch in ds) > 0 ]
    train_datasets = [BR_PAT.sub('', ds).strip() for ds in train_labels_set]
    datasets = set(datasets) | set(train_datasets)
    for filename in tqdm((DATA_DIR / 'test').glob('*')):
        idx = filename.name.split('.')[0]
        predictions[idx]
        with open(filename) as fin:
            data = json.load(fin)
        
        for sec in data:
            text = sec['text']    
            current_preds = []
            for paragraph in text.split('\n'):
                for sent in re.split('[\.]', paragraph):
                    for ds in datasets:
                        if ds in sent:
                            current_preds.append(ds)
                            current_preds.extend(get_parenthesis(sent, ds))
            predictions[idx].update(current_preds)
        predictions[idx] = remove_substring_predictions(predictions[idx])

    prediction_str_list = []
    for idx, datasets in predictions.items():
        datasets_str = '|'.join(clean_text(d) for d in sorted(set(datasets)))
        prediction_str_list.append([idx, datasets_str])

    with open('submission.csv', 'w') as fin:
        for idx, datasets in [['Id', 'PredictionString']] + prediction_str_list:
            fin.write(','.join([idx, datasets]) + '\n')


solution()

100%|██████████| 6929257/6929257 [00:09<00:00, 737329.35it/s]
100%|██████████| 6929257/6929257 [01:25<00:00, 80824.24it/s]
100%|██████████| 4624/4624 [00:04<00:00, 1103.79it/s]
4it [00:00, 11.10it/s]


In [11]:
!cat submission.csv

Id,PredictionString
8e6996b4-ca08-4c0b-bed2-aaf07a4c6a60,ces|consumer expenditure survey|national health and nutrition examination survey|ruccs|rural urban continuum codes
2100032a-7c33-4bff-97ef-690822c43466,adni|alzheimer s disease neuroimaging initiative|chs|cardiovascular health study|framingham heart study
2f392438-e215-4169-bebf-21ac4ff253e1,ccd|cps|current population survey|idb|international data base|nces common core of data|nces schools and staffing survey|pirls|pisa|program for international student assessment|progress in international reading literacy study|sass|timss 2007|trends in international mathematics and science study
3f316b38-1a24-45a9-8d8c-4e05a42257c6,national geodetic survey|national hydrography dataset|slosh model|us geological survey|usgs
