# Qwen2.5 Training & Inference Setup (LlamaPIE Adaptation)

This notebook adapts the LlamaPIE pipeline to train and run inference using **Qwen2.5** models.

**Models:**
- Small (Classifier): `Qwen/Qwen2.5-1.5B-Instruct`
- Large (Generator): `Qwen/Qwen2.5-7B-Instruct`

**Steps:**
1. Setup Environment
2. Patch/Create Code for Qwen Support
3. Train Small Model (Classifier)
4. Train Large Model (Generator)
5. Run Inference

In [None]:
!nvidia-smi

In [None]:
# 1. Setup Environment
!git clone https://github.com/YudoongY/LlamaPIE.git
%cd LlamaPIE
!sed -i 's/trainer==0.0.36/# trainer==0.0.36/' requirements.txt
!pip install -q -r requirements.txt
!pip install -q huggingface_hub accelerate

In [None]:
from huggingface_hub import login
# Login with your token
login(token='YOUR_TOKEN_HERE')

In [None]:
# Download Data
!gdown 1TEquHZR8E53WLR-v09F1Do5UMQH5tjYZ -O Llamapie_dataset.tar.gz
!tar -xf Llamapie_dataset.tar.gz

## 2. Code Adaptation for Qwen

In [None]:
%%writefile model/CasualTokenClassificationQwen.py
from transformers.models.qwen2.modeling_qwen2 import *
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers.cache_utils import Cache
from transformers.utils import (
    add_start_docstrings_to_model_forward,
    replace_return_docstrings,
)

QWEN2_INPUTS_DOCSTRING = ""
import math
from typing import List, Optional, Tuple, Union

import torch
import torch.nn.functional as F
import torch.utils.checkpoint
from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss

_CONFIG_FOR_DOC = "Qwen2Config"

class Qwen2ForCausalLM_TokenClassification(Qwen2PreTrainedModel):
    _tied_weights_keys = ["lm_head.weight"]

    def __init__(self, config):
        super().__init__(config)
        self.model = Qwen2Model(config)
        self.num_labels = config.num_labels
        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
        self.classifier = nn.Linear(config.hidden_size, self.num_labels, bias=False)
        self.post_init()

    def get_input_embeddings(self):
        return self.model.embed_tokens

    def set_input_embeddings(self, value):
        self.model.embed_tokens = value

    def get_output_embeddings(self):
        return self.lm_head

    def set_output_embeddings(self, new_embeddings):
        self.lm_head = new_embeddings

    def set_decoder(self, decoder):
        self.model = decoder

    def get_decoder(self):
        return self.model

    @add_start_docstrings_to_model_forward(QWEN2_INPUTS_DOCSTRING)
    @replace_return_docstrings(
        output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
    )
    def forward(
        self,
        input_ids: torch.LongTensor = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        cache_position: Optional[torch.LongTensor] = None,
    ) -> Union[Tuple, CausalLMOutputWithPast]:
        
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        outputs = self.model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_values=past_key_values,
            inputs_embeds=inputs_embeds,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
            cache_position=cache_position,
        )

        hidden_states = outputs[0]
        logits = self.classifier(hidden_states)
        logits = logits.float()

        loss = None
        if labels is not None:
            loss_fct = CrossEntropyLoss()
            logits2 = logits.view(-1, self.num_labels)
            labels2 = labels.view(-1)
            labels2 = labels2.to(logits.device)
            loss = loss_fct(logits2, labels2)

        if not return_dict:
            output = (logits,) + outputs[1:]
            return (loss,) + output if loss is not None else output

        return CausalLMOutputWithPast(
            loss=loss,
            logits=logits,
            past_key_values=outputs.past_key_values,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )
    # prepare_inputs_for_generation omitted for brevity as defaults usually work, but good to have if needed


In [None]:
%%writefile mydatasets/Gen_dataset_Qwen.py
import torch
from torch.utils.data import Dataset
from pathlib import Path
from transformers import AutoTokenizer, PreTrainedTokenizer
import json
import os
import numpy as np 
from .data_augmentation import augement_dialogue

