In [None]:
from datasets import load_from_disk

ds = load_from_disk("./data/russian_dialogues_20k")

In [None]:
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    TrainingArguments,
    Trainer,
    DataCollatorForLanguageModeling,
)
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training, PeftModel
from transformers import BitsAndBytesConfig
from datasets import load_from_disk
import torch
import os
import torch

In [None]:
def load_model_and_tokenizer(model_name: str, local_dir: str = "./model/TinyLlama"):

    bnb_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_use_double_quant=True,
        bnb_4bit_compute_dtype=torch.float16,
    )

    if os.path.exists(local_dir) and os.path.isdir(local_dir):
        print(f"üîÅ –ó–∞–≥—Ä—É–∑–∫–∞ –º–æ–¥–µ–ª–∏ –∏–∑ –ª–æ–∫–∞–ª—å–Ω–æ–≥–æ –∫–∞—Ç–∞–ª–æ–≥–∞: {local_dir}")
        tokenizer = AutoTokenizer.from_pretrained(local_dir)
        model = AutoModelForCausalLM.from_pretrained(
            local_dir,
            device_map="auto",
            torch_dtype=torch.float16,
            quantization_config=bnb_config
        )
    else:
        print(f"‚¨áÔ∏è –ú–æ–¥–µ–ª—å –Ω–µ –Ω–∞–π–¥–µ–Ω–∞ –ª–æ–∫–∞–ª—å–Ω–æ, –∑–∞–≥—Ä—É–∂–∞–µ–º –∏–∑ –∏–Ω—Ç–µ—Ä–Ω–µ—Ç–∞ –∏ —Å–æ—Ö—Ä–∞–Ω—è–µ–º –≤: {local_dir}")
        tokenizer = AutoTokenizer.from_pretrained(model_name)
        model = AutoModelForCausalLM.from_pretrained(
            model_name,
            device_map="auto",
            torch_dtype=torch.float16,
            quantization_config=bnb_config
        )
        # –°–æ—Ö—Ä–∞–Ω—è–µ–º –º–æ–¥–µ–ª—å –∏ —Ç–æ–∫–µ–Ω–∏–∑–∞—Ç–æ—Ä –≤ –ª–æ–∫–∞–ª—å–Ω—ã–π –∫–∞—Ç–∞–ª–æ–≥
        tokenizer.save_pretrained(local_dir)
        model.save_pretrained(local_dir)

    tokenizer.pad_token = tokenizer.eos_token
    tokenizer.padding_side = "left"
    model.config.use_cache = False
    model.config.pretraining_tp = 1

    return tokenizer, model


In [None]:
def apply_lora(model):
    model = prepare_model_for_kbit_training(model)

    lora_config = LoraConfig(
        r=8,
        lora_alpha=16,
        target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
        lora_dropout=0.05,
        bias="none",
        task_type="CAUSAL_LM"
    )

    model = get_peft_model(model, lora_config)
    model.gradient_checkpointing_enable()  

    return model

In [None]:
def tokenize_russian_dialogues(example, tokenizer, max_length=512):
    prompt = example["prompt"].strip()
    response = example["response"].strip()

    text = (
        "<|system|>\n–¢—ã –ø–æ–ª–µ–∑–Ω—ã–π AI-–∞—Å—Å–∏—Å—Ç–µ–Ω—Ç.\n"
        f"<|user|>\n{prompt}\n"
        f"<|assistant|>\n{response}"
    )

    return tokenizer(
        text,
        truncation=True,
        padding="max_length",
        max_length=max_length
    )


In [None]:
def tokenize_dataset(dataset, tokenizer, tokenize_fn, max_length=512):
    def wrapped(example):
        return tokenize_fn(example, tokenizer, max_length=max_length)

    return dataset.map(wrapped, batched=True, remove_columns=dataset.column_names)

