In [None]:
import os, sys
import random
import pandas as pd
import numpy as np
import torch
import datasets
from datasets import load_dataset

In [None]:
# Download XNLI Data
!wget https://dl.fbaipublicfiles.com/XNLI/XNLI-1.0.zip
!unzip XNLI-1.0.zip

# Load Dataset

In [None]:
# NusaTranslation Senti
nt_senti_dset = {
	"btk": load_dataset("indonlp/nusatranslation_senti", name="nusatranslation_senti_btk_nusantara_text"),
	"sun": load_dataset("indonlp/nusatranslation_senti", name="nusatranslation_senti_sun_nusantara_text"),
	"jav": load_dataset("indonlp/nusatranslation_senti", name="nusatranslation_senti_jav_nusantara_text"),
	"mad": load_dataset("indonlp/nusatranslation_senti", name="nusatranslation_senti_mad_nusantara_text"),
	"mak": load_dataset("indonlp/nusatranslation_senti", name="nusatranslation_senti_mak_nusantara_text"),
	"min": load_dataset("indonlp/nusatranslation_senti", name="nusatranslation_senti_min_nusantara_text"),
}

# NusaTranslation MT
nt_mt_dset = {
	"btk": load_dataset("indonlp/nusatranslation_mt", name="nusatranslation_mt_btk_ind_nusantara_t2t"),
	"sun": load_dataset("indonlp/nusatranslation_mt", name="nusatranslation_mt_sun_ind_nusantara_t2t"),
	"jav": load_dataset("indonlp/nusatranslation_mt", name="nusatranslation_mt_jav_ind_nusantara_t2t"),
	"mad": load_dataset("indonlp/nusatranslation_mt", name="nusatranslation_mt_mad_ind_nusantara_t2t"),
	"mak": load_dataset("indonlp/nusatranslation_mt", name="nusatranslation_mt_mak_ind_nusantara_t2t"),
	"min": load_dataset("indonlp/nusatranslation_mt", name="nusatranslation_mt_min_ind_nusantara_t2t"),
}

# NusaX Senti
nusax_senti_dset = {
	"btk": load_dataset("indonlp/NusaX-senti", name="bbc"),
	"sun": load_dataset("indonlp/NusaX-senti", name="sun"),
	"jav": load_dataset("indonlp/NusaX-senti", name="jav"),
	"mad": load_dataset("indonlp/NusaX-senti", name="mad"),
	"mak": load_dataset("indonlp/NusaX-senti", name="bug"),
	"min": load_dataset("indonlp/NusaX-senti", name="min"),
}

# NusaX MT ind
nusax_mt_ind_dset = {
	"btk": load_dataset("indonlp/NusaX-MT", name="bbc-ind"),
	"sun": load_dataset("indonlp/NusaX-MT", name="sun-ind"),
	"jav": load_dataset("indonlp/NusaX-MT", name="jav-ind"),
	"mad": load_dataset("indonlp/NusaX-MT", name="mad-ind"),
	"mak": load_dataset("indonlp/NusaX-MT", name="bug-ind"),
	"min": load_dataset("indonlp/NusaX-MT", name="min-ind"),
}

# NusaX MT eng (Extended experiment)
nusax_mt_eng_dset = {
	"btk": load_dataset("indonlp/NusaX-MT", name="bbc-eng"),
	"sun": load_dataset("indonlp/NusaX-MT", name="sun-eng"),
	"jav": load_dataset("indonlp/NusaX-MT", name="jav-eng"),
	"mad": load_dataset("indonlp/NusaX-MT", name="mad-eng"),
	"mak": load_dataset("indonlp/NusaX-MT", name="bug-eng"),
	"min": load_dataset("indonlp/NusaX-MT", name="min-eng"),
}

