In [2]:
from test_data_gen import generate_synthetic_dataset

NUMBER_OF_RECORDS = 200
VECTOR_DIMENSION = 30

synthetic_dataset = generate_synthetic_dataset(
    number_of_records=NUMBER_OF_RECORDS,
    vector_dimension=VECTOR_DIMENSION
)
synthetic_dataset[50]

{'prompt': 'Сочини забавный факт.',
 'response': 'Осьминоги имеют три сердца. Два качают кровь через жабры, а третье — по всему остальному телу. Видимо, поэтому они так хороши в многозадачности.',
 'custom_vector': [0.17773109674453735,
  0.09064134955406189,
  0.022509407252073288,
  0.1318150907754898,
  0.017499104142189026,
  0.1785397231578827,
  0.17107288539409637,
  0.12427335232496262,
  0.1430147886276245,
  0.023171450942754745,
  0.04329017922282219,
  0.1562054455280304,
  0.1573420763015747,
  0.152131587266922,
  0.09085719287395477,
  0.10491936653852463,
  0.10821109265089035,
  0.18221844732761383,
  0.14653459191322327,
  0.12184739857912064,
  0.9005042314529419,
  0.8830434083938599,
  0.8846337199211121,
  0.839137852191925,
  0.975746214389801,
  0.9308818578720093,
  0.8285808563232422,
  0.8986854553222656,
  0.9762755036354065,
  0.8013295531272888]}

In [None]:
import torch
import random
import numpy as np
import pandas as pd
from datasets import Dataset
from typing import List, Dict, Any, Tuple, Optional

# Unsloth, Transformers, TRL and PEFT imports
from unsloth import FastLanguageModel
from transformers import TrainingArguments, PreTrainedModel
from transformers.modeling_outputs import CausalLMOutputWithPast
import torch.nn as nn
from trl import SFTTrainer
from peft import LoraConfig


# ==============================================================================
# 1. MODEL LOADING (Updated for Qwen3 1.7B)
# ==============================================================================
# We will load the 4-bit quantized version of Qwen3-1.7B-Instruct from Unsloth.
# Create a folder named 'model_cache' in your project directory
model_cache_path: str = "./model_cache"

max_seq_length: int = 2048
dtype = None # Let Unsloth auto-select the best dtype (float16 or bfloat16)
load_in_4bit: bool = True

print("==> Step 1: Loading the Qwen3-1.7B model...")
model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = "unsloth/qwen3-1.7b-instruct-bnb-4bit",
    max_seq_length = max_seq_length,
    dtype = dtype,
    load_in_4bit = load_in_4bit,
    cache_dir = model_cache_path, # <-- ВОТ ЭТОТ ПАРАМЕТР
)
print("==> Model loaded successfully!\n")


# ==============================================================================
# 2. CUSTOM MODEL WRAPPER (No changes needed)
# ==============================================================================
# This wrapper class is generic and works with any model.

class ConditionalLM(PreTrainedModel):
    """
    A custom model that wraps a pre-trained language model and adds a conditional projection layer.
    """
    supports_gradient_checkpointing = True

    def __init__(
        self,
        language_model: PreTrainedModel,
        custom_vector_size: int
    ):
        super().__init__(language_model.config)
        self.language_model = language_model
        self.custom_vector_size = custom_vector_size
        self.embedding_size = self.language_model.get_input_embeddings().embedding_dim
        self.projection_layer = nn.Sequential(
            nn.Linear(self.custom_vector_size, self.embedding_size),
            nn.ReLU(),
            nn.Linear(self.embedding_size, self.embedding_size)
        )

    def get_input_embeddings(self) -> nn.Embedding:
        return self.language_model.get_input_embeddings()

    def forward(
        self,
        input_ids: Optional[torch.LongTensor] = None,
        attention_mask: Optional[torch.LongTensor] = None,
        custom_vector: Optional[torch.Tensor] = None,
        labels: Optional[torch.LongTensor] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        **kwargs,
    ) -> CausalLMOutputWithPast:
        if custom_vector is None:
            return self.language_model(
                input_ids=input_ids, attention_mask=attention_mask, labels=labels,
                inputs_embeds=inputs_embeds, **kwargs
            )
        projected_vector = self.projection_layer(custom_vector).unsqueeze(1)
        if inputs_embeds is None:
            inputs_embeds = self.get_input_embeddings()(input_ids)
        inputs_embeds = torch.cat([projected_vector, inputs_embeds], dim=1)
        new_attention_mask = None
        if attention_mask is not None:
            projected_vector_mask = torch.ones(
                attention_mask.shape[0], 1, dtype=attention_mask.dtype, device=attention_mask.device
            )
            new_attention_mask = torch.cat([projected_vector_mask, attention_mask], dim=1)
        new_labels = None
        if labels is not None:
            projected_vector_label = torch.full(
                (labels.shape[0], 1), -100, dtype=labels.dtype, device=labels.device
            )
            new_labels = torch.cat([projected_vector_label, labels], dim=1)
        return self.language_model(
            inputs_embeds=inputs_embeds, attention_mask=new_attention_mask,
            labels=new_labels, **kwargs
        )

# ==============================================================================
# 3. DATA GENERATION AND PREPARATION (No changes needed)
# ==============================================================================
# The data generation and formatting functions remain the same.

