In [None]:
# using opacus

import logging
import os
import pickle

import numpy as np
import torch
import torch.optim as optim
from datasets import load_dataset
from opacus import PrivacyEngine
from opacus.utils.batch_memory_manager import BatchMemoryManager
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
from torch.utils.data import DataLoader
from tqdm.notebook import tqdm
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig


# -----------------------
# Logging
# -----------------------
logging.basicConfig(level=logging.INFO, format="%(asctime)s | %(levelname)s | %(message)s")
logger = logging.getLogger(__name__)


# -----------------------
# User-editable settings (no hard-coded paths)
# -----------------------
DATASET_NAME = "your_dataset_name_or_path"          # e.g., "OpenAssistant/oasst_top1_2023-08-25" or local path
DATASET_SPLIT = "train"
SUBSET_SIZE = 1304

BASE_MODEL_ID = "your-model-id"                     # e.g., "meta-llama/Llama-2-7b-chat-hf", "meta-llama/Meta-Llama-3.1-8B-Instruct"
MAX_LENGTH = 1024                                   # e.g., 1536, 2048

DEVICE_INDEX = 0                                    # choose GPU index if available

OUTPUT_DIR = "./saved_model"
OUTPUT_NAME = "dp_lora_model.pth"


# -----------------------
# Training / DP settings: User-editable
# -----------------------
BATCH_SIZE = 32
MAX_PHYSICAL_BATCH_SIZE = 2

EPOCHS = 4                                   # e.g., 10
LR = 1e-4                                    # e.g., 5e-4, 2e-5
ADAM_EPS = 1e-8

EPSILON = 1.0                                # e.g., 4.0, 8.0
DELTA = 1.0 / SUBSET_SIZE
MAX_GRAD_NORM = 1.0


# -----------------------
# Helpers
# -----------------------
def print_trainable_parameters(model):
    trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
    total = sum(p.numel() for p in model.parameters())
    logger.info(
        "Trainable params: %d | All params: %d | Trainable%%: %.2f%%",
        trainable, total, (100 * trainable / total) if total else 0.0
    )


def custom_collate_fn(batch):
    # Keep only what you actually use in the forward pass (input_ids)
    input_ids = torch.stack([item["input_ids"] for item in batch])
    return {"input_ids": input_ids}


# -----------------------
# Device
# -----------------------
device = torch.device(f"cuda:{DEVICE_INDEX}" if torch.cuda.is_available() else "cpu")
device_map = {"": DEVICE_INDEX} if torch.cuda.is_available() else None


# -----------------------
# Load dataset
# -----------------------
dataset = load_dataset(DATASET_NAME)

train_dataset = dataset[DATASET_SPLIT].select(range(SUBSET_SIZE))

# -----------------------
# Tokenizer + tokenize dataset
# -----------------------
bnb_config = BitsAndBytesConfig(load_in_8bit=True)

tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_ID, add_eos_token=True, add_bos_token=True, legacy=True)
tokenizer.pad_token = tokenizer.eos_token

def tokenize_function(examples):
    return tokenizer(
        examples[TEXT_FIELD],
        truncation=True,
        padding="max_length",
        max_length=MAX_LENGTH,
    )

tokenized = train_dataset.map(tokenize_function, batched=True, remove_columns=train_dataset.column_names)
tokenized.set_format(type="torch", columns=["input_ids"])
train_dataset = tokenized


# -----------------------
# Model + LoRA
# -----------------------
model = AutoModelForCausalLM.from_pretrained(
    BASE_MODEL_ID,
    quantization_config=bnb_config,
    low_cpu_mem_usage=True,
    device_map=device_map,
)

model.config.pad_token_id = tokenizer.pad_token_id
model.config.use_cache = False
model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False})

model = prepare_model_for_kbit_training(model)

lora_config = LoraConfig(
    r=8,                              # 4
    lora_alpha=32,
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj", "lm_head"],
    bias="none",
    lora_dropout=0.05,                # 0.0
    task_type="CAUSAL_LM",
)
model = get_peft_model(model, lora_config)
model.train()

print_trainable_parameters(model)

optimizer = optim.AdamW(model.parameters(), lr=LR, eps=ADAM_EPS)


# -----------------------
# DataLoader + DP
# -----------------------
train_dataloader = DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    collate_fn=custom_collate_fn,
    num_workers=1,
    pin_memory=True,
    drop_last=True,
)

privacy_engine = PrivacyEngine()
model, optimizer, train_dataloader = privacy_engine.make_private_with_epsilon(
    module=model,
    optimizer=optimizer,
    data_loader=train_dataloader,
    target_epsilon=EPSILON,
    target_delta=DELTA,
    max_grad_norm=MAX_GRAD_NORM,
    epochs=EPOCHS,
)


# -----------------------
# Train loop
# -----------------------
for epoch in range(1, EPOCHS + 1):
    losses = []

    with BatchMemoryManager(
        data_loader=train_dataloader,
        max_physical_batch_size=MAX_PHYSICAL_BATCH_SIZE,
        optimizer=optimizer,
    ) as memory_safe_loader:
        for batch in tqdm(memory_safe_loader):
            inputs = batch["input_ids"].to(device)

            optimizer.zero_grad(set_to_none=True)
            outputs = model(inputs, labels=inputs)
            loss = outputs.loss
            loss.backward()
            optimizer.step()

            losses.append(loss.item())

    eps = privacy_engine.get_epsilon(delta=DELTA)
    logger.info("Epoch %d | Loss: %.4f | (ε=%.2f, δ=%g)", epoch, float(np.mean(losses)), eps, DELTA)


# -----------------------
# Save
# -----------------------
os.makedirs(OUTPUT_DIR, exist_ok=True)
save_path = os.path.join(OUTPUT_DIR, OUTPUT_NAME)
torch.save(model.state_dict(), save_path)
logger.info("Saved model state_dict to %s", save_path)