# MasakhaNews
masakhanews_dset = {
	"amh": load_dataset("masakhane/masakhanews", name="amh"),
	"hau": load_dataset("masakhane/masakhanews", name="hau"),
	"ibo": load_dataset("masakhane/masakhanews", name="ibo"),
	"lug": load_dataset("masakhane/masakhanews", name="lug"),
	"pcm": load_dataset("masakhane/masakhanews", name="pcm"),
	"sna": load_dataset("masakhane/masakhanews", name="sna"),
	"swa": load_dataset("masakhane/masakhanews", name="swa"),
	"xho": load_dataset("masakhane/masakhanews", name="xho"),
	"yor": load_dataset("masakhane/masakhanews", name="yor"),
}
# MAFAND
mafand_dset = {
	"amh": load_dataset("masakhane/mafand", name="en-amh"),
	"hau": load_dataset("masakhane/mafand", name="en-hau"),
	"ibo": load_dataset("masakhane/mafand", name="en-ibo"),
	"lug": load_dataset("masakhane/mafand", name="en-lug"),
	"pcm": load_dataset("masakhane/mafand", name="en-pcm"),
	"sna": load_dataset("masakhane/mafand", name="en-sna"),
	"swa": load_dataset("masakhane/mafand", name="en-swa"),
	"xho": load_dataset("masakhane/mafand", name="en-xho"),
	"yor": load_dataset("masakhane/mafand", name="en-yor"),
}

# AmericasNLI
americasnli_dset = {
	"aym": load_dataset("americas_nli", name="aym"),
	"bzd": load_dataset("americas_nli", name="bzd"),
	"cni": load_dataset("americas_nli", name="cni"),
	"gn": load_dataset("americas_nli", name="gn"),
	"hch": load_dataset("americas_nli", name="hch"),
	"nah": load_dataset("americas_nli", name="nah"),
	"oto": load_dataset("americas_nli", name="oto"),
	"quy": load_dataset("americas_nli", name="quy"),
	"shp": load_dataset("americas_nli", name="shp"),
	"tar": load_dataset("americas_nli", name="tar"),
}

# Standardize Dataset

### NLU Dataset

In [None]:
####
# Single-Sentence Classification [text, label]
# Pair-Sentence Classification [text_1, text_2, label]
####
save_path = '.'

def label2str(row, dset):
    row['str_label'] = dset.features['label'].int2str(row['label'])
    return row

# Process NusaTranslation Senti
nt_senti_dset_clean = {}
for key in nt_senti_dset.keys():
    dset = nt_senti_dset[key]['test'].remove_columns(['id'])
    nt_senti_dset_clean[key] = dset
nt_senti_dset_clean = datasets.DatasetDict(nt_senti_dset_clean)
nt_senti_dset_clean = nt_senti_dset_clean.map(
    label2str, remove_columns=['label'], fn_kwargs={"dset": nusax_senti_dset['jav']['train']}
).rename_columns({'str_label': 'label'})

# Process NusaX Senti
nusax_senti_dset_clean = {}
for key in nusax_senti_dset.keys():
    dset = nusax_senti_dset[key]['test'].remove_columns(['id', 'lang'])
    nusax_senti_dset_clean[key] = dset
nusax_senti_dset_clean = datasets.DatasetDict(nusax_senti_dset_clean)
nusax_senti_dset_clean = nusax_senti_dset_clean.map(
    label2str, remove_columns=['label'], fn_kwargs={"dset": nusax_senti_dset['jav']['train']}
).rename_columns({'str_label': 'label'})

# Process MasakhaNews Senti
masakhanews_dset_clean = {}
for key in masakhanews_dset.keys():
    dset = masakhanews_dset[key]['test'].remove_columns(['text', 'headline_text', 'url'])
    dset = dset.rename_columns({'headline': 'text'})
    dset = dset.map(
        label2str, remove_columns=['label'], fn_kwargs={"dset": dset}
    ).rename_columns({'str_label': 'label'})
    masakhanews_dset_clean[key] = dset
masakhanews_dset_clean = datasets.DatasetDict(masakhanews_dset_clean)

