In [33]:
from datasets import Dataset, DatasetDict
from tqdm import tqdm
import pandas as pd
import ast
import os

In [34]:
rewrites_fir_path = "../cached_rewrites"
cached_rewrite_fils = os.listdir(rewrites_fir_path)
cached_rewrite_fils

['boss_sentiment_stabilityai_StableBeluga-13B_temp=0.0.csv',
 'boss_sentiment_stabilityai_StableBeluga-7b_temp=0.0.csv',
 'boss_toxicity_aug_back-translate.csv',
 'boss_toxicity_stabilityai_StableBeluga-7b_temp=0.0.csv',
 'boss_sentiment_aug_back-translate.csv',
 'boss_toxicity_aug_substitute.csv',
 'boss_sentiment_aug_substitute.csv',
 'ag_news_twitter_stabilityai_StableBeluga-7b_temp=0.0.csv',
 'ag_news_twitter_aug_insert.csv',
 'ag_news_twitter_aug_substitute.csv',
 'boss_toxicity_aug_insert.csv',
 'ag_news_twitter_aug_back-translate.csv',
 'boss_sentiment_aug_insert.csv']

In [35]:
def parse_rewrites(rewrites_string):
    if rewrites_string == "[]" or rewrites_string == "['']" or rewrites_string is None:
        return None

    try:
        rewrites = ast.literal_eval(rewrites_string)[:4]
        return rewrites
    except:
        return None

all_rewrites = DatasetDict()
for rewrite_file in tqdm(cached_rewrite_fils):
    rewrite_file_path = os.path.join(rewrites_fir_path, rewrite_file)
    rewrite_data = pd.read_csv(rewrite_file_path, on_bad_lines="warn")
    rewrite_data["rewrites"] = rewrite_data["rewrites"].apply(parse_rewrites)

    with_na_length = len(rewrite_data)
    rewrite_data = rewrite_data.dropna()
    without_na_length = len(rewrite_data)
    print(f"Removed {with_na_length - without_na_length} rows with NA rewrites")

    if "__index_level_0__" in rewrite_data.columns:
        rewrite_data = rewrite_data.drop(columns=["__index_level_0__"])

    split_name = rewrite_file.replace(".csv", "").replace(".", "dot").replace("-", "_").replace("=", "equals")
    all_rewrites[split_name] = Dataset.from_pandas(rewrite_data)

all_rewrites

  8%|▊         | 1/13 [00:00<00:01,  9.65it/s]

Removed 0 rows with NA rewrites


 15%|█▌        | 2/13 [00:03<00:23,  2.17s/it]

Removed 19 rows with NA rewrites


 23%|██▎       | 3/13 [00:04<00:14,  1.49s/it]

Removed 0 rows with NA rewrites
Removed 34 rows with NA rewrites


 38%|███▊      | 5/13 [00:13<00:24,  3.02s/it]

Removed 0 rows with NA rewrites


 46%|████▌     | 6/13 [00:14<00:16,  2.36s/it]

Removed 0 rows with NA rewrites


 54%|█████▍    | 7/13 [00:16<00:13,  2.33s/it]

Removed 0 rows with NA rewrites


 62%|██████▏   | 8/13 [00:18<00:10,  2.12s/it]

Removed 0 rows with NA rewrites


 69%|██████▉   | 9/13 [00:18<00:06,  1.66s/it]

Removed 0 rows with NA rewrites


 77%|███████▋  | 10/13 [00:19<00:03,  1.33s/it]

Removed 0 rows with NA rewrites
Removed 0 rows with NA rewrites


 92%|█████████▏| 12/13 [00:25<00:01,  1.98s/it]

Removed 0 rows with NA rewrites


100%|██████████| 13/13 [00:27<00:00,  2.14s/it]

Removed 0 rows with NA rewrites





