In [None]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

from datasets import load_dataset
from transformers import LogitsProcessorList, RepetitionPenaltyLogitsProcessor

from peft import get_peft_model, LoraConfig, PeftModel
from torch import nn
import torch.nn.functional as F

import os
import re

In [None]:
device = torch.accelerator.current_accelerator().type if torch.cuda.is_available() else "cpu"
print(device)
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'
tokenizer = AutoTokenizer.from_pretrained("../models/deepseek/DeepSeek-R1-Distill-Qwen-1.5B/")
model = AutoModelForCausalLM.from_pretrained("../models/deepseek/DeepSeek-R1-Distill-Qwen-1.5B/")

In [None]:
# v3 load
model = PeftModel.from_pretrained(model, "../models/own/cr_v3")

## Conditioning section: 

In [None]:
ds = load_dataset("Elfsong/BBQ")

### Utilities:

In [None]:
def load_data(i=0):
    i = i
    input_text = "Context: " + ds['religion'][i]['context'] + \
    "\nQuestion: " + ds['religion'][i]['question'] + "\nChoices: " + ds['religion'][i]['ans0'] + \
    ", " + ds['religion'][i]['ans1'] + ", " + ds['religion'][i]['ans2'] + "\n\n<think>\n"
    print(input_text)
    inputs = tokenizer(input_text, return_tensors="pt")
    return inputs

In [None]:
lora_config = LoraConfig(
r=16,
lora_alpha=8,
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM")

model = get_peft_model(model, lora_config)
model.to(device)
model.print_trainable_parameters()

In [None]:
def check_infer(model, inputs, max_new_tokens=600):
    max_new_tokens = 600
    generated_ids = inputs.input_ids.clone().to(model.device)
    past_key_values = None
    
    processors = LogitsProcessorList()
    processors.append(RepetitionPenaltyLogitsProcessor(penalty=1.2))
    
    for step in range(max_new_tokens):
        next_input_ids = generated_ids[:, -1:] if past_key_values else generated_ids
    
        with torch.no_grad():
            outputs = model(input_ids=next_input_ids, past_key_values=past_key_values, use_cache=True)
            logits = outputs.logits  
            past_key_values = outputs.past_key_values
    
        next_token_logits = logits[:, -1, :]  
        next_token_logits = processors(generated_ids, next_token_logits)
        next_token_id = torch.argmax(next_token_logits, dim=-1, keepdim=True) 
    
        generated_ids = torch.cat([generated_ids, next_token_id], dim=-1)
    
        if next_token_id.item() == tokenizer.eos_token_id:
            break

    generated_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
    print("Generated Text:\n", generated_text)
    torch.cuda.empty_cache()

In [None]:
def condition_model(model, tid_reasoning, tid_output):
    max_new_tokens = 600
    generated_ids = inputs.input_ids.clone().to(model.device)
    past_key_values = None
    
    processors = LogitsProcessorList()
    processors.append(RepetitionPenaltyLogitsProcessor(penalty=1.2))
    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)
    kl_loss = nn.KLDivLoss(reduction='batchmean')  

    with torch.no_grad():
        for step in range(max_new_tokens):
            next_input_ids = generated_ids[:, -1:] if past_key_values else generated_ids
        
            outputs = model(input_ids=next_input_ids, past_key_values=past_key_values, use_cache=True)
            past_key_values = outputs.past_key_values

            next_token_id = torch.argmax(processors(generated_ids, outputs.logits[:, -1, :]), dim=-1, keepdim=True) 
            generated_ids = torch.cat([generated_ids, next_token_id], dim=1)
        
            if next_token_id.item() == tokenizer.eos_token_id:
                break

    model.train()
    outputs = model(generated_ids)
    logits = outputs.logits.squeeze(0)  

    reasoning_logits = []
    for tid in tid_reasoning:
        if tid < logits.size(0):
            reasoning_logits.append(F.log_softmax(processors(generated_ids[:, :tid], logits[tid].unsqueeze(0)), dim=-1))
    if not reasoning_logits:
        print("No valid reasoning token positions.")
        return
    avg_reasoning_logits = torch.stack(reasoning_logits).mean(dim=0) 
    
    output_logits = []
    for tid in tid_output:
        if tid < logits.size(0):
            output_logits.append(F.softmax(processors(generated_ids[:, :tid], logits[tid].unsqueeze(0)), dim=-1))
    if not output_logits:
        print("No valid output token positions.")
        return
    avg_output_logits = torch.stack(output_logits).mean(dim=0)  

    loss = kl_loss(avg_reasoning_logits, avg_output_logits)

    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

## Evaluating section:

### Utilities:

In [None]:
def load_data(i=0):
    i = i
    input_text = "Context: " + ds['religion'][i]['context'] + \
    "\nQuestion: " + ds['religion'][i]['question'] + "\nChoices: " + ds['religion'][i]['ans0'] + \
    ", " + ds['religion'][i]['ans1'] + ", " + ds['religion'][i]['ans2'] + "\n\n<think>\n"
    inputs = tokenizer(input_text, return_tensors="pt")
    return inputs

In [None]:
def check_infer(model, inputs, max_new_tokens=600):
    max_new_tokens = 600
    generated_ids = inputs.input_ids.clone().to(model.device)
    print("Input:\n", tokenizer.decode(generated_ids.squeeze(0)))
    past_key_values = None
    
    processors = LogitsProcessorList()
    processors.append(RepetitionPenaltyLogitsProcessor(penalty=1.2))
    
    for step in range(max_new_tokens):
        next_input_ids = generated_ids[:, -1:] if past_key_values else generated_ids
    
        with torch.no_grad():
            outputs = model(input_ids=next_input_ids, past_key_values=past_key_values, use_cache=True)
            logits = outputs.logits  
            past_key_values = outputs.past_key_values
    
        next_token_logits = logits[:, -1, :]  
        next_token_logits = processors(generated_ids, next_token_logits)
        next_token_id = torch.argmax(next_token_logits, dim=-1, keepdim=True) 
    
        generated_ids = torch.cat([generated_ids, next_token_id], dim=-1)
    
        if next_token_id.item() == tokenizer.eos_token_id:
            break

    generated_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
    print(generated_text[generated_text.find("</think>") if generated_text.find("</think>") != -1 else -150:].strip())
    torch.cuda.empty_cache()

In [None]:
def eval_model(model, samples=50):
    for i in range(samples):
        print("-------------------------\nActual label: ", ds['religion'][i]['ans' + str(ds['religion'][i]['answer_label'])])
        check_infer(model, inputs=load_data(i=i))