In [1]:
import torch
import os
import glob
import shutil
from tqdm.auto import tqdm
from datasets import load_dataset
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig,
    DataCollatorForLanguageModeling,
    get_linear_schedule_with_warmup,
)
from bitsandbytes.optim import AdamW8bit
from torch.utils.data import DataLoader
from accelerate import Accelerator

In [2]:
MODEL_NAME = "mistralai/Mistral-Nemo-Base-2407"
# --- Dataset Configuration ---
DATA_DIR = "nemo-rawtext-data"  # Directory containing .txt files
TEST_SPLIT_SIZE = 0.1  # 10% of the data will be used for validation
BLOCK_SIZE = 1024      # Context window size for the model

# --- Training Configuration ---
OUTPUT_DIR = "mistral-nemo-continued-pretrain-accelerate"
NUM_TRAIN_EPOCHS = 1
LEARNING_RATE = 2e-4
BATCH_SIZE = 4        # This will be the per-device batch size
GRAD_ACCUM_STEPS = 4  # Gradient accumulation steps

# --- Checkpoint Management ---
EVAL_STEPS = 50       # Evaluate every N training steps
SAVE_TOTAL_LIMIT = 3  # Keep only the best 3 checkpoints

In [3]:
accelerator = Accelerator(gradient_accumulation_steps=GRAD_ACCUM_STEPS)

In [None]:
# Find all .txt files recursively in the specified directory
text_files = glob.glob(f'{DATA_DIR}/**/*.txt', recursive=True)
if not text_files:
    raise ValueError(f"No .txt files found in {DATA_DIR}. Please check the path.")

# Load the text files into a dataset
raw_datasets = load_dataset('text', data_files=text_files, split='train')

# Create a train/test split
split_datasets = raw_datasets.train_test_split(test_size=TEST_SPLIT_SIZE, shuffle=True, seed=42)
train_dataset = split_datasets['train']
eval_dataset = split_datasets['test']

accelerator.print(f"Train dataset size: {len(train_dataset)}")
accelerator.print(f"Validation dataset size: {len(eval_dataset)}")

# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

# Tokenizing and grouping functions (same as before)
def tokenize_function(examples):
    return tokenizer(examples["text"])

def group_texts(examples):
    concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()}
    total_length = len(concatenated_examples[list(examples.keys())[0]])
    if total_length >= BLOCK_SIZE:
        total_length = (total_length // BLOCK_SIZE) * BLOCK_SIZE
    result = {
        k: [t[i : i + BLOCK_SIZE] for i in range(0, total_length, BLOCK_SIZE)]
        for k, t in concatenated_examples.items()
    }
    result["labels"] = result["input_ids"].copy()
    return result

with accelerator.main_process_first():
    # Apply processing on the main process first to prevent race conditions
    tokenized_train = train_dataset.map(tokenize_function, batched=True, remove_columns=list(train_dataset.features))
    lm_train_dataset = tokenized_train.map(group_texts, batched=True)

    tokenized_eval = eval_dataset.map(tokenize_function, batched=True, remove_columns=list(eval_dataset.features))
    lm_eval_dataset = tokenized_eval.map(group_texts, batched=True)

In [None]:

# 4-bit quantization config
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_use_double_quant=False,
    bnb_4bit_compute_dtype=torch.bfloat16
)

# Load base model - IMPORTANT: we do not use device_map="auto" here.
# Accelerate will handle the device placement.
model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    quantization_config=bnb_config,
    torch_dtype=torch.bfloat16,
    trust_remote_code=True
)
model.config.use_cache = False

# PEFT/LoRA configuration
model = prepare_model_for_kbit_training(model)
lora_config = LoraConfig(
    r=16,
    lora_alpha=32,
    lora_dropout=0.05,
    target_modules=["q_proj", "v_proj", "k_proj", "o_proj"],
    bias="none",
    task_type="CAUSAL_LM"
)
peft_model = get_peft_model(model, lora_config)
peft_model.print_trainable_parameters()


In [None]:

# Data Collator and DataLoaders
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
train_dataloader = DataLoader(lm_train_dataset, shuffle=True, collate_fn=data_collator, batch_size=BATCH_SIZE)
eval_dataloader = DataLoader(lm_eval_dataset, collate_fn=data_collator, batch_size=BATCH_SIZE)

# Optimizer and Scheduler
optimizer = AdamW8bit(peft_model.parameters(), lr=LEARNING_RATE)
num_update_steps_per_epoch = len(train_dataloader) // GRAD_ACCUM_STEPS
total_training_steps = NUM_TRAIN_EPOCHS * num_update_steps_per_epoch

lr_scheduler = get_linear_schedule_with_warmup(
    optimizer=optimizer,
    num_warmup_steps=0,
    num_training_steps=total_training_steps,
)

# Prepare everything with Accelerate
peft_model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = accelerator.prepare(
    peft_model, optimizer, train_dataloader, eval_dataloader, lr_scheduler
)


In [None]:
progress_bar = tqdm(range(total_training_steps), disable=not accelerator.is_local_main_process)
completed_steps = 0
best_checkpoints = [] # List to store (validation_loss, path)

for epoch in range(NUM_TRAIN_EPOCHS):
    peft_model.train()
    for step, batch in enumerate(train_dataloader):
        with accelerator.accumulate(peft_model):
            outputs = peft_model(**batch)
            loss = outputs.loss
            accelerator.backward(loss)

            optimizer.step()
            lr_scheduler.step()
            optimizer.zero_grad()

        if accelerator.sync_gradients:
            progress_bar.update(1)
            completed_steps += 1

        # --- Evaluation and Checkpointing Logic ---
        if completed_steps > 0 and completed_steps % EVAL_STEPS == 0:
            peft_model.eval()
            losses = []
            for eval_step, eval_batch in enumerate(eval_dataloader):
                with torch.no_grad():
                    outputs = peft_model(**eval_batch)
                loss = outputs.loss
                losses.append(accelerator.gather_for_metrics(loss.repeat(BATCH_SIZE)))

            losses = torch.cat(losses)
            try:
                eval_loss = torch.mean(losses)
                perplexity = torch.exp(eval_loss)
            except OverflowError:
                perplexity = float("inf")

            accelerator.print(f"Step {completed_steps}: Eval Loss: {eval_loss:.4f}, Perplexity: {perplexity:.4f}")

            # --- Save and Manage Checkpoints on Main Process ---
            if accelerator.is_main_process:
                checkpoint_dir = os.path.join(OUTPUT_DIR, f"checkpoint-{completed_steps}")
                accelerator.save_state(checkpoint_dir)
                
                # Add checkpoint to our list
                best_checkpoints.append((eval_loss, checkpoint_dir))
                # Sort checkpoints by loss (best first)
                best_checkpoints.sort(key=lambda x: x[0])

                # If we have more checkpoints than the limit, remove the worst one
                if len(best_checkpoints) > SAVE_TOTAL_LIMIT:
                    worst_checkpoint_loss, worst_checkpoint_dir = best_checkpoints.pop()
                    accelerator.print(f"Removing old checkpoint with loss {worst_checkpoint_loss:.4f}: {worst_checkpoint_dir}")
                    if os.path.exists(worst_checkpoint_dir):
                        shutil.rmtree(worst_checkpoint_dir)

            peft_model.train() # Switch back to training mode

        if completed_steps >= total_training_steps:
            break
    if completed_steps >= total_training_steps:
        break


In [None]:

if accelerator.is_main_process:
    final_adapter_dir = os.path.join(OUTPUT_DIR, "final_adapter")
    os.makedirs(final_adapter_dir, exist_ok=True)
    
    # To save the PEFT adapter, we need the unwrapped model
    unwrapped_model = accelerator.unwrap_model(peft_model)
    unwrapped_model.save_pretrained(final_adapter_dir)
    
    accelerator.print(f"Final PEFT adapter saved to {final_adapter_dir}")
    accelerator.print("Training complete.")
