In [None]:
!pip install torch==2.6.0 torchvision==0.21.0 torchaudio==2.6.0 --index-url https://download.pytorch.org/whl/cu118  # Version for rtx3090

In [7]:
from test_data_gen import generate_synthetic_dataset

NUMBER_OF_RECORDS = 200
VECTOR_DIMENSION = 30

synthetic_dataset = generate_synthetic_dataset(
    number_of_records=NUMBER_OF_RECORDS,
    vector_dimension=VECTOR_DIMENSION
)
synthetic_dataset[50]

{'prompt': 'Как работает GPS?',
 'response': 'Система глобального позиционирования (GPS) использует сигналы от спутников для определения точного местоположения приёмника на Земле путем трилатерации.',
 'custom_vector': [0.9241718649864197,
  0.8862757086753845,
  0.9418854117393494,
  0.8010405898094177,
  0.9979512691497803,
  0.9523987174034119,
  0.9858677387237549,
  0.8162245750427246,
  0.9277184009552002,
  0.8832427263259888,
  0.12798306345939636,
  0.1623060256242752,
  0.04388086497783661,
  0.11491576582193375,
  0.08334853500127792,
  0.07457728683948517,
  0.0013142724055796862,
  0.06693758070468903,
  0.18873128294944763,
  0.12715288996696472,
  0.038583215326070786,
  0.09092347323894501,
  0.0837830901145935,
  0.01080994587391615,
  0.1938796192407608,
  0.13350287079811096,
  0.01674053631722927,
  0.005774518009275198,
  0.1798594743013382,
  0.07867621630430222]}

In [8]:
!pip install unsloth

Collecting unsloth
  Downloading unsloth-2025.6.1-py3-none-any.whl.metadata (47 kB)
Collecting unsloth_zoo>=2025.6.1 (from unsloth)
  Downloading unsloth_zoo-2025.6.1-py3-none-any.whl.metadata (8.1 kB)
Collecting torch<=2.7.0,>=2.4.0 (from unsloth)
  Using cached torch-2.7.0-cp312-cp312-win_amd64.whl.metadata (29 kB)
Collecting xformers>=0.0.27.post2 (from unsloth)
  Downloading xformers-0.0.30-cp312-cp312-win_amd64.whl.metadata (1.3 kB)
Collecting bitsandbytes (from unsloth)
  Downloading bitsandbytes-0.46.0-py3-none-win_amd64.whl.metadata (10 kB)
Collecting triton-windows (from unsloth)
  Downloading triton_windows-3.3.1.post19-cp312-cp312-win_amd64.whl.metadata (1.6 kB)
Collecting tyro (from unsloth)
  Downloading tyro-0.9.24-py3-none-any.whl.metadata (11 kB)
Collecting transformers!=4.47.0,!=4.52.0,!=4.52.1,!=4.52.2,>=4.51.3 (from unsloth)
  Downloading transformers-4.52.4-py3-none-any.whl.metadata (38 kB)
Collecting sentencepiece>=0.2.0 (from unsloth)
  Using cached sentencepiece-

In [1]:
import torch

def check_gpu_availability() -> None:
    """
    Checks for the availability of a CUDA-enabled GPU and prints its status.
    """
    # Check if a CUDA-enabled GPU is available
    is_gpu_available: bool = torch.cuda.is_available()

    if is_gpu_available:
        # Get the number of available GPUs
        gpu_count: int = torch.cuda.device_count()
        # Get the name of the current GPU
        current_gpu_name: str = torch.cuda.get_device_name(torch.cuda.current_device())
        print(f"✅ GPU доступен.")
        print(f"Количество GPU: {gpu_count}")
        print(f"Имя устройства: {current_gpu_name}")
    else:
        print("❌ GPU не доступен. PyTorch будет использовать CPU.")

if __name__ == '__main__':
    check_gpu_availability()

✅ GPU доступен.
Количество GPU: 1
Имя устройства: NVIDIA GeForce RTX 3090


Looking in indexes: https://download.pytorch.org/whl/cu118


