In [1]:
from datasets import load_dataset
import datasets
import random
import os
import pandas as pd

In [2]:
task_to_keys = {
            "mnli": ("premise", "hypothesis"),
            "qqp": ("question1", "question2"),
            "rte": ("sentence1", "sentence2"),
            "sst2": ("sentence", None),
            "boolq": ("passage", "question"),
            "copa": ('choice1', 'choice2', 'premise', 'question'),
            "wic": ("start1", "end1", "sentence1", "start2", "end2", "sentence2", "word"),
            "wsc": ("span1_text", "span1_index", "span2_text", "span2_index", "text"),
            "wsc_bool": ("span1_text", "span1_index", "span2_text", "span2_index", "text"),
            "cb": ("premise", "hypothesis"),
            "record": ("passage", "query", "entities"),
            "multirc": ("question", "answer", "paragraph"),
            "rte_superglue": ("premise", "hypothesis"),
            "imdb": ("text", None),
            "ag_news": ("text", None),
            "yelp_review_full": ("text", None),
            "yahoo_answers_topics": ("question_content", "best_answer"),
            "dbpedia_14": ("title", "content"),
            "amazon": ("content", None)}

task_to_labels = {
            "mnli": {"label": ("entailment", "neutral", "contradiction"), "fold": "mnli_csv"},
            "qqp": {"label": ("not_duplicate", "duplicate"), "fold": "qqp_csv"},
            "rte": {"label": ("entailment", "not_entailment"), "fold": "rte_csv"},
            "sst2": { "label": ("negative", "positive"), "fold": "sst2_csv"},
            "boolq": {"label": ("false", "true"), "fold": "boolq_csv"},
            "copa": {"label": ("false", "true"), "fold": "copa_csv"},
            "wic": {"label": ("false", "true"), "fold": "wic_csv"},
            "cb": {"label": ("entailment", "contradiction", "neutral"), "fold": "cb_csv"},
            "multirc": {"label": ("false", "true"), "fold": "multirc_csv"},
            "imdb": {"label": ("negative", "positive"), "fold": "imdb_csv"},
            "ag_news": {"label": ("world", "sports", "business", "science"), "fold": "ag_news_csv"},
            "yelp_review_full": {"label":("terrible", "bad", "middle", "good", "wonderful"), "fold": "yelp_review_full_csv"},
            "yahoo_answers_topics":{"label" : ("society and culture", "science", "health", "education and reference",
                                     "computers and internet", "sports", "business", "entertainment and music",
                                     "family and relationships", "politics and government"), "fold" : "yahoo_answers_csv"},
            "dbpedia_14": {"label": ("company", "educationalinstitution", "artist", "athlete", "officeholder",
                           "meanoftransportation", "building", "naturalplace", "village", "animal",
                           "plant", "album", "film", "writtenwork"), "fold" : "dbpedia_csv"},
            "amazon": {"label": ("terrible", "bad", "middle", "good", "wonderful"), "fold": "amazon_csv"}
        }
glue_datasets = ['cola', 'sst2', 'mrpc', 'qqp', 'stsb', 'mnli', 'mnli_mismatched', 'mnli_matched', 'qnli', 'rte', 'wnli', 'ax']
superglue_datasets = ['copa', 'boolq', 'wic', 'wsc', 'cb', 'record', 'multirc', 'rte_superglue', 'wsc_bool']

In [3]:
# chia vao cac folder theo format run_seed/train(test/valid).txt
runs = [0, 1, 2]
seed = 0
fewshotnum = 16

In [4]:
def preprocess_function(examples, task, label_key):
    keys = task_to_keys[task]
    if keys[1]!=None:
        text = ''
        for key in keys:
            text += key + ': ' + str(examples[key]) + ' '
    else:
        text = examples[keys[0]]

    target = task_to_labels[task]['label'][examples[label_key]]
    return {"text": text, "target": target}

