In [84]:
import numpy as np
import pandas as pd 
from tqdm import tqdm

df_train = pd.read_csv('dataset/original_data/train.csv', index_col=0)
df_test = pd.read_csv('dataset/original_data/test.csv', index_col=0)
df_val = pd.read_csv('dataset/original_data/val.csv', index_col=0)

df = pd.concat([df_train, df_test, df_val])

In [85]:
def merge_duplicates(df: pd.DataFrame) -> pd.DataFrame:
    df.sort_values(by='sentence1', inplace=True, ignore_index=True)
    targets = [i for i, val in df.duplicated(keep=False, subset=['sentence1']).items() if val]
    dupes = []    
    for i in targets:
        next_row = i + 1
        while next_row < len(df) and df.iloc[i]['sentence1'] == df.iloc[next_row]['sentence1']:
            entry1 = df.iloc[i]
            entry2 = df.iloc[next_row]
            new_entry = pd.Series({
                                    'sentence1': entry1['sentence1'], 
                                    'sentence21': entry1['sentence2'], 
                                    'sentence22': entry2['sentence2'], 
                                    'pronoun': entry1['pronoun'], 
                                    'referent1': entry1['referent'], 
                                    'wrong_referent1': entry1['wrong_referent'], 
                                    'referent2': entry2['referent'], 
                                    'wrong_referent2': entry2['wrong_referent']})            
            dupes.append(new_entry)
            next_row += 1
    df_dup = pd.DataFrame(dupes)
    return df_dup.reset_index(drop=True)

In [86]:
merge_duplicates(df).to_csv('dupes.csv')

In [91]:
def split_data(df: pd.DataFrame) -> None:
    df_train = df.sample(frac=0.6)
    df_rest = df.drop(df_train.index)
    
    df_valid = df_rest.sample(frac=0.5)
    df_test = df_rest.drop(df_valid.index)

    df_train = df_train.reset_index(drop=True)
    df_valid = df_valid.reset_index(drop=True)
    df_test = df_test.reset_index(drop=True)

    df_train.to_csv('train_60.csv')
    df_valid.to_csv('valid_20.csv')
    df_test.to_csv('test_20.csv')

In [92]:
df_dupes = pd.read_csv('dataset/dupes.csv', index_col=0)

split_data(df_dupes)

In [115]:
def check_duplicates(df: pd.DataFrame) -> bool:
    df.sort_values(by='sentence1', inplace=True, ignore_index=True)
    
    sent_pairs = []
    prev_sent1 = None
    
    for _, row in df.iterrows():
        # Check for identical duplicate row entries
        if row['sentence21'] == row['sentence22']:
            return True
        if prev_sent1 != row['sentence1']:
            sent_pairs.append([])
        sent_pairs[-1].append([row['sentence21'], row['sentence22']])
        prev_sent1 = row['sentence1']

    for entry in sent_pairs:
        for i in range(len(entry)):
            pair = entry[i]
            # Check for identical triplets of the form (sent1, sent21, sent22) and (sent1, sent21, sent22)
            if pair in entry[:i] or pair in entry[i+1:]:
                return True
            # Check for similar triplets of the form (sent1, sent21, sent22) and (sent1, sent22, sent21)
            if pair.reverse() in entry[:i] or pair.reverse() in entry[i+1:]:
                return True
                
    return False

In [117]:
print(check_duplicates(df_dupes))

False