In [1]:
import torch
import random
import numpy as np
import pandas as pd
from datasets import Dataset
from typing import List, Dict, Any, Tuple, Optional

# Unsloth, Transformers, TRL and PEFT imports
from unsloth import FastLanguageModel
from transformers import TrainingArguments, PreTrainedModel
from transformers.modeling_outputs import CausalLMOutputWithPast
import torch.nn as nn
from trl import SFTTrainer
from peft import LoraConfig


# ==============================================================================
# 1. MODEL LOADING (Updated for Qwen3 1.7B)
# ==============================================================================
# We will load the 4-bit quantized version of Qwen3-1.7B-Instruct from Unsloth.
# Create a folder named 'model_cache' in your project directory
model_cache_path: str = "./model_cache"

max_seq_length: int = 2048
dtype = None # Let Unsloth auto-select the best dtype (float16 or bfloat16)
load_in_4bit: bool = True

print("==> Step 1: Loading the Qwen3-1.7B model...")
model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = "unsloth/qwen3-1.7b-instruct-bnb-4bit",
    max_seq_length = max_seq_length,
    dtype = dtype,
    load_in_4bit = load_in_4bit,
    cache_dir = model_cache_path, # <-- ВОТ ЭТОТ ПАРАМЕТР
)
print("==> Model loaded successfully!\n")


# ==============================================================================
# 2. CUSTOM MODEL WRAPPER (No changes needed)
# ==============================================================================
# This wrapper class is generic and works with any model.

class ConditionalLM(PreTrainedModel):
    """
    A custom model that wraps a pre-trained language model and adds a conditional projection layer.
    """
    supports_gradient_checkpointing = True

    def __init__(
        self,
        language_model: PreTrainedModel,
        custom_vector_size: int
    ):
        super().__init__(language_model.config)
        self.language_model = language_model
        self.custom_vector_size = custom_vector_size
        self.embedding_size = self.language_model.get_input_embeddings().embedding_dim
        self.projection_layer = nn.Sequential(
            nn.Linear(self.custom_vector_size, self.embedding_size),
            nn.ReLU(),
            nn.Linear(self.embedding_size, self.embedding_size)
        )

    def get_input_embeddings(self) -> nn.Embedding:
        return self.language_model.get_input_embeddings()

    def forward(
        self,
        input_ids: Optional[torch.LongTensor] = None,
        attention_mask: Optional[torch.LongTensor] = None,
        custom_vector: Optional[torch.Tensor] = None,
        labels: Optional[torch.LongTensor] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        **kwargs,
    ) -> CausalLMOutputWithPast:
        if custom_vector is None:
            return self.language_model(
                input_ids=input_ids, attention_mask=attention_mask, labels=labels,
                inputs_embeds=inputs_embeds, **kwargs
            )
        projected_vector = self.projection_layer(custom_vector).unsqueeze(1)
        if inputs_embeds is None:
            inputs_embeds = self.get_input_embeddings()(input_ids)
        inputs_embeds = torch.cat([projected_vector, inputs_embeds], dim=1)
        new_attention_mask = None
        if attention_mask is not None:
            projected_vector_mask = torch.ones(
                attention_mask.shape[0], 1, dtype=attention_mask.dtype, device=attention_mask.device
            )
            new_attention_mask = torch.cat([projected_vector_mask, attention_mask], dim=1)
        new_labels = None
        if labels is not None:
            projected_vector_label = torch.full(
                (labels.shape[0], 1), -100, dtype=labels.dtype, device=labels.device
            )
            new_labels = torch.cat([projected_vector_label, labels], dim=1)
        return self.language_model(
            inputs_embeds=inputs_embeds, attention_mask=new_attention_mask,
            labels=new_labels, **kwargs
        )

# ==============================================================================
# 3. DATA GENERATION AND PREPARATION (No changes needed)
# ==============================================================================
# The data generation and formatting functions remain the same.