class Gen_dataset_Qwen(Dataset):
    def __init__(self, tokenizer, dataset_names=[], split_set="Train", inference=False, mem_drop_rate=0, neg_prob=0, history_aware=False, aug_config=None):
        self.history_aware = history_aware
        self.mem_drop_rate = mem_drop_rate
        self.tokenizer = tokenizer
        self.datasets = {}
        self.inference = inference
        self.sample_size = 0
        self.aug_config = aug_config

        positive_samples = []
        negative_samples = []

        for i, dname in enumerate(dataset_names):
            positive_path = os.path.join(dname, split_set, "Pos")
            if os.path.exists(positive_path):
                samples = sorted(list(Path(positive_path).glob('[0-9]*')))
                print("Loading pos ", positive_path, len(samples))
                if split_set == "Val": samples = samples[:len(samples)//2]
                positive_samples.extend(samples)

                num_pos = len(samples)
                if neg_prob > 0:
                    target_neg_num = int(num_pos/(1-neg_prob)*(neg_prob))
                    negative_path = os.path.join(dname, split_set, "Neg")
                    if os.path.exists(negative_path):
                        neg_samples = sorted(list(Path(negative_path).glob('[0-9]*')))
                        print("Loading neg ", negative_path, len(neg_samples))
                        neg_samples = neg_samples[:target_neg_num] if target_neg_num < len(neg_samples) else neg_samples
                        negative_samples.extend(neg_samples)

        print(f"Positive {len(positive_samples)}, Negative {len(negative_samples)}")
        self.datasets["positive"] = {"name": "positive", "samples": positive_samples, "len": len(positive_samples), "start_index": 0, "end_index": len(positive_samples)}
        self.datasets["negative"] = {"name": "negative", "samples": negative_samples, "len": len(negative_samples), "start_index": len(positive_samples), "end_index": len(positive_samples) + len(negative_samples)}
        self.sample_size = len(positive_samples) + len(negative_samples)

    def __len__(self): return self.sample_size

    def shifted_index(self, i):
        for group in self.datasets.keys():
            if i < self.datasets[group]['end_index'] and i >= self.datasets[group]['start_index']:
                return group, i - self.datasets[group]['start_index']
        return -1, -1

    def __getitem__(self, i):
        selected_dataset, shifted_i = self.shifted_index(i)
        sample = self.datasets[selected_dataset]['samples'][shifted_i]
        name = self.datasets[selected_dataset]['name']

        dialogue_text = (Path(sample) / ('dialogue_aware.txt' if self.history_aware else 'dialogue.txt')).read_text()
        whisper_text = (Path(sample) / 'whisper.txt').read_text() if name != "negative" else ""
        memory_text = (Path(sample) / 'memory.txt').read_text() if not (np.random.rand() < self.mem_drop_rate) else ""

        if self.aug_config: dialogue_text = augement_dialogue(dialogue_text, self.aug_config)

        messages = [
            {"role": "system", "content": "You are a proactive AI agent designed to actively help humans by reminding and assisting them in following dialogue, by whispering short, concise phrases (1-3 words) to its user."},
            {'role': 'user', 'content': f'You have the following memory of facts for the user:\n{memory_text}'},
            {"role": "user", "content": dialogue_text},
        ]
        conv_text = self.tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
        eos_token = self.tokenizer.eos_token if self.tokenizer.eos_token else "<|endoftext|>"
        all_text = conv_text + whisper_text + eos_token
        
        label_ids = self.tokenizer.encode(all_text, return_tensors='pt')[0]
        return {"input_ids": label_ids}


In [None]:
%%writefile mydatasets/Pipeline_dataset_Qwen.py
# NEW: Dataset for Qwen Inference that calculates masks dynamically from raw.txt
import torch
from pathlib import Path
from transformers import PreTrainedTokenizer

START_INDEX = 0

class Syn_samples(object):
    def __init__(self, tokenizer, gen_tokenizer, output_base, sample_id, split_set="Test", input_dirs=None):
        self.tokenizer = tokenizer
        self.gen_tokenizer = gen_tokenizer
        samples = sorted(list(Path(output_base).glob('[0-9]*')))
        self.sample_size = len(samples)
        self.valid = True
        if sample_id >= self.sample_size:
            self.valid = False
            return

        sample = samples[sample_id]
        # Load RAW text to re-calculate masks for Qwen
        self.raw_text = (Path(sample) / 'raw.txt').read_text().strip()
        self.memory_text = (Path(sample) / 'memory.txt').read_text()

        # Logic from New_WhisperAware_dataset to align labels
        # We assume " ^^" is whisper trigger, " >" is opportunity
        # Need to handle Qwen tokenization quirks

        raw_text = " ".join(self.raw_text.split('\n')) # Flatten
        
        # Hacky reconstruction of labels matching Qwen tokens
        # We'll encode the full raw text, then find where split happened?
        # Simpler: We process token by token like the training dataset
        
        # We need "A ^^" and " >"
        try:
             whisper_id = self.tokenizer.encode(" ^^", add_special_tokens=False)[-1]
             symbol_id = self.tokenizer.encode(" >", add_special_tokens=False)[-1]
        except:
             # Fallback if specific tokens fail
             whisper_id = -999
             symbol_id = -999
        
        raw_tokens = self.tokenizer.encode(raw_text, return_tensors='pt')[0]
        
        labels = []
        masks = []
        clean_tokens = []
        
        for i in range(len(raw_tokens)):
            token = raw_tokens[i].item()
            if token == whisper_id:
                if len(labels) > 0: labels[-1] = 1
                if len(masks) > 0: masks[-1] = 1
            else:
                clean_tokens.append(token)
                labels.append(0)
                if token == symbol_id:
                    masks.append(1)
                else:
                    masks.append(0)
                    
        self.tokenized_dialogue = torch.tensor(clean_tokens, dtype=torch.long)
        self.labels = torch.tensor(labels, dtype=torch.int)
        self.mask = torch.tensor(masks, dtype=torch.int)
        
        self.tokenized_dialogue_history = self.tokenized_dialogue.clone()
        self.mask_history = self.mask.clone()
        self.stream_id = START_INDEX
        self.stream_id_history = START_INDEX

    def get_mem(self): return self.memory_text
    def count_turn(self): return self.raw_text.count("User:") + self.raw_text.count("Speaker") # Approx
    def reset_streaming(self): self.stream_id = START_INDEX; self.stream_id_history = START_INDEX

    def insert_whisper(self, whisper_text):
        whisper_token = self.tokenizer.encode(whisper_text, return_tensors='pt', add_special_tokens=False)[0]
        whisper_mask = torch.zeros_like(whisper_token)
        self.tokenized_dialogue_history = torch.cat([self.tokenized_dialogue_history[:self.stream_id_history], whisper_token, self.tokenized_dialogue_history[self.stream_id_history:]])
        self.mask_history = torch.cat([self.mask_history[:self.stream_id_history], whisper_mask, self.mask_history[self.stream_id_history:]])

    def get_gen_inputs(self, curr_diag, old=False):
        # For Qwen, we reconstruct chat template prompt
        messages = [
            {"role": "system", "content": "You are a proactive AI agent designed to actively help humans by reminding and assisting them in following dialogue, by whispering short, concise phrases (1-3 words) to its user."},
            {'role': 'user', 'content': f'You have the following memory of facts for the user:\n{self.memory_text}'},
            {"role": "user", "content": curr_diag},
        ]
        conv_text = self.gen_tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
        return self.gen_tokenizer.encode(conv_text, return_tensors='pt')[0]

    def snap_dialogue(self):
        return self.tokenized_dialogue, self.tokenized_dialogue_history

    def streaming_diaglogue(self):
        # Return next chunk
        # Simple implementation: return up to next mask=1
        # Find next index where mask == 1
        
        search_space = self.mask_history[self.stream_id_history:]
        if search_space.sum() == 0:
             return None
        
        next_stop = (search_space == 1).nonzero(as_tuple=True)[0][0].item() + 1
        
        current_chunk = self.tokenized_dialogue_history[:self.stream_id_history + next_stop]
        current_mask = self.mask_history[:self.stream_id_history + next_stop]
        
        # Labels only match original dialogue, so we need to map history to original? 
        # For inference, label acts as ground truth check. 
        # We simplify: we assume we just need tokens for prediction.
        
        self.stream_id_history += next_stop
        
        # We also need original tokens for non-history-aware? 
        # Simplifying to allow script to run.
        return current_chunk, current_chunk, current_mask, current_mask, self.labels[:len(self.labels)] # Dummy labels


In [None]:
%%writefile train_small_qwen.py
from datasets import load_dataset
import os
import numpy as np
from peft import LoraConfig, get_peft_model, TaskType
import argparse
import torch
from transformers import AutoTokenizer, DataCollatorForTokenClassification, TrainingArguments, Trainer, EvalPrediction
from model.CasualTokenClassificationQwen import Qwen2ForCausalLM_TokenClassification
from sklearn.metrics import accuracy_score
from mydatasets.Active_dataset import New_WhisperAware_dataset

os.environ["WANDB_DISABLED"] = "true"

model_id = "Qwen/Qwen2.5-1.5B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(model_id, cache_dir="./model_cache", trust_remote_code=True)
if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token

model = Qwen2ForCausalLM_TokenClassification.from_pretrained(model_id, device_map='auto', torch_dtype=torch.bfloat16, num_labels=2, cache_dir="./model_cache", trust_remote_code=True)

positive_path = ["Llamapie_dataset/synthetic/Train/claude", "Llamapie_dataset/perl/Train/claude", "Llamapie_dataset/soda/Train/claude"]
positive_dev = ["Llamapie_dataset/synthetic/Val/claude", "Llamapie_dataset/perl/Val/claude", "Llamapie_dataset/soda/Val/claude"]
positive_path = [p for p in positive_path if os.path.exists(p)]
positive_dev = [p for p in positive_dev if os.path.exists(p)]

dataset = New_WhisperAware_dataset(tokenizer, input_dirs=positive_path, split_set="Train", aug_config=None)
dataset_val = New_WhisperAware_dataset(tokenizer, input_dirs=positive_dev, split_set="Val", aug_config=None)

config = LoraConfig(r=8, lora_alpha=16, target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"], bias="none", task_type=TaskType.CAUSAL_LM, modules_to_save=["classifier"])
model = get_peft_model(model, config)

training_args = TrainingArguments(output_dir='models/qwen-small', eval_strategy="steps", per_device_train_batch_size=8, per_device_eval_batch_size=8, learning_rate=2e-5, num_train_epochs=1, save_steps=200, logging_steps=50)

def compute_metrics(pred): return {'accuracy': accuracy_score(pred.label_ids.flatten()[pred.label_ids.flatten() != -100], np.argmax(pred.predictions, axis=-1).flatten()[pred.label_ids.flatten() != -100])}

trainer = Trainer(model=model, args=training_args, train_dataset=dataset, eval_dataset=dataset_val, data_collator=DataCollatorForTokenClassification(tokenizer=dataset.tokenizer), compute_metrics=compute_metrics)
trainer.train()
model.save_pretrained('models/qwen-small')


In [None]:
%%writefile train_large_qwen.py
from mydatasets.Gen_dataset_Qwen import Gen_dataset_Qwen
import os
import numpy as np
from peft import LoraConfig, get_peft_model, TaskType
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments, Trainer
from mydatasets.collator import DataCollatorForCompletionOnlyLM
from sklearn.metrics import accuracy_score

os.environ["WANDB_DISABLED"] = "true"
model_id = "Qwen/Qwen2.5-7B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(model_id, cache_dir="./model_cache", trust_remote_code=True)
if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token

model = AutoModelForCausalLM.from_pretrained(model_id, device_map='auto', torch_dtype=torch.bfloat16, cache_dir="./model_cache", trust_remote_code=True)

dataset_names = ["Llamapie_dataset/perl/", "Llamapie_dataset/soda/", "Llamapie_dataset/synthetic/"]
dataset_names = [d for d in dataset_names if os.path.exists(d)]

dataset = Gen_dataset_Qwen(tokenizer, dataset_names=dataset_names, split_set="Train", mem_drop_rate=0.15, neg_prob=0.25, history_aware=True)
dataset_val = Gen_dataset_Qwen(tokenizer, dataset_names=dataset_names, split_set="Val", mem_drop_rate=0, neg_prob=0.25, history_aware=True)

config = LoraConfig(r=8, lora_alpha=16, target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"], bias="none", lora_dropout=0.05, task_type=TaskType.CAUSAL_LM)
model.gradient_checkpointing_enable()
model = get_peft_model(model, config)

training_args = TrainingArguments(output_dir='models/qwen-large', eval_strategy="steps", per_device_train_batch_size=2, gradient_accumulation_steps=4, learning_rate=2e-5, num_train_epochs=1, save_steps=200, logging_steps=50, bf16=True)

trainer = Trainer(
    model=model, args=training_args, train_dataset=dataset, eval_dataset=dataset_val,
    data_collator=DataCollatorForCompletionOnlyLM(instruction_template=None, response_template="<|im_start|>assistant\n", tokenizer=tokenizer, mlm=False)
)
trainer.train()
model.save_pretrained('models/qwen-large')


In [None]:
%%writefile infer_qwen.py
import os
import torch
import json
import argparse
import numpy as np
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel
from model.CasualTokenClassificationQwen import Qwen2ForCausalLM_TokenClassification

parser = argparse.ArgumentParser()
parser.add_argument('--dataset', type=str, default='Sync_claude')
parser.add_argument('--save-path', type=str, required=True)
args = parser.parse_args()

dataset_name = args.dataset
output_samples = f"{args.save_path}/{dataset_name}"
os.makedirs(output_samples, exist_ok = True)

if dataset_name == "Sync_claude":
    from mydatasets.Pipeline_dataset_Qwen import Syn_samples as SingleSample
    output_base = "Llamapie_dataset/synthetic/Test/claude"
else:
    raise ValueError("For demo, currently only Sync_claude is supported via Qwen pipeline")

base_small = "Qwen/Qwen2.5-1.5B-Instruct"
base_large = "Qwen/Qwen2.5-7B-Instruct"

print("Loading Models...")
tokenizer_small = AutoTokenizer.from_pretrained(base_small, trust_remote_code=True)
if tokenizer_small.pad_token is None: tokenizer_small.pad_token = tokenizer_small.eos_token
model_small = Qwen2ForCausalLM_TokenClassification.from_pretrained(base_small, device_map='cuda', torch_dtype=torch.bfloat16, num_labels=2, trust_remote_code=True)

if os.path.exists("models/qwen-small"): 
    model_small = PeftModel.from_pretrained(model_small, "models/qwen-small").merge_and_unload()
else: print("Using base small model")

tokenizer_big = AutoTokenizer.from_pretrained(base_large, trust_remote_code=True)
model_big = AutoModelForCausalLM.from_pretrained(base_large, device_map='cuda', torch_dtype=torch.bfloat16, trust_remote_code=True)
if os.path.exists("models/qwen-large"): 
    model_big = PeftModel.from_pretrained(model_big, "models/qwen-large")
else: print("Using base large model")

terminators = [tokenizer_big.eos_token_id, tokenizer_big.convert_tokens_to_ids("<|im_end|>")]
response_template = "<|im_start|>assistant\n"

resp_lengths, ratios, model_count = [], [], 0

for i in range(20): # Only 20 samples for quick checking
    print(f"Processing {i}", end="\r")
    try: sample = SingleSample(tokenizer_small, tokenizer_big, output_base, sample_id=i)
    except: continue
    if not sample.valid: continue

    curr_response = ""
    while True:
        info = sample.streaming_diaglogue()
        if info is None: break
        token, _, mask, _, _ = info
        
        if mask[-1] == 0: continue

        input_ids = token.unsqueeze(0).to(model_small.device)
        pred = torch.argmax(model_small(input_ids=input_ids).logits[0], dim=-1)[-1]

        if pred == 1:
            curr_diag = tokenizer_small.decode(token, skip_special_tokens=True)
            input_ids2 = sample.get_gen_inputs(curr_diag).unsqueeze(0).to(model_big.device)
            outputs = model_big.generate(input_ids2, max_new_tokens=128, eos_token_id=terminators, do_sample=True, temperature=0.6)
            resp = tokenizer_big.decode(outputs[0], skip_special_tokens=False)
            if response_template in resp:
                resp = resp.split(response_template)[-1].replace("<|im_end|>", "").strip()
                if resp: 
                   sample.insert_whisper(" Agent: " + resp)
                   resp_lengths.append(len(resp.split()))

    data = {"mem": sample.memory_text, "diag": tokenizer_small.decode(sample.tokenized_dialogue_history, skip_special_tokens=True)}
    with open(f"{output_samples}/{i:05d}.json", 'w') as f: json.dump(data, f, indent=4)

print(f"Finished. Avg Length: {np.mean(resp_lengths) if resp_lengths else 0}")


In [None]:
# 3. Run Training (Classifier)
!python train_small_qwen.py

In [None]:
# 4. Run Training (Generator)
!python train_large_qwen.py

In [None]:
# 5. Run Inference
!python infer_qwen.py --save-path results_qwen