In [None]:
def train_lora(
    model_name: str,
    dataset_path: str,
    output_dir: str,
    epochs: int = 3,
    batch_size: int = 4,
    lr: float = 2e-4,
    max_length: int = 512,
    tokenize_fn=None
):
    # –ó–∞–≥—Ä—É–∑–∫–∞
    tokenizer, model = load_model_and_tokenizer(model_name)
    model = apply_lora(model)
    dataset = load_from_disk(dataset_path)
    dataset = tokenize_dataset(
        dataset["train"],
        tokenizer=tokenizer,
        tokenize_fn=tokenize_fn,
        max_length=max_length
    )

    data_collator = DataCollatorForLanguageModeling(
        tokenizer=tokenizer,
        mlm=False  # –æ–±—è–∑–∞—Ç–µ–ª—å–Ω–æ –¥–ª—è CausalLM
    )


    torch.backends.cudnn.allow_tf32 = True
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = True
    torch.backends.cudnn.benchmark = True
    torch.backends.cudnn.enabled = True
    torch.backends.cudnn.deterministic = False


    # –ê—Ä–≥—É–º–µ–Ω—Ç—ã –æ–±—É—á–µ–Ω–∏—è
    training_args = TrainingArguments(
        output_dir=output_dir,
        per_device_train_batch_size=batch_size,
        gradient_accumulation_steps=4,
        lr_scheduler_type="cosine",
        warmup_steps=100,
        num_train_epochs=epochs,
        learning_rate=lr,
        fp16=True,
        logging_steps=20,
        save_steps=200,
        save_total_limit=2,
        report_to="none",
        evaluation_strategy="no"
    )

    # Trainer
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=dataset,
        tokenizer=tokenizer,
        data_collator=data_collator  
    )

    # –û–±—É—á–µ–Ω–∏–µ
    trainer.train()
    model.save_pretrained(output_dir)

In [None]:
def merge_lora_adapter(
    base_model_path: str,
    adapter_path: str,
    save_path: str
):
    # –ó–∞–≥—Ä—É–∂–∞–µ–º –±–∞–∑–æ–≤—É—é –º–æ–¥–µ–ª—å –∏ —Ç–æ–∫–µ–Ω–∏–∑–∞—Ç–æ—Ä —á–µ—Ä–µ–∑ —É–Ω–∏–≤–µ—Ä—Å–∞–ª—å–Ω—É—é —Ñ—É–Ω–∫—Ü–∏—é
    tokenizer, model = load_model_and_tokenizer(base_model_path)

    # –ó–∞–≥—Ä—É–∂–∞–µ–º LoRA –∏ –æ–±—ä–µ–¥–∏–Ω—è–µ–º
    model = PeftModel.from_pretrained(model, adapter_path)
    model = model.merge_and_unload()

    # –°–æ—Ö—Ä–∞–Ω—è–µ–º –æ–±—ä–µ–¥–∏–Ω—ë–Ω–Ω—É—é –º–æ–¥–µ–ª—å
    model.save_pretrained(save_path)
    tokenizer.save_pretrained(save_path)

    print(f"‚úÖ –û–±—ä–µ–¥–∏–Ω—ë–Ω–Ω–∞—è –º–æ–¥–µ–ª—å —Å–æ—Ö—Ä–∞–Ω–µ–Ω–∞ –≤: {save_path}")

In [None]:
train_lora(
    model_name="TinyLlama/TinyLlama-1.1B-Chat-v1.0",
    dataset_path="./data/russian_dialogues_20k",
    output_dir="./checkpoints/step_dialogue",
    epochs=3,
    batch_size=4,
    tokenize_fn=tokenize_russian_dialogues
)
merge_lora_adapter(
    base_model_path="./model/TinyLlama",                  # ‚Üê –ª–æ–∫–∞–ª—å–Ω–æ —Å–æ—Ö—Ä–∞–Ω—ë–Ω–Ω–∞—è/–æ–±–Ω–æ–≤–ª—ë–Ω–Ω–∞—è –º–æ–¥–µ–ª—å
    adapter_path="./checkpoints/step_dialogue",           # ‚Üê –ø—É—Ç—å –∫ –æ–±—É—á–µ–Ω–Ω–æ–º—É –∞–¥–∞–ø—Ç–µ—Ä—É
    save_path="./model/TinyLlama"                         # ‚Üê –ø–µ—Ä–µ–∑–∞–ø–∏—Å—ã–≤–∞–µ–º –º–æ–¥–µ–ª—å
)