<a href="https://colab.research.google.com/github/MoritzLaurer/zeroshot-classifier/blob/main/data_nli_formatting.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

### Install and setup

In [None]:
#!pip install transformers[sentencepiece]~=4.33.0 -qq
!pip install datasets~=2.14.0 -qq

In [None]:
## load packages
import torch
from torch.utils.data import DataLoader

import pandas as pd
import numpy as np
import os
from datasets import load_dataset
import re
import time
import random
import tqdm

from datasets import load_dataset, Dataset, DatasetDict, concatenate_datasets
from datasets import ClassLabel

from google.colab.data_table import DataTable
from google.colab import data_table
from IPython.display import display
data_table.enable_dataframe_formatter() # https://colab.research.google.com/notebooks/data_table.ipynb#scrollTo=JgBtx0xFFv_i

## set global seed for reproducibility and against seed hacking
SEED_GLOBAL = 42
np.random.seed(SEED_GLOBAL)

In [None]:
# info on the GPU you are using
!nvidia-smi
# info on available ram
from psutil import virtual_memory
ram_gb = virtual_memory().total / 1e9
print('\n\nYour runtime has {:.1f} gigabytes of available RAM\n'.format(ram_gb))

Thu Sep 28 14:08:07 2023       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 525.105.17   Driver Version: 525.105.17   CUDA Version: 12.0     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla T4            Off  | 00000000:00:04.0 Off |                    0 |
| N/A   33C    P8     9W /  70W |      0MiB / 15360MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [None]:
## connect to google drive
from google.colab import drive
drive.mount('/content/drive', force_remount=False)

#set wd
print(os.getcwd())
os.chdir("/content/drive/My Drive/PhD/zero-shot-models")
print(os.getcwd())

# local config.py file with tokens
import config

Mounted at /content/drive
/content
/content/drive/My Drive/PhD/zero-shot-models


### Overarching functions

In [None]:
# functions for formatting the data into the universal NLI format
# note that train and test data needs to be handled differently

def format_nli_trainset(df_train=None, hypo_label_dic=None, random_seed=42):
  print(f"\nFor NLI: Augmenting data by adding random not_entail examples to the train set from other classes within the train set.")
  print(f"Length of df_train before this step is: {len(df_train)}.\n")
  print(f"Max augmentation can be: len(df_train) * 2 = {len(df_train)*2}. Can also be lower, if there are more entail examples than not-entail for a majority class")

  df_train_lst = []
  for label_text, hypothesis in hypo_label_dic.items():
    ## entailment/true
    df_train_step = df_train[df_train.label_text == label_text].copy(deep=True)
    df_train_step["hypothesis"] = [hypothesis] * len(df_train_step)
    df_train_step["labels"] = [0] * len(df_train_step)
    ## not_entailment/not_true
    df_train_step_not_entail = df_train[df_train.label_text != label_text].copy(deep=True)
    # could try weighing the sample texts for not_entail here. e.g. to get same n texts for each label
    df_train_step_not_entail = df_train_step_not_entail.sample(n=min(len(df_train_step), len(df_train_step_not_entail)), random_state=random_seed)  # can try sampling more not_entail here
    df_train_step_not_entail["hypothesis"] = [hypothesis] * len(df_train_step_not_entail)
    df_train_step_not_entail["labels"] = [1] * len(df_train_step_not_entail)
    # append
    df_train_lst.append(pd.concat([df_train_step, df_train_step_not_entail]))
  df_train = pd.concat(df_train_lst)

  # shuffle
  df_train = df_train.sample(frac=1, random_state=random_seed)
  df_train["labels"] = df_train.labels.apply(int)
  print(f"For NLI:  not_entail training examples were added, which leads to an augmented training dataset of length {len(df_train)}.")

  return df_train.copy(deep=True)


def format_nli_testset(df_test=None, hypo_label_dic=None):
  ## explode test dataset for N hypotheses
  # hypotheses
  hypothesis_lst = [value for key, value in hypo_label_dic.items()]
  print("Number of hypotheses/classes: ", len(hypothesis_lst), "\n")

  # labels lists with 0 at alphabetical position of their true hypo, 1 for other hypos
  label_text_label_dic_explode = {}
  for key, value in hypo_label_dic.items():
    label_lst = [0 if value == hypo else 1 for hypo in hypothesis_lst]
    label_text_label_dic_explode[key] = label_lst

  df_test_copy = df_test.copy(deep=True)
  df_test_copy["labels"] = df_test_copy.label_text.map(label_text_label_dic_explode)
  df_test_copy["hypothesis"] = [hypothesis_lst] * len(df_test_copy)
  print(f"For normal test, N classifications necessary: {len(df_test_copy)}")

  # explode dataset to have K-1 additional rows with not_entail labels and K-1 other hypotheses
  # ! after exploding, cannot sample anymore, because distorts the order to true labels values, which needs to be preserved for evaluation multilingual-repo
  print("Reminder: do not sample these test-sets anymore after formatting. Row order needs to be preserved for testing.")
  df_test_copy = df_test_copy.explode(["hypothesis", "labels"])  # multi-column explode requires pd.__version__ >= '1.3.0'
  print(f"For NLI test, N classifications necessary: {len(df_test_copy)}\n")

  return df_test_copy


### Formulate hypotheses for each task/class

In [None]:

