# 3. Model Training

**Purpose:**  
Fine-tune two LLaMA-3-8B-Instruct models (Question Parser and CoT Parser) using LoRA adapters and our ICL prompt templates.

**Training Inputs:**  
- `train_question_parsing.jsonl`  
- `train_cot_parsing.jsonl`  

**Key Config:**
- LoRA: rank=64, α=16, dropout=0.05
- Model: `unsloth/llama-3-8b-Instruct-bnb-4bit`
- Epochs: 12, Batch Size: 8 (×2 grad steps)

**Outputs:**  
- `…/llm-sr-project/finetuned_llama3_qp_parsing/` (LoRA-adapter weights + tokenizer)  
- `…/llm-sr-project/finetuned_llama3_cot_parsing/`  


## Environments and Imports

In [None]:
# Install core evaluation utilities
!pip install -q evaluate
!pip install json5

!pip uninstall -y nltk
!pip install -q --upgrade nltk

In [None]:
# Install Unsloth for efficient LLM fine-tuning
%%capture
!pip install "unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git"
!pip install --no-deps "xformers<0.0.27" "trl<0.9.0" peft accelerate bitsandbytes

In [None]:
# ─────────────────────────────────────────────────────────────────────────────
# Define In-Context Learning Demonstrations and Prompt Templates
# ─────────────────────────────────────────────────────────────────────────────

# QP_DEMON: One-shot example for Question Parsing
QP_DEMON = '''The question is:

There are 6 volunteers: A, B, C, D, E and F. They will be assigned to either Project Alpha or Project Beta. Each person works on exactly one project. This assignment must satisfy:
(1) If A works on Alpha, then B works on Beta.
(2) If C works on Alpha, then D and E work on Beta.
(3) F works on a different project than E.
(4) D must work on a different project than A.
(5) If F works on Alpha, then B works on Alpha.

If A works on Beta, which of the following must be true?
A. B works on Alpha
B. C works on Beta
C. D works on Alpha
D. F works on Beta

The parsing result is:

[
  "There are 6 volunteers: A, B, C, D, E and F. They will be assigned to either Project Alpha or Project Beta. Each person works on exactly one project.",
  "If A works on Alpha, then B works on Beta",
  "If C works on Alpha, then D and E work on Beta",
  "F works on a different project than E",
  "D must work on a different project than A",
  "If F works on Alpha, then B works on Alpha",
  "A works on Beta"
]
'''

# QP_TEMPLATE: Formats a new question with the demonstration
QP_TEMPLATE = '''Given a question, extract all relevant information from the question that would help to solve it.

This includes:
- General setup information (e.g., number of people, projects involved)
- Explicit facts given in the question
- All logical constraints or conditions

Output only a JSON list and nothing else. Follow the format shown in the example.

Example:

{demon}

Now, the question is:

{question}

Your output:
'''

# CP_DEMON: One-shot example for Chain-of-Thought Parsing
CP_DEMON = '''The question is:

There are 6 volunteers: A, B, C, D, E and F. They will be assigned to either Project Alpha or Project Beta. Each person works on exactly one project.

Conditions:
(1) If A works on Alpha, then B works on Beta.
(2) If C works on Alpha, then D and E work on Beta.
(3) F works on a different project than E.
(4) D must work on a different project than A.
(5) If F works on Alpha, then B works on Alpha.

Question:
If A works on Beta, which of the following must be true?

CoT:
Since A works on Beta, Condition (1) is not triggered. Condition (2) is not triggered since C’s assignment is unknown. Condition (3) doesn’t give anything because E’s assignment is unspecified. Condition (4) says D must work on a different project than A, so D must work on Alpha. Condition (5) depends on F, which is unknown.

Parsing result:

[
  {
    "statement": "Condition (1) is not applicable",
    "evidence": "Condition (1): If A works on Alpha, then B works on Beta. | A is working on Beta",
    "Verification": "false"
  },
  {
    "statement": "Condition (2) is not applicable",
    "evidence": "Condition (2): If C works on Alpha, then D and E work on Beta. | C’s assignment is unknown",
    "Verification": "false"
  },
  {
    "statement": "Condition (3) does not provide any info",
    "evidence": "Condition (3): F works on a different project than E. | E’s assignment is unknown",
    "Verification": "false"
  },
  {
    "statement": "D must work on Alpha",
    "evidence": "Condition (4): D must work on a different project than A, and A is working on Beta",
    "Verification": "true"
  },
  {
    "statement": "Condition (5) is not applicable",
    "evidence": "Condition (5): If F works on Alpha, then B works on Alpha. | F’s assignment is unknown",
    "Verification": "false"
  }
]
'''
# CP_TEMPLATE: Formats a question + CoT + conditions for CP model
CP_TEMPLATE = '''You are a reasoning assistant. Based on the question, conditions, and chain-of-thought (CoT), extract every inference or non-inference step as a JSON object.

For each CoT sentence that either:
  1. Refers to a condition (e.g. “Condition (2) …”)
  2. Starts with an inference cue (“Since”, “Therefore”, “This means”, “We can deduce”, etc.)

Produce one object with:
  • "statement": the new claim you read in that CoT sentence (don’t quote the entire sentence—just the core inference).
  • "evidence":
      – if the claim restates a constraint, use the exact line from the **Conditions** block,
      – otherwise, use the CoT fragment that you extracted it from.
  • "Verification":
      – `"false"` if the sentence rejects or blocks a condition (contains “not applicable”, “does not provide”, etc.),
      – otherwise `"true"`.

Keep the objects in the same order as they appear in the CoT.

Example:

{demon}

Now, given:

Question:
{question}

Conditions:
{conditions}

Chain-of-Thought:
{cot}

Your output:
'''