def generate_synthetic_dataset(number_of_records: int, vector_dimension: int) -> Dataset:
    # (The function body is the same as before, so it is omitted here for brevity)
    if vector_dimension % 3 != 0: raise ValueError("vector_dimension must be divisible by 3.")
    source_data: Dict[str, List[Tuple[str, str]]] = {
        "science": [("Что такое черная дыра?", "Чёрная дыра — это область пространства-времени, гравитационное притяжение которой настолько велико, что покинуть её не могут даже объекты, движущиеся со скоростью света."), ("Объясни фотосинтез.", "Фотосинтез — это сложный химический процесс преобразования энергии видимого света в энергию химических связей органических веществ."),],
        "history": [("Расскажи о Ренессансе.", "Эпоха Возрождения, или Ренессанс, — это период в истории культуры Европы, пришедший на смену Средним векам и предшествующий Просвещению."), ("Кто такой Юлий Цезарь?", "Гай Юлий Цезарь был древнеримским государственным и политическим деятелем, полководцем и писателем."),],
        "creative": [("Придумай шутку про программиста.", "Почему программисты так не любят природу? Слишком много багов."), ("Напиши короткий стих о космосе.", "Средь миллиардов звёздных троп, летит бесшумно телескоп. Он ищет дом, он ищет свет, вдали от суетных планет."),],
    }
    records_list: List[Dict[str, Any]] = []
    categories: List[str] = list(source_data.keys())
    chunk_size: int = vector_dimension // 3
    for _ in range(number_of_records):
        chosen_category: str = random.choice(categories)
        prompt, response = random.choice(source_data[chosen_category])
        custom_vector = np.zeros(vector_dimension, dtype=np.float32)
        for i in range(3):
            start_index, end_index = i * chunk_size, (i + 1) * chunk_size
            custom_vector[start_index:end_index] = np.random.uniform(0.0, 0.2, size=chunk_size)
        category_index = categories.index(chosen_category)
        start_index, end_index = category_index * chunk_size, (category_index + 1) * chunk_size
        custom_vector[start_index:end_index] = np.random.uniform(0.8, 1.0, size=chunk_size)
        records_list.append({"prompt": prompt, "response": response, "custom_vector": custom_vector})
    return Dataset.from_list(records_list)

def formatting_prompts_func(example: Dict[str, Any]) -> Dict[str, Any]:
    text_parts = [
        f"<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n",
        f"<|im_start|>user\n{example['prompt']}<|im_end|>\n",
        f"<|im_start|>assistant\n{example['response']}<|im_end|>"
    ]
    example["text"] = "".join(text_parts) + tokenizer.eos_token
    return example

class ConditionalDataCollator:
    def __init__(self, tokenizer):
        self.tokenizer = tokenizer
    def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]:
        tokenized_inputs = self.tokenizer(
            [f["text"] for f in features], return_tensors="pt", padding=True,
            truncation=True, max_length=max_seq_length
        )
        custom_vectors = torch.tensor([f["custom_vector"] for f in features], dtype=torch.float)
        tokenized_inputs["labels"] = tokenized_inputs["input_ids"].clone()
        tokenized_inputs["custom_vector"] = custom_vectors
        return tokenized_inputs

print("==> Step 3: Generating and preparing dataset...")
NUMBER_OF_RECORDS = 200
VECTOR_DIMENSION = 30
synthetic_dataset = generate_synthetic_dataset(
    number_of_records=NUMBER_OF_RECORDS, vector_dimension=VECTOR_DIMENSION
)
processed_dataset = synthetic_dataset.map(formatting_prompts_func, num_proc=4)
print("==> Dataset prepared successfully!\n")


# ==============================================================================
# 4. TRAINING SETUP (Updated for Qwen3)
# ==============================================================================
print("==> Step 4: Setting up the training components...")
custom_model = ConditionalLM(language_model=model, custom_vector_size=VECTOR_DIMENSION)

# LoRA configuration for Qwen3
lora_config = LoraConfig(
    r=16,
    lora_alpha=16,
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM",
    # NOTE: We assume these are the correct modules for Qwen3,
    # as they are standard for Qwen1.5 and Qwen2. This is an educated guess.
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj",
                      "gate_proj", "up_proj", "down_proj"],
    modules_to_save=["projection_layer"], # Don't forget our custom layer!
)

training_arguments = TrainingArguments(
    output_dir="qwen3_1.7b_conditional_finetune", # <-- Updated output directory
    per_device_train_batch_size=2,
    gradient_accumulation_steps=4,
    warmup_steps=5,
    max_steps=100,
    learning_rate=2e-4,
    fp16=not torch.cuda.is_bf16_supported(),
    bf16=torch.cuda.is_bf16_supported(),
    logging_steps=1,
    optim="adamw_8bit",
    weight_decay=0.01,
    lr_scheduler_type="linear",
    seed=42,
)

data_collator = ConditionalDataCollator(tokenizer=tokenizer)

trainer = SFTTrainer(
    model=custom_model,
    args=training_arguments,
    train_dataset=processed_dataset,
    peft_config=lora_config,
    dataset_text_field="text",
    max_seq_length=max_seq_length,
    tokenizer=tokenizer,
    data_collator=data_collator,
)
print("==> Trainer is ready for Qwen3 1.7B!\n")
print("To start training, run the command: trainer.train()")