## Step 1. Generate data splits.
**ACL_ARC**: https://github.com/oacore/dynamic_citation_context/tree/main/data/acl_arc

**ACT2**: https://github.com/oacore/dynamic_citation_context/tree/main/data/sdp_act

We reserve 15% of the training set for the validation as there is not published validation split. We replace the original "#AUTHOR_TAG" as "#CITATION_TAG".  

**FOCAL**: https://huggingface.co/datasets/adsabs/FOCAL

We use the original split. 

Focal allows multiple labels for one cited work (although mentioned as "citation" in the dataset). Each label is associated with a piece of citation text (The cited work is explicited mentioned for multiple times, or the mentioning is scattered in the text). Therefore, we reorganize the dataset and insert citation anchors in such a way:

- If there are **multiple different labels**, we insert a tag at the end of each citation text to create **multiple samples with different labels**. 
- If there is only **one label and one citation text** , we replace the citation text with the tag. 
- If there is only **one label but multiple citation texts**, we insert a tag at the end of the **first text** to create **one sample**. 

Also, the classification schema seems to be finegrained compared with that of the other two datasets among these classes: 'Similarities', 'Differences' and 'Compare/Contrast'(neutral). So, we merge them into 'Compare/Contrast' class.



In [None]:
import pandas as pd
import re

def process_acl(dataset, dataframe):
    paragraphs = [] 
    uids = []
    labels = []
    pattern = r'#(.*?)_TAG'
    for i, data in dataframe.iterrows():
        if dataset=='acl_arc' and (data['unique_id']=='CC862' or data['unique_id']=='CC1543'): 
            continue # these citations has an empty context.
        paragraph = data['cite_context_paragraph']
        label = data['citation_class_label']
        uid = data['unique_id']
        match = re.search(pattern, paragraph)
        if match:
            res = match.group(1)
            paragraph = paragraph.replace(res, 'CITATION')
        else: # fix a few cases 
            if dataset == 'acl_arc':
                paragraph = paragraph.replace('Clark and Curran', '#CITATION_TAG')
            elif data['unique_id']=='CC109' or data['unique_id']=='CC42':
                paragraph = paragraph.replace('[14]', '#CITATION_TAG') 
            elif data['unique_id']=='CC1039':
                paragraph = paragraph.replace('Wiener, 1948', '#CITATION_TAG') 
            elif data['unique_id']=='CCT218':
                paragraph = paragraph.replace('94', '#CITATION_TAG')
            elif data['unique_id']=='CCT365':
                paragraph = paragraph.replace('38', '#CITATION_TAG') 
            elif data['unique_id']=='CCT591':
                paragraph = paragraph.replace('65', '#CITATION_TAG') 
            elif data['unique_id']=='CCT610':
                paragraph = paragraph.replace('74', '#CITATION_TAG') 
            elif data['unique_id']=='CCT879':
                paragraph = paragraph.replace('Pledger et al', '#CITATION_TAG') 

            
        paragraph = paragraph[1:-1].replace('\'', '\"').split('\",')
        paragraph = ' '.join([_.strip('\'\" ') for _ in paragraph])

        uids.append(uid)
        paragraphs.append(paragraph)
        labels.append(label)

    res = {
    'unique_id':uids,
    'context':paragraphs,
    'label':labels,
    }
    
    return pd.DataFrame(res)


dataset = 'acl_arc'

train_data = pd.read_csv(f"raw_data/{dataset}_train.txt", sep='\t')
test_data = pd.read_csv(f"raw_data/{dataset}_test.txt", sep='\t')
valid_data = train_data.sample(frac=.15, random_state=1)

train_data = train_data.drop(valid_data.index)

train_data = process_acl(dataset, train_data)
valid_data = process_acl(dataset, valid_data)
test_data = process_acl(dataset, test_data)

