## Data Injestion
Parses different dataset formats

### College Admissions
[github](https://zenodo.org/record/7114359#.Y6ILC3bMKXI) / [paper](https://jantrienes.com/assets/papers/tsar2022.pdf)

In [1]:
def edit_dist(s1, s2):
    if len(s1) > len(s2):
        s1, s2 = s2, s1

    distances = range(len(s1) + 1)
    for i2, c2 in enumerate(s2):
        distances_ = [i2+1]
        for i1, c1 in enumerate(s1):
            if c1 == c2:
                distances_.append(distances[i1])
            else:
                distances_.append(1 + min((distances[i1], distances[i1 + 1], distances_[-1])))
        distances = distances_
    return distances[-1]

In [2]:
# Take from the test set, some sentences not clean enough to include
test_set_ids = [4, 30, 38, 52, 65, 71, 77, 91, 97, 100, 106, 107, 109, 119, 121, 149, 166, 172, 174, 184, 192, 199, 200, 215, 221, 227, 250, 253, 258, 259, 278, 286, 309]
exclude = [97]
test_set_ids = [x for x in test_set_ids if x not in exclude]

In [3]:
admissions = []
for id_ in test_set_ids:
    alignment = f'admissions/alignments/alignments_{id_}.txt'
    original = f'admissions/original/original_{id_}.txt'
    simplified = f'admissions/simplified/simplified_{id_}.txt'

    with open(alignment, 'r', encoding='utf-8') as f:
        alignment = [x.replace('(', '').replace(')', '').replace(' ', '').split(',') for x in f.read().split('\n')]
        alignment = [[int(x[0]), [int(y) for y in x[1:]]] for x in alignment if len(x) != 1 and x[1] != 'N']

    with open(original, 'r', encoding='utf-8') as f:
        original = f.read().split('\n')

    with open(simplified, 'r', encoding='utf-8') as f:
        simplified = f.read().split('\n')

    # Identify simplified sentences with multiple source alignments (i.e. sent fusion)
    align_map = {}
    for align in alignment:
        for simp_align in align[1]:
            if simp_align not in align_map:
                align_map[simp_align] = []
        align_map[simp_align] += [align[0]]
    n_to_1 = [i for j in [v for k, v in align_map.items() if len(v) > 1] for i in j]

    for align in alignment:
        orig_sent = original[align[0]]
        simp_sent = ''.join([simplified[i] for i in align[1]])

        # Add space after periods, except the last period
        simp_sent = simp_sent[:-1].replace('.', '. ') + simp_sent[-1]
        
        ed = edit_dist(orig_sent, simp_sent)

        if (
            align[0] not in n_to_1 and  # Our interface does not support N:1 simplification. Throw these out.
            orig_sent != simp_sent and  # Exlude identical sentences
            ed > 50 and                 # Exclude sentences with minimal change
            len(orig_sent) > 50 and     # Exclude short sentences
            len(simp_sent) > 50         # Exclude sentences which delete almost all original information. Typically this is because they are contained elsewhere
        ):
            admissions += [{
                'original': orig_sent,
                'simplified': simp_sent
            }]

### Medical Transcriptions
[github](https://github.com/babylonhealth/laymaker) / [paper](https://www.ncbi.nlm.nih.gov/pmc/articles/PMC8861686/pdf/3576988.pdf)

In [4]:
import csv
path = 'medical-transcriptions/test.csv'

with open(path, encoding='utf-8') as f:
    reader = csv.reader(f)
    keys = next(reader)
    contents = [row for row in reader]
    loaded = []
    for sent in contents:
        loaded += [{k: v for k, v in zip(keys, sent)}]

medical = []
for sent in loaded:
    orig_sent = sent['ORIGINAL']
    simp_sent = sent['REFERENCE']

    ed = edit_dist(orig_sent, simp_sent)

    if (
        orig_sent != simp_sent and  # Exlude identical sentences
        ed > 45 and                 # Exclude sentences with minimal change
        len(orig_sent) > 40 and     # Exclude short sentences
        len(simp_sent) > 40         # Exclude sentences which delete almost all original information. Typically this is because they are contained elsewhere
    ):
        medical += [{
            'original': orig_sent,
            'simplified': simp_sent
        }]

### Medical Abstracts
[github](https://github.com/AshOlogn/Paragraph-level-Simplification-of-Medical-Texts) / [paper](https://aclanthology.org/2021.naacl-main.395/)

In [5]:
import json
with open('devaraj-medical/data-1024.json', 'r') as f:
    data = json.load(f)

devaraj = []
for sent in data:
    orig_sent = sent['abstract']
    simp_sent = sent['pls']

    devaraj += [{
        'original': orig_sent,
        'simplified': simp_sent
    }]

In [6]:
# These are paragraphs. This will *not* work without a sentence alignment

### Clinical Notes
*"Sharing of the dataset is currently underway."*

[github](https://github.com/jantrienes/simple-patho) / [paper](https://jantrienes.com/assets/papers/tsar2022.pdf)

### Newsela
[paper](https://aclanthology.org/Q15-1021/)

Data is not publically available, must request access. We'll be taking the lest/most complicated versions of the simplification. Should cite the paper that does this

In [7]:
import csv

# TODO: CAREFUL!
path = 'newsela/crowdsourced/train.tsv'

newsela = []
with open(path, encoding='utf-8') as f:
    reader = csv.reader(f, delimiter='\t')
    keys = ['alignment', 'simp_info', 'orig_info', 'simplified', 'original']
    contents = [row for row in reader]
    loaded = []
    for sent in contents:
        loaded += [{k: v for k, v in zip(keys, sent)}]

for sent in loaded:
    sent['article'], orig_info = sent['orig_info'].split('.')
    sent['language'], sent['orig_readability'], sent['orig_pg_idx'], sent['orig_sent_idx'] = orig_info.split('-')

    sent['article'], simp_info = sent['simp_info'].split('.')
    sent['language'], sent['simp_readability'], sent['simp_pg_idx'], sent['simp_sent_idx'] = simp_info.split('-')

    sent['orig_readability'] = int(sent['orig_readability'])
    sent['simp_readability'] = int(sent['simp_readability'])

loaded = [x for x in loaded if 'original' in x.keys()]

for sent in loaded:
    if (
        sent['alignment'] == 'aligned' and
        sent['orig_readability'] == 0
    ):
        # Find the original sentence
        curr = sent
        level = 0
        while level != 4:
            cands = [x for x in loaded if x['orig_info'] == curr['simp_info']]
            if len(cands) != 1 or cands[0]['alignment'] != 'aligned' or len(cands[0]['simplified']) < 10:
                break
            curr = cands[0]
            level += 1
        if level == 0:
            continue

        newsela += [{
            'original': sent['original'],
            'simplified': curr['simplified'],
            # 'orig_readability': sent['orig_readability'],
            # 'simp_readability': curr['simp_readability']
        }]

### Create Batches

In [8]:
data = {
    'multi-domain/admissions': admissions,
    'multi-domain/medical-transcriptions': medical,
    # 'multi-domain/devaraj': devaraj,
    'multi-domain/newsela': newsela
}

In [9]:
import re

def prepare_sent(original, simplified, system, id_):
    current_dict = {}
    current_dict["id"] = id_
    current_dict["original"] = original.strip()
    current_dict["original_spans"] = []

    simplified = simplified.strip()
    simplified = re.sub(r" 's", "'s", simplified)
    simplified = simplified.replace('. ', '. || ')

    current_dict["simplified"] = simplified
    current_dict["simplified_spans"] = []

    if "||" in simplified:
        #  find index of all || in simplified
        indices = [m.start() for m in re.finditer('\|\|', simplified)]
        for i, indice in enumerate(indices):
            current_dict["simplified_spans"].append([2, indice, indice+2, i])
    
    current_dict["system"] = f"system"
    return current_dict

system_list = data.keys()
annotators = ['yao', 'david']
batch_size = 20

id_ = 0
for i, annotator in enumerate(annotators):
    batch = []
    for system, dataset in data.items():
        segment = batch_size // len(data.keys())
        for sent in dataset[i*segment:(i+1)*segment]:
            curr = prepare_sent(sent['original'], sent['simplified'], system, id_)
            batch += [curr]
    with open(f"batches/{annotator}.json", "w") as f:
        json.dump(batch, f, indent=4)