In [1]:
import pandas as pd
import json
import torch
from torch.optim import AdamW
from transformers import AutoModel
from transformers import AutoTokenizer
from datasets import load_dataset
from torch.utils.data import DataLoader
from pytorch_metric_learning import miners, losses
from datasets import load_metric

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
data_path_aclarc = "./acl-arc/scaffolds/sections-scaffold-train.jsonl"
data_path_scicite = "./scicite/scaffolds/sections-scaffold-train.jsonl"
with open(data_path_scicite, encoding='utf-8') as data_file:
    data = [json.loads(line) for line in data_file]
    df = pd.DataFrame(data).drop_duplicates()

In [24]:
final_cols = ['cleaned_cite_text', 'cleaned_cite_text_pos', 'label']

def split_and_concatenate(group):
    # Calculate the split index
    split_index = len(group) // 2
    
    # Split the group into two halves
    first_half = group.iloc[:split_index].reset_index(drop=True)['cleaned_cite_text']
    second_half = group.iloc[split_index:].reset_index(drop=True)
    second_half.rename(columns={'cleaned_cite_text': 'cleaned_cite_text_pos'}, inplace=True)

    # Concatenate the halves horizontally
    concatenated = pd.concat([first_half, second_half], axis=1)
    return concatenated

# Gets samples using concatenation
def get_pos_samples_concat(df, sort_cols):
    df_concat = df.copy(deep=True)

    # Dummy columns for groupby, to keep original columns
    include_groups = [i + '_drop' for i in sort_cols]
    df_concat[include_groups] = df_concat[sort_cols]
    
    result = df_concat.groupby(include_groups).apply(split_and_concatenate, include_groups=False).reset_index(drop=True)
    return result

def add_label(result, sort_cols):
    # Add Label
    if len(sort_cols) > 1:
        result['combined'] = result[sort_cols].T.agg(''.join)
    else:
        result['combined'] = result[sort_cols]

    labels, _ = pd.factorize(result['combined'])
    result['label'] = labels
    return result[final_cols]

# Replace NA with text_pos (dropout in roberta will treat this as unsupervised learning)
def handle_na(input_df):
    input_df.loc[pd.isna(input_df['cleaned_cite_text']), 'cleaned_cite_text'] = input_df.loc[pd.isna(input_df['cleaned_cite_text'])]['cleaned_cite_text_pos']

def process_data(df, sort_cols):
    concat = get_pos_samples_concat(df, sort_cols=sort_cols)
    concat_with_labels = add_label(concat, sort_cols)
    handle_na(concat_with_labels)
    concat_with_labels.columns = ['text', 'text_pos', 'label']

    return concat_with_labels

section_paper = ['section_name', 'cited_paper_id']
section = ['section_name']

# scicite does not have cited paper id
concat_section = process_data(df, sort_cols=section)


In [25]:
print(concat_section['label'].value_counts())
print(len(concat_section.loc[concat_section['text'] == concat_section['text_pos']]))

label
2    17656
0     9441
1     7496
3     7387
4     3727
Name: count, dtype: int64
8
