# Download / Import libraries

In [25]:
import os
os.environ["HF_HOME"] = "/network/scratch/x/xut/hf_cache"
os.environ["TRANSFORMERS_CACHE"] = os.path.join(os.environ["HF_HOME"], "transformers")
os.environ["HF_DATASETS_CACHE"] = os.path.join(os.environ["HF_HOME"], "datasets")


## Mapping

In [1]:
fine_web2_labels = ['aeb_Arab', 'afr_Latn', 'amh_Ethi', 'arz_Arab', 'bam_Latn', 'bem_Latn', 'cjk_Latn', 'dyu_Latn', 'gaz_Latn', 'ibo_Latn', 'kab_Latn', 'kam_Latn', 'kbp_Latn', 'kin_Latn', 'kmb_Latn', 'knc_Arab', 'knc_Latn', 'lin_Latn', 'lug_Latn', 'luo_Latn', 'nus_Latn', 'plt_Latn', 'run_Latn', 'sag_Latn', 'sna_Latn', 'sot_Latn', 'ssw_Latn', 'swc_Latn', 'taq_Tfng', 'tir_Ethi', 'tsn_Latn', 'twi_Latn', 'tzm_Tfng', 'umb_Latn', 'xho_Latn', 'yor_Latn']

fineweb2_to_afrimgsm = {
    "amh_Ethi": "amh",
    "ewe_Latn": "ewe",
    "gaz_Latn": "orm",  
    "hau_Latn": "hau",
    "kin_Latn": "kin",
    "lin_Latn": "lin",
    "lug_Latn": "lug",
    "sna_Latn": "sna",
    "swc_Latn": "swa",
    "twi_Latn": "twi",
    "wol_Latn": "wol",
    "xho_Latn": "xho",
    "yor_Latn": "yor",
    "zul_Latn": "zul"
}

fineweb2_to_wura = {
        "afr_Latn": "af",
        "amh_Ethi": "am",
        "arz_Arab": "ar",
        "hau_Latn": "ha",
        "ibo_Latn": "ig",
        "kin_Latn": "ki",
        "plt_Latn": "mg",
        "gaz_Latn": "or",
        "som_Latn": "sm",
        "sna_Latn": "sn",
        "Sesotho": "st",
        "swc_Latn": "sw",
        "tir_Ethi": "ti",
        "xho_Latn": "xh",
        "yor_Latn": "yo",
        "zul_Latn": "zu",
    }



fineweb2_to_madlad = {
    "afr_Latn": "af",
    "aka_Latn": "ak",
    "amh_Ethi": "am",
    "bam_Latn": "bm",
    "dik_Latn": "din",
    "dyu_Latn": "dyu",
    "ewe_Latn": "ee",
    "fon_Latn": "fon",
    "fuv_Latn": "ff",
    "gaz_Latn": "om",   
    "hau_Latn": "ha",  
    "ibo_Latn": "ig",   
    "kbp_Latn": "kbp",
    "kin_Latn": "rw",
    "kmb_Latn": "kmb",
    "kon_Latn": "kg",
    "lin_Latn": "ln",
    "lug_Latn": "lg",
    "run_Latn": "rn",
    "sag_Latn": "sg",
    "sna_Latn": "sn",
    "som_Latn": "so",
    "sot_Latn": "st",
    "ssw_Latn": "ss",
    "swc_Latn": "sw",
    "tir_Ethi": "ti",
    "tsn_Latn": "tn",
    "tso_Latn": "ts",
    "tzm_Tfng": "ber",
    "wol_Latn": "wo",
    "xho_Latn": "xh",
    "yor_Latn": "yo",
    "zul_Latn": "zu"
}




# NLLB target language codes (can be adjusted)
NLLB_target_langs = [
    "fra_Latn",  # French
    "swa_Latn",  # Swahili
    "amh_Ethi",  # Amharic
    "yor_Latn",  # Yoruba
    "hau_Latn"   # Hausa
]


all_african_language_list = [
    'aeb_Arab',
    'afr_Latn',
    'aka_Latn',
    'amh_Ethi',
    'ary_Arab',
    'arz_Arab',
    'bam_Latn',
    'bem_Latn',
    'cjk_Latn',
    'dik_Latn',
    'dyu_Latn',
    'ewe_Latn',
    'fon_Latn',
    'fuv_Latn',
    'gaz_Latn',
    'hau_Latn',
    'ibo_Latn',
    'kab_Latn',
    'kam_Latn',
    'kbp_Latn',
    'kea_Latn',
    'kik_Tatn',
    'kin_Latn',
    'kmb_Latn',
    'knc_Arab',
    'knc_Latn',
    'kon_Latn',
    'lin_Latn',
    'lua_Latn',
    'lug_Latn',
    'luo_Latn',
    'Mos_Latn',
    'nqo_Nkoo',
    'nso_Latn',
    'nus_Latn',
    'nya_Latn',
    'plt_Latn',
    'run_Latn',
    'sag_Latn',
    'sna_Latn',
    'som_Latn',
    'sot_Latn',
    'ssw_Latn',
    'swc_Latn',
    'taq_Latn',
    'taq_Tfng',
    'tir_Ethi',
    'tsn_Latn',
    'tso_Latn',
    'tum_Latn',
    'twi_Latn',
    'tzm_Tfng',
    'umb_Latn',
    'wol_Latn',
    'xho_Latn',
    'yor_Latn',
    'zul_Latn',
]

## Fine Web 2 Dataset

In [None]:
import os
import json
from tqdm import tqdm
from datatrove.pipeline.readers import ParquetReader
from transformers import AutoTokenizer

In [None]:
# Read token
with open("/network/scratch/x/xut/hf_cache/token", "r") as f:
    token = f.read().strip()

In [None]:
# Hyperparam
output_dir = "../../../scratch/data/data_pretrain/fineweb2"
os.makedirs(output_dir, exist_ok=True)
MAX_TOKENS = 1_000_000_000
BUFFER_SIZE = 1000

In [None]:
# Tokenizer (Gemma 2B/7B tokenizer)
tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b", token=token)

# Count tokens
def count_tokens(text):
    return len(tokenizer.encode(text, add_special_tokens=False))

# Loop through languages
for lang_code in fine_web2_labels:
    reader = ParquetReader(f"hf://datasets/HuggingFaceFW/fineweb-2/data/{lang_code}/train")
    output_path = os.path.join(output_dir, f"{lang_code}_fw2.jsonl")
    
    token_count = 0
    doc_count = 0
    buffer = []

    with open(output_path, "w", encoding="utf-8") as out_file:
        for doc in tqdm(reader(), desc=f"{lang_code} docs"):
            text = doc.text.strip()
            if not text:
                continue
            tokens = count_tokens(text)
            token_count += tokens
            doc_count += 1
            buffer.append(json.dumps({"text": text}))

            if len(buffer) >= BUFFER_SIZE:
                out_file.write("\n".join(buffer) + "\n")
                buffer = []

            if token_count >= MAX_TOKENS:
                break

        # Write remaining documents in buffer
        if buffer:
            out_file.write("\n".join(buffer) + "\n")

    print(f"[✓] {lang_code}: {doc_count:,} docs, {token_count:,} tokens ➜ {output_path}")


