In [None]:
import torch
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments, Trainer, DataCollatorWithPadding
from peft import LoraConfig, get_peft_model, PeftModel
from torch.utils.data import default_collate

In [None]:
model_name = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"

In [None]:
# Tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True, cache_dir="/media/tamal/New_HardDrive/Machine Learning, AI/Aaladin AI/model and tokenizers")
# Add padding
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

In [None]:
# Load model
base_model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype=torch.float16,
    device_map="auto",
    cache_dir="/media/tamal/New_HardDrive/Machine Learning, AI/Aaladin AI/model and tokenizers"
)

### LoRA config

In [None]:
lora_config = LoraConfig(
    r=8,
    lora_alpha=16,
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM",
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "up_proj", "down_proj", "gate_proj"],
)
model = get_peft_model(base_model, lora_config)

In [None]:
# Load Dataset
ds = load_dataset("databricks/databricks-dolly-15k", split="train[:10000]")


In [None]:
ds

In [None]:
ds['instruction']


In [None]:
ds['response']

### Preprocess

In [None]:
# Build chat using the model's template
def build_messages(instruction, context, response):
    # instruction (user ask), context (optional), response (assistant)
    user_text = instruction if instruction else ""
    if context and len(context.strip()) > 0:
        user_text = f"{user_text}\n\nContext:\n{context}".strip()

    # messages list in HF chat format
    msgs = [
        {"role": "user", "content": user_text},
        {"role": "assistant", "content": response if response else ""},
    ]
    return msgs



#  Preprocess: create prompt-only and prompt+response encodings, then mask labels
max_len = 1024

def preprocess_batch(batch):
    prompts_text = []        # up to assistant start (no answer text)
    full_text = []           # prompt + assistant response

    for instr, ctx, resp in zip(batch["instruction"], batch["context"], batch["response"]):
        msgs = build_messages(instr, ctx, resp)

        # prompt-only text (generation prompt = True adds assistant header)
        prompt_only = tokenizer.apply_chat_template(
            msgs[:1],  # only the user message
            tokenize=False,
            add_generation_prompt=True,  # adds the assistant prefix the model expects
        )
        # full text with assistant message included
        full = tokenizer.apply_chat_template(
            msgs,
            tokenize=False,
            add_generation_prompt=False,
        )

        prompts_text.append(prompt_only)
        full_text.append(full)

    # Tokenize both
    tok_prompt = tokenizer(
        prompts_text,
        max_length=max_len,
        truncation=True,
        padding="max_length",
        add_special_tokens=True,
        return_tensors="pt",
    )
    tok_full = tokenizer(
        full_text,
        max_length=max_len,
        truncation=True,
        padding="max_length",
        add_special_tokens=True,
        return_tensors="pt",
    )

    input_ids = tok_full["input_ids"]
    attention_mask = tok_full["attention_mask"]

    # Labels start as a copy of input_ids
    labels = input_ids.clone()

    # Mask everything that belongs to the prompt-only prefix
    # For each sample, find the prompt length by re-tokenizing prompt-only
    prompt_len = (tok_prompt["attention_mask"].sum(dim=1)).tolist()  # length per example

    for i, p_len in enumerate(prompt_len):
        # mask prompt tokens
        labels[i, :p_len] = -100
        # also mask padding positions anywhere attention_mask == 0
        labels[i, attention_mask[i] == 0] = -100

    return {
        "input_ids": input_ids,
        "attention_mask": attention_mask,
        "labels": labels,
    }

proc = ds.map(
    preprocess_batch,
    batched=True,
    remove_columns=ds.column_names,
    desc="Formatting with chat template + masking",
)


In [None]:
# Data collator — just pad; labels already masked
class LMDataCollator:
    def __init__(self, tokenizer):
        self.tokenizer = tokenizer

    def __call__(self, features):
        # Convert lists to tensors before stacking
        for f in features:
            for k, v in f.items():
                if not isinstance(v, torch.Tensor):
                    f[k] = torch.tensor(v, dtype=torch.long)
        return default_collate(features)

collator = LMDataCollator(tokenizer)

## Training

In [None]:

# Training
args = TrainingArguments(
    output_dir="./deepseek-lora-fixed",
    per_device_train_batch_size=1,          # raise if VRAM allows
    gradient_accumulation_steps=4,
    num_train_epochs=2,
    learning_rate=2e-4,
    fp16=True,
    logging_steps=50,
    save_strategy="epoch",
    save_total_limit=1,
    report_to="none",
)

trainer = Trainer(
    model=model,
    args=args,
    train_dataset=proc,
    data_collator=collator,
    tokenizer=tokenizer,
)


In [None]:
trainer.train()

In [None]:
trainer.save_model("./deepseek-lora-fixed")
tokenizer.save_pretrained("./deepseek-lora-fixed")


# Inference

In [None]:
base = AutoModelForCausalLM.from_pretrained(
    model_name, torch_dtype=torch.float16, device_map="auto"
)
ft = PeftModel.from_pretrained(base, "./deepseek-lora-fixed")
ft = ft.merge_and_unload()  # merge LoRA into base weights
ft.eval()

In [None]:
def chat_infer(user_text, max_new_tokens=128):
    messages = [{"role": "user", "content": user_text}]
    prompt = tokenizer.apply_chat_template(
        messages, tokenize=True, add_generation_prompt=True, return_tensors="pt"
    ).to(ft.device)

    with torch.inference_mode():
        out = ft.generate(
            input_ids=prompt,
            max_new_tokens=max_new_tokens,
            temperature=0.8,
            top_p=0.9,
            do_sample=True,
            repetition_penalty=1.15,
            eos_token_id=tokenizer.eos_token_id,
            pad_token_id=tokenizer.pad_token_id,
        )
    text = tokenizer.decode(out[0], skip_special_tokens=True)
    return text

print("\n--- Inference sample ---")
print(chat_infer("When did Virgin Australia start operating?"))