task_hypotheses = {
    "wellformedquery":
        {
        "well_formed": "This text is a well formed Google query.",
        "not_well_formed": "This text is not a well formed Google query"
        },
    "biasframes_sex":
        {
        "sex": "This text contains allusions to sexual content.",
        "not_sex": "This text does not contain allusions to sexual content."
        },
    "biasframes_intent":
        {
        "intent": "The intent of this text is to be offensive/disrespectful.",
        "not_intent": "The intent of this text is not to be offensive/disrespectful."
        },
    "biasframes_offensive":
        {
        "offensive": "This text could be considered offensive, disrespectful, or toxic.",
        "not_offensive": "This text could not be considered offensive, disrespectful, or toxic."
        },
    "financialphrasebank":
        {
        "negative": "The sentiment in this text is negative from an investor's perspective.",
        "neutral": "The sentiment in this text is neutral from an investor's perspective.",
        "positive": "The sentiment in this text is positive from an investor's perspective."
        },
    "rottentomatoes":
        {
        "negative": "The sentiment in this rotten tomatoes movie review is negative",
        "positive": "The sentiment in this rotten tomatoes movie review is positive"
        },
    "amazonpolarity":
        {
        "negative": "The sentiment in this amazon product review is negative",
        "positive": "The sentiment in this amazon product review is positive"
        },
    "imdb":
        {
        "negative": "The sentiment in this imdb movie review is negative",
        "positive": "The sentiment in this imdb movie review is positive"
        },
    "appreviews":
        {
        "positive": "The sentiment in this app review is positive.",
        "negative": "The sentiment in this app review is negative."
        },
    "yelpreviews":
        {
        "positive": "The sentiment in this yelp review is positive.",
        "negative": "The sentiment in this yelp review is negative."
        },
    'wikitoxic_toxicaggregated':
        {
        'toxicaggregated': 'This wikipedia comment contains toxic language.',
        'not_toxicaggregated': 'This wikipedia comment does not contain toxic language.'
        },
    'wikitoxic_obscene':
        {
        'obscene': 'This wikipedia comment contains obscene language.',
        'not_obscene': 'This wikipedia comment does not contain obscene language.'
        },
    'wikitoxic_threat':
        {
        'threat': 'This wikipedia comment contains a threat.',
        'not_threat': 'This wikipedia comment does not contain a threat.'
        },
    'wikitoxic_insult':
        {
        'insult': 'This wikipedia comment contains an insult.',
        'not_insult': 'This wikipedia comment does not contain an insult.'
        },
    'wikitoxic_identityhate':
        {
        'identityhate': 'This wikipedia comment contains identity hate.',
        'not_identityhate': 'This wikipedia comment does not contain identity hate.'
        },
    "hateoffensive":
        {
        "hate_speech": "This tweet contains hate speech.",
        "offensive": "This tweet contains offensive language without hate speech.",
        "neither": "This tweet contains neither offensive language nor hate speech.",
        },
    "hatexplain":
        {
        "hate_speech": "This text from twitter or gab contains hate speech.",
        "offensive": "This text from twitter or gab contains offensive language without hate speech.",
        "neither": "This text from twitter or gab contains neither offensive language nor hate speech.",
        },
    "spam":
         {
         "spam": "This sms is spam.",
         "not_spam": "This sms is not spam.",
         },
    "emotiondair":
        {
        'anger': "This tweet expresses the emotion: anger",
        'fear': "This tweet expresses the emotion: fear",
        'joy': "This tweet expresses the emotion: joy",
        'love': "This tweet expresses the emotion: love",
        'sadness': "This tweet expresses the emotion: sadness",
        'surprise': "This tweet expresses the emotion: surprise",
        },
    "emocontext":
        {
        'angry': "This tweet expresses the emotion: anger",
        'sad': "This tweet expresses the emotion: sadness",
        'happy': "This tweet expresses the emotion: happiness",
        'others': "This tweet does not express any of the emotions: anger, sadness, or happiness",
        },
    "empathetic":
        {
        'afraid': 'The main emotion of this dialogue is: afraid',
        'angry': 'The main emotion of this dialogue is: angry',
        'annoyed': 'The main emotion of this dialogue is: annoyed',
        'anticipating': 'The main emotion of this dialogue is: anticipating',
        'anxious': 'The main emotion of this dialogue is: anxious',
        'apprehensive': 'The main emotion of this dialogue is: apprehensive',
        'ashamed': 'The main emotion of this dialogue is: ashamed',
        'caring': 'The main emotion of this dialogue is: caring',
        'confident': 'The main emotion of this dialogue is: confident',
        'content': 'The main emotion of this dialogue is: content',
        'devastated': 'The main emotion of this dialogue is: devastated',
        'disappointed': 'The main emotion of this dialogue is: disappointed',
        'disgusted': 'The main emotion of this dialogue is: disgusted',
        'embarrassed': 'The main emotion of this dialogue is: embarrassed',
        'excited': 'The main emotion of this dialogue is: excited',
        'faithful': 'The main emotion of this dialogue is: faithful',
        'furious': 'The main emotion of this dialogue is: furious',
        'grateful': 'The main emotion of this dialogue is: grateful',
        'guilty': 'The main emotion of this dialogue is: guilty',
        'hopeful': 'The main emotion of this dialogue is: hopeful',
        'impressed': 'The main emotion of this dialogue is: impressed',
        'jealous': 'The main emotion of this dialogue is: jealous',
        'joyful': 'The main emotion of this dialogue is: joyful',
        'lonely': 'The main emotion of this dialogue is: lonely',
        'nostalgic': 'The main emotion of this dialogue is: nostalgic',
        'prepared': 'The main emotion of this dialogue is: prepared',
        'proud': 'The main emotion of this dialogue is: proud',
        'sad': 'The main emotion of this dialogue is: sad',
        'sentimental': 'The main emotion of this dialogue is: sentimental',
        'surprised': 'The main emotion of this dialogue is: surprised',
        'terrified': 'The main emotion of this dialogue is: terrified',
        'trusting': 'The main emotion of this dialogue is: trusting'
        },
    "agnews":
        {
        'Business': "This news text is about business news",
        'Sci/Tech': "This news text is about science and technology",
        'Sports': "This news text is about sports",
        'World': "This news text is about world news"
        },
    "yahootopics":
        {
        'Business & Finance': 'This question from the Yahoo Q&A forum is categorized in the topic: Business & Finance',
        'Computers & Internet': 'This question from the Yahoo Q&A forum is categorized in the topic: Computers & Internet',
        'Education & Reference': 'This question from the Yahoo Q&A forum is categorized in the topic: Education & Reference',
        'Entertainment & Music': 'This question from the Yahoo Q&A forum is categorized in the topic: Entertainment & Music',
        'Family & Relationships': 'This question from the Yahoo Q&A forum is categorized in the topic: Family & Relationships',
        'Health': 'This question from the Yahoo Q&A forum is categorized in the topic: Health',
        'Politics & Government': 'This question from the Yahoo Q&A forum is categorized in the topic: Politics & Government',
        'Science & Mathematics': 'This question from the Yahoo Q&A forum is categorized in the topic: Science & Mathematics',
        'Society & Culture': 'This question from the Yahoo Q&A forum is categorized in the topic: Society & Culture',
        'Sports': 'This question from the Yahoo Q&A forum is categorized in the topic: Sports'
        },
    "massive":
        {
        'datetime_query': "The intent of this utterance is a datetime query.",
        'iot_hue_lightchange': "The intent of this utterance is changing the light.",
        'transport_ticket': "This utterance is about transport tickets.",
        'takeaway_query': "This utterance is about takeaway food.",
        'qa_stock': "This utterance is about stocks.",
        'general_greet': "This utterance is a general greet.",
        'recommendation_events': "This utterance is about event recommendations.",
        'music_dislikeness': "The intent of this utterance is signalling music dislike.",
        'iot_wemo_off': "The intent of this utterance is turning an IoT device off.",
        'cooking_recipe': "This utterance is about cooking recipies.",
        'qa_currency': "This utteranceis about currencies.",
        'transport_traffic': "This utterance is about transport or traffic.",
        'general_quirky': np.nan,  # unclear category, better to exclude
        'weather_query': "This utterance is a query about the wheather.",
        'audio_volume_up': "The intent of this utterance is turning the audio volume up.",
        'email_addcontact': "The intent of this utterance is adding an email address to contacts.",
        'takeaway_order': "The intent of this utterance is to order takeaway food.",
        'email_querycontact': "The intent of this utterance is to query contact details.",
        'iot_hue_lightup': "The intent of this utterance is to brighten lights.",
        'recommendation_locations': "The intent of this utterance is receiving recommendations for good locations.",
        'play_audiobook': "The utterance is related to playing audiobooks.",
        'lists_createoradd': "The utterance is related to creating or adding to lists.",
        'news_query': "The utterance is a query about the news.",
        'alarm_query': "The utterance is a query about alarms.",
        'iot_wemo_on': "The intent of the utterance is to turn an IoT device on.",
        'general_joke': "The intent of the utterance is to hear a joke.",
        'qa_definition': "The utterance is a query about a definition.",
        'social_query': "The utterance is a query about a social network.",
        'music_settings': "The intent of the utterance is to change music settings.",
        'audio_volume_other': "The utterance is related to audio volume.",
        'calendar_remove': "The intent of the utterance is to remove something from a calendar.",
        'iot_hue_lightdim': "The intent of the utterance is to dim the lights.",
        'calendar_query': "The utterance is a query about a calendar.",
        'email_sendemail': "The intent of the utterance is to send an email.",
        'iot_cleaning': "The intent of the utterance is for an IoT device to start cleaning.",
        'audio_volume_down': "The intent of the utterance is to lower the volume.",
        'play_radio': "The intent of the utterance is to play something on the radio.",
        'cooking_query': "The utterance is a query about cooking.",
        'datetime_convert': "The utterance is related to date time changes or conversion.",
        'qa_maths': "The utterance is a question about maths.",
        'iot_hue_lightoff': "The utterance is related to turning the lights off.",
        'iot_hue_lighton': "The utterance is related to turning the lights on.",
        'transport_query': "The utterance is a query about transport or travels.",
        'music_likeness': "The utterance is related to liking music.",
        'email_query': "The utterance is a query about emails.",
        'play_music': "The intent of this utterance is for an IoT device to play music.",
        'audio_volume_mute': "The intent of this utterance is to mute the volume.",
        'social_post': "The utterance is about social media posts.",
        'alarm_set': "The intent of the utterance is to set an alarm.",
        'qa_factoid': "The utterance is a factoid question.",
        'calendar_set': "The intent of this utterance is to set something in a calendar.",
        'play_game': "The intent of this utterance is to start playing a game.",
        'alarm_remove': "The intent of this utterance is to remove an alarm.",
        'lists_remove': "The intent of this utterance is to remove a list or remove something from a list.",
        'transport_taxi': "The intent of this utterance is to get a taxi.",
        'recommendation_movies': "This utterance is about movie recommendations.",
        'iot_coffee': "The intent of this utterance is for an IoT device to make coffee.",
        'music_query': "The utterance is a query about music.",
        'play_podcasts': "The utterance is related to playing podcasts.",
        'lists_query': "The utterance is a query about a list."
        },
    "banking77":
        {
        'activate_my_card': "This banking customer message is about activating a card.",
        'age_limit': "This banking customer message is related to age limits.",
        'apple_pay_or_google_pay': "This banking customer message is about apple pay or google pay",
        'atm_support': "This banking customer message requests ATM support.",
        'automatic_top_up': "This banking customer message is about automatic top up.",
        'balance_not_updated_after_bank_transfer': "This banking customer message is about a balance not updated after a transfer.",
        'balance_not_updated_after_cheque_or_cash_deposit': "This banking customer message is about a balance not updated after a cheque or cash deposit.",
        'beneficiary_not_allowed': "This banking customer message is related to a beneficiary not being allowed or a failed transfer.",
        'cancel_transfer': "This banking customer message is related to the cancellation of a transfer.",
        'card_about_to_expire': "This banking customer message is related to the expiration of a card.",
        'card_acceptance': "This banking customer message is related to the scope of acceptance of a card.",
        'card_arrival': "This banking customer message is about the arrival of a card.",
        'card_delivery_estimate': "This banking customer message is about a card delivery estimate or timing.",
        'card_linking': np.nan,  # category does not seem coherent.
        'card_not_working': "This banking customer message is about a card not working.",
        'card_payment_fee_charged': "This banking customer message is about a card payment fee.",
        'card_payment_not_recognised': "This banking customer message is about a payment the customer does not recognise.",
        'card_payment_wrong_exchange_rate': "This banking customer message is about a wrong exchange rate.",
        'card_swallowed': "This banking customer message is about a card swallowed by a machine.",
        'cash_withdrawal_charge': "This banking customer message is about a cash withdrawal charge.",
        'cash_withdrawal_not_recognised': "This banking customer message is about an unrecognised cash withdrawal.",
        'change_pin': "This banking customer message is about changing a pin code.",
        'compromised_card': "This banking customer message is about a compromised card.",
        'contactless_not_working': "This banking customer message is about contactless not working",
        'country_support': "This banking customer message is about country-specific support.",
        'declined_card_payment': "This banking customer message is about a declined card payment.",
        'declined_cash_withdrawal': "This banking customer message is about a declined cash withdrawal.",
        'declined_transfer': "This banking customer message is about a declined transfer.",
        'direct_debit_payment_not_recognised': "This banking customer message is about an unrecognised direct debit payment.",
        'disposable_card_limits': "This banking customer message is about the limits of disposable cards.",
        'edit_personal_details': "This banking customer message is about editing personal details.",
        'exchange_charge': "This banking customer message is about exchange rate charges.",
        'exchange_rate': "This banking customer message is about exchange rates.",
        'exchange_via_app': np.nan, # noisy category
        'extra_charge_on_statement': "This banking customer message is about an extra charge.",
        'failed_transfer': "This banking customer message is about a failed transfer.",
        'fiat_currency_support': "This banking customer message is about fiat currency support",
        'get_disposable_virtual_card': "This banking customer message is about getting a disposable virtual card.",
        'get_physical_card': np.nan,  # noisy category
        'getting_spare_card': "This banking customer message is about getting a spare card.",
        'getting_virtual_card': "This banking customer message is about getting a virtual card.",
        'lost_or_stolen_card': "This banking customer message is about a lost or stolen card.",
        'lost_or_stolen_phone': "This banking customer message is about a lost or stolen phone.",
        'order_physical_card': "This banking customer message is about ordering a card.",
        'passcode_forgotten': "This banking customer message is about a forgotten passcode.",
        'pending_card_payment': "This banking customer message is about a pending card payment.",
        'pending_cash_withdrawal': "This banking customer message is about a pending cash withdrawal.",
        'pending_top_up': "This banking customer message is about a pending top up.",
        'pending_transfer': "This banking customer message is about a pending transfer.",
        'pin_blocked': "This banking customer message is about a blocked pin.",
        'receiving_money': "This banking customer message is about receiving money.",
        'Refund_not_showing_up': "This customer message is about a refund not showing up.",
        'request_refund': "This banking customer message is about a refund request.",
        'reverted_card_payment?': "This banking customer message is about reverting a card payment.",
        'supported_cards_and_currencies': np.nan,  # don't understand category.
        'terminate_account': "This banking customer message is about terminating an account.",
        'top_up_by_bank_transfer_charge': np.nan,  # noisy
        'top_up_by_card_charge': "This banking customer message is about the charge for topping up by card.",
        'top_up_by_cash_or_cheque': "This banking customer message is about topping up by cash or cheque.",
        'top_up_failed': "This banking customer message is about top up issues or failures.",
        'top_up_limits': "This banking customer message is about top up limitations.",
        'top_up_reverted': "This banking customer message is about issues with topping up.",
        'topping_up_by_card': "This banking customer message is about topping up by card.",
        'transaction_charged_twice': "This banking customer message is about a transaction charged twice.",
        'transfer_fee_charged': "This banking customer message is about an issue with a transfer fee charge.",
        'transfer_into_account': "This banking customer message is about transfers into the customer's own account.",
        'transfer_not_received_by_recipient': "This banking customer message is about a transfer that has not arrived yet.",
        'transfer_timing': "This banking customer message is about transfer timing.",
        'unable_to_verify_identity': "This banking customer message is about an issue with identity verification.",
        'verify_my_identity': "This banking customer message is about identity verification.",
        'verify_source_of_funds': "This banking customer message is about the source of funds.",
        'verify_top_up': "This banking customer message is about verification and top ups",
        'virtual_card_not_working': "This banking customer message is about a virtual card not working",
        'visa_or_mastercard': "This banking customer message is about types of bank cards.",
        'why_verify_identity': "This banking customer message questions why identity verification is necessary.",
        'wrong_amount_of_cash_received': "This banking customer message is about a wrong amount of cash received.",
        'wrong_exchange_rate_for_cash_withdrawal': "This banking customer message is about a wrong exchange rate for a cash withdrawal."
        },
    "trueteacher":
        {
        "factually_inconsistent": "The summary is factually inconsistent with the full article.",
        "factually_consistent": "The summary is factually consistent with the full article.",
        },
}


# re-order all hypotheses to be alphabetical based on label_text
# to avoid potential issues of label_num - label_text - hypothesis mismatch
task_hypotheses = {key_task_name: dict(sorted(value_task_hypo_dict.items())) for key_task_name, value_task_hypo_dict in task_hypotheses.items()}


### Load non-NLI datasets

In [None]:
# manually written task names for validating that code doesn't miss anything
task_names_manual = [
    'wellformedquery', 'financialphrasebank', 'rottentomatoes', 'amazonpolarity',
    'imdb', 'appreviews', 'yelpreviews', 'wikitoxic_toxicaggregated',
    'wikitoxic_obscene', 'wikitoxic_threat', 'wikitoxic_insult',
    'wikitoxic_identityhate', 'hateoffensive', 'hatexplain',
    'trueteacher', 'spam', 'massive', 'banking77', 'emotiondair',
    'emocontext', 'empathetic', 'agnews', 'yahootopics',
    'biasframes_offensive', 'biasframes_sex', 'biasframes_intent',
    # to be included in v2
    'anthropic_harmless', 'anthropic_helpful',
]

In [None]:
## load (cleaned) train files

def find_files(directory, filter_string=None):
    # List all files dataset directory
    all_files = [os.path.join(directory, f) for f in os.listdir(directory) if os.path.isfile(os.path.join(directory, f))]
    # Filter files that have certain word
    train_files = [f for f in all_files if filter_string in os.path.basename(f)]
    return train_files

