In [None]:
from pathlib import Path
import argparse
from datasets import load_from_disk
from collections import defaultdict
import shutil
import math
from tqdm import tqdm
from joblib import Parallel, delayed
from transformers import AutoTokenizer

PILE_DOMAINS = ['ArXiv', 'BookCorpus2', 'Books3', 'DM Mathematics', 'Enron Emails', 'EuroParl', 'FreeLaw', 'Github', 'Gutenberg (PG-19)', 'HackerNews', 'NIH ExPorter', 'OpenSubtitles', 'OpenWebText2', 'PhilPapers', 'Pile-CC', 'PubMed Abstracts', 'PubMed Central', 'StackExchange', 'USPTO Backgrounds', 'Ubuntu IRC', 'Wikipedia (en)', 'YoutubeSubtitles']
SLIM_DOMAINS = ['RedPajamaCommonCrawl', 'RedPajamaC4', 'RedPajamaGithub', 'RedPajamaWikipedia', 'RedPajamaBook', 'RedPajamaArXiv', 'RedPajamaStackExchange']
preprocessed_dir = "/home/wth/My_codes/doremi/data/slim_preprocessed/preprocessed"
preprocessed_dir = Path(preprocessed_dir) / 'train'

In [None]:
tokenizer = AutoTokenizer.from_pretrained("/home/wth/My_codes/doremi/tokenizer")
nopack = False

def process_shard(shard_dir):
    curr_count = 0
    ds = load_from_disk(dataset_path=str(shard_dir))
    if nopack:
        # in the DoReMi paper, we first padded to the context length then counted
        # the number of chunks, and dynamically packed the examples
        # together (possibly even from different domains)
        num_tokens_in_curr_doc = 0
        chunk_size = 2048
        for ex in tqdm(ds):
            toks = ex['input_ids']
            sep_idxs = [i for i in range(len(toks)) if toks[i] == tokenizer.eos_token_id]
            if len(sep_idxs) > 0:
                prev_sep_idx = -1
                for sep_idx in sep_idxs:
                    num_tokens_in_curr_doc += sep_idx - prev_sep_idx - 1
                    prev_sep_idx = sep_idx
                    curr_count += math.ceil(num_tokens_in_curr_doc / chunk_size)
                    num_tokens_in_curr_doc = 0
                if prev_sep_idx != len(toks) - 1:
                    num_tokens_in_curr_doc += len(toks) - prev_sep_idx - 1
            else:
                num_tokens_in_curr_doc += len(toks)
        if num_tokens_in_curr_doc > 0:
            curr_count += math.ceil(num_tokens_in_curr_doc / chunk_size)
    else:
        curr_count = len(ds)

    return curr_count

domain_lens = defaultdict(int)
for domain_dir in preprocessed_dir.iterdir():
    print("Counting domain", domain_dir.name)
    counts = Parallel(n_jobs=30)(delayed(process_shard)(shard_dir) for shard_dir in domain_dir.iterdir())
    domain_lens[domain_dir.name] = sum(counts)
print(domain_lens)

In [None]:
domain_lens = defaultdict(int)
for domain_dir in preprocessed_dir.iterdir():
    print("Counting domain", domain_dir.name)
    counts = Parallel(n_jobs=30)(delayed(process_shard)(shard_dir) for shard_dir in domain_dir.iterdir())
    domain_lens[domain_dir.name] = sum(counts)
print(domain_lens)

In [None]:
nums = 0
for domain in domain_lens.keys():
    nums += domain_lens[domain]
print(nums)

In [None]:
slim_epochs = {
            "RedPajamaCommonCrawl": 1.0,
            "RedPajamaC4": 1.0,
            "RedPajamaGithub": 1.0,
            "RedPajamaWikipedia": 1.0,
            "RedPajamaBook": 1.0,
            "RedPajamaArXiv": 1.0,
            "RedPajamaStackExchange": 1.0
        }
domain_lens = {k: v * slim_epochs[k] for k,v in domain_lens.items()}
print(domain_lens)

In [None]:
# renormalize domain_lens
total_len = sum(domain_lens.values())
domain_lens = {k: v / total_len for k, v in domain_lens.items()}
print("Baseline domain weights:", domain_lens)

In [20]:
from datasets import load_from_disk
import jsonlines

jsonl_path = "/home/wth/My_codes/doremi/data/multi_domain/train/train_allenai-WildChat-1m.jsonl"
with open(jsonl_path, "r", encoding="utf-8") as f:
    items = jsonlines.Reader(f)
    for item in items:
        print(item)
        print(item.keys())
        print(item["text"])
        print(item["meta"])
        break

FileNotFoundError: [Errno 2] No such file or directory: '/home/wth/My_codes/doremi/data/multi_domain/train/train_allenai-WildChat-1m.jsonl'

In [21]:
from pathlib import Path
import os
dataset_dir = "/home/wth/My_codes/doremi/data/multi_domain/train"
dataset_dir = Path(dataset_dir)
DOMAINS = list(sorted([str(domain_dir.name.split(".j")[0]) for domain_dir in dataset_dir.iterdir() if not str(domain_dir.name).endswith('txt')]))
print(DOMAINS)
for domain in DOMAINS:
    domain_path = Path(dataset_dir) / domain
    print(domain_path)
    # os.mkdir(domain_path)

['train_Open-Orca-1million-gpt-4', 'train_allenai-WildChat-1m', 'train_alpaca_chat_turn2', 'train_alpaca_de_49963', 'train_alpaca_es_51942', 'train_alpaca_fr_55178', 'train_alpaca_gpt4', 'train_alpaca_it_51710', 'train_alpaca_ja_51999', 'train_alpaca_ko_49620', 'train_alpaca_pt_51759', 'train_alpaca_ru_29822', 'train_alpaca_zh_48818', 'train_jondurbin-airoboros-3.2', 'train_lmsys-chat-1m', 'train_share_gpt4', 'train_slimorca', 'train_teknium-GPTeacher-General-Instruct']
/home/wth/My_codes/doremi/data/multi_domain/train/train_Open-Orca-1million-gpt-4
/home/wth/My_codes/doremi/data/multi_domain/train/train_allenai-WildChat-1m
/home/wth/My_codes/doremi/data/multi_domain/train/train_alpaca_chat_turn2
/home/wth/My_codes/doremi/data/multi_domain/train/train_alpaca_de_49963
/home/wth/My_codes/doremi/data/multi_domain/train/train_alpaca_es_51942
/home/wth/My_codes/doremi/data/multi_domain/train/train_alpaca_fr_55178
/home/wth/My_codes/doremi/data/multi_domain/train/train_alpaca_gpt4
/home/wth/

In [None]:
from datasets import load_dataset

data_file = "/home/wth/My_codes/doremi/data/multi_domain/train/train_allenai-WildChat-1m/train_allenai-WildChat-1m.jsonl"
ds = load_dataset(
    "json",
    data_files=data_file
)
print(ds)