train_data.to_csv(f"data/{dataset}_train.txt", sep='\t',index=False)
valid_data.to_csv(f"data/{dataset}_valid.txt",sep='\t', index=False)
test_data.to_csv(f"data/{dataset}_test.txt",sep='\t', index=False)


In [None]:
import json
import pandas as pd

MAP = {'Background': 0,
    'Compare/Contrast': 1,
    'Extends': 2,
    'Future Work': 3,
    'Motivation': 4,
    'Uses': 5}


def annotate_first_text(uid, sd, para, labels, res):
        cur_para = para[:sd[0][0]]+'(#CITATION_TAG)'+ para[sd[0][1]:]
        cur_label = MAP[labels[0]]
        cur_para = cur_para.replace('\n','').replace('\t','').replace('\r','')
        res['unique_id'].append(uid)
        res['context'].append(cur_para)
        res['label'].append(cur_label)

def annotate_every_text(uid, text_ranges, para, labels, res):
    for label, text_range in zip(labels, text_ranges):
        d = text_range[1]
        if f"{uid}_{d}" not in res['unique_id']:
            if para[d-1]=='.':
                cur_para = para[:d-1]+'(#CITATION_TAG)'+ para[d-1:]
            else:
                cur_para = para[:d]+'(#CITATION_TAG)'+ para[d:]
            cur_label = MAP[label]
            cur_para = cur_para.replace('\n','').replace('\t','').replace('\r','')

            res['unique_id'].append(f"{uid}_{d}")
            res['context'].append(cur_para)
            res['label'].append(cur_label)

def annotate_citation_entry(uid, s, d, para, labels, res):
    cur_para = para[:s] + '(#CITATION_TAG)' +para[d:]
    cur_label = MAP[labels[0]]
    cur_para = cur_para.replace('\n','').replace('\t','').replace('\r','')
    res['unique_id'].append(uid)
    #print(cur_para)
    res['context'].append(cur_para)
    res['label'].append(cur_label)

def process_focal(train_data):
    print(f'Load {len(train_data)} samples.' )

    # We manually handle some duplicates in identifiers of train_data.
    updated_train_data = {}
    for i, record in enumerate(train_data):
        uid = record['Identifier']
        if not uid in updated_train_data:
            updated_train_data[uid] = record
        else:
            if uid=='2021MNRAS.504..146V__Vink_&_Gräfener_2012_Instance_1':
                pass
            elif uid == '2021AandA...655A..99D__Carigi_et_al._2005_Instance_3':
                pass
            elif uid == '2021AandA...655A..99D__Carigi_et_al._2005_Instance_2':
                pass
            elif uid == '2018ApJ...863..162M__Liu_et_al._2013_Instance_3':
                pass
            elif uid == '2018MNRAS.479.3254V___2000_Instance_1':
                pass
            elif uid == '2015MNRAS.450.4364N__Wu_et_al._2004_Instance_1':
                updated_train_data[uid]['Functions Label'] += record['Functions Label']
                updated_train_data[uid]['Functions Start End'] += record['Functions Start End']
            elif uid == '2018AandA...619A..13V__Saviane_et_al._2012_Instance_4':
                updated_train_data[uid] = record
            elif uid == '2016AandA...588A..44Y__Jones_et_al._2014_Instance_1':
                updated_train_data[uid+'_d'] = record
            elif uid == '2022ApJ...936..102A__Williams_et_al._2006_Instance_1':
                updated_train_data[uid+'_d'] = record
            else: pass

    print(f'To be processed: {len(updated_train_data )} samples.' )

    res = {'unique_id':[],'context':[],'label':[]}

    for uid in updated_train_data:
        record = updated_train_data[uid]
        para = record['Paragraph']
        
        labels = record['Functions Label']
        for j, l in enumerate(labels):
            if l == 'Similarities' or l == 'Differences':
                labels[j] = 'Compare/Contrast'
        
        sd = record['Citation Start End']
        text_ranges = record['Functions Start End']
        texts = record['Citation Text']
        
        # multiple texts & labels
        if len(set(labels))>1: 
            annotate_every_text(uid, text_ranges, para, labels, res)
        else:
            # one label
            if len(sd)==1: # one entry
                annotate_citation_entry(uid, sd[0][0], sd[0][1], para, labels, res)     
            else: # entry is split
                #the entry is locally split as (author, year)
                if texts[1][0]>='0' and texts[1][0]<='9':  
                    annotate_citation_entry(uid, sd[0][0], sd[1][1], para, labels, res)
                else:
                    # multiple texts are far apart. 
                    annotate_first_text(uid, sd, para, labels, res)

    print('Got ', len(res['unique_id']), 'samples.', 'Unique id:', len(set(res['unique_id'])))
    res = pd.DataFrame(res)
    return res