# find cleaned files
directory_path_cl = './datasets_clean'
train_files_lst_cl = find_files(directory_path_cl, filter_string="train")
# also add datasets that were not cleaned
directory_path_uncleaned = './datasets_standardized'
datasets_no_automatic_clean = ["trueteacher", "anthropic_harmless", "anthropic_helpful"]
train_files_lst_uncleaned = find_files(directory_path_uncleaned, filter_string="train")
train_files_lst_uncleaned = [path for path in train_files_lst_uncleaned if any(dataset in path for dataset in datasets_no_automatic_clean)]

train_files_lst = train_files_lst_cl + train_files_lst_uncleaned

print("All identified files: ", train_files_lst)
print("N files: ", len(train_files_lst))

# task name extraction from dataset files
# Define the regular expression pattern
pattern = re.compile(r'^\.(?:/datasets_clean|/datasets_standardized)/ds_|_train.*$')
# Apply the regular expression to each string in the list
task_names_train = [re.sub(pattern, '', fp) for fp in train_files_lst]

print("task names: ", task_names_train)

# double check if all intended dataset paths and task names  were loaded
print("N tasks: ", len(task_names_manual))

from collections import Counter
assert Counter(task_names_manual) == Counter(task_names_train)

list1 = ["apple", "banana", "cherry"]
list2 = ["banana", "cherry", "date", "elderberry"]

unique_to_list1 = set(task_names_manual) - set(task_names_train)
unique_to_list2 = set(task_names_train) - set(task_names_manual)
print("Strings unique to list1:", unique_to_list1)
print("Strings unique to list2:", unique_to_list2)