# Process AmericasNLI Senti
americasnli_dset_clean = {}
for key in americasnli_dset.keys():
    dset = americasnli_dset[key]['test']
    americasnli_dset_clean[key] = dset
americasnli_dset_clean = datasets.DatasetDict(americasnli_dset_clean)
americasnli_dset_clean = americasnli_dset_clean.map(
    label2str, remove_columns=['label'], fn_kwargs={"dset": americasnli_dset['aym']['test']}
).rename_columns({'str_label': 'label'})

nt_senti_dset_clean.save_to_disk(f'{save_path}/nt_senti_test_dset')
nusax_senti_dset_clean.save_to_disk(f'{save_path}/nusax_senti_test_dset')
masakhanews_dset_clean.save_to_disk(f'{save_path}/masakhanews_test_dset')
americasnli_dset_clean.save_to_disk(f'{save_path}/americasnli_test_dset')

### MT Dataset

In [None]:
####
# MT [text_1, text_2]
####
save_path = '.'

# Process NusaX MT ind
nusax_mt_ind_dset_clean = {}
for key in nusax_mt_ind_dset.keys():
    dset = []
    for split in nusax_mt_ind_dset[key].keys():
        dset.append(nusax_mt_ind_dset[key][split].remove_columns(['id', 'text_1_lang', 'text_2_lang']))
    nusax_mt_ind_dset_clean[key] = datasets.concatenate_datasets(dset)
nusax_mt_ind_dset_clean = datasets.DatasetDict(nusax_mt_ind_dset_clean)

# Process NusaX MT eng
nusax_mt_eng_dset_clean = {}
for key in nusax_mt_eng_dset.keys():
    dset = []
    for split in nusax_mt_eng_dset[key].keys():
        dset.append(nusax_mt_eng_dset[key][split].remove_columns(['id', 'text_1_lang', 'text_2_lang']))
    nusax_mt_eng_dset_clean[key] = datasets.concatenate_datasets(dset)
nusax_mt_eng_dset_clean = datasets.DatasetDict(nusax_mt_eng_dset_clean)

# Process MAFAND
mafand_dset_clean = {}
for key in mafand_dset.keys():
    tmp_dset = {'text_1': [], 'text_2': []}
    for split in mafand_dset[key].keys():        
        for i in range(len(mafand_dset[key][split])):
            tmp_dset['text_1'].append(mafand_dset[key][split][i]['translation'][key])
            tmp_dset['text_2'].append(mafand_dset[key][split][i]['translation']['en'])
    mafand_dset_clean[key] = datasets.Dataset.from_dict(tmp_dset)
mafand_dset_clean = datasets.DatasetDict(mafand_dset_clean)

# Save all datasets
nusax_mt_ind_dset_clean.save_to_disk(f'{save_path}/nusax_mt_ind_dset')
nusax_mt_eng_dset_clean.save_to_disk(f'{save_path}/nusax_mt_eng_dset')
mafand_dset_clean.save_to_disk(f'{save_path}/mafand_mt_dset')

### Merge Dataset

In [None]:
nusax_mt_ind_dset_clean = datasets.load_from_disk(f'{save_path}/nusax_mt_ind_dset')
nusax_mt_eng_dset_clean = datasets.load_from_disk(f'{save_path}/nusax_mt_eng_dset')

In [None]:
####
# NusaX Combined [text_1, text_2, label]
####
save_path = '.'

def label2str(row, dset):
    row['str_label'] = dset.features['label'].int2str(row['label'])
    return row

# Process NusaX Combined ind
nusax_combined_ind_dset = {}
for key in nusax_senti_dset.keys():
    dset = []
    for split in nusax_senti_dset[key].keys():
        dset.append(nusax_senti_dset[key][split].remove_columns(['id', 'lang', 'text']))
    nusax_combined_ind_dset[key] = datasets.concatenate_datasets([
        datasets.concatenate_datasets(dset),
        nusax_mt_ind_dset_clean[key]
    ], axis=1)
    