path = 'raw_data'
save_path = 'data'

with open(f'{path}/FOCAL-TRAINING.jsonl','r') as fp:
    train_data = fp.readlines()
    train_data = [json.loads(_) for _ in train_data]

train_data = process_focal(train_data)
train_data.to_csv(f"{save_path}/focal_train.txt",sep='\t',index=False)

with open(f'{path}/FOCAL-VALIDATION.jsonl','r') as fp:
    valid_data = fp.readlines()
    valid_data = [json.loads(_) for _ in valid_data]

valid_data = process_focal(valid_data)
valid_data.to_csv(f"{save_path}/focal_valid.txt",sep='\t',index=False)

with open(f'{path}/FOCAL-TESTING.jsonl','r') as fp:
    test_data = fp.readlines()
    test_data = [json.loads(_) for _ in test_data]

test_data = process_focal(test_data)
test_data.to_csv(f"{save_path}/focal_test.txt",sep='\t',index=False)


Load 2421 samples.
To be processed: 2403 samples.
Got  2617 samples. Unique id: 2617
Load 606 samples.
To be processed: 606 samples.
Got  660 samples. Unique id: 660
Load 821 samples.
To be processed: 821 samples.
Got  889 samples. Unique id: 889


## Step 2: Extract input context and generate SP augmentations.

The split_into_sentences function is based on  https://stackoverflow.com/questions/4576077/how-can-i-split-a-text-into-sentences

We manually add some rules to solve for all datasets. For Citss, we only need to process for training set. 

In [None]:
import pandas as pd
import json
import re
alphabets= "([A-Za-z])"
prefixes = "(Mr|St|Mrs|Ms|Dr| al|Fig|\{|Eq|right\)|right\) |μm|fig| etc|,|Eqs|eqs|\(cf| vs|Sect)[.]"
suffixes = "(Inc|Ltd|Jr|Sr|Co|)"
starters = "(Mr|Mrs|Ms|Dr|Prof|Capt|Cpt|Lt|He\s|She\s|It\s|They\s|Their\s|Our\s|We\s|But\s|However\s|That\s|This\s|Wherever)"
acronyms = "([A-Z][.][A-Z][.](?:[A-Z][.])?)"
websites = "[.](com|net|org|io|gov|edu|me)"
digits = "([0-9])"
multiple_dots = r'\.{2,}'

specials = ['0.′′','Sect. 5.7']