aeb_Arab docs: 0it [00:00, ?it/s][32m2025-05-07 17:51:39.372[0m | [1mINFO    [0m | [36mdatatrove.pipeline.readers.base[0m:[36mread_files_shard[0m:[36m201[0m - [1mReading input file 000_00000.parquet, 1/1[0m
aeb_Arab docs: 262884it [03:46, 1160.54it/s]


[✓] aeb_Arab: 262,884 docs, 138,746,907 tokens ➜ ../../../scratch/data/data_pretrain/fineweb2/aeb_Arab_fw2.jsonl


afr_Latn docs: 0it [00:00, ?it/s][32m2025-05-07 17:55:25.906[0m | [1mINFO    [0m | [36mdatatrove.pipeline.readers.base[0m:[36mread_files_shard[0m:[36m201[0m - [1mReading input file 000_00000.parquet, 1/1[0m
afr_Latn docs: 877108it [28:46, 508.03it/s]


[✓] afr_Latn: 877,109 docs, 1,000,000,248 tokens ➜ ../../../scratch/data/data_pretrain/fineweb2/afr_Latn_fw2.jsonl


amh_Ethi docs: 0it [00:00, ?it/s][32m2025-05-07 18:24:12.423[0m | [1mINFO    [0m | [36mdatatrove.pipeline.readers.base[0m:[36mread_files_shard[0m:[36m201[0m - [1mReading input file 000_00000.parquet, 1/1[0m
amh_Ethi docs: 280355it [07:37, 612.43it/s]


[✓] amh_Ethi: 280,355 docs, 637,971,403 tokens ➜ ../../../scratch/data/data_pretrain/fineweb2/amh_Ethi_fw2.jsonl


arz_Arab docs: 0it [00:00, ?it/s][32m2025-05-07 18:31:50.212[0m | [1mINFO    [0m | [36mdatatrove.pipeline.readers.base[0m:[36mread_files_shard[0m:[36m201[0m - [1mReading input file 000_00000.parquet, 1/1[0m
arz_Arab docs: 1410134it [23:24, 1003.91it/s]


[✓] arz_Arab: 1,410,134 docs, 847,888,995 tokens ➜ ../../../scratch/data/data_pretrain/fineweb2/arz_Arab_fw2.jsonl


bam_Latn docs: 0it [00:00, ?it/s][32m2025-05-07 18:55:14.867[0m | [1mINFO    [0m | [36mdatatrove.pipeline.readers.base[0m:[36mread_files_shard[0m:[36m201[0m - [1mReading input file 000_00000.parquet, 1/1[0m
bam_Latn docs: 14044it [00:14, 957.48it/s] 


[✓] bam_Latn: 14,044 docs, 6,834,477 tokens ➜ ../../../scratch/data/data_pretrain/fineweb2/bam_Latn_fw2.jsonl


bem_Latn docs: 0it [00:00, ?it/s][32m2025-05-07 18:55:29.539[0m | [1mINFO    [0m | [36mdatatrove.pipeline.readers.base[0m:[36mread_files_shard[0m:[36m201[0m - [1mReading input file 000_00000.parquet, 1/1[0m
bem_Latn docs: 1143it [00:02, 535.61it/s]


[✓] bem_Latn: 1,143 docs, 1,468,373 tokens ➜ ../../../scratch/data/data_pretrain/fineweb2/bem_Latn_fw2.jsonl


cjk_Latn docs: 0it [00:00, ?it/s][32m2025-05-07 18:55:31.681[0m | [1mINFO    [0m | [36mdatatrove.pipeline.readers.base[0m:[36mread_files_shard[0m:[36m201[0m - [1mReading input file 000_00000.parquet, 1/1[0m
cjk_Latn docs: 44it [00:00, 305.00it/s]


[✓] cjk_Latn: 44 docs, 32,816 tokens ➜ ../../../scratch/data/data_pretrain/fineweb2/cjk_Latn_fw2.jsonl


dyu_Latn docs: 0it [00:00, ?it/s][32m2025-05-07 18:55:31.829[0m | [1mINFO    [0m | [36mdatatrove.pipeline.readers.base[0m:[36mread_files_shard[0m:[36m201[0m - [1mReading input file 000_00000.parquet, 1/1[0m
dyu_Latn docs: 2209it [00:02, 925.42it/s] 


[✓] dyu_Latn: 2,209 docs, 1,998,324 tokens ➜ ../../../scratch/data/data_pretrain/fineweb2/dyu_Latn_fw2.jsonl


gaz_Latn docs: 0it [00:00, ?it/s][32m2025-05-07 18:55:34.222[0m | [1mINFO    [0m | [36mdatatrove.pipeline.readers.base[0m:[36mread_files_shard[0m:[36m201[0m - [1mReading input file 000_00000.parquet, 1/1[0m
gaz_Latn docs: 43468it [01:01, 702.95it/s]


[✓] gaz_Latn: 43,468 docs, 42,534,779 tokens ➜ ../../../scratch/data/data_pretrain/fineweb2/gaz_Latn_fw2.jsonl


ibo_Latn docs: 0it [00:00, ?it/s][32m2025-05-07 18:56:36.065[0m | [1mINFO    [0m | [36mdatatrove.pipeline.readers.base[0m:[36mread_files_shard[0m:[36m201[0m - [1mReading input file 000_00000.parquet, 1/1[0m
ibo_Latn docs: 95184it [02:54, 544.58it/s]


[✓] ibo_Latn: 95,184 docs, 144,594,856 tokens ➜ ../../../scratch/data/data_pretrain/fineweb2/ibo_Latn_fw2.jsonl


kab_Latn docs: 0it [00:00, ?it/s][32m2025-05-07 18:59:30.858[0m | [1mINFO    [0m | [36mdatatrove.pipeline.readers.base[0m:[36mread_files_shard[0m:[36m201[0m - [1mReading input file 000_00000.parquet, 1/1[0m
kab_Latn docs: 7717it [00:05, 1441.15it/s]


[✓] kab_Latn: 7,717 docs, 3,799,925 tokens ➜ ../../../scratch/data/data_pretrain/fineweb2/kab_Latn_fw2.jsonl


kam_Latn docs: 0it [00:00, ?it/s][32m2025-05-07 18:59:36.218[0m | [1mINFO    [0m | [36mdatatrove.pipeline.readers.base[0m:[36mread_files_shard[0m:[36m201[0m - [1mReading input file 000_00000.parquet, 1/1[0m
kam_Latn docs: 1218it [00:01, 935.76it/s] 


[✓] kam_Latn: 1,218 docs, 1,007,712 tokens ➜ ../../../scratch/data/data_pretrain/fineweb2/kam_Latn_fw2.jsonl


kbp_Latn docs: 0it [00:00, ?it/s][32m2025-05-07 18:59:37.525[0m | [1mINFO    [0m | [36mdatatrove.pipeline.readers.base[0m:[36mread_files_shard[0m:[36m201[0m - [1mReading input file 000_00000.parquet, 1/1[0m
kbp_Latn docs: 1231it [00:01, 1153.56it/s]


[✓] kbp_Latn: 1,231 docs, 1,050,748 tokens ➜ ../../../scratch/data/data_pretrain/fineweb2/kbp_Latn_fw2.jsonl


kin_Latn docs: 0it [00:00, ?it/s][32m2025-05-07 18:59:38.598[0m | [1mINFO    [0m | [36mdatatrove.pipeline.readers.base[0m:[36mread_files_shard[0m:[36m201[0m - [1mReading input file 000_00000.parquet, 1/1[0m
kin_Latn docs: 199112it [03:17, 1009.90it/s]


[✓] kin_Latn: 199,112 docs, 136,741,122 tokens ➜ ../../../scratch/data/data_pretrain/fineweb2/kin_Latn_fw2.jsonl


kmb_Latn docs: 0it [00:00, ?it/s][32m2025-05-07 19:02:55.763[0m | [1mINFO    [0m | [36mdatatrove.pipeline.readers.base[0m:[36mread_files_shard[0m:[36m201[0m - [1mReading input file 000_00000.parquet, 1/1[0m
kmb_Latn docs: 1132it [00:02, 524.04it/s]


[✓] kmb_Latn: 1,132 docs, 1,496,461 tokens ➜ ../../../scratch/data/data_pretrain/fineweb2/kmb_Latn_fw2.jsonl


knc_Arab docs: 0it [00:00, ?it/s][32m2025-05-07 19:02:58.403[0m | [1mINFO    [0m | [36mdatatrove.pipeline.readers.base[0m:[36mread_files_shard[0m:[36m201[0m - [1mReading input file 000_00000.parquet, 1/1[0m
knc_Arab docs: 290it [00:04, 68.05it/s]


[✓] knc_Arab: 290 docs, 3,145,798 tokens ➜ ../../../scratch/data/data_pretrain/fineweb2/knc_Arab_fw2.jsonl


knc_Latn docs: 0it [00:00, ?it/s][32m2025-05-07 19:03:02.727[0m | [1mINFO    [0m | [36mdatatrove.pipeline.readers.base[0m:[36mread_files_shard[0m:[36m201[0m - [1mReading input file 000_00000.parquet, 1/1[0m
knc_Latn docs: 437it [00:00, 680.62it/s]


[✓] knc_Latn: 437 docs, 252,638 tokens ➜ ../../../scratch/data/data_pretrain/fineweb2/knc_Latn_fw2.jsonl


lin_Latn docs: 0it [00:00, ?it/s][32m2025-05-07 19:03:03.379[0m | [1mINFO    [0m | [36mdatatrove.pipeline.readers.base[0m:[36mread_files_shard[0m:[36m201[0m - [1mReading input file 000_00000.parquet, 1/1[0m
lin_Latn docs: 15241it [00:26, 580.71it/s]


[✓] lin_Latn: 15,241 docs, 16,269,314 tokens ➜ ../../../scratch/data/data_pretrain/fineweb2/lin_Latn_fw2.jsonl


lug_Latn docs: 0it [00:00, ?it/s][32m2025-05-07 19:03:29.631[0m | [1mINFO    [0m | [36mdatatrove.pipeline.readers.base[0m:[36mread_files_shard[0m:[36m201[0m - [1mReading input file 000_00000.parquet, 1/1[0m
lug_Latn docs: 32954it [00:34, 958.06it/s] 


[✓] lug_Latn: 32,954 docs, 24,286,951 tokens ➜ ../../../scratch/data/data_pretrain/fineweb2/lug_Latn_fw2.jsonl


luo_Latn docs: 0it [00:00, ?it/s][32m2025-05-07 19:04:04.036[0m | [1mINFO    [0m | [36mdatatrove.pipeline.readers.base[0m:[36mread_files_shard[0m:[36m201[0m - [1mReading input file 000_00000.parquet, 1/1[0m
luo_Latn docs: 2210it [00:03, 703.49it/s]


[✓] luo_Latn: 2,210 docs, 2,005,909 tokens ➜ ../../../scratch/data/data_pretrain/fineweb2/luo_Latn_fw2.jsonl


nus_Latn docs: 0it [00:00, ?it/s][32m2025-05-07 19:04:07.183[0m | [1mINFO    [0m | [36mdatatrove.pipeline.readers.base[0m:[36mread_files_shard[0m:[36m201[0m - [1mReading input file 000_00000.parquet, 1/1[0m
nus_Latn docs: 152it [00:00, 351.92it/s]


[✓] nus_Latn: 152 docs, 222,008 tokens ➜ ../../../scratch/data/data_pretrain/fineweb2/nus_Latn_fw2.jsonl


plt_Latn docs: 0it [00:00, ?it/s][32m2025-05-07 19:04:07.620[0m | [1mINFO    [0m | [36mdatatrove.pipeline.readers.base[0m:[36mread_files_shard[0m:[36m201[0m - [1mReading input file 000_00000.parquet, 1/1[0m
plt_Latn docs: 254482it [07:29, 566.71it/s]


[✓] plt_Latn: 254,482 docs, 305,273,095 tokens ➜ ../../../scratch/data/data_pretrain/fineweb2/plt_Latn_fw2.jsonl


run_Latn docs: 0it [00:00, ?it/s][32m2025-05-07 19:11:36.680[0m | [1mINFO    [0m | [36mdatatrove.pipeline.readers.base[0m:[36mread_files_shard[0m:[36m201[0m - [1mReading input file 000_00000.parquet, 1/1[0m
run_Latn docs: 88823it [01:26, 1022.85it/s]


[✓] run_Latn: 88,823 docs, 56,824,462 tokens ➜ ../../../scratch/data/data_pretrain/fineweb2/run_Latn_fw2.jsonl


sag_Latn docs: 0it [00:00, ?it/s][32m2025-05-07 19:13:03.524[0m | [1mINFO    [0m | [36mdatatrove.pipeline.readers.base[0m:[36mread_files_shard[0m:[36m201[0m - [1mReading input file 000_00000.parquet, 1/1[0m
sag_Latn docs: 4537it [00:09, 469.65it/s]


[✓] sag_Latn: 4,537 docs, 7,111,907 tokens ➜ ../../../scratch/data/data_pretrain/fineweb2/sag_Latn_fw2.jsonl


sna_Latn docs: 0it [00:00, ?it/s][32m2025-05-07 19:13:13.192[0m | [1mINFO    [0m | [36mdatatrove.pipeline.readers.base[0m:[36mread_files_shard[0m:[36m201[0m - [1mReading input file 000_00000.parquet, 1/1[0m
sna_Latn docs: 80003it [02:18, 576.71it/s]


[✓] sna_Latn: 80,003 docs, 95,262,722 tokens ➜ ../../../scratch/data/data_pretrain/fineweb2/sna_Latn_fw2.jsonl


sot_Latn docs: 0it [00:00, ?it/s][32m2025-05-07 19:15:31.919[0m | [1mINFO    [0m | [36mdatatrove.pipeline.readers.base[0m:[36mread_files_shard[0m:[36m201[0m - [1mReading input file 000_00000.parquet, 1/1[0m
sot_Latn docs: 83329it [02:55, 474.96it/s]


[✓] sot_Latn: 83,329 docs, 120,628,123 tokens ➜ ../../../scratch/data/data_pretrain/fineweb2/sot_Latn_fw2.jsonl


ssw_Latn docs: 0it [00:00, ?it/s][32m2025-05-07 19:18:27.371[0m | [1mINFO    [0m | [36mdatatrove.pipeline.readers.base[0m:[36mread_files_shard[0m:[36m201[0m - [1mReading input file 000_00000.parquet, 1/1[0m
ssw_Latn docs: 1668it [00:03, 439.26it/s]


[✓] ssw_Latn: 1,668 docs, 2,670,011 tokens ➜ ../../../scratch/data/data_pretrain/fineweb2/ssw_Latn_fw2.jsonl


swc_Latn docs: 0it [00:00, ?it/s][32m2025-05-07 19:18:31.179[0m | [1mINFO    [0m | [36mdatatrove.pipeline.readers.base[0m:[36mread_files_shard[0m:[36m201[0m - [1mReading input file 000_00000.parquet, 1/1[0m
swc_Latn docs: 2161it [00:01, 1264.64it/s]


[✓] swc_Latn: 2,161 docs, 826,467 tokens ➜ ../../../scratch/data/data_pretrain/fineweb2/swc_Latn_fw2.jsonl


taq_Tfng docs: 0it [00:00, ?it/s][32m2025-05-07 19:18:32.893[0m | [1mINFO    [0m | [36mdatatrove.pipeline.readers.base[0m:[36mread_files_shard[0m:[36m201[0m - [1mReading input file 000_00000.parquet, 1/1[0m
taq_Tfng docs: 208it [00:00, 521.71it/s]


[✓] taq_Tfng: 208 docs, 441,499 tokens ➜ ../../../scratch/data/data_pretrain/fineweb2/taq_Tfng_fw2.jsonl


tir_Ethi docs: 0it [00:00, ?it/s][32m2025-05-07 19:18:33.299[0m | [1mINFO    [0m | [36mdatatrove.pipeline.readers.base[0m:[36mread_files_shard[0m:[36m201[0m - [1mReading input file 000_00000.parquet, 1/1[0m
tir_Ethi docs: 65569it [01:40, 651.98it/s]


[✓] tir_Ethi: 65,569 docs, 135,822,971 tokens ➜ ../../../scratch/data/data_pretrain/fineweb2/tir_Ethi_fw2.jsonl


tsn_Latn docs: 0it [00:00, ?it/s][32m2025-05-07 19:20:13.884[0m | [1mINFO    [0m | [36mdatatrove.pipeline.readers.base[0m:[36mread_files_shard[0m:[36m201[0m - [1mReading input file 000_00000.parquet, 1/1[0m
tsn_Latn docs: 5530it [00:13, 423.78it/s]


[✓] tsn_Latn: 5,530 docs, 9,162,081 tokens ➜ ../../../scratch/data/data_pretrain/fineweb2/tsn_Latn_fw2.jsonl


twi_Latn docs: 0it [00:00, ?it/s][32m2025-05-07 19:20:26.947[0m | [1mINFO    [0m | [36mdatatrove.pipeline.readers.base[0m:[36mread_files_shard[0m:[36m201[0m - [1mReading input file 000_00000.parquet, 1/1[0m
twi_Latn docs: 5655it [00:13, 405.99it/s]


[✓] twi_Latn: 5,655 docs, 11,239,289 tokens ➜ ../../../scratch/data/data_pretrain/fineweb2/twi_Latn_fw2.jsonl


tzm_Tfng docs: 0it [00:00, ?it/s][32m2025-05-07 19:20:40.887[0m | [1mINFO    [0m | [36mdatatrove.pipeline.readers.base[0m:[36mread_files_shard[0m:[36m201[0m - [1mReading input file 000_00000.parquet, 1/1[0m
tzm_Tfng docs: 2376it [00:03, 773.27it/s]


[✓] tzm_Tfng: 2,376 docs, 4,448,373 tokens ➜ ../../../scratch/data/data_pretrain/fineweb2/tzm_Tfng_fw2.jsonl


umb_Latn docs: 0it [00:00, ?it/s][32m2025-05-07 19:20:43.968[0m | [1mINFO    [0m | [36mdatatrove.pipeline.readers.base[0m:[36mread_files_shard[0m:[36m201[0m - [1mReading input file 000_00000.parquet, 1/1[0m
umb_Latn docs: 709it [00:00, 779.46it/s]


[✓] umb_Latn: 709 docs, 536,879 tokens ➜ ../../../scratch/data/data_pretrain/fineweb2/umb_Latn_fw2.jsonl


xho_Latn docs: 0it [00:00, ?it/s][32m2025-05-07 19:20:44.884[0m | [1mINFO    [0m | [36mdatatrove.pipeline.readers.base[0m:[36mread_files_shard[0m:[36m201[0m - [1mReading input file 000_00000.parquet, 1/1[0m
xho_Latn docs: 99567it [02:45, 602.44it/s]


[✓] xho_Latn: 99,567 docs, 119,244,239 tokens ➜ ../../../scratch/data/data_pretrain/fineweb2/xho_Latn_fw2.jsonl


yor_Latn docs: 0it [00:00, ?it/s][32m2025-05-07 19:23:30.164[0m | [1mINFO    [0m | [36mdatatrove.pipeline.readers.base[0m:[36mread_files_shard[0m:[36m201[0m - [1mReading input file 000_00000.parquet, 1/1[0m
yor_Latn docs: 67447it [01:58, 571.02it/s]

[✓] yor_Latn: 67,447 docs, 93,824,538 tokens ➜ ../../../scratch/data/data_pretrain/fineweb2/yor_Latn_fw2.jsonl





# Afri-MGSM Dataset

In [18]:
from datasets import load_dataset

## Hyperparam

In [19]:
# Output directory
output_dir = "../../../scratch/data/data_pretrain/afrimgsm"
os.makedirs(output_dir, exist_ok=True)

In [20]:
# Loop through each mapping
for long_code, mgsm_code in fineweb2_to_afrimgsm.items():
    
    try:
        dataset = load_dataset("masakhane/afrimgsm", name=mgsm_code, split="train")
    except Exception as e:
        print(f"Failed to load {mgsm_code}: {e}")
        continue

    output_path = os.path.join(output_dir, f"{long_code}_mgsm.jsonl")

    with open(output_path, "w", encoding="utf-8") as f:
        for example in dataset:
            json.dump(example, f)
            f.write("\n")

    print(f"Saved to {output_path}")

Generating train split: 100%|██████████| 8/8 [00:00<00:00, 268.14 examples/s]
Generating test split: 100%|██████████| 250/250 [00:00<00:00, 29184.67 examples/s]


Saved to ../../../scratch/data/data_pretrain/afrimgsm/amh_Ethi_mgsm.jsonl


Generating train split: 100%|██████████| 8/8 [00:00<00:00, 1282.42 examples/s]
Generating test split: 100%|██████████| 250/250 [00:00<00:00, 30372.38 examples/s]


Saved to ../../../scratch/data/data_pretrain/afrimgsm/ewe_Latn_mgsm.jsonl


Generating train split: 100%|██████████| 8/8 [00:00<00:00, 823.42 examples/s]
Generating test split: 100%|██████████| 250/250 [00:00<00:00, 22172.38 examples/s]


Saved to ../../../scratch/data/data_pretrain/afrimgsm/gaz_Latn_mgsm.jsonl


Generating train split: 100%|██████████| 8/8 [00:00<00:00, 1121.77 examples/s]
Generating test split: 100%|██████████| 250/250 [00:00<00:00, 28855.39 examples/s]


Saved to ../../../scratch/data/data_pretrain/afrimgsm/hau_Latn_mgsm.jsonl


Generating train split: 100%|██████████| 8/8 [00:00<00:00, 1115.88 examples/s]
Generating test split: 100%|██████████| 250/250 [00:00<00:00, 32098.94 examples/s]


Saved to ../../../scratch/data/data_pretrain/afrimgsm/kin_Latn_mgsm.jsonl


Generating train split: 100%|██████████| 8/8 [00:00<00:00, 1903.80 examples/s]
Generating test split: 100%|██████████| 250/250 [00:00<00:00, 40888.13 examples/s]


Saved to ../../../scratch/data/data_pretrain/afrimgsm/lin_Latn_mgsm.jsonl


Generating train split: 100%|██████████| 8/8 [00:00<00:00, 1280.85 examples/s]
Generating test split: 100%|██████████| 250/250 [00:00<00:00, 36393.72 examples/s]


Saved to ../../../scratch/data/data_pretrain/afrimgsm/lug_Latn_mgsm.jsonl


Generating train split: 100%|██████████| 8/8 [00:00<00:00, 1326.94 examples/s]
Generating test split: 100%|██████████| 250/250 [00:00<00:00, 36653.24 examples/s]


Saved to ../../../scratch/data/data_pretrain/afrimgsm/sna_Latn_mgsm.jsonl


Generating train split: 100%|██████████| 8/8 [00:00<00:00, 1106.64 examples/s]
Generating test split: 100%|██████████| 250/250 [00:00<00:00, 39711.27 examples/s]


Saved to ../../../scratch/data/data_pretrain/afrimgsm/swc_Latn_mgsm.jsonl


Generating train split: 100%|██████████| 8/8 [00:00<00:00, 1208.78 examples/s]
Generating test split: 100%|██████████| 250/250 [00:00<00:00, 28920.65 examples/s]


Saved to ../../../scratch/data/data_pretrain/afrimgsm/twi_Latn_mgsm.jsonl


Generating train split: 100%|██████████| 8/8 [00:00<00:00, 1250.68 examples/s]
Generating test split: 100%|██████████| 250/250 [00:00<00:00, 31310.12 examples/s]


Saved to ../../../scratch/data/data_pretrain/afrimgsm/wol_Latn_mgsm.jsonl


Generating train split: 100%|██████████| 8/8 [00:00<00:00, 1225.69 examples/s]
Generating test split: 100%|██████████| 250/250 [00:00<00:00, 33375.01 examples/s]


Saved to ../../../scratch/data/data_pretrain/afrimgsm/xho_Latn_mgsm.jsonl


Generating train split: 100%|██████████| 8/8 [00:00<00:00, 1191.86 examples/s]
Generating test split: 100%|██████████| 250/250 [00:00<00:00, 23812.87 examples/s]


Saved to ../../../scratch/data/data_pretrain/afrimgsm/yor_Latn_mgsm.jsonl


Generating train split: 100%|██████████| 8/8 [00:00<00:00, 995.27 examples/s]
Generating test split: 100%|██████████| 250/250 [00:00<00:00, 35208.38 examples/s]

Saved to ../../../scratch/data/data_pretrain/afrimgsm/zul_Latn_mgsm.jsonl





# Wura

## Hyperparam

In [27]:
# Output directory
output_dir = "../../../scratch/data/data_pretrain/wura"
os.makedirs(output_dir, exist_ok=True)

In [28]:
# Loop through each language
for long_code, wura_code in fineweb2_to_wura.items():
    print(f"⬇Downloading WURA for {long_code} ({wura_code})...")

    try:
        ds = load_dataset("llama-lang-adapt/wura", name=wura_code, split="train")
    except Exception as e:
        print(f"Failed to load {wura_code}: {e}")
        continue

    output_path = os.path.join(output_dir, f"{long_code}_wura.jsonl")

    with open(output_path, "w", encoding="utf-8") as f:
        for example in ds:
            json.dump(example, f)
            f.write("\n")

    print(f"Saved to {output_path}")

⬇Downloading WURA for afr_Latn (af)...


KeyboardInterrupt: 

# Madlad 400

## Hyperparam

In [32]:
from datasets import load_dataset

# Folder
output_dir = "../../../scratch/data/data_pretrain/madlad400"
os.makedirs(output_dir, exist_ok=True)

# write every 1000 docs
BUFFER_SIZE = 1000

In [33]:
for fineweb_code, madlad_code in fineweb2_to_madlad.items():

    try:
        dataset = load_dataset("allenai/madlad-400", languages=[madlad_code], split="clean", streaming=True)
    except Exception as e:
        print(f"Failed to stream {madlad_code}: {e}")
        continue

    output_path = os.path.join(output_dir, f"{fineweb_code}_madlad.jsonl")
    buffer = []

    with open(output_path, "w", encoding="utf-8") as f:
        for i, example in enumerate(dataset, 1):
            buffer.append(json.dumps(example))

            # write buffer to file every BUFFER_SIZE items
            if len(buffer) >= BUFFER_SIZE:
                f.write("\n".join(buffer) + "\n")
                buffer = []

        # write remaining items
        if buffer:
            f.write("\n".join(buffer) + "\n")

    print(f"Done: {output_path} (streamed {i} docs)")

Done: ../../../scratch/data/data_pretrain/madlad400/lug_Latn_madlad.jsonl (streamed 13030 docs)


# Other data

## Malagsay Data

In [6]:
# Libraries

import os
import json
from datasets import load_dataset

  from .autonotebook import tqdm as notebook_tqdm


In [7]:
# Output folder & path
output_dir = "../../../scratch/data/data_pretrain/extradata"
os.makedirs(output_dir, exist_ok=True)
output_path = os.path.join(output_dir, "tsn_Latn.jsonl")

In [9]:
# Load the extra dataset
ds = load_dataset("OxxoCodes/Marothodi", split="train")  

# 3) Choose write mode: 'a' to append if exists, else 'w' to write new
mode = "a" if os.path.isfile(output_path) else "w"

# 4) Stream and write (append) each example as JSONL
with open(output_path, mode, encoding="utf-8") as out_f:
    for example in ds:
        json.dump(example, out_f)
        out_f.write("\n")

print(f"✅ {'Appended to' if mode=='a' else 'Wrote'} {output_path}")


✅ Wrote ../../../scratch/data/data_pretrain/extradata/tsn_Latn.jsonl


# Combine all data

In [10]:
# Folders & suffixes to look for; note extradata uses plain "<lang>.jsonl"
datasets = [
    ("fineweb2",   "_fw2.jsonl"),
    ("afrimgsm",   "_mgsm.jsonl"),
    ("madlad400",  "_madlad.jsonl"),
    ("wura",       "_wura.jsonl"),
    ("extradata",  ".jsonl"),   
]

In [11]:
# Loop, collect paths, and concatenate
for lang in all_african_language_list:
    # gather any existing source files for this lang
    sources = []
    for folder, suffix in datasets:
        path = os.path.join(base_dir, folder, f"{lang}{suffix}")
        if os.path.isfile(path):
            sources.append(path)

    if not sources:
        print(f"⏭ no files found for {lang}, skipping.")
        continue

    out_path = os.path.join(output_dir, f"{lang}_data.jsonl")
    with open(out_path, "w", encoding="utf-8") as out_f:
        for src in sources:
            with open(src, "r", encoding="utf-8") as in_f:
                for line in in_f:
                    if line.strip():
                        out_f.write(line)
    print(f"Wrote {lang}_data.jsonl ({len(sources)} sources)")

print("Done concatenating all languages.")


Wrote aeb_Arab_data.jsonl (1 sources)
Wrote afr_Latn_data.jsonl (3 sources)
Wrote aka_Latn_data.jsonl (1 sources)
Wrote amh_Ethi_data.jsonl (4 sources)
⏭ no files found for ary_Arab, skipping.
Wrote arz_Arab_data.jsonl (2 sources)
Wrote bam_Latn_data.jsonl (2 sources)
Wrote bem_Latn_data.jsonl (1 sources)
Wrote cjk_Latn_data.jsonl (1 sources)
Wrote dik_Latn_data.jsonl (1 sources)
Wrote dyu_Latn_data.jsonl (2 sources)
Wrote ewe_Latn_data.jsonl (2 sources)
Wrote fon_Latn_data.jsonl (1 sources)
Wrote fuv_Latn_data.jsonl (1 sources)
Wrote gaz_Latn_data.jsonl (4 sources)
Wrote hau_Latn_data.jsonl (3 sources)
Wrote ibo_Latn_data.jsonl (3 sources)
Wrote kab_Latn_data.jsonl (1 sources)
Wrote kam_Latn_data.jsonl (1 sources)
Wrote kbp_Latn_data.jsonl (2 sources)
⏭ no files found for kea_Latn, skipping.
⏭ no files found for kik_Tatn, skipping.
Wrote kin_Latn_data.jsonl (4 sources)
Wrote kmb_Latn_data.jsonl (2 sources)
Wrote knc_Arab_data.jsonl (1 sources)
Wrote knc_Latn_data.jsonl (1 sources)
Wro

# Data Analysis

In [2]:
import os
import json
import tiktoken

In [3]:
base_dir       = "../../../scratch/data/data_pretrain"
concat_dir     = os.path.join(base_dir, "concat_data")

# raw‐data folders and file suffixes
datasets = {
    "fineweb2":   "_fw2.jsonl",
    "afrimgsm":   "_mgsm.jsonl",
    "madlad400":  "_madlad.jsonl",
    "wura":       "_wura.jsonl",
    "extradata":  ".jsonl",    # catches e.g. tsn_Latn.jsonl
}


In [4]:
# load a fast tokenizer
enc = tiktoken.get_encoding("cl100k_base")

# ——— Loop over languages ———
for lang in all_african_language_list:
    # 1) find which folders have this lang
    present = []
    for folder, suffix in datasets.items():
        path = os.path.join(base_dir, folder, f"{lang}{suffix}")
        if os.path.isfile(path):
            present.append(folder)
    if not present:
        print(f"{lang}:  ⏭ no raw data found")
        continue

    print(f"{lang}:  files in → {', '.join(present)}")

    # 2) count tokens in the merged file, if it exists
    merged_path = os.path.join(concat_dir, f"{lang}_data.jsonl")
    if not os.path.isfile(merged_path):
        print("    ✖ no merged file found")
        continue

    total_tokens = 0
    with open(merged_path, "r", encoding="utf-8") as f:
        for line in f:
            line=line.strip()
            if not line:
                continue
            data = json.loads(line)
            text = data.get("text")
            if not text:
                continue
            # count subword tokens
            total_tokens += len(enc.encode(text))

    print(f"merged tokens = {total_tokens:,}")

print("Done.")    


aeb_Arab:  files in → fineweb2
merged tokens = 252,259,810
afr_Latn:  files in → fineweb2, madlad400, wura
merged tokens = 4,020,892,876
aka_Latn:  files in → madlad400
merged tokens = 13,456,823
amh_Ethi:  files in → fineweb2, afrimgsm, madlad400, wura
merged tokens = 3,592,851,378
ary_Arab:  ⏭ no raw data found
arz_Arab:  files in → fineweb2, wura
merged tokens = 1,610,913,304
bam_Latn:  files in → fineweb2, madlad400
merged tokens = 9,779,956
bem_Latn:  files in → fineweb2
merged tokens = 1,595,796
cjk_Latn:  files in → fineweb2
merged tokens = 34,858
dik_Latn:  files in → madlad400
merged tokens = 1,638,176
dyu_Latn:  files in → fineweb2, madlad400
merged tokens = 3,423,550
ewe_Latn:  files in → afrimgsm, madlad400
merged tokens = 19,013,371
fon_Latn:  files in → madlad400
merged tokens = 6,032,326
fuv_Latn:  files in → madlad400
merged tokens = 99,385
gaz_Latn:  files in → fineweb2, afrimgsm, madlad400, wura
merged tokens = 100,091,781
hau_Latn:  files in → afrimgsm, madlad400, wu

# Tiny Stories

## Hyperparam

In [56]:
# Folder directory
output_dir = "../../../scratch/data/data_pretrain/tinystories"
os.makedirs(output_dir, exist_ok=True)
BUFFER_SIZE = 1000  # Write every 1000 examples

In [58]:
output_path = os.path.join(output_dir, "eng_Latn_story.jsonl")

try:
    dataset = load_dataset("roneneldan/TinyStories", split="train", streaming=True)
except Exception as e:
    print(f"Failed to load TinyStories: {e}")
    exit(1)

buffer = []

with open(output_path, "w", encoding="utf-8") as f:
    for i, example in enumerate(dataset, 1):
        buffer.append(json.dumps(example))

        if len(buffer) >= BUFFER_SIZE:
            f.write("\n".join(buffer) + "\n")
            buffer = []

    # flush remaining examples
    if buffer:
        f.write("\n".join(buffer) + "\n")

print(f"Saved TinyStories to {output_path} ({i} examples)")


Saved TinyStories to ../../../scratch/data/data_pretrain/tinystories/eng_Latn_story.jsonl (2119719 examples)


## Load NLLB model and tokenizer

In [59]:
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline

model_name = "facebook/nllb-200-distilled-600M"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)

translation_pipeline = pipeline("translation", model=model, tokenizer=tokenizer, device=0)  # use device=0 for GPU


Device set to use cuda:0


## Translate into Different Langauges

In [64]:
# Directories and paths and hyperparams
input_path = "../../../scratch/data/data_pretrain/tinystories/eng_Latn_story.jsonl"
output_base_dir = "../../../scratch/data/data_pretrain/tinystories"
os.makedirs(output_base_dir, exist_ok=True)

BATCH_SIZE = 16 
MAX_INPUT_LENGTH = 1024

In [65]:
from tqdm import tqdm


def translate_batch(batch_texts, tgt_lang_code):
    tokenizer.src_lang = "eng_Latn"
    encoded = tokenizer(batch_texts, return_tensors="pt", padding=True, truncation=False, max_length=MAX_INPUT_LENGTH).to(model.device)
    forced_bos_id = tokenizer.convert_tokens_to_ids(f"__{tgt_lang_code}__")
    generated_tokens = model.generate(**encoded, forced_bos_token_id=forced_bos_id)
    return tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)

# Read original English stories
with open(input_path, "r", encoding="utf-8") as f:
    stories = [json.loads(line.strip())["text"] for line in f if line.strip()]

# For each language
for lang_code in NLLB_target_langs:
    print(f"Translating to {lang_code}...")

    translated = []
    for i in tqdm(range(0, len(stories), BATCH_SIZE), desc=f"Translating {lang_code}"):
        batch = stories[i:i+BATCH_SIZE]
        try:
            translations = translate_batch(batch, lang_code)
            translated.extend(translations)
        except Exception as e:
            print(f"Error translating batch {i}-{i+BATCH_SIZE}: {e}")
            continue

    output_path = os.path.join(output_base_dir, f"{lang_code}_story.jsonl")
    with open(output_path, "w", encoding="utf-8") as out_f:
        for t in translated:
            json.dump({"text": t}, out_f)
            out_f.write("\n")

    print(f"Saved: {output_path}")


Translating to fra_Latn...


Translating fra_Latn:   0%|          | 62/132483 [02:58<106:02:39,  2.88s/it]


KeyboardInterrupt: 

In [68]:
import json
import os
import torch
import logging
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from tqdm.auto import tqdm
import gc
import argparse
from torch.utils.data import Dataset, DataLoader
import time
from functools import partial

# Setup logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    handlers=[logging.StreamHandler()]
)
logger = logging.getLogger(__name__)

# Global configurations
MAX_INPUT_LENGTH = 512  # Reduced from 1024 to 512 to improve memory efficiency
SAVE_CHUNK_SIZE = 10000  # Save after processing this many examples
MEMORY_CHECKPOINT_INTERVAL = 100  # Check memory and clean up every N batches

# Custom dataset for efficient loading
class TextDataset(Dataset):
    def __init__(self, texts):
        self.texts = texts
        
    def __len__(self):
        return len(self.texts)
        
    def __getitem__(self, idx):
        return self.texts[idx]

def collate_fn(batch):
    return batch

def get_memory_usage():
    if torch.cuda.is_available():
        return {
            "allocated": torch.cuda.memory_allocated() / (1024 ** 3),  # GB
            "cached": torch.cuda.memory_reserved() / (1024 ** 3)  # GB
        }
    return {"allocated": 0, "cached": 0}

def clear_memory():
    """Aggressively clean up memory"""
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        torch.cuda.synchronize()

def load_model(model_name, use_fp16=True, device="cuda" if torch.cuda.is_available() else "cpu"):
    """Load model with optimizations"""
    logger.info(f"Loading model {model_name} on {device}")
    
    # Load tokenizer with caching
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    
    # Load model with memory optimization
    model = AutoModelForSeq2SeqLM.from_pretrained(
        model_name,
        torch_dtype=torch.float16 if use_fp16 else torch.float32,
        low_cpu_mem_usage=True,
    )
    
    # Apply performance optimizations
    if device == "cuda":
        model = model.to(device)
        if use_fp16:
            model = model.half()  # Ensure FP16
        
        # Enable further optimizations
        torch.backends.cuda.matmul.allow_tf32 = True
        torch.backends.cudnn.benchmark = True
        torch.backends.cudnn.allow_tf32 = True
    
    model.eval()  # Set to evaluation mode
    
    # Log memory usage after loading
    mem = get_memory_usage()
    logger.info(f"Model loaded. GPU memory allocated: {mem['allocated']:.2f} GB")
    
    return tokenizer, model

def chunked_file_reader(file_path, chunk_size=50000):
    """Read large files in chunks to avoid memory issues"""
    stories = []
    chunk_count = 0
    
    logger.info(f"Reading data in chunks from {file_path}")
    with open(file_path, "r", encoding="utf-8") as f:
        for i, line in enumerate(f):
            if i % 100000 == 0 and i > 0:
                logger.info(f"Read {i} lines so far")
            
            if line.strip():
                try:
                    stories.append(json.loads(line.strip())["text"])
                except json.JSONDecodeError:
                    continue
                    
            if len(stories) >= chunk_size:
                chunk_count += 1
                logger.info(f"Yielding chunk {chunk_count} with {len(stories)} stories")
                yield stories
                stories = []
    
    if stories:  # Don't forget the last chunk
        logger.info(f"Yielding final chunk with {len(stories)} stories")
        yield stories

def translate_chunk(chunk, model, tokenizer, lang_code, batch_size=32, max_length=MAX_INPUT_LENGTH):
    """Translate a chunk of text"""
    logger.info(f"Processing chunk of {len(chunk)} items with batch size {batch_size}")
    
    translated = []
    dataset = TextDataset(chunk)
    dataloader = DataLoader(
        dataset, 
        batch_size=batch_size,
        collate_fn=collate_fn,
        num_workers=0,  # Avoid fork issues with tokenizers
    )
    
    forced_bos_id = tokenizer.convert_tokens_to_ids(f"__{lang_code}__")
    tokenizer.src_lang = "eng_Latn"
    
    # Process in batches
    for batch_idx, batch_texts in enumerate(tqdm(dataloader, desc=f"Translating batch")):
        if batch_idx % MEMORY_CHECKPOINT_INTERVAL == 0 and batch_idx > 0:
            mem = get_memory_usage()
            logger.info(f"Memory check: {mem['allocated']:.2f} GB allocated, {mem['cached']:.2f} GB cached")
            if mem['allocated'] > 10:  # If using more than 10GB
                logger.info("High memory usage detected, cleaning up...")
                clear_memory()
        
        try:
            # Tokenize the batch
            encoded = tokenizer(
                batch_texts, 
                return_tensors="pt", 
                padding=True, 
                truncation=True,  # Enable truncation
                max_length=max_length,
            )
            
            # Move to device
            encoded = {k: v.to(model.device) for k, v in encoded.items()}
            
            # Generate translations
            with torch.no_grad(), torch.cuda.amp.autocast(enabled=True):
                generated_tokens = model.generate(
                    **encoded,
                    forced_bos_token_id=forced_bos_id,
                    max_length=int(max_length * 1.2),
                    num_beams=2,  # Faster beam search
                    early_stopping=True,
                    length_penalty=0.6,  # Slightly penalize length
                )
                
            # Decode translations
            batch_translations = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
            translated.extend(batch_translations)
            
        except RuntimeError as e:
            if "out of memory" in str(e).lower():
                logger.error(f"GPU OOM error. Clearing cache and retrying with smaller batch")
                clear_memory()
                
                # Try again with smaller batch size
                if len(batch_texts) > 1:
                    half_point = len(batch_texts) // 2
                    first_half = batch_texts[:half_point]
                    second_half = batch_texts[half_point:]
                    
                    # Process each half separately
                    for mini_batch in [first_half, second_half]:
                        mini_encoded = tokenizer(
                            mini_batch, 
                            return_tensors="pt", 
                            padding=True, 
                            truncation=True,
                            max_length=max_length,
                        )
                        mini_encoded = {k: v.to(model.device) for k, v in mini_encoded.items()}
                        
                        with torch.no_grad(), torch.cuda.amp.autocast(enabled=True):
                            mini_generated = model.generate(
                                **mini_encoded,
                                forced_bos_token_id=forced_bos_id,
                                max_length=int(max_length * 1.2),
                                num_beams=1,  # Use greedy search for recovery
                                early_stopping=True,
                            )
                            
                        mini_translations = tokenizer.batch_decode(mini_generated, skip_special_tokens=True)
                        translated.extend(mini_translations)
                else:
                    # If single sample causes OOM, truncate further
                    logger.warning(f"Single sample causing OOM. Truncating further to {max_length//2}")
                    mini_encoded = tokenizer(
                        batch_texts, 
                        return_tensors="pt", 
                        padding=True, 
                        truncation=True,
                        max_length=max_length//2,
                    )
                    mini_encoded = {k: v.to(model.device) for k, v in mini_encoded.items()}
                    
                    with torch.no_grad(), torch.cuda.amp.autocast(enabled=True):
                        mini_generated = model.generate(
                            **mini_encoded,
                            forced_bos_token_id=forced_bos_id,
                            max_length=max_length//2,
                            num_beams=1,  # Use greedy search for recovery
                        )
                        
                    mini_translations = tokenizer.batch_decode(mini_generated, skip_special_tokens=True)
                    translated.extend(mini_translations)
            else:
                logger.error(f"Error translating batch: {e}")
                # Add empty strings for this batch to maintain alignment
                translated.extend(["" for _ in range(len(batch_texts))])
                
    return translated

def main():
    parser = argparse.ArgumentParser(description="Optimized NLLB Translation Pipeline")
    parser.add_argument("--model", default="facebook/nllb-200-distilled-600M", help="Model name")
    parser.add_argument("--input", default="../../../scratch/data/data_pretrain/tinystories/eng_Latn_story.jsonl", help="Input file path")
    parser.add_argument("--output_dir", default="../../../scratch/data/data_pretrain/tinystories", help="Output directory")
    parser.add_argument("--batch_size", type=int, default=16, help="Batch size for translation")
    parser.add_argument("--chunk_size", type=int, default=50000, help="Number of stories to process in each chunk")
    parser.add_argument("--max_length", type=int, default=512, help="Max input length")
    parser.add_argument("--fp16", action="store_true", default=True, help="Use FP16 precision")
    parser.add_argument("--languages", nargs="+", default=["fra_Latn", "spa_Latn", "deu_Latn"], help="Target languages")
    args = parser.parse_args()
    
    # Create output directory
    os.makedirs(args.output_dir, exist_ok=True)
    
    # Load model and tokenizer
    tokenizer, model = load_model(args.model, use_fp16=args.fp16)
    
    # Process each language
    for lang_code in args.languages:
        output_path = os.path.join(args.output_dir, f"{lang_code}_story.jsonl")
        
        # Check if we should resume
        if os.path.exists(output_path):
            # Count existing examples
            with open(output_path, "r", encoding="utf-8") as f:
                existing_count = sum(1 for _ in f)
                
            if existing_count > 0:
                logger.info(f"Found existing file with {existing_count} translations. Skipping {lang_code}.")
                continue
        
        logger.info(f"Starting translation to {lang_code}")
        start_time = time.time()
        
        # Process in chunks to manage memory
        chunk_id = 0
        total_translated = 0
        
        with open(output_path, "w", encoding="utf-8") as out_f:
            # Process data in chunks
            for chunk in chunked_file_reader(args.input, chunk_size=args.chunk_size):
                chunk_id += 1
                chunk_time = time.time()
                
                logger.info(f"Processing chunk {chunk_id} ({len(chunk)} examples)")
                translations = translate_chunk(
                    chunk, 
                    model, 
                    tokenizer, 
                    lang_code, 
                    batch_size=args.batch_size,
                    max_length=args.max_length
                )
                
                # Write results
                for t in translations:
                    json.dump({"text": t}, out_f)
                    out_f.write("\n")
                
                # Update counters and log progress
                total_translated += len(translations)
                chunk_elapsed = time.time() - chunk_time
                total_elapsed = time.time() - start_time
                
                # Calculate speeds and ETA
                examples_per_sec = len(translations) / chunk_elapsed
                total_examples_per_sec = total_translated / total_elapsed
                
                logger.info(f"Chunk {chunk_id} completed: {len(translations)} examples in {chunk_elapsed:.2f}s")
                logger.info(f"Speed: {examples_per_sec:.2f} examples/sec for chunk, {total_examples_per_sec:.2f} examples/sec overall")
                
                # Clean up memory after each chunk
                clear_memory()
        
        total_time = time.time() - start_time
        logger.info(f"Completed translation to {lang_code}: {total_translated} examples in {total_time:.2f}s")
        logger.info(f"Average speed: {total_translated/total_time:.2f} examples/sec")
    
    logger.info("All translations completed!")

if __name__ == "__main__":
    main()

usage: ipykernel_launcher.py [-h] [--model MODEL] [--input INPUT]
                             [--output_dir OUTPUT_DIR]
                             [--batch_size BATCH_SIZE]
                             [--chunk_size CHUNK_SIZE]
                             [--max_length MAX_LENGTH] [--fp16]
                             [--languages LANGUAGES [LANGUAGES ...]]
ipykernel_launcher.py: error: argument --fp16: ignored explicit argument '/home/mila/x/xut/.local/share/jupyter/runtime/kernel-v3a4ebf84ddf5d92b39c10750abd0bc1162f22a7ca.json'


SystemExit: 2