In [None]:
# !pip install -U -q transformers==4.51.3 datasets==3.5.0 bitsandbytes==0.45.5 triton==3.2.0 unsloth==2025.3.19 torch==2.6.0 peft==0.15.2 trl==0.15.2 wandb==0.19.10

In [None]:
import os
import wandb
os.environ["WANDB_API_KEY"] = 
os.environ["WANDB_PROJECT"] = "Coursework" 
os.environ["WANDB_LOG_MODEL"] = "checkpoint"

In [None]:
from huggingface_hub import login
login()

In [None]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

In [None]:
import torch
device = torch.cuda.current_device()
device

In [None]:
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, AutoModelForImageTextToText
from datasets import load_dataset
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.float16
)

model_name = "Qwen/Qwen2.5-0.5B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(model_name)

policy_model = AutoModelForCausalLM.from_pretrained(
    model_name,
    quantization_config=bnb_config,
    device_map={"": device},
    torch_dtype=torch.float16,
)

policy_model = prepare_model_for_kbit_training(policy_model)

lora_config = LoraConfig(
    r=8,
    lora_alpha=16,
    target_modules=["q_proj", "v_proj", "k_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
    lora_dropout=0,
    bias="none",
    task_type="CAUSAL_LM"
)

policy_model = get_peft_model(policy_model, lora_config)
policy_model.print_trainable_parameters()

ref_model = AutoModelForCausalLM.from_pretrained(
    model_name,
    quantization_config=bnb_config,
    device_map={"": device},
    torch_dtype=torch.float16,
)

ref_model.eval()
for param in ref_model.parameters():
        param.requires_grad = False

In [None]:
import random
def process_and_mask(examples, pairing_ratio=0.5, seed=42):
    random.seed(seed)
    instrs, resps, labels = [], [], []
    for chosen_pair, rejected_pair in zip(examples['chosen'], examples['rejected']):
        instruction = chosen_pair[0]['content']
        chosen_resp = chosen_pair[1]['content']
        rejected_resp = rejected_pair[1]['content']

        if random.random() < pairing_ratio:
            # оставляем один пример
            if random.random() < 0.5:
                instrs.append(instruction)
                resps.append(chosen_resp)
                labels.append(1.0)
            else:
                instrs.append(instruction)
                resps.append(rejected_resp)
                labels.append(-1.0)
        else:
            # оставляем оба
            instrs.extend([instruction, instruction])
            resps.extend([chosen_resp, rejected_resp])
            labels.extend([1.0, -1.0])

    return {'instruction': instrs, 'response': resps, 'label': labels}

In [None]:
 train_ds = load_dataset("trl-lib/ultrafeedback_binarized", split="train").map(
        process_and_mask, batched=True,
        remove_columns=["chosen","rejected","score_chosen","score_rejected"]
    )

In [None]:
def expand_pair_batch(batch):
    instrs, resps, labels = [], [], []
    for chosen_pair, rejected_pair in zip(batch["chosen"], batch["rejected"]):
        instruction   = chosen_pair[0]["content"]
        chosen_resp   = chosen_pair[1]["content"]
        rejected_resp = rejected_pair[1]["content"]

        instrs.extend([instruction, instruction])
        resps.extend([chosen_resp, rejected_resp])
        labels.extend([1.0, -1.0])

    return {
        "instruction": instrs,
        "response":    resps,
        "label":       labels,
    }

eval_raw = load_dataset("trl-lib/ultrafeedback_binarized", split="test[:32]")
eval_ds  = eval_raw.map(
    expand_pair_batch,
    batched=True,
    remove_columns=["chosen","rejected","score_chosen","score_rejected"]
)

In [None]:
def bnf_loss(policy_logits, ref_logits, input_ids, pref_labels, pad_token_id):

    policy_logps = policy_logits.log_softmax(-1)      # (B, T, V)
    policy_ps    = policy_logps.exp()                 # (B, T, V)
    ref_logps    = ref_logits.log_softmax(-1)         # (B, T, V)
    ref_ps       = ref_logps.exp()                    # (B, T, V)

    resp_ps      = policy_ps.gather(-1, input_ids.unsqueeze(-1)).squeeze(-1)      # (B, T)
    resp_ps_ref  = ref_ps.gather(-1, input_ids.unsqueeze(-1)).squeeze(-1)         # (B, T)

    resp_target = torch.clamp(resp_ps / resp_ps_ref, max=1.0).detach()             # (B, T)

    denom = (1 - resp_ps).clamp(min=1e-8).unsqueeze(-1)                            # (B, T, 1)
    others_target = ((1 - resp_target).unsqueeze(-1) / denom) * policy_ps.detach() # (B, T, V)

    f_target = others_target.clone()
    f_target.scatter_(-1, input_ids.unsqueeze(-1), resp_target.unsqueeze(-1))

    ce = -(f_target * policy_logps).sum(-1)  # (B, T)

    mask = (input_ids != pad_token_id).float()  # (B, T)
    seq_loss = (ce * mask).sum(-1) / mask.sum(-1).clamp(min=1)  # (B,)

    loss = (pref_labels * seq_loss).mean()

    pref_mask = (pref_labels == 1)
    dispref_mask = (pref_labels == -1)
    with torch.no_grad():
        pref_loss = seq_loss[pref_mask].mean() if pref_mask.any() else torch.tensor(0.)
        dispref_loss = seq_loss[dispref_mask].mean() if dispref_mask.any() else torch.tensor(0.)
    
    return loss, pref_loss, dispref_loss

In [None]:
def collate_fn(batch):
    combined = [f"Instruction: {x['instruction']}\nResponse: {x['response']}" for x in batch]
    tokenized = tokenizer(
        combined,
        padding="longest", 
        max_length=1024,
        truncation=True,
        return_tensors="pt"
    )
    return {
        "input_ids": tokenized.input_ids,
        "attention_mask": tokenized.attention_mask,
        "labels": tokenized.input_ids.clone(),
        "preference_labels": torch.tensor([x["label"] for x in batch]).float()
    }

In [None]:
from transformers import Trainer, TrainingArguments
import numpy as np

class BNFTrainer(Trainer):
    def __init__(self, *args, ref_model=None, pad_token_id=None, **kwargs):
        super().__init__(*args, **kwargs)
        self.ref_model = ref_model
        self.pad_token_id = pad_token_id

    def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
        with torch.no_grad():
            ref_logits = self.ref_model(
                input_ids=inputs["input_ids"],
                attention_mask=inputs["attention_mask"]
            ).logits
        policy_logits = model(
            input_ids=inputs["input_ids"],
            attention_mask=inputs["attention_mask"]
        ).logits

        loss, pref_loss, dispref_loss = bnf_loss(
            policy_logits, ref_logits,
            inputs["input_ids"],
            inputs["preference_labels"],
            self.pad_token_id
        )

        if self.state.global_step % self.args.logging_steps == 0:
            self.log({
                "pref_loss": float(pref_loss),
                "dispref_loss": float(dispref_loss),
                "loss_ratio": float(dispref_loss / (abs(pref_loss) + 1e-8))
            })

        return (loss, policy_logits) if return_outputs else loss

    def evaluate(self, eval_dataset=None, ignore_keys=None, metric_key_prefix: str = "eval"):
        eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset
        dataloader = self.get_eval_dataloader(eval_dataset)

        sum_loss = 0.0
        sum_pref_loss = 0.0
        sum_dispref_loss = 0.0
        count_loss = 0
        count_pref = 0
        count_dispref = 0

        sum_diff_all = 0.0
        count_tokens_all = 0
        sum_diff_pref = 0.0
        count_tokens_pref = 0
        sum_diff_dispref = 0.0
        count_tokens_dispref = 0

        for batch in dataloader:
            batch = self._prepare_inputs(batch)
            with torch.no_grad():
                ref_logits = self.ref_model(
                    input_ids=batch["input_ids"],
                    attention_mask=batch["attention_mask"]
                ).logits
                policy_logits = self.model(
                    input_ids=batch["input_ids"],
                    attention_mask=batch["attention_mask"]
                ).logits

                loss, pref_loss, dispref_loss = bnf_loss(
                    policy_logits, ref_logits,
                    batch["input_ids"],
                    batch["preference_labels"],
                    self.pad_token_id
                )
                
                bs = batch["input_ids"].size(0)
                sum_loss += loss.item() * bs
                count_loss += bs
                if not torch.isnan(pref_loss):
                    sum_pref_loss += pref_loss.item() * bs
                    count_pref += bs
                if not torch.isnan(dispref_loss):
                    sum_dispref_loss += dispref_loss.item() * bs
                    count_dispref += bs

                logp_pol = torch.log_softmax(policy_logits, dim=-1)
                logp_ref = torch.log_softmax(ref_logits, dim=-1)
                mask = (batch["input_ids"] != self.pad_token_id).float()
                diff = (logp_pol.gather(-1, batch["input_ids"].unsqueeze(-1)).squeeze(-1)
                        - logp_ref.gather(-1, batch["input_ids"].unsqueeze(-1)).squeeze(-1)) * mask
                sum_diff_all += diff.sum().item()
                count_tokens_all += mask.sum().item()

                labels = batch["preference_labels"]
                pref_tok_mask = mask * (labels == 1).unsqueeze(-1).float()
                dispref_tok_mask = mask * (labels == -1).unsqueeze(-1).float()
                sum_diff_pref += (diff * pref_tok_mask.squeeze(-1)).sum().item()
                count_tokens_pref += pref_tok_mask.sum().item()
                sum_diff_dispref += (diff * dispref_tok_mask.squeeze(-1)).sum().item()
                count_tokens_dispref += dispref_tok_mask.sum().item()


        metrics = {}
        metrics[f"{metric_key_prefix}_loss"] = sum_loss / count_loss
        metrics[f"{metric_key_prefix}_pref_loss"] = sum_pref_loss / count_pref
        metrics[f"{metric_key_prefix}_dispref_loss"] = sum_dispref_loss / count_dispref
        metrics[f"{metric_key_prefix}_ll_shift"] = sum_diff_all / count_tokens_all
        metrics[f"{metric_key_prefix}_ll_shift_pref"] = sum_diff_pref / count_tokens_pref
        metrics[f"{metric_key_prefix}_loss_ratio"] = (sum_dispref_loss / count_dispref) / (sum_pref_loss / count_pref)

        self.log({k: float(v) for k, v in metrics.items()})
        return metrics

In [None]:
args = TrainingArguments(
    output_dir='Qwen2.5-0.5B-Instruct-BNF',
    per_device_train_batch_size=2,
    per_device_eval_batch_size=1,
    gradient_accumulation_steps=4,
    fp16=True,
    eval_strategy='steps',
    num_train_epochs=3,
    logging_steps=40,
    eval_steps=40,
    optim='adamw_8bit',
    learning_rate=5e-7,
    lr_scheduler_type='cosine',
    warmup_ratio=0.1,   
    save_strategy='steps',
    save_steps=1000,
    push_to_hub=True,
    hub_model_id='theevolutionisnear/Qwen2.5-0.5B-Instruct-BNF',
    hub_strategy='checkpoint',
    hub_token=True,
    report_to='wandb',
    remove_unused_columns=False,
    gradient_checkpointing=True,
)

In [None]:
# wandb.init(project="Coursework",
#            id="crmnp6jy",
#            resume="must")
# artifact = run.use_artifact('animavestra888-independent/Coursework/model-crmnp6jy:v15', type='model')
# artifact_dir = artifact.download()

# _torch_load = torch.load

# def _load_with_full_pickle(*args, **kwargs):
#     kwargs["weights_only"] = False

#     return _torch_load(*args, **kwargs)

# torch.load = _load_with_full_pickle 

In [None]:
trainer = BNFTrainer(
    model=policy_model,
    ref_model=ref_model,
    args=args,
    train_dataset=train_ds,
    eval_dataset=eval_ds,
    data_collator=collate_fn,
    pad_token_id=tokenizer.pad_token_id,
    )
trainer.train()
#trainer.train(resume_from_checkpoint=artifact_dir)