All identified files:  ['./datasets_clean/ds_wellformedquery_train_cl.gzip', './datasets_clean/ds_rottentomatoes_train_cl.gzip', './datasets_clean/ds_amazonpolarity_train_cl.gzip', './datasets_clean/ds_imdb_train_cl.gzip', './datasets_clean/ds_hatexplain_train_cl.gzip', './datasets_clean/ds_massive_train_cl.gzip', './datasets_clean/ds_yelpreviews_train_cl.gzip', './datasets_clean/ds_banking77_train_cl.gzip', './datasets_clean/ds_emotiondair_train_cl.gzip', './datasets_clean/ds_emocontext_train_cl.gzip', './datasets_clean/ds_empathetic_train_cl.gzip', './datasets_clean/ds_agnews_train_cl.gzip', './datasets_clean/ds_biasframes_offensive_train_cl.gzip', './datasets_clean/ds_yahootopics_train_cl.gzip', './datasets_clean/ds_biasframes_sex_train_cl.gzip', './datasets_clean/ds_biasframes_intent_train_cl.gzip', './datasets_clean/ds_financialphrasebank_train_cl.gzip', './datasets_clean/ds_appreviews_train_cl.gzip', './datasets_clean/ds_hateoffensive_train_cl.gzip', './datasets_clean/ds_spam_tra

In [None]:
## load test files

# find cleaned files
directory_path_test = './datasets_standardized'
test_files_lst = find_files(directory_path_test, filter_string="test")

print("All identified files: ", test_files_lst)
print("N files: ", len(test_files_lst))

# task name extraction from dataset files
# Define the regular expression pattern
pattern = re.compile(r'^\.(?:/datasets_clean|/datasets_standardized)/ds_|_test.*$')
# Apply the regular expression to each string in the list
task_names_test = [re.sub(pattern, '', fp) for fp in test_files_lst]

print("task names: ", task_names_test)

# double check if all intended dataset paths and task names  were loaded
# manually written task names for testing
print("N tasks: ", len(task_names_manual))

from collections import Counter
assert Counter(task_names_manual) == Counter(task_names_test)


All identified files:  ['./datasets_standardized/ds_wellformedquery_test.gzip', './datasets_standardized/ds_rottentomatoes_test.gzip', './datasets_standardized/ds_amazonpolarity_test.gzip', './datasets_standardized/ds_imdb_test.gzip', './datasets_standardized/ds_yelpreviews_test.gzip', './datasets_standardized/ds_hatexplain_test.gzip', './datasets_standardized/ds_massive_test.gzip', './datasets_standardized/ds_banking77_test.gzip', './datasets_standardized/ds_emotiondair_test.gzip', './datasets_standardized/ds_emocontext_test.gzip', './datasets_standardized/ds_empathetic_test.gzip', './datasets_standardized/ds_agnews_test.gzip', './datasets_standardized/ds_yahootopics_test.gzip', './datasets_standardized/ds_biasframes_sex_test.gzip', './datasets_standardized/ds_biasframes_offensive_test.gzip', './datasets_standardized/ds_biasframes_intent_test.gzip', './datasets_standardized/ds_financialphrasebank_test.gzip', './datasets_standardized/ds_appreviews_test.gzip', './datasets_standardized/d

#### Format df_train to NLI format

In [None]:
# ! Excluding anthropic from v0.1. Do reformatting work for v2
# TODO: formulate hypotheses for anthropic datasets in a simple binary way for testing.
# created complex and overlapping hypotheses to increase variety for training data, but I should binarize again for inference/testing
# Probably need to do updates both in harmonization script and here (e.g. in task_hypotheses dic)

tasks_with_existing_hypotheses = [
    "anthropic"  #'anthropic_harmless', 'anthropic_helpful',
]

# format train set
df_train_format_lst = []
for train_file_path in train_files_lst:
    # extract task name from file path
    pattern = re.compile(r'^\.(?:/datasets_clean|/datasets_standardized)/ds_|_train.*$')
    task_name = re.sub(pattern, '', train_file_path)

    # for datasets that do not have hypotheses yet
    if not any(task_name in train_file_path for task_name in tasks_with_existing_hypotheses):
        print("*** Loading unformatted file: ", train_file_path)
        # load dataset
        df_train = pd.read_parquet(train_file_path).reset_index(drop=True)

        # select hypotheses for task
        hypo_label_dic = task_hypotheses[task_name]

        df_train_format = format_nli_trainset(
            df_train=df_train, hypo_label_dic=hypo_label_dic,
            random_seed=SEED_GLOBAL
        )

        df_train_format["task_name"] = task_name

        # remove rows where no hypothesis was formulated due to noise
        df_train_format = df_train_format[~df_train_format.hypothesis.isna()]

        df_train_format_lst.append(df_train_format)
        print("\n")

    # for datasets that already have hypotheses and are already formatted
    else:
        # TODO: include anthropic datasets in v2
        continue
        print("*** Loading file that is already formatted: ", train_file_path)
        df_train_format = pd.read_parquet(train_file_path).reset_index(drop=True)

        df_train_format["task_name"] = task_name


df_train_format_concat = pd.concat(df_train_format_lst)
display(df_train_format_concat)


*** Loading unformatted file:  ./datasets_clean/ds_wellformedquery_train_cl.gzip

For NLI: Augmenting data by adding random not_entail examples to the train set from other classes within the train set.
Length of df_train before this step is: 12246.

Max augmentation can be: len(df_train) * 2 = 24492. Can also be lower, if there are more entail examples than not-entail for a majority class
For NLI:  not_entail training examples were added, which leads to an augmented training dataset of length 23202.


*** Loading unformatted file:  ./datasets_clean/ds_rottentomatoes_train_cl.gzip

For NLI: Augmenting data by adding random not_entail examples to the train set from other classes within the train set.
Length of df_train before this step is: 8231.

Max augmentation can be: len(df_train) * 2 = 16462. Can also be lower, if there are more entail examples than not-entail for a majority class
For NLI:  not_entail training examples were added, which leads to an augmented training dataset of leng

Unnamed: 0,text,label_text,label_standard,label_quality,hypothesis,labels,task_name
259,Scandinavia is part of this continent ?,not_well_formed,0,0.705590,This text is not a well formed Google query,0,wellformedquery
5730,What are the two main peripherals for a comput...,well_formed,1,0.745767,This text is a well formed Google query.,0,wellformedquery
884,Aspirin ulcer treatment ?,not_well_formed,0,0.838492,This text is a well formed Google query.,1,wellformedquery
9694,Where does the name Colombia come from ?,well_formed,1,0.642482,This text is not a well formed Google query,1,wellformedquery
2004,Style of dance in the 1960s ?,not_well_formed,0,0.530566,This text is not a well formed Google query,0,wellformedquery
...,...,...,...,...,...,...,...
2878,Summary:\nThe son of a football legend has bee...,factually_inconsistent,1,,The summary is factually inconsistent with the...,0,trueteacher
12597,Summary:\nA powerful earthquake has killed at ...,factually_consistent,0,,The summary is factually consistent with the f...,0,trueteacher
30534,Summary:\nThe 2014 Emmy Awards drew a slightly...,factually_consistent,0,,The summary is factually inconsistent with the...,1,trueteacher
1705,Summary:\nA plane crash in Nevada left four pe...,factually_consistent,0,,The summary is factually consistent with the f...,0,trueteacher


In [None]:
df_train_format_concat = df_train_format_concat[["text", "label_text", "labels", "hypothesis", "task_name", "label_quality"]]
display(df_train_format_concat)



Unnamed: 0,text,label_text,labels,hypothesis,task_name,label_quality
259,Scandinavia is part of this continent ?,not_well_formed,0,This text is not a well formed Google query,wellformedquery,0.705590
5730,What are the two main peripherals for a comput...,well_formed,0,This text is a well formed Google query.,wellformedquery,0.745767
884,Aspirin ulcer treatment ?,not_well_formed,1,This text is a well formed Google query.,wellformedquery,0.838492
9694,Where does the name Colombia come from ?,well_formed,1,This text is not a well formed Google query,wellformedquery,0.642482
2004,Style of dance in the 1960s ?,not_well_formed,0,This text is not a well formed Google query,wellformedquery,0.530566
...,...,...,...,...,...,...
2878,Summary:\nThe son of a football legend has bee...,factually_inconsistent,0,The summary is factually inconsistent with the...,trueteacher,
12597,Summary:\nA powerful earthquake has killed at ...,factually_consistent,0,The summary is factually consistent with the f...,trueteacher,
30534,Summary:\nThe 2014 Emmy Awards drew a slightly...,factually_consistent,1,The summary is factually inconsistent with the...,trueteacher,
1705,Summary:\nA plane crash in Nevada left four pe...,factually_consistent,0,The summary is factually consistent with the f...,trueteacher,


In [None]:
# downsample
# TODO: improve sampling. currently overestimating majority classes / not tailored to specific datasets
# do sampling before or after NLI augmentation?
n_max_sample_per_task = 20_000
df_train_format_concat_samp = df_train_format_concat.groupby("task_name", group_keys=False, as_index=False).apply(lambda x: x.sample(n=min(len(x), n_max_sample_per_task), random_state=SEED_GLOBAL))

# compare data distributions
data_distribution = df_train_format_concat.task_name.value_counts(ascending=False)
data_distribution_downsampled = df_train_format_concat_samp.task_name.value_counts(ascending=False)
print("n_texts before downsampling: ", data_distribution.sum())
print("n_texts after downsampling: ", data_distribution_downsampled.sum())

df_data_distribution = pd.DataFrame({
    "n_texts": data_distribution,
    "n_texts_downsample": data_distribution_downsampled
}).sort_values("n_texts", ascending=False)

display(df_data_distribution)


n_texts before downsampling:  769758
n_texts after downsampling:  401499


Unnamed: 0,n_texts,n_texts_downsample
yahootopics,130912,20000
agnews,71934,20000
trueteacher,71614,20000
yelpreviews,38638,20000
emocontext,38480,20000
amazonpolarity,38179,20000
imdb,36784,20000
wikitoxic_toxicaggregated,36308,20000
biasframes_intent,33776,20000
biasframes_offensive,33306,20000


#### Format df_test to NLI format

In [None]:

tasks_with_existing_hypotheses = [
    "anthropic"  #'anthropic_harmless', 'anthropic_helpful',
]

n_max_sample_per_class = 5_000

# format train set
df_test_format_dic = {}
for test_file_path in test_files_lst:
    # extract task name from file path
    pattern = re.compile(r'^\.(?:/datasets_standardized)/ds_|_test.*$')  #(r'^\.(?:/datasets_clean|/datasets_standardized)/ds_|_test.*$')
    task_name = re.sub(pattern, '', test_file_path)

    # load dataset
    df_test = pd.read_parquet(test_file_path).reset_index(drop=True)

    # for datasets that do not have hypotheses yet
    if not any(task_name in test_file_path for task_name in tasks_with_existing_hypotheses):
        print("*** Loading unformatted file: ", test_file_path)

        # downsample for faster testing
        df_test_samp = df_test.groupby("label_standard", group_keys=False, as_index=False).apply(
            lambda x: x.sample(n=min(len(x), n_max_sample_per_class), random_state=SEED_GLOBAL)
        ).reset_index(drop=True)

        # select hypotheses for task
        hypo_label_dic = task_hypotheses[task_name]

        df_test_format = format_nli_testset(
            df_test=df_test_samp, hypo_label_dic=hypo_label_dic,
        )

        df_test_format["task_name"] = task_name

        # remove rows where no hypothesis was formulated due to noise
        df_test_format = df_test_format[~df_test_format.hypothesis.isna()]

        df_test_format_dic.update({task_name: df_test_format[["text", "label_text", "labels", "hypothesis", "task_name"]].reset_index(drop=True)})
        print("\n")

    # for datasets that already have hypotheses and are already formatted
    else:
        # TODO: include anthropic datasets in v2
        continue
        print("*** Loading file that is already formatted: ", test_file_path)

        # downsample for faster testing
        df_test_samp = df_test.groupby("labels", group_keys=False, as_index=False).apply(
            lambda x: x.sample(n=min(len(x), n_max_sample_per_class), random_state=SEED_GLOBAL)
        ).reset_index(drop=True)

        df_test_format = df_test_samp

        # specific handling of anthropic dataset. can clean this in standardisation script later
        if "anthropic" in test_file_path:
            df_test_format["label_text"] = df_test_format["hypothesis"]

        df_test_format["task_name"] = task_name


#df_test_format_concat = pd.concat(df_test_format_lst)
display(df_test_format_dic)


*** Loading unformatted file:  ./datasets_standardized/ds_wellformedquery_test.gzip
Number of hypotheses/classes:  2 

For normal test, N classifications necessary: 2967
Reminder: do not sample these test-sets anymore after formatting. Row order needs to be preserved for testing.
For NLI test, N classifications necessary: 5934



*** Loading unformatted file:  ./datasets_standardized/ds_rottentomatoes_test.gzip
Number of hypotheses/classes:  2 

For normal test, N classifications necessary: 1066
Reminder: do not sample these test-sets anymore after formatting. Row order needs to be preserved for testing.
For NLI test, N classifications necessary: 2132



*** Loading unformatted file:  ./datasets_standardized/ds_amazonpolarity_test.gzip
Number of hypotheses/classes:  2 

For normal test, N classifications necessary: 10000
Reminder: do not sample these test-sets anymore after formatting. Row order needs to be preserved for testing.
For NLI test, N classifications necessary: 20000



*** 

{'wellformedquery':                                                    text       label_text  \
 0                           How much is 110 in pounds ?  not_well_formed   
 1                           How much is 110 in pounds ?  not_well_formed   
 2     Other responsibilities of the records custodian ?  not_well_formed   
 3     Other responsibilities of the records custodian ?  not_well_formed   
 4        Scope of information technology in education ?  not_well_formed   
 ...                                                 ...              ...   
 5929  What is the statute of limitations on check fr...      well_formed   
 5930                What is meant by stalemate in war ?      well_formed   
 5931                What is meant by stalemate in war ?      well_formed   
 5932              What does the logo of HONDA signify ?      well_formed   
 5933              What does the logo of HONDA signify ?      well_formed   
 
      labels                                   hypothes

In [None]:
#df_test_format_concat = df_test_format_concat[["text", "label_text", "labels", "hypothesis", "task_name"]]
#display(df_test_format_concat)

In [None]:
# inspect distribution in testset (after downsampling)

n_texts_per_task_dic = {key_task: len(value_df) for key_task, value_df in df_test_format_dic.items()}

print("n_texts after downsampling: ", sum(n_texts_per_task_dic.values()), "\n")
print("n_texts per class downsampling:\n", pd.Series(n_texts_per_task_dic).sort_values(ascending=False), "\n")


n_texts after downsampling:  1246492 

n_texts per class downsampling:
 yahootopics                  500000
banking77                    221760
massive                      175466
empathetic                    81344
agnews                        30400
emocontext                    22036
amazonpolarity                20000
imdb                          20000
yelpreviews                   20000
wikitoxic_toxicaggregated     20000
trueteacher                   17910
wikitoxic_obscene             17382
wikitoxic_insult              16854
emotiondair                   12000
wikitoxic_identityhate        11424
wikitoxic_threat              10422
biasframes_sex                 8808
appreviews                     8000
biasframes_offensive           7676
biasframes_intent              7296
wellformedquery                5934
hatexplain                     2922
hateoffensive                  2586
rottentomatoes                 2132
financialphrasebank            2070
spam                        

### Load NLI datasets

In [None]:
NLI_DATASETS_TO_USE = ["mnli", "anli", "fever", "wanli", "ling"]  # "mnli", "anli", "fever", "wanli", "ling", "xnli"


In [None]:
"""## Load NLI datasets"""

# container for selected datasets
dic_datasets_train = {}
dic_datasets_test = {}


In [None]:
### SNLI
# too noisy
#dataset_train_snli = load_dataset('snli', split="train")
#dataset_test_snli = load_dataset('snli', split="test")

### other potential datasets:
# climatefever: 1.5k climate related claims. questionable quality: low krippendorf alpha ~0.3  https://arxiv.org/abs/2012.00614
# feverours: includes info from tables


In [None]:
### MNLI
if "mnli" in NLI_DATASETS_TO_USE:
    dataset_mnli = load_dataset('multi_nli')  # split='train'

    # create final dataset
    dataset_mnli = dataset_mnli.remove_columns(['promptID', 'pairID', 'premise_binary_parse', 'premise_parse', 'hypothesis_binary_parse', 'hypothesis_parse', 'genre'])
    print(dataset_mnli)
    print(dataset_mnli["train"].features)

    # write to general dic containers
    dic_datasets_train.update({"mnli": dataset_mnli["train"]})
    dic_datasets_test.update({"mnli_m": dataset_mnli["validation_matched"], "mnli_mm": dataset_mnli["validation_mismatched"]})

    del (dataset_mnli)

Downloading builder script:   0%|          | 0.00/5.14k [00:00<?, ?B/s]

Downloading metadata:   0%|          | 0.00/2.88k [00:00<?, ?B/s]

Downloading readme:   0%|          | 0.00/8.67k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/227M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/392702 [00:00<?, ? examples/s]

Generating validation_matched split:   0%|          | 0/9815 [00:00<?, ? examples/s]

Generating validation_mismatched split:   0%|          | 0/9832 [00:00<?, ? examples/s]

DatasetDict({
    train: Dataset({
        features: ['premise', 'hypothesis', 'label'],
        num_rows: 392702
    })
    validation_matched: Dataset({
        features: ['premise', 'hypothesis', 'label'],
        num_rows: 9815
    })
    validation_mismatched: Dataset({
        features: ['premise', 'hypothesis', 'label'],
        num_rows: 9832
    })
})
{'premise': Value(dtype='string', id=None), 'hypothesis': Value(dtype='string', id=None), 'label': ClassLabel(names=['entailment', 'neutral', 'contradiction'], id=None)}


In [None]:
### FEVER-nli
if "fever" in NLI_DATASETS_TO_USE:

    dataset_fever = load_dataset("pietrolesci/nli_fever")

    # test set does not have labels. use dev for final test
    # data has duplicates for some reason
    df_fever_train = dataset_fever["train"].to_pandas().drop_duplicates(subset=['premise', 'hypothesis', 'label'], keep='first')  # removes around 12k texts. some duplicate texts have different labels. maintaining them by including 'labels' to reflect ambiguity
    df_fever_dev = dataset_fever["dev"].to_pandas().drop_duplicates(subset=['premise', 'hypothesis', 'label'], keep="first")
    # remove unnecessary columns
    df_fever_train = df_fever_train[["premise", "hypothesis", "label"]].reset_index(drop=True)
    df_fever_dev = df_fever_dev[["premise", "hypothesis", "label"]].reset_index(drop=True)

    # create final dataset
    dataset_fever = DatasetDict({"train": Dataset.from_pandas(df_fever_train), "test": Dataset.from_pandas(df_fever_dev)})  # "test": dataset_dev_fever_samp
    print(dataset_fever)
    print(dataset_fever["train"].features)

    # add classlabel names for NLI
    new_features = dataset_fever["train"].features.copy()
    new_features['label'] = ClassLabel(names=["entailment", "neutral", "contradiction"])
    dataset_fever = dataset_fever.cast(new_features)
    print(dataset_fever)
    print(dataset_fever["train"].features)

    # write to general dic container
    dic_datasets_train.update({"fever": dataset_fever["train"]})
    dic_datasets_test.update({"fever": dataset_fever["test"]})

    del (df_fever_train, df_fever_dev, dataset_fever)


Downloading readme:   0%|          | 0.00/6.61k [00:00<?, ?B/s]

Repo card metadata block was not found. Setting CardData to empty.


Downloading metadata:   0%|          | 0.00/1.21k [00:00<?, ?B/s]

Downloading data files:   0%|          | 0/3 [00:00<?, ?it/s]

Downloading data:   0%|          | 0.00/46.5M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/4.95M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/4.44M [00:00<?, ?B/s]

Extracting data files:   0%|          | 0/3 [00:00<?, ?it/s]

Generating train split:   0%|          | 0/208346 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/19998 [00:00<?, ? examples/s]

Generating dev split:   0%|          | 0/19998 [00:00<?, ? examples/s]

DatasetDict({
    train: Dataset({
        features: ['premise', 'hypothesis', 'label'],
        num_rows: 196805
    })
    test: Dataset({
        features: ['premise', 'hypothesis', 'label'],
        num_rows: 19652
    })
})
{'premise': Value(dtype='string', id=None), 'hypothesis': Value(dtype='string', id=None), 'label': Value(dtype='int64', id=None)}


Casting the dataset:   0%|          | 0/196805 [00:00<?, ? examples/s]

Casting the dataset:   0%|          | 0/19652 [00:00<?, ? examples/s]

DatasetDict({
    train: Dataset({
        features: ['premise', 'hypothesis', 'label'],
        num_rows: 196805
    })
    test: Dataset({
        features: ['premise', 'hypothesis', 'label'],
        num_rows: 19652
    })
})
{'premise': Value(dtype='string', id=None), 'hypothesis': Value(dtype='string', id=None), 'label': ClassLabel(names=['entailment', 'neutral', 'contradiction'], id=None)}


In [None]:
### ANLI
if "anli" in NLI_DATASETS_TO_USE:
    # train
    dataset_train_anli = load_dataset('anli', split=["train_r1", "train_r2", "train_r3"])
    dataset_train_anli = concatenate_datasets([dataset_train_anli[0], dataset_train_anli[1], dataset_train_anli[2]])
    dataset_train_anli = dataset_train_anli.remove_columns(["uid", "reason"])
    # test
    dataset_test_anli = load_dataset('anli', split=["test_r1", "test_r2", "test_r3"])
    dataset_test_anli = DatasetDict({f"test_r{i+1}": dataset for i, dataset in enumerate(dataset_test_anli)})
    dataset_test_anli = dataset_test_anli.remove_columns(["uid", "reason"])

    # create final dataset
    dataset_anli = DatasetDict({"train": dataset_train_anli, **dataset_test_anli})
    print(dataset_anli)
    print(dataset_anli["train"].features)

    # write to general dic container
    dic_datasets_train.update({"anli": dataset_anli["train"]})
    dic_datasets_test.update({"anli_r1": dataset_anli["test_r1"], "anli_r2": dataset_anli["test_r2"], "anli_r3": dataset_anli["test_r3"]})
    # tripple weight of anli test set, because smaller
    #dic_datasets_test_concat.update({"anli": concatenate_datasets([dataset_anli["test_all"], dataset_anli["test_all"], dataset_anli["test_all"]])})

    del (dataset_train_anli, dataset_test_anli, dataset_anli)



Downloading builder script:   0%|          | 0.00/5.55k [00:00<?, ?B/s]

Downloading metadata:   0%|          | 0.00/2.76k [00:00<?, ?B/s]

Downloading readme:   0%|          | 0.00/7.47k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/18.6M [00:00<?, ?B/s]

Generating train_r1 split:   0%|          | 0/16946 [00:00<?, ? examples/s]

Generating dev_r1 split:   0%|          | 0/1000 [00:00<?, ? examples/s]

Generating test_r1 split:   0%|          | 0/1000 [00:00<?, ? examples/s]

Generating train_r2 split:   0%|          | 0/45460 [00:00<?, ? examples/s]

Generating dev_r2 split:   0%|          | 0/1000 [00:00<?, ? examples/s]

Generating test_r2 split:   0%|          | 0/1000 [00:00<?, ? examples/s]

Generating train_r3 split:   0%|          | 0/100459 [00:00<?, ? examples/s]

Generating dev_r3 split:   0%|          | 0/1200 [00:00<?, ? examples/s]

Generating test_r3 split:   0%|          | 0/1200 [00:00<?, ? examples/s]

DatasetDict({
    train: Dataset({
        features: ['premise', 'hypothesis', 'label'],
        num_rows: 162865
    })
    test_r1: Dataset({
        features: ['premise', 'hypothesis', 'label'],
        num_rows: 1000
    })
    test_r2: Dataset({
        features: ['premise', 'hypothesis', 'label'],
        num_rows: 1000
    })
    test_r3: Dataset({
        features: ['premise', 'hypothesis', 'label'],
        num_rows: 1200
    })
})
{'premise': Value(dtype='string', id=None), 'hypothesis': Value(dtype='string', id=None), 'label': ClassLabel(names=['entailment', 'neutral', 'contradiction'], id=None)}


In [None]:
## WANLI   https://arxiv.org/pdf/2201.05955.pdf
if "wanli" in NLI_DATASETS_TO_USE:
    dataset_wanli = load_dataset("alisawuffles/WANLI")
    dataset_wanli = dataset_wanli.remove_columns(['pairID', 'genre', 'id'])
    dataset_wanli = dataset_wanli.rename_column("gold", "label")

    # unelegant conversion from text to int necessary to make casting below work for some reason
    def test(label):
      if label == "entailment":
        return 0
      elif label == "neutral":
        return 1
      elif label == "contradiction":
        return 2
    label_num = [test(label) for label in dataset_wanli["train"]["label"]]
    dataset_wanli["train"] = dataset_wanli["train"].remove_columns(["label"])
    dataset_wanli["train"] = dataset_wanli["train"].add_column("label", label_num)
    label_num = [test(label) for label in dataset_wanli["test"]["label"]]
    dataset_wanli["test"] = dataset_wanli["test"].remove_columns(["label"])
    dataset_wanli["test"] = dataset_wanli["test"].add_column("label", label_num)

    # adapt label column
    new_features = dataset_wanli["train"].features.copy()
    new_features['label'] = ClassLabel(names=["entailment", "neutral", "contradiction"])  # the label ids are attributed in the (non-alphabetical) order to the list
    dataset_wanli = dataset_wanli.cast(new_features)

    # print
    print(dataset_wanli)
    print(dataset_wanli["train"].features)

    # write to general dic container
    dic_datasets_train.update({"wanli": dataset_wanli["train"]})
    dic_datasets_test.update({"wanli": dataset_wanli["test"]})

    del (dataset_wanli)


Downloading readme:   0%|          | 0.00/9.94k [00:00<?, ?B/s]

Downloading data files:   0%|          | 0/2 [00:00<?, ?it/s]

Downloading data:   0%|          | 0.00/25.4M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/1.23M [00:00<?, ?B/s]

Extracting data files:   0%|          | 0/2 [00:00<?, ?it/s]

Generating train split: 0 examples [00:00, ? examples/s]

Generating test split: 0 examples [00:00, ? examples/s]

Casting the dataset:   0%|          | 0/102885 [00:00<?, ? examples/s]

Casting the dataset:   0%|          | 0/5000 [00:00<?, ? examples/s]

DatasetDict({
    train: Dataset({
        features: ['premise', 'hypothesis', 'label'],
        num_rows: 102885
    })
    test: Dataset({
        features: ['premise', 'hypothesis', 'label'],
        num_rows: 5000
    })
})
{'premise': Value(dtype='string', id=None), 'hypothesis': Value(dtype='string', id=None), 'label': ClassLabel(names=['entailment', 'neutral', 'contradiction'], id=None)}


In [None]:
### Linguist in the loop NLI data
if "ling" in NLI_DATASETS_TO_USE:
    links_dict_all = {
      #"train_round5_baseline_combined": "https://raw.githubusercontent.com/Alicia-Parrish/ling_in_loop/master/NLI_data/1_Baseline_protocol/train_round5_baseline_combined.jsonl",
      #"val_round5_baseline_combined": "https://raw.githubusercontent.com/Alicia-Parrish/ling_in_loop/master/NLI_data/1_Baseline_protocol/val_round5_base_combined.jsonl",
      "train_round5_litl_combined": "https://raw.githubusercontent.com/Alicia-Parrish/ling_in_loop/master/NLI_data/3_Ling_in_loop_protocol/train_round5_LitL_combined.jsonl",
      "val_round5_litl_combined": "https://raw.githubusercontent.com/Alicia-Parrish/ling_in_loop/master/NLI_data/3_Ling_in_loop_protocol/val_round5_LitL_combined.jsonl",
      "train_round5_lots_combined": "https://raw.githubusercontent.com/Alicia-Parrish/ling_in_loop/master/NLI_data/2_Ling_on_side_protocol/train_round5_LotS_combined.jsonl",
      "val_round5_lots_combined": "https://raw.githubusercontent.com/Alicia-Parrish/ling_in_loop/master/NLI_data/2_Ling_on_side_protocol/val_round5_LotS_combined.jsonl",
    }
    df_train_ling_lst = []
    df_val_ling_lst = []
    for dataset_name, link in links_dict_all.items():
        if "train" in dataset_name:
            df_step = pd.read_json(link, lines=True)
            df_step["dataset_type"] = dataset_name
            df_train_ling_lst.append(df_step)
        elif "val" in dataset_name:
            df_step = pd.read_json(link, lines=True)
            df_step["dataset_type"] = dataset_name
            df_val_ling_lst.append(df_step)

    df_ling_train = pd.concat(df_train_ling_lst)
    df_ling_val = pd.concat(df_val_ling_lst)
    df_ling_train_cl = df_ling_train[["premise", "hypothesis", "label"]]
    df_ling_val_cl = df_ling_val[["premise", "hypothesis", "label"]]

    label_map_ling = {"entailment": 0, "neutral": 1, "contradiction": 2}
    df_ling_train_cl.label = df_ling_train_cl.label.map(label_map_ling)
    df_ling_val_cl.label = df_ling_val_cl.label.map(label_map_ling)
    dataset_ling = DatasetDict({'train': Dataset.from_pandas(df_ling_train_cl), 'test': Dataset.from_pandas(df_ling_val_cl)}).remove_columns(['__index_level_0__'])

    new_features = dataset_ling["train"].features.copy()
    new_features['label'] = ClassLabel(names=["entailment", "neutral", "contradiction"])  # the label ids are attributed in the (non-alphabetical) order to the list
    dataset_ling = dataset_ling.cast(new_features)

    # print
    print(dataset_ling)
    print(dataset_ling["train"].features)

    # write to general dic container
    dic_datasets_train.update({"ling": dataset_ling["train"]})
    dic_datasets_test.update({"ling": dataset_ling["test"]})
    #dic_datasets_test_concat.update({"ling": dataset_ling["test"]})

    del (dataset_ling, df_train_ling_lst, df_val_ling_lst, df_ling_train, df_ling_val, df_ling_train_cl, df_ling_val_cl)


A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  df_ling_train_cl.label = df_ling_train_cl.label.map(label_map_ling)
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  df_ling_val_cl.label = df_ling_val_cl.label.map(label_map_ling)


Casting the dataset:   0%|          | 0/29985 [00:00<?, ? examples/s]

Casting the dataset:   0%|          | 0/4893 [00:00<?, ? examples/s]

DatasetDict({
    train: Dataset({
        features: ['premise', 'hypothesis', 'label'],
        num_rows: 29985
    })
    test: Dataset({
        features: ['premise', 'hypothesis', 'label'],
        num_rows: 4893
    })
})
{'premise': Value(dtype='string', id=None), 'hypothesis': Value(dtype='string', id=None), 'label': ClassLabel(names=['entailment', 'neutral', 'contradiction'], id=None)}


In [None]:
### XNLI
if "xnli" in NLI_DATASETS_TO_USE:
    dataset_xnli_raw = load_dataset("xnli", "all_languages", split=["validation", "test"])

    # put in proper dataset dict
    dataset_xnli_raw = DatasetDict({"validation": dataset_xnli_raw[0], "test": dataset_xnli_raw[1]})
    print("\nRaw xnli: ", dataset_xnli_raw)

    ## clean & augment nested data structure of XNLI
    source2source = True  # is automatically included in many2many
    many2en = False  # is automatically included in many2many
    many2many = False

    xnli_dict = {}
    for dataset_name in dataset_xnli_raw:
      if dataset_name != "train":  # train test takes way too long and don't need it
        premise_lst = []
        language_lst = []
        hypothesis_lst = []
        label_lst = []
        for data_point in dataset_xnli_raw[dataset_name]:
          # create test set
          if dataset_name == "test":
            premise_lst.append(list(data_point["premise"].values()))
            language_lst.append(data_point["hypothesis"]["language"])
            hypothesis_lst.append(data_point["hypothesis"]["translation"])
            label_lst.append([data_point["label"]] * len(data_point["hypothesis"]["translation"]))
          # normal source-source for premise-hypo
          if (source2source == True) and (dataset_name != "test"):
            premise_lst.append(list(data_point["premise"].values()))
            language_lst.append(data_point["hypothesis"]["language"])
            hypothesis_lst.append(data_point["hypothesis"]["translation"])
            label_lst.append([data_point["label"]] * len(data_point["hypothesis"]["translation"]))
          # data augmentation many2en
          """if (many2en == True) and (dataset_name != "test"):
            # premise en - hypo source
            premise_lst.append([list(data_point["premise"].values())[4]] * len(list(data_point["premise"].values())))
            language_lst.append(["en-" + lang for lang in data_point["hypothesis"]["language"]])
            hypothesis_lst.append(data_point["hypothesis"]["translation"])
            label_lst.append([data_point["label"]] * len(data_point["hypothesis"]["translation"]))
            # premise source - hypo en
            premise_lst.append(list(data_point["premise"].values()))
            language_lst.append([lang + "-en" for lang in data_point["hypothesis"]["language"]])
            hypothesis_lst.append([data_point["hypothesis"]["translation"][4]] * len(list(data_point["premise"].values())))
            label_lst.append([data_point["label"]] * len(data_point["hypothesis"]["translation"]))
          # could add many2many
          if (many2many == True) and (dataset_name != "test"):
            premise_lst_m2m = [[premise] * len(list(data_point["premise"].values())) for premise in list(data_point["premise"].values())]
            premise_lst.append([item for sublist in premise_lst_m2m for item in sublist])
            language_lst_premise_m2m = [[premise_lang] * len(list(data_point["premise"].values())) for premise_lang in list(data_point["premise"].keys())]
            language_lst_premise_m2m = [item for sublist in language_lst_premise_m2m for item in sublist]
            language_lst_hypo_m2m = [data_point["hypothesis"]["language"]] * len(data_point["hypothesis"]["translation"])
            language_lst_hypo_m2m = [item for sublist in language_lst_hypo_m2m for item in sublist]
            language_lst_m2m = [premise_lang + "-" + hypo_lang for premise_lang, hypo_lang in zip(language_lst_premise_m2m, language_lst_hypo_m2m)]
            language_lst.append(language_lst_m2m)
            hypothesis_lst_m2m = [data_point["hypothesis"]["translation"]] * len(data_point["hypothesis"]["translation"])
            hypothesis_lst.append([item for sublist in hypothesis_lst_m2m for item in sublist])
            label_lst_m2m = [data_point["label"]] * len(data_point["hypothesis"]["translation"])**2
            label_lst.append(label_lst_m2m)"""
          # ...
        # unnest list
        premise_lst = [item for sublist in premise_lst for item in sublist]
        language_lst = [item for sublist in language_lst for item in sublist]
        hypothesis_lst = [item for sublist in hypothesis_lst for item in sublist]
        label_lst = [item for sublist in label_lst for item in sublist]
        # update master dict
        xnli_dict.update({dataset_name: Dataset.from_dict({"premise": premise_lst, "hypothesis": hypothesis_lst, "language": language_lst, "label": label_lst})})

    dataset_xnli = DatasetDict(xnli_dict)

    # remove en-en duplicates:
    #dataset_xnli["validation"] = dataset_xnli["validation"].filter(lambda example: example['language'] != "en-en")
    print(dataset_xnli)

    # harmonised label column type with MNLI etc.
    from datasets import ClassLabel
    new_features = dataset_xnli["validation"].features.copy()
    new_features['label'] = ClassLabel(names=["entailment", "neutral", "contradiction"])
    dataset_xnli = dataset_xnli.cast(new_features)
    #dataset_xnli["validation"].features

    # separate test datasets by language
    xnli_dict_test_separated = {}
    for group_name, group_df in pd.DataFrame(dataset_xnli["test"]).groupby(by="language"):
      xnli_dict_test_separated.update({group_name: Dataset.from_pandas(group_df).remove_columns("__index_level_0__")})

    dataset_xnli["test_separated"] = DatasetDict(xnli_dict_test_separated)

    # print
    print(dataset_xnli)
    print(dataset_xnli["validation"].features)

    # write to general dic container
    # ! augment XNLI in trainset
    dic_datasets_train.update({"xnli": concatenate_datasets([dataset_xnli["validation"], dataset_xnli["validation"]])})
    dic_datasets_test.update({**dataset_xnli["test_separated"].remove_columns(["language"])})
    #dic_datasets_test_concat.update({"xnli": dataset_xnli["test"]})
    del(xnli_dict, dataset_xnli_raw, xnli_dict_test_separated)


In [None]:
## add task_name column to NLI datasets for downstream inspection possibilities
for key_dataset_name, value_dataset in dic_datasets_train.items():
    dic_datasets_train[key_dataset_name] = dic_datasets_train[key_dataset_name].add_column(name="task_name", column=[key_dataset_name] * len(value_dataset))
for key_dataset_name, value_dataset in dic_datasets_test.items():
    dic_datasets_test[key_dataset_name] = dic_datasets_test[key_dataset_name].add_column(name="task_name", column=[key_dataset_name] * len(value_dataset))


In [None]:
### NLI datasets concatenation

## train set
# print train sets
print("TRAIN SETS: ")
for dataset_name in NLI_DATASETS_TO_USE:
    print(dataset_name, " number of train examples: ", len(dic_datasets_train[dataset_name]))
# concatenate all selected trainsets
dataset_train_nli = concatenate_datasets([value_dataset for key_dataset_name, value_dataset in dic_datasets_train.items()])
dataset_train_nli = dataset_train_nli.shuffle(seed=SEED_GLOBAL)

## test set
# disaggregated test set
dataset_test_disaggregated_nli = DatasetDict(**dic_datasets_test)

# aggregated NLI test set for single number for choosing best model over all NLI datasets
# increase weight of selected test-sets
dataset_test_concat_nli = []
for key_dataset_name, value_dataset in dic_datasets_test.items():
    if "anli" in key_dataset_name:
        # increase weight of anli in overall testset, because small compared to others
        value_dataset = concatenate_datasets([value_dataset, value_dataset, value_dataset])
        dataset_test_concat_nli.append(value_dataset)
    elif "fever" in key_dataset_name:
        # downsample fever in aggregated testset
        value_dataset = value_dataset.select(random.sample(range(0, len(value_dataset)), 10_000))
        dataset_test_concat_nli.append(value_dataset)
    elif key_dataset_name == "xnli":
        raise NotImplementedError
    else:
        dataset_test_concat_nli.append(value_dataset)

dataset_test_concat_nli = concatenate_datasets(dataset_test_concat_nli)


# harmonise label column name
dataset_train_nli = dataset_train_nli.rename_columns({"label": "labels", "premise": "text"})
dataset_test_concat_nli = dataset_test_concat_nli.rename_columns({"label": "labels", "premise": "text"})
dataset_test_disaggregated_nli = dataset_test_disaggregated_nli.rename_columns({"label": "labels", "premise": "text"})
print("Full train set: ", len(dataset_train_nli))
print("\nAll available test sets for disaggregated testing: ", dataset_test_disaggregated_nli)
print("\nAll available test data for aggregated testing: ", dataset_test_concat_nli)


TRAIN SETS: 
mnli  number of train examples:  392702
anli  number of train examples:  162865
fever  number of train examples:  196805
wanli  number of train examples:  102885
ling  number of train examples:  29985
Full train set:  885242

All available test sets for disaggregated testing:  DatasetDict({
    mnli_m: Dataset({
        features: ['text', 'hypothesis', 'labels', 'task_name'],
        num_rows: 9815
    })
    mnli_mm: Dataset({
        features: ['text', 'hypothesis', 'labels', 'task_name'],
        num_rows: 9832
    })
    fever: Dataset({
        features: ['text', 'hypothesis', 'labels', 'task_name'],
        num_rows: 19652
    })
    anli_r1: Dataset({
        features: ['text', 'hypothesis', 'labels', 'task_name'],
        num_rows: 1000
    })
    anli_r2: Dataset({
        features: ['text', 'hypothesis', 'labels', 'task_name'],
        num_rows: 1000
    })
    anli_r3: Dataset({
        features: ['text', 'hypothesis', 'labels', 'task_name'],
        num_rows: 1

In [None]:
## make nli datasets binary

def binarize_labels(example):
    if example["labels"] >= 1:
        labels = 1
    else:
        labels = 0
    return {"labels": labels}

dataset_train_nli = dataset_train_nli.map(binarize_labels)
dataset_test_concat_nli = dataset_test_concat_nli.map(binarize_labels)
dataset_test_disaggregated_nli = dataset_test_disaggregated_nli.map(binarize_labels)

# add new binarized feature names
new_features = dataset_train_nli.features.copy()
new_features['labels'] = ClassLabel(names=["true", "not_true"])

dataset_train_nli = dataset_train_nli.cast(new_features)
dataset_test_concat_nli = dataset_test_concat_nli.cast(new_features)
dataset_test_disaggregated_nli = dataset_test_disaggregated_nli.cast(new_features)

# add label_text column to enable upload of DatasetDict with both NLI and not-NLI data (need same columns)
label_text_map = {0: "true", 1: "not_true"}
for key_taskname in dataset_test_disaggregated_nli:
    dataset_test_disaggregated_nli[key_taskname] = dataset_test_disaggregated_nli[key_taskname].map(lambda x: {"label_text": label_text_map[x["labels"]]})

print(dataset_train_nli.features["labels"].names)

Map:   0%|          | 0/59140 [00:00<?, ? examples/s]

Map:   0%|          | 0/19652 [00:00<?, ? examples/s]

Map:   0%|          | 0/4893 [00:00<?, ? examples/s]

Casting the dataset:   0%|          | 0/59140 [00:00<?, ? examples/s]

Casting the dataset:   0%|          | 0/19652 [00:00<?, ? examples/s]

Casting the dataset:   0%|          | 0/4893 [00:00<?, ? examples/s]

Map:   0%|          | 0/19652 [00:00<?, ? examples/s]

Map:   0%|          | 0/4893 [00:00<?, ? examples/s]

['true', 'not_true']


### Combine NLI datasets with non-NLI datasets

In [None]:
# harmonise non-nli datasets
dataset_train_not_nli = Dataset.from_pandas(df_train_format_concat_samp.reset_index(drop=True))
dataset_train_not_nli = dataset_train_not_nli.remove_columns(["label_quality"])

dataset_test_not_nli = DatasetDict({
    key_task: Dataset.from_pandas(value_df)
    for key_task, value_df in df_test_format_dic.items()
})

new_features = dataset_train_not_nli.features.copy()
new_features['labels'] = ClassLabel(names=["true", "not_true"])
dataset_train_not_nli = dataset_train_not_nli.cast(new_features)
dataset_test_not_nli = dataset_test_not_nli.cast(new_features)

print(dataset_train_not_nli)
print(dataset_test_not_nli)

Casting the dataset:   0%|          | 0/401499 [00:00<?, ? examples/s]

Casting the dataset:   0%|          | 0/5934 [00:00<?, ? examples/s]

Casting the dataset:   0%|          | 0/2132 [00:00<?, ? examples/s]

Casting the dataset:   0%|          | 0/20000 [00:00<?, ? examples/s]

Casting the dataset:   0%|          | 0/20000 [00:00<?, ? examples/s]

Casting the dataset:   0%|          | 0/20000 [00:00<?, ? examples/s]

Casting the dataset:   0%|          | 0/2922 [00:00<?, ? examples/s]

Casting the dataset:   0%|          | 0/175466 [00:00<?, ? examples/s]

Casting the dataset:   0%|          | 0/221760 [00:00<?, ? examples/s]

Casting the dataset:   0%|          | 0/12000 [00:00<?, ? examples/s]

Casting the dataset:   0%|          | 0/22036 [00:00<?, ? examples/s]

Casting the dataset:   0%|          | 0/81344 [00:00<?, ? examples/s]

Casting the dataset:   0%|          | 0/30400 [00:00<?, ? examples/s]

Casting the dataset:   0%|          | 0/500000 [00:00<?, ? examples/s]

Casting the dataset:   0%|          | 0/8808 [00:00<?, ? examples/s]

Casting the dataset:   0%|          | 0/7676 [00:00<?, ? examples/s]

Casting the dataset:   0%|          | 0/7296 [00:00<?, ? examples/s]

Casting the dataset:   0%|          | 0/2070 [00:00<?, ? examples/s]

Casting the dataset:   0%|          | 0/8000 [00:00<?, ? examples/s]

Casting the dataset:   0%|          | 0/2586 [00:00<?, ? examples/s]

Casting the dataset:   0%|          | 0/17910 [00:00<?, ? examples/s]

Casting the dataset:   0%|          | 0/2070 [00:00<?, ? examples/s]

Casting the dataset:   0%|          | 0/20000 [00:00<?, ? examples/s]

Casting the dataset:   0%|          | 0/17382 [00:00<?, ? examples/s]

Casting the dataset:   0%|          | 0/11424 [00:00<?, ? examples/s]

Casting the dataset:   0%|          | 0/10422 [00:00<?, ? examples/s]

Casting the dataset:   0%|          | 0/16854 [00:00<?, ? examples/s]

Dataset({
    features: ['text', 'label_text', 'labels', 'hypothesis', 'task_name'],
    num_rows: 401499
})
DatasetDict({
    wellformedquery: Dataset({
        features: ['text', 'label_text', 'labels', 'hypothesis', 'task_name'],
        num_rows: 5934
    })
    rottentomatoes: Dataset({
        features: ['text', 'label_text', 'labels', 'hypothesis', 'task_name'],
        num_rows: 2132
    })
    amazonpolarity: Dataset({
        features: ['text', 'label_text', 'labels', 'hypothesis', 'task_name'],
        num_rows: 20000
    })
    imdb: Dataset({
        features: ['text', 'label_text', 'labels', 'hypothesis', 'task_name'],
        num_rows: 20000
    })
    yelpreviews: Dataset({
        features: ['text', 'label_text', 'labels', 'hypothesis', 'task_name'],
        num_rows: 20000
    })
    hatexplain: Dataset({
        features: ['text', 'label_text', 'labels', 'hypothesis', 'task_name'],
        num_rows: 2922
    })
    massive: Dataset({
        features: ['text', 'label

In [None]:
# final harmonized datasets

dataset_train = concatenate_datasets([dataset_train_nli, dataset_train_not_nli])
dataset_train = dataset_train.shuffle(seed=SEED_GLOBAL)

dataset_test_concat_nli = dataset_test_concat_nli.shuffle(seed=SEED_GLOBAL)

dataset_test_disaggregated = DatasetDict({**dataset_test_disaggregated_nli, **dataset_test_not_nli})


### Save final train and test sets to disk

In [None]:
store_data = False

if store_data:
    # save to disk
    dataset_train.save_to_disk("./datasets_final/dataset_train")
    dataset_test_concat_nli.save_to_disk("./datasets_final/dataset_test_concat_nli")
    dataset_test_disaggregated.save_to_disk("./datasets_final/dataset_test_disaggregated")

    # push to hub
    dataset_train.push_to_hub(repo_id="dataset_train_nli", private=True, token=config.HF_ACCESS_TOKEN)
    dataset_test_concat_nli.push_to_hub(repo_id="dataset_test_concat_nli", private=True, token=config.HF_ACCESS_TOKEN)
    dataset_test_disaggregated.push_to_hub(repo_id="dataset_test_disaggregated_nli", private=True, token=config.HF_ACCESS_TOKEN)


Pushing dataset shards to the dataset hub:   0%|          | 0/1 [00:00<?, ?it/s]

Creating parquet from Arrow format:   0%|          | 0/10 [00:00<?, ?ba/s]

Pushing dataset shards to the dataset hub:   0%|          | 0/1 [00:00<?, ?it/s]

Creating parquet from Arrow format:   0%|          | 0/10 [00:00<?, ?ba/s]

Pushing dataset shards to the dataset hub:   0%|          | 0/1 [00:00<?, ?it/s]

Creating parquet from Arrow format:   0%|          | 0/20 [00:00<?, ?ba/s]

Pushing dataset shards to the dataset hub:   0%|          | 0/1 [00:00<?, ?it/s]

Creating parquet from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

Pushing dataset shards to the dataset hub:   0%|          | 0/1 [00:00<?, ?it/s]

Creating parquet from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

Pushing dataset shards to the dataset hub:   0%|          | 0/1 [00:00<?, ?it/s]

Creating parquet from Arrow format:   0%|          | 0/2 [00:00<?, ?ba/s]

Pushing dataset shards to the dataset hub:   0%|          | 0/1 [00:00<?, ?it/s]

Creating parquet from Arrow format:   0%|          | 0/5 [00:00<?, ?ba/s]

Pushing dataset shards to the dataset hub:   0%|          | 0/1 [00:00<?, ?it/s]

Creating parquet from Arrow format:   0%|          | 0/5 [00:00<?, ?ba/s]

Pushing dataset shards to the dataset hub:   0%|          | 0/1 [00:00<?, ?it/s]

Creating parquet from Arrow format:   0%|          | 0/6 [00:00<?, ?ba/s]

Pushing dataset shards to the dataset hub:   0%|          | 0/1 [00:00<?, ?it/s]

Creating parquet from Arrow format:   0%|          | 0/3 [00:00<?, ?ba/s]

Pushing dataset shards to the dataset hub:   0%|          | 0/1 [00:00<?, ?it/s]

Creating parquet from Arrow format:   0%|          | 0/20 [00:00<?, ?ba/s]

Pushing dataset shards to the dataset hub:   0%|          | 0/1 [00:00<?, ?it/s]

Creating parquet from Arrow format:   0%|          | 0/20 [00:00<?, ?ba/s]

Pushing dataset shards to the dataset hub:   0%|          | 0/1 [00:00<?, ?it/s]

Creating parquet from Arrow format:   0%|          | 0/20 [00:00<?, ?ba/s]

Pushing dataset shards to the dataset hub:   0%|          | 0/1 [00:00<?, ?it/s]

Creating parquet from Arrow format:   0%|          | 0/3 [00:00<?, ?ba/s]

Pushing dataset shards to the dataset hub:   0%|          | 0/1 [00:00<?, ?it/s]

Creating parquet from Arrow format:   0%|          | 0/176 [00:00<?, ?ba/s]

Pushing dataset shards to the dataset hub:   0%|          | 0/1 [00:00<?, ?it/s]

Creating parquet from Arrow format:   0%|          | 0/222 [00:00<?, ?ba/s]

Pushing dataset shards to the dataset hub:   0%|          | 0/1 [00:00<?, ?it/s]

Creating parquet from Arrow format:   0%|          | 0/12 [00:00<?, ?ba/s]

Pushing dataset shards to the dataset hub:   0%|          | 0/1 [00:00<?, ?it/s]

Creating parquet from Arrow format:   0%|          | 0/23 [00:00<?, ?ba/s]

Pushing dataset shards to the dataset hub:   0%|          | 0/1 [00:00<?, ?it/s]

Creating parquet from Arrow format:   0%|          | 0/82 [00:00<?, ?ba/s]

Pushing dataset shards to the dataset hub:   0%|          | 0/1 [00:00<?, ?it/s]

Creating parquet from Arrow format:   0%|          | 0/31 [00:00<?, ?ba/s]

Pushing dataset shards to the dataset hub:   0%|          | 0/1 [00:00<?, ?it/s]

Creating parquet from Arrow format:   0%|          | 0/500 [00:00<?, ?ba/s]

Pushing dataset shards to the dataset hub:   0%|          | 0/1 [00:00<?, ?it/s]

Creating parquet from Arrow format:   0%|          | 0/9 [00:00<?, ?ba/s]

Pushing dataset shards to the dataset hub:   0%|          | 0/1 [00:00<?, ?it/s]

Creating parquet from Arrow format:   0%|          | 0/8 [00:00<?, ?ba/s]

Pushing dataset shards to the dataset hub:   0%|          | 0/1 [00:00<?, ?it/s]

Creating parquet from Arrow format:   0%|          | 0/8 [00:00<?, ?ba/s]

Pushing dataset shards to the dataset hub:   0%|          | 0/1 [00:00<?, ?it/s]

Creating parquet from Arrow format:   0%|          | 0/3 [00:00<?, ?ba/s]

Pushing dataset shards to the dataset hub:   0%|          | 0/1 [00:00<?, ?it/s]

Creating parquet from Arrow format:   0%|          | 0/8 [00:00<?, ?ba/s]

Pushing dataset shards to the dataset hub:   0%|          | 0/1 [00:00<?, ?it/s]

Creating parquet from Arrow format:   0%|          | 0/3 [00:00<?, ?ba/s]

Pushing dataset shards to the dataset hub:   0%|          | 0/1 [00:00<?, ?it/s]

Creating parquet from Arrow format:   0%|          | 0/18 [00:00<?, ?ba/s]

Pushing dataset shards to the dataset hub:   0%|          | 0/1 [00:00<?, ?it/s]

Creating parquet from Arrow format:   0%|          | 0/3 [00:00<?, ?ba/s]

Pushing dataset shards to the dataset hub:   0%|          | 0/1 [00:00<?, ?it/s]

Creating parquet from Arrow format:   0%|          | 0/20 [00:00<?, ?ba/s]

Pushing dataset shards to the dataset hub:   0%|          | 0/1 [00:00<?, ?it/s]

Creating parquet from Arrow format:   0%|          | 0/18 [00:00<?, ?ba/s]

Pushing dataset shards to the dataset hub:   0%|          | 0/1 [00:00<?, ?it/s]

Creating parquet from Arrow format:   0%|          | 0/12 [00:00<?, ?ba/s]

Pushing dataset shards to the dataset hub:   0%|          | 0/1 [00:00<?, ?it/s]

Creating parquet from Arrow format:   0%|          | 0/11 [00:00<?, ?ba/s]

Pushing dataset shards to the dataset hub:   0%|          | 0/1 [00:00<?, ?it/s]

Creating parquet from Arrow format:   0%|          | 0/17 [00:00<?, ?ba/s]

### (Hypothesis formulation scripts)

In [None]:
assert 1 == 2, "Block following code from executing when running entire notebook"

In [None]:
# wikitoxic dataset
hypo_label_dic_wiki_toxic = {}
for task in ["toxic_aggreg", "obscene", "threat", "insult", "identity_hate"]:
    # format in UD/NLI format
    if task == "toxic_aggreg":
        hypo_label_dic = {
            task: "This wikipedia comment contains toxic language.",
            "not_" + task: "This wikipedia comment does not contain toxic language.",
        }
    elif task == "obscene":
        hypo_label_dic = {
            task: "This wikipedia comment contains obscene language.",
            "not_" + task: "This wikipedia comment does not contain obscene language.",
        }
    elif task == "threat":
        hypo_label_dic = {
            task: "This wikipedia comment contains a threat.",
            "not_" + task: "This wikipedia comment does not contain a threat.",
        }
    elif task == "insult":
        hypo_label_dic = {
            task: "This wikipedia comment contains an insult.",
            "not_" + task: "This wikipedia comment does not contain an insult.",
        }
    elif task == "identity_hate":
        hypo_label_dic = {
            task: "This wikipedia comment contains identity hate.",
            "not_" + task: "This wikipedia comment does not contain identity hate.",
        }
    else:
        raise NotImplementedError
    hypo_label_dic_wiki_toxic.update({task: hypo_label_dic})

hypo_label_dic_wiki_toxic

In [None]:
# empathetic dialogue
# ! Challenge with dataset: is effectively multi-label. NLI augmentation can lead to errors
# strongly cleaning and downsampling to reduce risk

from datasets import load_dataset

dataset_empathetic = load_dataset("empathetic_dialogues")
print(dataset_empathetic)

df_empathetic_train = pd.concat([
    dataset_empathetic["train"].to_pandas(),
    dataset_empathetic["validation"].to_pandas()
])

def merge_dialogues(example):
    dialogue = "Context: " + example["prompt"].iloc[0]
    first_speaker = True
    for utterance in example["utterance"].to_list():
        if first_speaker:
            dialogue += f"\nSpeaker 1: {utterance}"
        else:
            dialogue += f"\nSpeaker 2: {utterance}"
        first_speaker = not first_speaker

    dialogue = dialogue.replace("_comma_", ",")

    return pd.Series({"text": dialogue, "label_text": example["context"].iloc[0]})


df_empathetic_train = df_empathetic_train.groupby(by="conv_id", as_index=False, group_keys=False).apply(
    lambda x: merge_dialogues(x)
)
df_empathetic_train = df_empathetic_train[~df_empathetic_train.text.duplicated()]
df_empathetic_train = df_empathetic_train.drop(columns=["conv_id"])
df_empathetic_train = df_empathetic_train.reset_index(drop=True)

df_empathetic_train["label_standard"] = df_empathetic_train.label_text.factorize(sort=True)[0]

print("Label distribution in dataset:\n", df_empathetic_train.label_text.value_counts())


## create hypotheses
print({label_text: i for i, label_text in enumerate(np.sort(df_empathetic_train.label_text.unique()))})

hypothesis_template_empathetic = "The main emotion of this dialogue is: "

hypo_label_dic_empathetic = {
    label_text: hypothesis_template_empathetic + label_text for label_text in np.sort(df_empathetic_train.label_text.unique())
}
hypo_label_dic_empathetic = {k: hypo_label_dic_empathetic[k] for k in sorted(hypo_label_dic_empathetic.keys())}
print(hypo_label_dic_empathetic)



In [None]:
dataset_yahoo_topics = load_dataset("yahoo_answers_topics")

label_mapping_yahoo_topics = {
    idx: name for idx, name in enumerate(dataset_yahoo_topics["train"].features["topic"].names)
}

df_yahoo_topics_train = dataset_yahoo_topics["train"].to_pandas()

df_yahoo_topics_train = df_yahoo_topics_train.rename(columns={"topic": "label"})
df_yahoo_topics_train["text"] = "Question: " + df_yahoo_topics_train["question_title"] + " " + df_yahoo_topics_train["question_content"] + "\n\n" + "Answer: " + df_yahoo_topics_train["best_answer"]

df_yahoo_topics_train["label_text"] = df_yahoo_topics_train.label.map(label_mapping_yahoo_topics)
df_yahoo_topics_train = df_yahoo_topics_train[~df_yahoo_topics_train.text.duplicated()]
df_yahoo_topics_train["label_standard"] = df_yahoo_topics_train.label_text.factorize(sort=True)[0]

# too large, downsample
n_data_per_label = 10_000
print("Dataset length before downsampling: ", len(df_yahoo_topics_train))
df_yahoo_topics_train = df_yahoo_topics_train.groupby("label_text", as_index=False, group_keys=False).apply(
    lambda x: x.sample(min(n_data_per_label, len(x)), random_state=SEED_GLOBAL)
)
print("Dataset length after downsampling: ", len(df_yahoo_topics_train))


df_yahoo_topics_train = df_yahoo_topics_train[["text", "label_text", "label_standard"]].reset_index(drop=True)

print("Label distribution in dataset:\n", df_yahoo_topics_train.label_text.value_counts())
#display(df_yahoo_topics_train)


# create hypotheses
#print({label_text: i for i, label_text in enumerate(np.sort(dataset_yahoo_topics.label_text.unique()))})

hypothesis_template_yahoo = "This question from the Yahoo Q&A forum is categorized in the topic: "

hypo_label_dic_yahoo = {
    label_text: hypothesis_template_yahoo + label_text for label_text in np.sort(df_yahoo_topics_train.label_text.unique())
}
hypo_label_dic_yahoo = {k: hypo_label_dic_yahoo[k] for k in sorted(hypo_label_dic_yahoo.keys())}
print(hypo_label_dic_yahoo)



In [None]:
### Anthropic helpful-harmful
# hypotheses were already added at dataset standardisation pphase

from datasets import load_dataset, concatenate_datasets, DatasetDict

# dataset: https://huggingface.co/datasets/Anthropic/hh-rlhf
# paper: https://arxiv.org/pdf/2204.05862.pdf
# subsets described on p. 12
dataset_anthropic_rlhf_helpful_base = load_dataset("Anthropic/hh-rlhf", data_dir="helpful-base")
dataset_anthropic_rlhf_helpful_rs = load_dataset("Anthropic/hh-rlhf", data_dir="helpful-rejection-sampled")
dataset_anthropic_rlhf_helpful_online = load_dataset("Anthropic/hh-rlhf", data_dir="helpful-online")
dataset_anthropic_rlhf_harmless = load_dataset("Anthropic/hh-rlhf", data_dir="harmless-base")
# not meant for preference modelling, different format, see: https://huggingface.co/datasets/Anthropic/hh-rlhf
#dataset_anthropic_rlhf_redteam = load_dataset("Anthropic/hh-rlhf", data_dir="red-team-attempts")

dataset_anthropic_rlhf_helpful = DatasetDict({
    "train": concatenate_datasets([
        dataset_anthropic_rlhf_helpful_base["train"],
        dataset_anthropic_rlhf_helpful_rs["train"],
        dataset_anthropic_rlhf_helpful_online["train"]
    ]),
    "test": concatenate_datasets([
        dataset_anthropic_rlhf_helpful_base["test"],
        dataset_anthropic_rlhf_helpful_rs["test"],
        dataset_anthropic_rlhf_helpful_online["test"]
    ])
})


## specific dataset characteristics:
# the same conversation can appear multiple times, only that new X conversation turns have been added
#df_testing = dataset_anthropic_rlhf_redteam["train"].to_pandas()[:100]
#df_testing = df_testing[df_testing["chosen"].str.contains("Human: How can I store food if I don't hav")]

print("dataset helpful: ", dataset_anthropic_rlhf_helpful)
print("dataset harmless: ", dataset_anthropic_rlhf_harmless)


In [None]:
## tailored prompt
import random

def prompt_format_anthropic_rlhf(example, helpful_or_harmless=None, make_prompt_true=True):
    np.random.seed(None)
    chosen_first = random.choice([True, False])
    positive_prompt = random.choice([True, False])

    def prompt_format_correct(example, helpful_or_harmless=helpful_or_harmless, chosen_first=None, positive_prompt=None):
        if chosen_first and positive_prompt:
            #prompt_true = f"""First conversation:\n{example["chosen"]}\n\nSecond conversation:\n{example["rejected"]}
            #\nThe assistant is more {helpful_or_harmless} in the first conversation.
            #"""
            text = f"""First conversation:\n{example["chosen"]}\n\nSecond conversation:\n{example["rejected"]}"""
            hypothesis = f"The assistant is more {helpful_or_harmless} in the first conversation."
        elif chosen_first and not positive_prompt:
            #prompt_true = f"""First conversation:\n{example["chosen"]}\n\nSecond conversation:\n{example["rejected"]}
            #\nThe assistant is less {helpful_or_harmless} in the second conversation.
            #"""
            text = f"""First conversation:\n{example["chosen"]}\n\nSecond conversation:\n{example["rejected"]}"""
            hypothesis = f"The assistant is less {helpful_or_harmless} in the second conversation."
        elif not chosen_first and positive_prompt:
            #prompt_true = f"""First conversation:\n{example["rejected"]}\n\nSecond conversation:\n{example["chosen"]}
            #\nThe assistant is more {helpful_or_harmless} in the second conversation.
            #"""
            text = f"""First conversation:\n{example["rejected"]}\n\nSecond conversation:\n{example["chosen"]}"""
            hypothesis = f"The assistant is more {helpful_or_harmless} in the second conversation."
        elif not chosen_first and not positive_prompt:
            #prompt_true = f"""First conversation:\n{example["rejected"]}\n\nSecond conversation:\n{example["chosen"]}
            #\nThe assistant is less {helpful_or_harmless} in the first conversation.
            #"""
            text = f"""First conversation:\n{example["rejected"]}\n\nSecond conversation:\n{example["chosen"]}"""
            hypothesis = f"The assistant is less {helpful_or_harmless} in the first conversation."
        return text, hypothesis

    def prompt_format_wrong(example, helpful_or_harmless=helpful_or_harmless, chosen_first=None, positive_prompt=None):
        #simply flipped sequence of conversations
        if chosen_first and positive_prompt:
            #prompt_true = f"""First conversation:\n{example["rejected"]}\n\nSecond conversation:\n{example["chosen"]}
            #\nThe assistant is more {helpful_or_harmless} in the first conversation.
            #"""
            text = f"""First conversation:\n{example["rejected"]}\n\nSecond conversation:\n{example["chosen"]}"""
            hypothesis = f"The assistant is more {helpful_or_harmless} in the first conversation."
        elif chosen_first and not positive_prompt:
            #prompt_true = f"""First conversation:\n{example["rejected"]}\n\nSecond conversation:\n{example["chosen"]}
            #\nThe assistant is less {helpful_or_harmless} in the second conversation.
            #"""
            text = f"""First conversation:\n{example["rejected"]}\n\nSecond conversation:\n{example["chosen"]}"""
            hypothesis = f"The assistant is less {helpful_or_harmless} in the second conversation."
        elif not chosen_first and positive_prompt:
            #prompt_true = f"""First conversation:\n{example["chosen"]}\n\nSecond conversation:\n{example["rejected"]}
            #\nThe assistant is more {helpful_or_harmless} in the second conversation.
            #"""
            text = f"""First conversation:\n{example["chosen"]}\n\nSecond conversation:\n{example["rejected"]}"""
            hypothesis = f"The assistant is more {helpful_or_harmless} in the second conversation."
        elif not chosen_first and not positive_prompt:
            #prompt_true = f"""First conversation:\n{example["chosen"]}\n\nSecond conversation:\n{example["rejected"]}
            #\nThe assistant is less {helpful_or_harmless} in the first conversation.
            #"""
            text = f"""First conversation:\n{example["chosen"]}\n\nSecond conversation:\n{example["rejected"]}"""
            hypothesis = f"The assistant is less {helpful_or_harmless} in the first conversation."
        return text, hypothesis

    # create prompt
    if make_prompt_true:
        text, hypothesis = prompt_format_correct(example, helpful_or_harmless=helpful_or_harmless, chosen_first=chosen_first, positive_prompt=positive_prompt)
    elif not make_prompt_true:
        text, hypothesis = prompt_format_wrong(example, helpful_or_harmless=helpful_or_harmless, chosen_first=chosen_first, positive_prompt=positive_prompt)

    label = 0 if make_prompt_true else 1

    return {"text": text, "labels": label, "hypothesis": hypothesis}


# for each dataset (helpful & harmless) create two rows per text: one with correct label/hypothesis and one with false
# helpful dataset
dataset_anthropic_rlhf_helpful_true = dataset_anthropic_rlhf_helpful.map(
    lambda example: prompt_format_anthropic_rlhf(
        example, helpful_or_harmless="helpful and honest",
        make_prompt_true=True
    )
)
dataset_anthropic_rlhf_helpful_false = dataset_anthropic_rlhf_helpful.map(
    lambda example: prompt_format_anthropic_rlhf(
        example, helpful_or_harmless="helpful and honest",
        make_prompt_true=False
    )
)

dataset_anthropic_rlhf_helpful_train = concatenate_datasets([dataset_anthropic_rlhf_helpful_true["train"], dataset_anthropic_rlhf_helpful_false["train"]])
dataset_anthropic_rlhf_helpful_test = concatenate_datasets([dataset_anthropic_rlhf_helpful_true["test"], dataset_anthropic_rlhf_helpful_false["test"]])
dataset_anthropic_rlhf_helpful = DatasetDict({"train": dataset_anthropic_rlhf_helpful_train, "test": dataset_anthropic_rlhf_helpful_test})

#dataset_anthropic_rlhf_helpful = dataset_anthropic_rlhf_helpful.remove_columns(["chosen", "rejected"])
dataset_anthropic_rlhf_helpful = dataset_anthropic_rlhf_helpful.shuffle(seed=SEED_GLOBAL)

# harmless dataset
dataset_anthropic_rlhf_harmless_true = dataset_anthropic_rlhf_harmless.map(
    lambda example: prompt_format_anthropic_rlhf(
        example, helpful_or_harmless="harmless",
        make_prompt_true=True
    )
)
dataset_anthropic_rlhf_harmless_false = dataset_anthropic_rlhf_harmless.map(
    lambda example: prompt_format_anthropic_rlhf(
        example, helpful_or_harmless="harmless",
        make_prompt_true=False
    )
)
dataset_anthropic_rlhf_harmless_train = concatenate_datasets([dataset_anthropic_rlhf_harmless_true["train"], dataset_anthropic_rlhf_harmless_false["train"]])
dataset_anthropic_rlhf_harmless_test = concatenate_datasets([dataset_anthropic_rlhf_harmless_true["test"], dataset_anthropic_rlhf_harmless_false["test"]])
dataset_anthropic_rlhf_harmless = DatasetDict({"train": dataset_anthropic_rlhf_harmless_train, "test": dataset_anthropic_rlhf_harmless_test})

#dataset_anthropic_rlhf_harmless = dataset_anthropic_rlhf_harmless.remove_columns(["chosen", "rejected"])
dataset_anthropic_rlhf_harmless = dataset_anthropic_rlhf_harmless.shuffle(seed=SEED_GLOBAL)

# inspect results
print("Label distribution harmless\n", dataset_anthropic_rlhf_harmless["train"].to_pandas().labels.value_counts(), "\n")
print("Label distribution helpful\n", dataset_anthropic_rlhf_helpful["train"].to_pandas().labels.value_counts(), "\n")

display(dataset_anthropic_rlhf_harmless["train"].to_pandas().head(100))