In [5]:
# load all datasets 
# for faster create datasets, change datasets outside the loop: 
dataset_names = [key for key in task_to_labels.keys()]
for run in runs:
    curr_seed = seed + run * 100
    for name in dataset_names:
        # if name != 'amazon':
        #     continue
        outpath = "longseqtextclsdata/" + task_to_labels[name]["fold"] + '/' + str(run) + '_' + str(curr_seed)
        print("out path: ", outpath)
        if not os.path.exists(outpath):
            os.makedirs(outpath)

        label_key = 'label' if 'yahoo_' not in name else 'topic'
        print("Task label key: ", label_key)
        
        if name == 'mnli':
            dataset_train = load_dataset("LysandreJik/glue-mnli-train", split='train')
            dataset_test = load_dataset("LysandreJik/glue-mnli-train", split='validation')
        elif name == 'amazon':
            # load train
            prefix_path = '/home/nguyen/projects/prompt_cl/datasets/src/data/amazon'
            df = pd.read_csv(os.path.join(prefix_path,"train.csv"), header=None)
            df = df.rename(columns={0: "label", 1: "title", 2: "content"})
            df['label'] = df['label'] - 1
            dataset_train = datasets.Dataset.from_pandas(df)
            
            # load test
            df = pd.read_csv(os.path.join(prefix_path,'test.csv'), header=None)
            df = df.rename(columns={0: "label", 1: "title", 2: "content"})
            df['label'] = df['label'] - 1
            dataset_test = datasets.Dataset.from_pandas(df)
        else:
            if name not in glue_datasets and name not in superglue_datasets:
                print("Task not glue", name)
                dataset_train = load_dataset(name, split='train')
                dataset_test = load_dataset(name, split='test')
            else:
                benchmark = 'glue' if name in glue_datasets else 'super_glue'
                dataset_train = load_dataset(benchmark,
                                             name.replace('_superglue', '').replace('_bool', ''),
                                             split="train")
                dataset_test = load_dataset(benchmark,
                                             name.replace('_superglue', '').replace('_bool', ''),
                                             split="validation")

        # processed datasets:
        dataset_train = dataset_train.map(lambda x: preprocess_function(x, name, label_key))
        dataset_test = dataset_test.map(lambda x: preprocess_function(x, name, label_key))

        all_label = []
        trainresult = {}
        for item in dataset_train:
            if item["target"] not in all_label:
                all_label.append(item["target"])
            # replace all \t by space
            text = item['text'].replace("\t", " ")
            if item["target"] not in trainresult.keys():
                trainresult[item["target"]] = [text]
            else:
                trainresult[item["target"]].append(text)
        
        testresult = {}
        for item in dataset_test:
            # replace all \t by space
            text = item['text'].replace("\t", " ")
            if item["target"] not in testresult.keys():
                testresult[item["target"]] = [text]
            else:
                testresult[item["target"]].append(text)
        
        fewtrainname = os.path.join(outpath, "train.txt")
        fewvalidname = os.path.join(outpath, "valid.txt")
        fewtestname = os.path.join(outpath, "test.txt")
        tousetres = {}
        for key in trainresult.keys():
            if 2 * fewshotnum < len(trainresult[key]):
                thisres = random.sample(trainresult[key], 2 * fewshotnum)
            else:
                thisres = trainresult[key]
            tousetres[key] = thisres

        sampletestres = {}
        for key in testresult.keys():
            sampletestnum = 500
            if sampletestnum < len(testresult[key]):
                thisres = random.sample(testresult[key], sampletestnum)
            else:
                thisres = testresult[key]
            sampletestres[key] = thisres

        tousetrainres = {}
        tousevalidres = {}
        for key in tousetres.keys():
            allres = tousetres[key]
            fortrain = allres[0:fewshotnum]
            forvalid = allres[fewshotnum:2 * fewshotnum]
            tousetrainres[key] = fortrain
            tousevalidres[key] = forvalid
        f = open(fewtrainname,'w')
        for key in tousetrainres.keys():
            for one in tousetrainres[key]:
                f.write(one+"\t"+key + "\n")
        f.close()

        f = open(fewvalidname, 'w')
        for key in tousevalidres.keys():
            for one in tousevalidres[key]:
                f.write(one + "\t" + key + "\n")
        f.close()
        ####test
        f = open(fewtestname, 'w')
        for key in sampletestres.keys():
            for one in sampletestres[key]:
                f.write(one + "\t" + key + "\n")
        f.close()

out path:  longseqtextclsdata/amazon_csv/0_0
Task label key:  label


Map:   0%|          | 0/115000 [00:00<?, ? examples/s]

Map:   0%|          | 0/7600 [00:00<?, ? examples/s]

out path:  longseqtextclsdata/amazon_csv/1_100
Task label key:  label


Map:   0%|          | 0/115000 [00:00<?, ? examples/s]

Map:   0%|          | 0/7600 [00:00<?, ? examples/s]

out path:  longseqtextclsdata/amazon_csv/2_200
Task label key:  label


Map:   0%|          | 0/115000 [00:00<?, ? examples/s]

Map:   0%|          | 0/7600 [00:00<?, ? examples/s]