def generate_synthetic_dataset(number_of_records: int, vector_dimension: int) -> Dataset:
    # (The function body is the same as before, so it is omitted here for brevity)
    if vector_dimension % 3 != 0: raise ValueError("vector_dimension must be divisible by 3.")
    source_data: Dict[str, List[Tuple[str, str]]] = {
        "science": [("Что такое черная дыра?", "Чёрная дыра — это область пространства-времени, гравитационное притяжение которой настолько велико, что покинуть её не могут даже объекты, движущиеся со скоростью света."), ("Объясни фотосинтез.", "Фотосинтез — это сложный химический процесс преобразования энергии видимого света в энергию химических связей органических веществ."),],
        "history": [("Расскажи о Ренессансе.", "Эпоха Возрождения, или Ренессанс, — это период в истории культуры Европы, пришедший на смену Средним векам и предшествующий Просвещению."), ("Кто такой Юлий Цезарь?", "Гай Юлий Цезарь был древнеримским государственным и политическим деятелем, полководцем и писателем."),],
        "creative": [("Придумай шутку про программиста.", "Почему программисты так не любят природу? Слишком много багов."), ("Напиши короткий стих о космосе.", "Средь миллиардов звёздных троп, летит бесшумно телескоп. Он ищет дом, он ищет свет, вдали от суетных планет."),],
    }
    records_list: List[Dict[str, Any]] = []
    categories: List[str] = list(source_data.keys())
    chunk_size: int = vector_dimension // 3
    for _ in range(number_of_records):
        chosen_category: str = random.choice(categories)
        prompt, response = random.choice(source_data[chosen_category])
        custom_vector = np.zeros(vector_dimension, dtype=np.float32)
        for i in range(3):
            start_index, end_index = i * chunk_size, (i + 1) * chunk_size
            custom_vector[start_index:end_index] = np.random.uniform(0.0, 0.2, size=chunk_size)
        category_index = categories.index(chosen_category)
        start_index, end_index = category_index * chunk_size, (category_index + 1) * chunk_size
        custom_vector[start_index:end_index] = np.random.uniform(0.8, 1.0, size=chunk_size)
        records_list.append({"prompt": prompt, "response": response, "custom_vector": custom_vector})
    return Dataset.from_list(records_list)

def formatting_prompts_func(example: Dict[str, Any]) -> Dict[str, Any]:
    text_parts = [
        f"<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n",
        f"<|im_start|>user\n{example['prompt']}<|im_end|>\n",
        f"<|im_start|>assistant\n{example['response']}<|im_end|>"
    ]
    example["text"] = "".join(text_parts) + tokenizer.eos_token
    return example

class ConditionalDataCollator:
    def __init__(self, tokenizer):
        self.tokenizer = tokenizer
    def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]:
        tokenized_inputs = self.tokenizer(
            [f["text"] for f in features], return_tensors="pt", padding=True,
            truncation=True, max_length=max_seq_length
        )
        custom_vectors = torch.tensor([f["custom_vector"] for f in features], dtype=torch.float)
        tokenized_inputs["labels"] = tokenized_inputs["input_ids"].clone()
        tokenized_inputs["custom_vector"] = custom_vectors
        return tokenized_inputs

print("==> Step 3: Generating and preparing dataset...")
NUMBER_OF_RECORDS = 200
VECTOR_DIMENSION = 30
synthetic_dataset = generate_synthetic_dataset(
    number_of_records=NUMBER_OF_RECORDS, vector_dimension=VECTOR_DIMENSION
)
processed_dataset = synthetic_dataset.map(formatting_prompts_func, num_proc=4)
print("==> Dataset prepared successfully!\n")


# ==============================================================================
# 4. TRAINING SETUP (Updated for Qwen3)
# ==============================================================================
print("==> Step 4: Setting up the training components...")
custom_model = ConditionalLM(language_model=model, custom_vector_size=VECTOR_DIMENSION)

