In [None]:
!pip -q install -U "transformers>=4.45" "datasets>=2.20" "accelerate>=0.33" peft trl bitsandbytes

In [None]:
import os
import torch
from datasets import load_dataset

from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    BitsAndBytesConfig,
    TrainingArguments,
)
from peft import LoraConfig
from trl import SFTTrainer


In [None]:
ds = load_dataset("flax-sentence-embeddings/stackexchange_title_best_voted_answer_jsonl", split="crypto")
print(ds[0].keys(), len(ds))
print("\nSAMPLE QUESTION:\n", ds[0]["title_body"][:800])
print("\nSAMPLE ANSWER:\n", ds[0]["upvoted_answer"][:800])


In [None]:
ds = ds.shuffle(seed=42)
split = ds.train_test_split(test_size=0.02, seed=42)
train_ds = split["train"]
eval_ds  = split["test"]

print(len(train_ds), len(eval_ds))


In [None]:
MODEL_NAME = "microsoft/Phi-3-mini-4k-instruct"

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=True)

if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token


In [None]:
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_use_double_quant=True,
    bnb_4bit_compute_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float16,
)

model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    quantization_config=bnb_config,
    device_map="auto",
    torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float16,
)

model.config.use_cache = False


In [None]:
SYSTEM_PROMPT = (
    "You are a helpful assistant specialized in cryptography. "
    "Explain concepts clearly, be precise with definitions, and when helpful include short examples."
)

def to_chat_text(example):
    messages = [
        {"role": "system", "content": SYSTEM_PROMPT},
        {"role": "user", "content": example["title_body"]},
        {"role": "assistant", "content": example["upvoted_answer"]},
    ]

    if hasattr(tokenizer, "apply_chat_template"):
        return tokenizer.apply_chat_template(
            messages,
            tokenize=False,
            add_generation_prompt=False
        )
    else:
        return (
            f"<|system|>\n{SYSTEM_PROMPT}\n"
            f"<|user|>\n{example['title_body']}\n"
            f"<|assistant|>\n{example['upvoted_answer']}\n"
        )

train_ds = train_ds.map(lambda ex: {"text": to_chat_text(ex)}, remove_columns=train_ds.column_names)
eval_ds  = eval_ds.map(lambda ex: {"text": to_chat_text(ex)},  remove_columns=eval_ds.column_names)

print(train_ds[0]["text"][:1200])


In [None]:
peft_config = LoraConfig(
    r=16,
    lora_alpha=32,
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM",
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
)


In [None]:
OUTPUT_DIR = "./phi3_crypto_lora"

args = TrainingArguments(
    output_dir=OUTPUT_DIR,
    per_device_train_batch_size=2,
    per_device_eval_batch_size=2,
    gradient_accumulation_steps=8,
    learning_rate=2e-4,
    num_train_epochs=1,
    lr_scheduler_type="cosine",
    warmup_ratio=0.03,
    logging_steps=20,
    eval_strategy="steps",
    eval_steps=200,
    save_steps=200,
    save_total_limit=2,
    bf16=torch.cuda.is_available(),
    fp16=(not torch.cuda.is_available()),
    gradient_checkpointing=True,
    report_to="none",
)


In [None]:
trainer = SFTTrainer(
    model=model,
    tokenizer=tokenizer,
    train_dataset=train_ds,
    eval_dataset=eval_ds,
    dataset_text_field="text",
    max_seq_length=2048,
    peft_config=peft_config,
    args=args,
    packing=True,
)

trainer.train()


In [None]:
trainer.model.save_pretrained(OUTPUT_DIR)
tokenizer.save_pretrained(OUTPUT_DIR)
print("Saved LoRA adapter to:", OUTPUT_DIR)


In [None]:
prompt = "Explain the difference between semantic security (IND-CPA) and IND-CCA2. Keep it concise."

messages = [
    {"role": "system", "content": SYSTEM_PROMPT},
    {"role": "user", "content": prompt},
]

text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
inputs = tokenizer(text, return_tensors="pt").to(trainer.model.device)

with torch.no_grad():
    out = trainer.model.generate(
        **inputs,
        max_new_tokens=220,
        do_sample=True,
        temperature=0.7,
        top_p=0.9,
    )

print(tokenizer.decode(out[0], skip_special_tokens=True))
