In [None]:
!pip install -q transformers datasets

import os, time, glob
# disable W&B & telemetry
os.environ["WANDB_DISABLED"] = "true"
os.environ["WANDB_MODE"] = "offline"
os.environ["HF_HUB_DISABLE_TELEMETRY"] = "1"

# ---------------- CONFIG ----------------
MERGED_FILE = "/content/merged_corpus.txt"
OUT_DIR = "/content/roberta_pretrained_continued"
TMP_SHARD_DIR = "/content/rt_shards"     # temporary local shards
COMBINED_DS_DIR = "/content/rt_combined" # saved combined dataset (optional)
LOG_FILE = "/content/rt_training_log.txt"

# tokenization / block params
MAX_LEN = 512
TOKENIZE_BATCH = 2000    # docs per tokenization batch (adjust based on memory)
DOC_SEPARATOR = True     # insert EOS token between docs for boundary safety

# training params (tuned for single T4)
PER_DEVICE_BATCH = 4
GRAD_ACC = 8
EPOCHS = 3
LR = 2e-5
FP16 = True
SAVE_STEPS = 2000
LOGGING_STEPS = 200
NUM_PROC = 2             # dataset.map workers where used

os.makedirs(TMP_SHARD_DIR, exist_ok=True)
os.makedirs(OUT_DIR, exist_ok=True)
os.makedirs(COMBINED_DS_DIR, exist_ok=True)

# ---------------- helper: find latest checkpoint ----------------
def find_latest_checkpoint(path):
    ckpts = sorted(glob.glob(os.path.join(path, "checkpoint-*")), key=os.path.getmtime)
    return ckpts[-1] if ckpts else None

resume_ckpt = find_latest_checkpoint(OUT_DIR)
if resume_ckpt:
    print("Auto-detected checkpoint to resume:", resume_ckpt)
else:
    print("No checkpoint found in", OUT_DIR, "- will continue from roberta-base")

# ---------------- sanity check ----------------
assert os.path.exists(MERGED_FILE), f"merged corpus not found at {MERGED_FILE}"

# ---------------- Step 1: load docs ----------------
print("Loading merged corpus...")
with open(MERGED_FILE, "r", encoding="utf-8") as f:
    docs = [d.strip() for d in f.read().split("\n\n") if d.strip()]
print("Total documents:", len(docs))

# ---------------- Step 2: tokenizer & tokenization (batched) ----------------
from transformers import RobertaTokenizerFast
tokenizer = RobertaTokenizerFast.from_pretrained("roberta-base")
tokenizer.model_max_length = 10**12  # avoid >512 warnings while tokenizing before blockifying

print("Tokenizing documents in batches...")
all_ids = []   # list of lists (per-doc token ids)
for i in range(0, len(docs), TOKENIZE_BATCH):
    batch = docs[i:i+TOKENIZE_BATCH]
    enc = tokenizer(batch, add_special_tokens=False)
    ids = enc["input_ids"]
    all_ids.extend(ids)
    print(f"Tokenized docs {i}..{i+len(batch)-1} -> total doc-tokens lists: {len(all_ids)}")

# optionally free docs list to save RAM
# del docs

# ---------------- Step 3: concatenate token ids and create contiguous 512-token blocks ----------------
print("Concatenating token ids and creating 512-token blocks...")
flat = []
EOS_ID = tokenizer.eos_token_id if tokenizer.eos_token_id is not None else tokenizer.sep_token_id
for seq in all_ids:
    if DOC_SEPARATOR:
        # add EOS between docs to avoid cross-doc bleed
        flat.extend(seq + ([EOS_ID] if EOS_ID is not None else []))
    else:
        flat.extend(seq)

total_tokens = len(flat)
num_blocks = total_tokens // MAX_LEN
total_used = num_blocks * MAX_LEN
print("Total tokens (concatenated):", total_tokens)
print("Number of full 512-token blocks:", num_blocks)
if num_blocks == 0:
    raise RuntimeError("Not enough tokens to form a single 512-token block. Check your corpus.")

# create blocks (list of lists)
blocks = [flat[i:i+MAX_LEN] for i in range(0, total_used, MAX_LEN)]
print("Created blocks length:", len(blocks))

# free memory for 'flat' and 'all_ids' if needed
del flat
del all_ids

# ---------------- Step 4: convert blocks into HF Dataset ----------------
from datasets import Dataset
print("Building Hugging Face Dataset from blocks (this may take a moment)...")
ds = Dataset.from_dict({"input_ids": blocks})
# create attention masks (all ones since blocks are full-length)
import torch
attn = [[1]*MAX_LEN for _ in range(len(blocks))]
ds = ds.add_column("attention_mask", attn)
print("Dataset created. Examples:", len(ds))
# save for reuse
try:
    ds.save_to_disk(COMBINED_DS_DIR)
    print("Saved combined dataset to:", COMBINED_DS_DIR)
except Exception as e:
    print("Warning: could not save combined dataset:", e)

# set format to torch
ds.set_format(type="torch", columns=["input_ids", "attention_mask"])

# ---------------- Step 5: prepare model & collator (continued pretraining) ----------------
from transformers import RobertaForMaskedLM, DataCollatorForLanguageModeling
print("Loading roberta-base model for continued pretraining...")
model = RobertaForMaskedLM.from_pretrained("roberta-base")

data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=True, mlm_probability=0.15)

# ---------------- Step 6: TrainingArguments & Trainer ----------------
from transformers import TrainingArguments, Trainer
training_args = TrainingArguments(
    output_dir=OUT_DIR,
    overwrite_output_dir=False,
    num_train_epochs=EPOCHS,
    per_device_train_batch_size=PER_DEVICE_BATCH,
    gradient_accumulation_steps=GRAD_ACC,
    learning_rate=LR,
    fp16=FP16,
    logging_steps=LOGGING_STEPS,
    save_steps=SAVE_STEPS,
    save_total_limit=3,
    remove_unused_columns=False,
    report_to=[],   # disable tracking/reporting
)

trainer = Trainer(
    model=model,
    args=training_args,
    data_collator=data_collator,
    train_dataset=ds,
)

# ---------------- Step 7: Train (resume if checkpoint exists) ----------------
print("Starting training. Resume checkpoint:", resume_ckpt)
start_time = time.time()
try:
    if resume_ckpt:
        trainer.train(resume_from_checkpoint=resume_ckpt)
    else:
        trainer.train()
except KeyboardInterrupt:
    print("Interrupted by user.")
except Exception as e:
    print("Training crashed:", repr(e))
    raise
finally:
    # save final model and tokenizer
    trainer.save_model(OUT_DIR)
    tokenizer.save_pretrained(OUT_DIR)
    elapsed = time.time() - start_time
    print(f"Training finished (or stopped). Time elapsed: {elapsed/60:.2f} minutes")
    print("Model + tokenizer saved to:", OUT_DIR)
    print("Logs saved to:", LOG_FILE)