#### source
https://www.kaggle.com/code/alturutin/watson-xlm-r-nli-inference

In [None]:
from datasets import load_dataset
import pandas as pd
import re
import string

#### add class for finding duplicates and deleting them

In [None]:
class DuplicateDeleter:

    def __init__(self, valid_set, external_dataset):
        self.valid_set = valid_set
        self.external_dataset = external_dataset

    def preprocess_query(self, q):
        punct = '[' + ''.join([c for c in string.punctuation if c != "'"]) + ']'
        q = q.lower()
        q = re.sub(punct, ' ', q)
        q = re.sub('[ ]{2,}', ' ', q)
        return q

    def search_in_base(self, q, kb):
        q = self.preprocess_query(q)
        return int(q in kb)

    def delete(self):
        index_to_delete = []
        original_length =  self.external_dataset.shape[0]



        self.external_dataset_preprocessed = self.external_dataset['premise'].apply(self.preprocess_query)  # preprocess the external dataset


        
        self.knowledge_base = set(self.external_dataset['premise'].apply(self.preprocess_query))            # create a set of the external dataset for searching duplicates
        
        self.valid_set['duplicate'] = self.valid_set['premise'].apply(lambda q: self.search_in_base(q, self.knowledge_base))    # search for duplicates in the valid set and mark them
        print(f"fraction of train set english premises occurence in MNLI = {self.valid_set.loc[self.valid_set.lang_abv=='en', 'duplicate'].mean() * 100}%")

        for i in self.valid_set[self.valid_set.duplicate > 0.5].index:

            print("index from valid set to drop: ", i)

            # search duplicates in external dataset
            print("found in:", self.external_dataset_preprocessed[self.external_dataset_preprocessed == self.preprocess_query(self.valid_set.iloc[i,1])].index)

            for i in self.external_dataset_preprocessed[self.external_dataset_preprocessed == self.preprocess_query(self.valid_set.iloc[i,1])].index:
                index_to_delete.append(i)
            
            # drop duplicates in external dataset
            print("index in external dataset to drop: ", index_to_delete)

        print("*******************************")
        print(set(index_to_delete))
        print("index_to_delete")
        self.external_dataset.drop(set(index_to_delete), inplace=True)
        
        

        print(original_length - self.external_dataset.shape[0], " duplicates deleted")
        return self.external_dataset.reset_index(drop=True)

#### loading Datasets

In [None]:
valid_set = pd.read_csv("data/valid.csv")
valid_set.shape

In [None]:
train_set = pd.read_csv("data/train.csv")
train_set.shape

In [None]:
test_set = pd.read_csv("data/test.csv")
test_set.shape

In [None]:
val_train_set = pd.concat([valid_set, train_set], axis=0)

In [None]:
valid_set_translated = pd.read_csv("data/valid.csv")
train_set_translated = pd.read_csv("data/train.csv")
val_train_set_translated = pd.concat([train_set_translated, valid_set_translated], axis=0)

### MNLI

In [None]:
mnli = load_dataset('glue', 'mnli')
df_mnli = pd.DataFrame.from_dict(mnli["train"])
#df_mnli = pd.DataFrame.from_dict(mnli["validation_matched"])
df_mnli.drop(columns=['idx'], inplace=True)
original_count = df_mnli.shape[0]

In [None]:
mnli_deleter = DuplicateDeleter(val_train_set_translated, df_mnli)
mnli_deleter.delete().to_csv("data/mnli_train.csv", index=False)

### SNLI

In [None]:
snli = load_dataset('snli')
df_snli = pd.DataFrame.from_dict(snli["train"])

In [None]:
snli_deleter = DuplicateDeleter(val_train_set_translated, df_snli)
snli_deleter.delete().to_csv("data/snli_train.csv", index=False)

### XNLI

In [None]:
xnli_languages = ['ar', 'bg', 'de', 'el', 'en', 'es', 'fr', 'hi', 'ru', 'sw', 'th', 'tr', 'ur', 'vi', 'zh']
df_ar = pd.DataFrame()
df_bg = pd.DataFrame()
df_de = pd.DataFrame()
df_el = pd.DataFrame()
df_en = pd.DataFrame() 
df_es = pd.DataFrame()
df_fr = pd.DataFrame()
df_hi = pd.DataFrame()
df_ru = pd.DataFrame()
df_sw = pd.DataFrame()
df_th = pd.DataFrame()
df_tr = pd.DataFrame()
df_ur = pd.DataFrame()
df_vi = pd.DataFrame()
df_zh = pd.DataFrame()

dataframes = [df_ar, df_bg, df_de, df_el, df_en, df_es, df_fr, df_hi, df_ru, df_sw, df_th, df_tr, df_ur, df_vi, df_zh]

xnli_dataframe = pd.DataFrame()

for i in range(len(xnli_languages)):
    dataset = load_dataset('xnli', xnli_languages[i])
    dataframes[i]= pd.DataFrame.from_dict(dataset["train"])
    xnli_dataframe = pd.concat([xnli_dataframe, dataframes[i]], ignore_index=True)

In [None]:
xnli_deleter = DuplicateDeleter(val_train_set, xnli_dataframe)
xnli_deleted = xnli_deleter.delete()

In [None]:
xnli_deleted = xnli_deleted.sample(frac=1).reset_index(drop=True)
xnli_deleted[0:500000].to_csv("data/xnli_train.csv", index=False)