DatasetDict({
    boss_sentiment_stabilityai_StableBeluga_13B_tempequals0dot0: Dataset({
        features: ['prompt_hash', 'prompt', 'rewrites'],
        num_rows: 2132
    })
    boss_sentiment_stabilityai_StableBeluga_7b_tempequals0dot0: Dataset({
        features: ['prompt_hash', 'prompt', 'rewrites', '__index_level_0__'],
        num_rows: 90974
    })
    boss_toxicity_aug_back_translate: Dataset({
        features: ['prompt_hash', 'prompt', 'rewrites'],
        num_rows: 23299
    })
    boss_toxicity_stabilityai_StableBeluga_7b_tempequals0dot0: Dataset({
        features: ['prompt_hash', 'prompt', 'rewrites', '__index_level_0__'],
        num_rows: 122180
    })
    boss_sentiment_aug_back_translate: Dataset({
        features: ['prompt_hash', 'prompt', 'rewrites'],
        num_rows: 61580
    })
    boss_toxicity_aug_substitute: Dataset({
        features: ['prompt_hash', 'prompt', 'rewrites'],
        num_rows: 34531
    })
    boss_sentiment_aug_substitute: Dataset({
        

In [36]:
dot_name = []

for split in all_rewrites:
    if "__index_level_0__" in all_rewrites[split].column_names:
        all_rewrites[split] = all_rewrites[split].remove_columns(["__index_level_0__"])

    if "." in split:
        dot_name.append(split)

all_rewrites

DatasetDict({
    boss_sentiment_stabilityai_StableBeluga_13B_tempequals0dot0: Dataset({
        features: ['prompt_hash', 'prompt', 'rewrites'],
        num_rows: 2132
    })
    boss_sentiment_stabilityai_StableBeluga_7b_tempequals0dot0: Dataset({
        features: ['prompt_hash', 'prompt', 'rewrites'],
        num_rows: 90974
    })
    boss_toxicity_aug_back_translate: Dataset({
        features: ['prompt_hash', 'prompt', 'rewrites'],
        num_rows: 23299
    })
    boss_toxicity_stabilityai_StableBeluga_7b_tempequals0dot0: Dataset({
        features: ['prompt_hash', 'prompt', 'rewrites'],
        num_rows: 122180
    })
    boss_sentiment_aug_back_translate: Dataset({
        features: ['prompt_hash', 'prompt', 'rewrites'],
        num_rows: 61580
    })
    boss_toxicity_aug_substitute: Dataset({
        features: ['prompt_hash', 'prompt', 'rewrites'],
        num_rows: 34531
    })
    boss_sentiment_aug_substitute: Dataset({
        features: ['prompt_hash', 'prompt', 'rewri

In [37]:
# display a datafram of all the feature datatypes
pd.DataFrame.from_dict({split: all_rewrites[split].features for split in all_rewrites}).T

Unnamed: 0,prompt_hash,prompt,rewrites
boss_sentiment_stabilityai_StableBeluga_13B_tempequals0dot0,"Value(dtype='string', id=None)","Value(dtype='string', id=None)","Sequence(feature=Value(dtype='string', id=None..."
boss_sentiment_stabilityai_StableBeluga_7b_tempequals0dot0,"Value(dtype='string', id=None)","Value(dtype='string', id=None)","Sequence(feature=Value(dtype='string', id=None..."
boss_toxicity_aug_back_translate,"Value(dtype='string', id=None)","Value(dtype='string', id=None)","Sequence(feature=Value(dtype='string', id=None..."
boss_toxicity_stabilityai_StableBeluga_7b_tempequals0dot0,"Value(dtype='string', id=None)","Value(dtype='string', id=None)","Sequence(feature=Value(dtype='string', id=None..."
boss_sentiment_aug_back_translate,"Value(dtype='string', id=None)","Value(dtype='string', id=None)","Sequence(feature=Value(dtype='string', id=None..."
boss_toxicity_aug_substitute,"Value(dtype='string', id=None)","Value(dtype='string', id=None)","Sequence(feature=Value(dtype='string', id=None..."
boss_sentiment_aug_substitute,"Value(dtype='string', id=None)","Value(dtype='string', id=None)","Sequence(feature=Value(dtype='string', id=None..."
ag_news_twitter_stabilityai_StableBeluga_7b_tempequals0dot0,"Value(dtype='string', id=None)","Value(dtype='string', id=None)","Sequence(feature=Value(dtype='string', id=None..."
ag_news_twitter_aug_insert,"Value(dtype='string', id=None)","Value(dtype='string', id=None)","Sequence(feature=Value(dtype='string', id=None..."
ag_news_twitter_aug_substitute,"Value(dtype='string', id=None)","Value(dtype='string', id=None)","Sequence(feature=Value(dtype='string', id=None..."


In [38]:
all_rewrites.push_to_hub("LLM-TTA-Cached-Rewrites")

Creating parquet from Arrow format: 100%|██████████| 3/3 [00:00<00:00, 124.71ba/s]
Pushing dataset shards to the dataset hub: 100%|██████████| 1/1 [00:00<00:00,  1.01it/s]
Creating parquet from Arrow format: 100%|██████████| 91/91 [00:00<00:00, 194.82ba/s]
Pushing dataset shards to the dataset hub: 100%|██████████| 1/1 [00:02<00:00,  2.70s/it]
Creating parquet from Arrow format: 100%|██████████| 24/24 [00:00<00:00, 313.43ba/s]
Pushing dataset shards to the dataset hub: 100%|██████████| 1/1 [00:00<00:00,  1.40it/s]
Creating parquet from Arrow format: 100%|██████████| 123/123 [00:00<00:00, 164.62ba/s]
Pushing dataset shards to the dataset hub: 100%|██████████| 1/1 [00:03<00:00,  3.53s/it]
Creating parquet from Arrow format: 100%|██████████| 62/62 [00:00<00:00, 242.28ba/s]
Pushing dataset shards to the dataset hub: 100%|██████████| 1/1 [00:01<00:00,  1.56s/it]
Creating parquet from Arrow format: 100%|██████████| 35/35 [00:00<00:00, 194.69ba/s]
Pushing dataset shards to the dataset hub: 10