# LoRA configuration for Qwen3
lora_config = LoraConfig(
    r=16,
    lora_alpha=16,
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM",
    # NOTE: We assume these are the correct modules for Qwen3,
    # as they are standard for Qwen1.5 and Qwen2. This is an educated guess.
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj",
                      "gate_proj", "up_proj", "down_proj"],
    modules_to_save=["projection_layer"], # Don't forget our custom layer!
)

training_arguments = TrainingArguments(
    output_dir="qwen3_1.7b_conditional_finetune", # <-- Updated output directory
    per_device_train_batch_size=2,
    gradient_accumulation_steps=4,
    warmup_steps=5,
    max_steps=100,
    learning_rate=2e-4,
    fp16=not torch.cuda.is_bf16_supported(),
    bf16=torch.cuda.is_bf16_supported(),
    logging_steps=1,
    optim="adamw_8bit",
    weight_decay=0.01,
    lr_scheduler_type="linear",
    seed=42,
)

data_collator = ConditionalDataCollator(tokenizer=tokenizer)

trainer = SFTTrainer(
    model=custom_model,
    args=training_arguments,
    train_dataset=processed_dataset,
    peft_config=lora_config,
    dataset_text_field="text",
    max_seq_length=max_seq_length,
    tokenizer=tokenizer,
    data_collator=data_collator,
)
print("==> Trainer is ready for Qwen3 1.7B!\n")
print("To start training, run the command: trainer.train()")

🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning.


    PyTorch 2.7.0+cu126 with CUDA 1206 (you have 2.6.0+cu118)
    Python  3.12.10 (you have 3.12.10)
  Please reinstall xformers (see https://github.com/facebookresearch/xformers#installing-xformers)
  Memory-efficient attention, SwiGLU, sparse and more won't be available.
  Set XFORMERS_MORE_DETAILS=1 for more details


🦥 Unsloth Zoo will now patch everything to make training faster!
==> Step 1: Loading the Qwen3-1.7B model...


  GPU_BUFFERS = tuple([torch.empty(2*256*2048, dtype = dtype, device = f"{DEVICE_TYPE}:{i}") for i in range(n_gpus)])


FileNotFoundError: unsloth/qwen3-1.7b-instruct-bnb-4bit/*.json (repository not found)

In [1]:
import torch
import torch.nn as nn
import random
import numpy as np
import copy
from datasets import Dataset
from typing import List, Dict, Any, Tuple, Optional
from enum import Enum

# TRL, Transformers, Unsloth imports
from unsloth import FastLanguageModel
from transformers import PreTrainedModel, AutoTokenizer
from trl import PPOTrainer, PPOConfig
from trl.core import LengthSampler

# ==============================================================================
# 1. Core Classes (ConditionalLM and InjectionMethod Enum)
# ==============================================================================
# These are the same classes we finalized before.

class InjectionMethod(Enum):
    PREPEND_EMBEDDING = "prepend_embedding"
    ADD_AFTER_LAYER_N = "add_after_layer_n"
    ADD_TO_EVERY_LAYER = "add_to_every_layer"

class ConditionalLM(PreTrainedModel):
    supports_gradient_checkpointing = True
    def __init__(
            self,
            language_model: PreTrainedModel,
            custom_vector_size: int,

             injection_method: InjectionMethod = InjectionMethod.PREPEND_EMBEDDING,
                 injection_layer_index: Optional[int] = None):
        super().__init__(language_model.config)
        self.language_model = language_model
        self.custom_vector_size = custom_vector_size
        self.injection_method = injection_method
        self.injection_layer_index = injection_layer_index
        self.embedding_size = self.language_model.get_input_embeddings().embedding_dim
        self._validate_settings()
        self.projection_layer = nn.Sequential(nn.Linear(self.custom_vector_size, self.embedding_size), nn.ReLU(), nn.Linear(self.embedding_size, self.embedding_size))
        self.projected_vector_cache: Optional[torch.Tensor] = None
        self._register_hooks()
    def _validate_settings(self):
        if self.injection_method == InjectionMethod.ADD_AFTER_LAYER_N:
            if self.injection_layer_index is None: raise ValueError("`injection_layer_index` must be set.")
            num_layers = len(self.language_model.model.layers)
            if not (0 <= self.injection_layer_index < num_layers): raise ValueError(f"`injection_layer_index` must be between 0 and {num_layers - 1}.")
    def _register_hooks(self):
        if self.injection_method == InjectionMethod.ADD_TO_EVERY_LAYER:
            for layer in self.language_model.model.layers: layer.register_forward_hook(self._addition_hook)
        elif self.injection_method == InjectionMethod.ADD_AFTER_LAYER_N:
            self.language_model.model.layers[self.injection_layer_index].register_forward_hook(self._addition_hook)
    def _addition_hook(self, module: nn.Module, inputs: Any, outputs: Any) -> Any:
        hidden_states = outputs[0]
        modified_hidden_states = hidden_states + self.projected_vector_cache.unsqueeze(1)
        return (modified_hidden_states,) + outputs[1:]
    def forward(
            self,
            input_ids: Optional[torch.LongTensor] = None,
            custom_vector: Optional[torch.Tensor] = None,
            **kwargs
         ):
        if custom_vector is None: return self.language_model(input_ids=input_ids, **kwargs)
        if self.injection_method == InjectionMethod.PREPEND_EMBEDDING:
            projected_vector = self.projection_layer(custom_vector).unsqueeze(1)
            token_embeddings = self.get_input_embeddings()(input_ids)
            inputs_embeds = torch.cat([projected_vector, token_embeddings], dim=1)
            attention_mask, labels = kwargs.get("attention_mask"), kwargs.get("labels")
            if attention_mask is not None:
                proj_mask = torch.ones(attention_mask.shape[0], 1, dtype=attention_mask.dtype, device=attention_mask.device)
                kwargs["attention_mask"] = torch.cat([proj_mask, attention_mask], dim=1)
            if labels is not None:
                proj_label = torch.full((labels.shape[0], 1), -100, dtype=labels.dtype, device=labels.device)
                kwargs["labels"] = torch.cat([proj_label, labels], dim=1)
            return self.language_model(inputs_embeds=inputs_embeds, **kwargs)
        elif self.injection_method in [InjectionMethod.ADD_TO_EVERY_LAYER, InjectionMethod.ADD_AFTER_LAYER_N]:
            self.projected_vector_cache = self.projection_layer(custom_vector)
            outputs = self.language_model(input_ids=input_ids, **kwargs)
            self.projected_vector_cache = None
            return outputs
        else: raise NotImplementedError(f"Injection method {self.injection_method} is not implemented.")
    def get_input_embeddings(self) -> nn.Embedding: return self.language_model.get_input_embeddings()

# ==============================================================================
# 2. NEW: Placeholder Reward Model (The Supervisor)
# ==============================================================================
# In a real project, this would be a separately trained model.
# Here, it just returns a random score for any given input.

def get_rewards_from_api(
    prompts: List[str],
    responses: List[str],
    api_url: str,
    api_token: str,
    device: str
) -> torch.Tensor:
    """
    Gets rewards for a batch of prompts and responses from an external API.

    Args:
        prompts (List[str]): A list of prompts sent to the policy model.
        responses (List[str]): A list of responses generated by the policy model.
        api_url (str): The endpoint URL of the external reward model.
        api_token (str): The authentication token for the API.
        device (str): The torch device to place the resulting tensor on.

    Returns:
        torch.Tensor: A tensor of shape (batch_size, 1) with rewards.
    """
    headers = {
        "Authorization": f"Bearer {api_token}",
        "Content-Type": "application/json",
    }

    # We assume the API can handle batch requests for efficiency.
    # The payload is a list of objects, each with a prompt and a response.
    payload = {
        "inputs": [
            {"prompt": p, "response": r} for p, r in zip(prompts, responses)
        ]
    }

    rewards = []
    try:
        # Make the POST request to the external API
        response = requests.post(api_url, headers=headers, json=payload)
        response.raise_for_status()  # Raise an exception for bad status codes (4xx or 5xx)

        # Parse the JSON response
        results = response.json()

        # We expect the API to return a list of scores
        # Example format: {"scores": [0.8, 0.2, 0.9, 0.5]}
        scores = results.get("scores", [])
        if len(scores) != len(prompts):
             raise ValueError("API returned a different number of scores than expected.")

        rewards = [torch.tensor(score) for score in scores]

    except requests.exceptions.RequestException as e:
        print(f"Error calling external API: {e}")
        # On error, return a neutral reward (0.0) for the whole batch
        rewards = [torch.tensor(0.0) for _ in prompts]

    # The PPOTrainer expects a tensor of shape (batch_size,).
    return torch.tensor(rewards, dtype=torch.float32, device=device)

class PlaceholderRewardModel(nn.Module):
    """
    A placeholder for a real reward model. It returns a random scalar reward.
    """
    def __init__(self, model_name_or_path: str):
        super().__init__()
        # In a real scenario, you might load a model with a scalar output head.
        # For this example, we don't need to load anything.
        print(f"Initialized PlaceholderRewardModel. It will return random rewards.")

    def forward(
        self,
        input_ids: torch.LongTensor,
        **kwargs,
    ) -> torch.Tensor:
        """
        Args:
            input_ids (torch.LongTensor): Tokenized sequence of (prompt + response).
        Returns:
            torch.Tensor: A tensor of shape (batch_size, 1) with random rewards.
        """
        batch_size = input_ids.shape[0]
        # Return a random reward for each item in the batch
        return torch.randn(batch_size, 1, device=input_ids.device)

# ==============================================================================
# 3. Data Generation for PPO
# ==============================================================================
# For PPO, we only need prompts and their corresponding vectors to start generation.

def generate_ppo_dataset(number_of_prompts: int, vector_dimension: int) -> Dataset:
    """
    Generates a synthetic dataset of prompts and custom vectors for PPO.
    """
    prompts = [
        "Что такое черная дыра?", "Расскажи о Ренессансе.", "Придумай шутку про программиста.",
        "Объясни фотосинтез.", "Кто такой Юлий Цезарь?", "Напиши короткий стих о космосе."
    ]
    records_list = []
    for i in range(number_of_prompts):
        records_list.append({
            "prompt": random.choice(prompts),
            # In a real case, the vector would be meaningful. Here, it's random.
            "custom_vector": np.random.randn(vector_dimension).astype(np.float32)
        })
    return Dataset.from_list(records_list)

# ==============================================================================
# 4. Main PPO Script Setup
# ==============================================================================

# --- Configuration ---
SFT_MODEL_PATH = "unsloth/qwen3-1.7b-instruct-bnb-4bit" # Path to your SFT model
BASE_MODEL_NAME = "unsloth/qwen3-1.7b-instruct-bnb-4bit"
VECTOR_DIMENSION = 30
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

ppo_config = PPOConfig(
    batch_size=4,
    mini_batch_size=2,
    gradient_accumulation_steps=1,
    learning_rate=1.4e-6, # Use a very low learning rate for PPO
    log_with="wandb", # or "tensorboard" or None
)

# --- Load Base Model and Tokenizer ---
print("Loading base model and tokenizer...")
base_model, tokenizer = FastLanguageModel.from_pretrained(
    model_name=BASE_MODEL_NAME,
    max_seq_length=512,
    load_in_4bit=True,
)
# Qwen models need a padding token
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "left"

# --- Initialize Models for PPO ---
print("Initializing Policy, Reference, and Reward Models...")

# In a real workflow, you would load the weights from your SFT training here.
# For this example, we start from the base model.
policy_model = ConditionalLM(
    language_model=base_model,
    custom_vector_size=VECTOR_DIMENSION,
    injection_method=InjectionMethod.PREPEND_EMBEDDING,
).to(DEVICE)

# The reference model is a frozen copy of the policy model before PPO training.
ref_model = copy.deepcopy(policy_model)
for param in ref_model.parameters():
    param.requires_grad = False
ref_model.eval()

# Our placeholder reward model
reward_model = PlaceholderRewardModel(SFT_MODEL_PATH).to(DEVICE)


# --- Prepare Dataset and Collator ---
print("Preparing dataset...")
dataset = generate_ppo_dataset(number_of_prompts=100, vector_dimension=VECTOR_DIMENSION)

def collator(data: List[Dict[str, Any]]):
    # Collate function to prepare batches for the PPOTrainer
    batch = {}
    prompts_with_template = [f"<|im_start|>user\n{x['prompt']}<|im_end|>\n<|im_start|>assistant\n" for x in data]
    batch["input_ids"] = tokenizer(prompts_with_template, padding=True, truncation=True, return_tensors="pt")["input_ids"]
    batch["query"] = tokenizer.batch_decode(batch["input_ids"])
    batch["custom_vector"] = torch.tensor([x['custom_vector'] for x in data], dtype=torch.float32)
    return batch

# --- Instantiate PPOTrainer ---
print("Initializing PPOTrainer...")
ppo_trainer = PPOTrainer(
    config=ppo_config,
    model=policy_model,
    ref_model=ref_model,
    tokenizer=tokenizer,
    dataset=dataset,
    data_collator=collator,
)

# ==============================================================================
# 5. The PPO Training Loop
# ==============================================================================
print("\n=== Starting PPO Training Loop ===\n")

output_min_length = 32
output_max_length = 128
output_length_sampler = LengthSampler(output_min_length, output_max_length)

for step, batch in enumerate(ppo_trainer.dataloader):
    if step >= ppo_config.total_ppo_epochs:
        break

    prompt_tensors = batch["input_ids"].to(DEVICE)
    custom_vectors = batch["custom_vector"].to(DEVICE)

    # --- Generation ---
    # Generate responses from the policy model, passing the custom vector
    generation_kwargs = {
        "min_length": -1,
        "top_k": 0.0,
        "top_p": 1.0,
        "do_sample": True,
        "pad_token_id": tokenizer.pad_token_id,
        "eos_token_id": tokenizer.eos_token_id,
        "custom_vector": custom_vectors, # Pass our vector here!
    }
    response_tensors = ppo_trainer.generate(
        prompt_tensors,
        length_sampler=output_length_sampler,
        **generation_kwargs,
    )

    # The output from generate is the full sequence (prompt + response)
    batch["response"] = tokenizer.batch_decode(response_tensors)

    # --- Reward Calculation ---
    # Prepare input for the reward model: prompt + response
    texts_for_reward = [q + r for q, r in zip(batch["query"], batch["response"])]
    tokenized_texts_for_reward = tokenizer(texts_for_reward, padding=True, truncation=True, return_tensors="pt").to(DEVICE)

    # Get rewards from the supervisor/reward model
    # NOTE: In a real scenario, the reward model might also need the custom_vector
    rewards = reward_model(input_ids=tokenized_texts_for_reward["input_ids"])

    # --- PPO Step ---
    # The trainer performs the PPO update.
    stats = ppo_trainer.step(prompt_tensors, response_tensors, rewards)
    ppo_trainer.log_stats(stats, batch, rewards)

    print(f"--- Step {step+1} ---")
    print(f"Objective/kl: {stats['objective/kl']:.4f}")
    print(f"Mean reward: {torch.mean(rewards).item():.4f}")
    print(f"Example response: {batch['response'][0]}\n")

print("\n=== PPO Training Finished ===\n")

🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning.


    PyTorch 2.7.0+cu126 with CUDA 1206 (you have 2.6.0+cu118)
    Python  3.12.10 (you have 3.12.10)
  Please reinstall xformers (see https://github.com/facebookresearch/xformers#installing-xformers)
  Memory-efficient attention, SwiGLU, sparse and more won't be available.
  Set XFORMERS_MORE_DETAILS=1 for more details


🦥 Unsloth Zoo will now patch everything to make training faster!


TypeError: PPOConfig.__init__() got an unexpected keyword argument 'log_with'