In [None]:
import unsloth
import os
import shutil
import torch
from datasets import load_dataset
from transformers import TrainingArguments, Trainer
from unsloth import FastLanguageModel

## Model Factory

In [None]:
# Factory to load LLaMA-3 + LoRA adapter
def make_lora_model(model_name, max_length):
    model, tokenizer = FastLanguageModel.from_pretrained(
        model_name,
        max_seq_length=max_length,
        dtype=torch.float16,
        load_in_4bit=True
    )
    model = FastLanguageModel.get_peft_model(
        model,
        r=64, lora_alpha=16, lora_dropout=0.05,
        target_modules=["q_proj","k_proj","v_proj","o_proj"],
        bias="none", random_state=42,
        max_seq_length=max_length
    )
    return model, tokenizer

# Mask out prompt tokens so only the generated part contributes to loss
def preprocess_fn(batch, tokenizer, max_length):
    prompt, output = batch["input"], batch["output"]
    full = tokenizer(prompt + output,
                     truncation=True, padding="max_length", max_length=max_length)
    prompt_ids = tokenizer(prompt, truncation=True, max_length=max_length)["input_ids"]
    labels = full["input_ids"].copy()
    labels[:len(prompt_ids)] = [-100] * len(prompt_ids)
    full["labels"] = labels
    return full

## Fine-Tune Question Parsing (QP)

In [None]:
# Paths
PROJECT_DIR  = "/content/drive/MyDrive/llm-sr-project"
TRAIN_FILE   = os.path.join(PROJECT_DIR, "train_question_parsing.jsonl")
MODEL_OUTPUT = os.path.join(PROJECT_DIR, "qp_model")
MODEL_NAME   = "unsloth/llama-3-8b-Instruct-bnb-4bit"
MAX_LEN      = 2048

# Load and tokenize
ds_qp = load_dataset("json", data_files={"train": TRAIN_FILE})["train"]
qp_model, qp_tokenizer = make_lora_model(MODEL_NAME, MAX_LEN)

qp_tokenizer_ds = ds_qp.map(
    lambda x: preprocess_fn(x, qp_tokenizer, max_length=1024),
    batched=False, remove_columns=ds_qp.column_names
)

# Trainer
args_qp = TrainingArguments(
    output_dir=MODEL_OUTPUT,
    num_train_epochs=12,
    per_device_train_batch_size=8,
    gradient_accumulation_steps=2,
    learning_rate=5e-5,
    fp16=True,
    save_strategy="epoch",
    logging_strategy="epoch",
    report_to="none",
    eval_strategy="no"
)
trainer_qp = Trainer(
    model=qp_model,
    args=args_qp,
    train_dataset=qp_tokenizer_ds,
    tokenizer=qp_tokenizer
)
trainer_qp.train()

# Save
lora_qp_dir = "/content/drive/MyDrive/llm-sr-project/finetuned_llama3_question_parsing"

qp_model.save_pretrained(lora_qp_dir)
qp_tokenizer.save_pretrained(lora_qp_dir)

# Create a ZIP archive
shutil.make_archive(lora_qp_dir, 'zip', lora_qp_dir)

print(f"✅ QP model saved and zipped at: {lora_qp_dir}.zip")

## Fine-Tune Chain-Of-Thought (CoT) Parsing

In [None]:
# Load CoT model (using model factory)
cot_model, cot_tokenizer = make_lora_model(MODEL_NAME, MAX_LEN)
MAX_LEN = 2048

# Paths
PROJECT_DIR = "/content/drive/MyDrive/llm-sr-project"
TRAIN_COT_FILE = os.path.join(PROJECT_DIR, "train_cot_parsing.jsonl")
COT_MODEL_OUTPUT = os.path.join(PROJECT_DIR, "cot_model")


# Load CoT dataset
cot_data = load_dataset("json", data_files={
    "train": TRAIN_COT_FILE,
})

cot_tokenizer_ds = cot_data.map(
    lambda x: preprocess_fn(x, cot_tokenizer, MAX_LEN),
    batched=False, remove_columns=cot_data.column_names
)

# Training
cot_args = TrainingArguments(
    output_dir = COT_MODEL_OUTPUT,
    per_device_train_batch_size = 8,
    gradient_accumulation_steps = 2,
    learning_rate = 5e-5,
    num_train_epochs = 12,
    fp16 = True,
    save_strategy = "epoch",
    eval_strategy = "no",
    logging_strategy = "epoch",
    logging_steps = 1,
    save_total_limit = 2,
    load_best_model_at_end = False,
    metric_for_best_model = "loss",
    greater_is_better = False,
    report_to = "none",
)

cot_trainer = Trainer(
    model = cot_model,
    args = cot_args,
    train_dataset = cot_tokenizer_ds,
    eval_dataset= None,
    tokenizer = cot_tokenizer,
    callbacks = None
)

cot_trainer.train()

# Save
lora_cot_dir = os.path.join(PROJECT_DIR, "finetuned_llama3_cot_parsing")

cot_model.save_pretrained(lora_cot_dir)
cot_tokenizer.save_pretrained(lora_cot_dir)

# Zip the directory
shutil.make_archive(lora_cot_dir, 'zip', lora_cot_dir)

print(f"✅ CoT model saved and zipped at: {lora_cot_dir}.zip")

#Training complete!

- QP model → `…/llm-sr-project/finetuned_llama3_question_parsing/`
- CoT model → `…/llm-sr-project/finetuned_llama3_cot_parsing/`

Proceed to **4_Evaluation.ipynb** for inference and metrics.