def split_into_sentences(text: str) -> list[str]:
    text = " " + text + "  "
    text = text.replace("\n"," ")
    text = re.sub(prefixes,"\\1<prd>",text)
    text = re.sub(websites,"<prd>\\1",text)
    text = re.sub(digits + "[.]" + digits,"\\1<prd>\\2",text)
    text = re.sub(multiple_dots, lambda match: "<prd>" * len(match.group(0)) + "<stop>", text)
    if "Ph.D" in text: text = text.replace("Ph.D.","Ph<prd>D<prd>")
    if "2.2.2" in text: text = text.replace("2.2.2","2<prd>2<prd>2<prd>")
    if "4.2.1" in text: text = text.replace("4.2.1","4<prd>2<prd>1<prd>")
    if "e.g." in text: text = text.replace("e.g.","e<prd>g<prd>")
    if "e.g" in text: text = text.replace("e.g","e<prd>g<prd>")
    if ".\end" in text: text = text.replace(".\end","<prd>\\end")
    for s in specials:
        if s in text: 
            tmp = s.replace('.',"<prd>")+"<prd>"
            text = text.replace(s, tmp)
           
    text = re.sub("\s" + alphabets + "[.] "," \\1<prd> ",text)
    text = re.sub(acronyms+" "+starters,"\\1<stop> \\2",text)
    text = re.sub(alphabets + "[.]" + alphabets + "[.]" + alphabets + "[.]","\\1<prd>\\2<prd>\\3<prd>",text)
    text = re.sub(alphabets + "[.]" + alphabets + "[.]","\\1<prd>\\2<prd>",text)
    text = re.sub(" "+suffixes+"[.] "+starters," \\1<stop> \\2",text)
    text = re.sub(" "+suffixes+"[.]"," \\1<prd>",text)
    text = re.sub(" " + alphabets + "[.]"," \\1<prd>",text)

    if "”" in text: text = text.replace(".”","”.")
    if "\"" in text: text = text.replace(".\"","\".")
    if "!" in text: text = text.replace("!\"","\"!")
    if "?" in text: text = text.replace("?\"","\"?")
    text = text.replace(".",".<stop>")
    text = text.replace("?","?<stop>")
    text = text.replace("!","!<stop>")
    text = text.replace("<prd>",".")
    sentences = text.split("<stop>")
    sentences = [s.strip() for s in sentences]
    if sentences and not sentences[-1]: sentences = sentences[:-1]
    return sentences

def sp(data_path, generate_aug=False):
    data = pd.read_csv(data_path, sep='\t')

    sas = {}
    L = 3 # one-side range
    input_context = [] # T_i
    citance = []
    num_augmentation = []
    
    for i, line in data.iterrows():
        uid = line['unique_id']
        paragraph = line['context']
        sents = split_into_sentences(paragraph)

        num_sent = len(sents)
        sentence_augmentations = []  

        no_citation = True
        for j, sent in enumerate(sents):
            if '#CITATION_TAG' in sent: # the citance   
                no_citation = False
                min_clip = max(0, j-L)
                max_clip = min(num_sent, j+L+1)
                input_context.append(' '.join(sents[min_clip:max_clip]))   
                citance.append(sent[j])
                if generate_aug:
                    for k in range(1, 2*L+2):
                        min_b = max(min_clip, j-k+1)
                        for b in range(min_b, j+1): # b<=j, must contain the citance
                            if (b+k) <= max_clip: 
                                if not b==min_clip or not b+k==max_clip:
                                    tmp = sents[b: b+k]
                                    sentence_augmentations.append(' '.join(tmp))
                            else: break   
                break

        
        num_augmentation.append(len(sentence_augmentations))
        sas[uid] = sentence_augmentations
    
    if generate_aug:
        print('Avg. #SP', sum(num_augmentation)/len(sas), 'Maximum', max(num_augmentation))

    data['input_context'] = input_context
    data['citance'] = citance
    data.to_csv(data_path, sep='\t', index=False)
    return sas


def extract_input(dataset):
    path = 'data'
    sp_augs = sp(f'{path}/{dataset}_train.txt', generate_aug=True)
    with open(f'{path}/{dataset}_sas.json', 'w') as fp:
        json.dump(sp_augs, fp) 
    sp(f'{path}/{dataset}_train.txt')
    sp(f'{path}/{dataset}_valid.txt')
    sp(f'{path}/{dataset}_test.txt')


extract_input('acl_arc')
extract_input('act2')
extract_input('focal')

Avg. #SP 4.644031451036454 Maximum 15
Avg. #SP 6.850196078431373 Maximum 15
Avg. #SP 12.988536492166602 Maximum 15