nusax_combined_ind_dset = datasets.DatasetDict(nusax_combined_ind_dset)
nusax_combined_ind_dset = nusax_combined_ind_dset.map(
    lambda x: label2str(x), remove_columns=['label'], fn_kwargs={"dset": nusax_senti_dset['jav']['train']}
).rename_columns({'str_label': 'label'})

# Process NusaX Combined ind
nusax_combined_eng_dset = {}
for key in nusax_senti_dset.keys():
    dset = []
    for split in nusax_senti_dset[key].keys():
        dset.append(nusax_senti_dset[key][split].remove_columns(['id', 'lang', 'text']))
    nusax_combined_eng_dset[key] = datasets.concatenate_datasets([
        datasets.concatenate_datasets(dset),
        nusax_mt_eng_dset_clean[key]
    ], axis=1)
nusax_combined_eng_dset = datasets.DatasetDict(nusax_combined_eng_dset)
nusax_combined_eng_dset = nusax_combined_eng_dset.map(
    lambda x: label2str(x), remove_columns=['label'], fn_kwargs={"dset": nusax_senti_dset['jav']['train']}
).rename_columns({'str_label': 'label'})

nusax_combined_ind_dset.save_to_disk(f'{save_path}/nusax_combined_ind_dset')
nusax_combined_eng_dset.save_to_disk(f'{save_path}/nusax_combined_eng_dset')

In [None]:
####
# AmericasNLI - XNLI Combined
####
save_path = '.'

# Cannot use the XNLI from HuggingFace, somehow the results are not aligned,
# so we use the original XNLI file (https://dl.fbaipublicfiles.com/XNLI/XNLI-1.0.zip) instead
xnli_df = pd.read_csv('XNLI-1.0/xnli.dev.tsv', sep='\t').reset_index()
americasnli_combined_dset = {}
for key in americasnli_dset.keys():
    anli_df = pd.read_csv(f'https://github.com/abteen/americasnli/raw/main/data/anli_final/dev/{key}.tsv', sep='\t')
    anli_dset = datasets.Dataset.from_pandas(
        anli_df[['premise', 'hypothesis', 'label']].rename({
            'premise': 'premise_1', 'hypothesis': 'hypothesis_1'
        }, axis='columns')
    )    
    
    xnli_dset = datasets.Dataset.from_pandas(
        xnli_df.loc[anli_df.id-1, ['sentence1', 'sentence2']].rename({
            'sentence1': 'premise_2', 'sentence2': 'hypothesis_2'
        }, axis='columns')
    ).remove_columns('__index_level_0__')
    americasnli_combined_dset[key] = datasets.concatenate_datasets([anli_dset, xnli_dset], axis=1)
americasnli_combined_dset = datasets.DatasetDict(americasnli_combined_dset)
americasnli_combined_dset.save_to_disk(f'{save_path}/americasnli_combined_dev_dset')

In [None]:
####
# MAFAND Random Label
####
random.seed(12345)
save_path = '.'

# Process MAFAND
mafand_dset_clean = {}
for key in mafand_dset.keys():
    label_names = masakhanews_dset[key]['train'].features['label'].names
    tmp_dset = {'text_1': [], 'text_2': [], 'label': []}
    for split in mafand_dset[key].keys():        
        for i in range(len(mafand_dset[key][split])):
            tmp_dset['text_1'].append(mafand_dset[key][split][i]['translation'][key])
            tmp_dset['text_2'].append(mafand_dset[key][split][i]['translation']['en'])
            tmp_dset['label'].append(random.choice(label_names))
    mafand_dset_clean[key] = datasets.Dataset.from_dict(tmp_dset)
mafand_dset_clean = datasets.DatasetDict(mafand_dset_clean)
mafand_dset_clean.save_to_disk(f'{save_path}/mafand_rand_label_dset')