# Блок обучения на данных учителя

In [None]:
import re
from datasets import Dataset
from transformers import AutoModelForCausalLM, AutoTokenizer

def define_model(path_to_model):
    tokenizer = AutoTokenizer.from_pretrained(path_to_model)
    model = AutoModelForCausalLM.from_pretrained(path_to_model,
                                                 torch_dtype="auto",
                                                 device_map="auto")
    return tokenizer, model

In [None]:
student_tokenizer, student_model = define_model('./student_model/')

In [None]:
# Приведение данных для дообучения в ChatML-формат

def parse_gsm8k_txt(filepath):
    with open(filepath, "r", encoding="utf-8") as f:
        content = f.read()
    pattern = r'(<\|im_start\|>user[\s\S]*?<\|im_end\|>\s*<\|im_start\|>assistant[\s\S]*?<\|im_end\|>)'
    
    matches = re.findall(pattern, content)
    
    examples = []
    for match in matches:
        cleaned = re.sub(r'<think>\s*\n\s*</think>', '', match, flags=re.DOTALL)
        cleaned = cleaned.strip()

        examples.append({"text": cleaned})
    
    return Dataset.from_list(examples)

dataset = parse_gsm8k_txt('train_data.txt')
print(f"Полных примеров: {len(dataset)}")

In [None]:
# Оборачивание в LoRA

from peft import LoraConfig, get_peft_model, TaskType

lora_config = LoraConfig(r=16, 
                         lora_alpha=32,
                         target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
                         lora_dropout=0.05, 
                         bias="none", 
                         task_type=TaskType.CAUSAL_LM)

peft_model = get_peft_model(student_model, lora_config)
peft_model.print_trainable_parameters()

In [None]:
# Обучение

from trl import SFTTrainer, SFTConfig
OUTPUT_DIR = "./qwen2_rank_32"

sft_config = SFTConfig(
    output_dir=OUTPUT_DIR,
    per_device_train_batch_size=2,
    gradient_accumulation_steps=8,
    num_train_epochs=3,
    learning_rate=1e-4,
    bf16=True,
    logging_steps=10,
    save_strategy="steps",
    save_steps=200,
    save_total_limit=3,
    optim="adamw_torch",
    lr_scheduler_type="cosine",
    warmup_ratio=0.1,
    report_to="none",
    remove_unused_columns=True,    
    neftune_noise_alpha=0.6,
    # max_seq_length=1024,
    packing=False,
    dataset_text_field="text"
)

trainer = SFTTrainer(
    model=peft_model,
    train_dataset=dataset,
    args=sft_config,
    processing_class=student_tokenizer
)

trainer.train()

# Блок валидации обученного ученика

In [None]:
# Инициализация дообученного ученика

from peft import PeftModel

supa_model = PeftModel.from_pretrained(student_model, "./qwen2_rank_32/checkpoint-231")
supa_tokenizer = AutoTokenizer.from_pretrained("./qwen2_rank_32/checkpoint-231")
supa_tokenizer.pad_token = supa_tokenizer.eos_token

In [None]:
# Инициализация few-shot'a

few_shot = '''
Problem:
Olivia has $23. She bought five cupcakes for $3 each and a milkshake for $4. How much money does she have left?
Solution:
First, calculate the total cost of the cupcakes:
5 cupcakes × $3 = $15.
Add the cost of the milkshake: $15 + $4 = $19.
Subtract from her initial amount: $23 − $19 = $4.
So, Olivia has $4 left.
Final Answer: 4​

Problem:
A bakery sells cookies in packs of 6. If a customer buys 9 packs, how many cookies does the customer get in total?
Solution:
Each pack contains 6 cookies.
The customer buys 9 packs.
Total cookies = 6 × 9 = 54.
Final Answer: 54

Problem:
There are 42 students in a class. One-third of them are boys. How many girls are in the class?
Solution:
Number of boys = 42 ÷ 3 = 14.
Number of girls = total students − boys = 42 − 14 = 28.
Final Answer: 28

Problem:
A car travels 60 miles per hour. How many miles does it travel in 2 hours and 30 minutes?
Solution:
Convert 2 hours 30 minutes to hours: 2.5 hours.
Distance = speed × time = 60 × 2.5 = 150 miles.
Final Answer: 150

Problem:
James has 3 times as many marbles as Lisa. Together, they have 48 marbles. How many marbles does James have?

Solution:
Let Lisa have x marbles.
Then James has 3x marbles.
Together: x + 3x = 4x = 48.
So, x = 48 ÷ 4 = 12.
James has 3 × 12 = 36 marbles.
Final Answer: 36
'''

In [None]:
# Функция генерации ответа учеником

def invoke_llm(type_of_model, task, few_shot=few_shot):
    if type_of_model == 'student':
        tokenizer = supa_tokenizer
        model = supa_model
    else:
        return None
    
    messages = [{"role": "user", 
                 "content": few_shot + '/n' + task}]
    
    text = tokenizer.apply_chat_template(messages,
                                         tokenize=False,
                                         add_generation_prompt=True,
                                         enable_thinking=False)
    
    model_inputs = tokenizer([text], return_tensors="pt").to(model.device)

    generated_ids = model.generate(
        **model_inputs,
        max_new_tokens=32768
    )

    output_ids = generated_ids[0][len(model_inputs.input_ids[0]):].tolist()
    content = tokenizer.decode(output_ids, skip_special_tokens=True).strip("\n")
    return content

In [None]:
# Импорт данных для валидации

import polars as pl

splits = {'test': 'main/test-00000-of-00001.parquet'}
df_train = pl.read_parquet("hf://datasets/openai/gsm8k/" + splits["train"])

questions = df_test['question'].to_list()
answers = df_test['answer'].to_list()

In [None]:
# Генерация обученным учеником

for p in range(1,16):
    with open(f'{p}_ft_32.txt', 'w', encoding='utf-8') as f:
        for q, a in zip(questions[:250], answers[:250]):
            student_response = invoke_llm('student', q)
            correct_answer = a[a.find('\n#### ')+len('\n#### '):]

            f.write(f'Correct answer: {correct_answer}\n\n')
            f.write(f'Student solution: {student_response}\n\n')
            f.write('-'*100 + '\n\n')
            
            print(f'Correct answer: {correct_answer}\n\nStudent solution: {student_response}')
            print(f'-----'*20)

print("Готово!")

In [None]:
# Подсчет числа правильных ответов

str_to_delete = '''prompt="\nProblem:\nOlivia has $23. She bought five cupcakes for $3 each and a milkshake for $4. How much money does she have left?\nSolution:\nFirst, calculate the total cost of the cupcakes:\n5 cupcakes × $3 = $15.\nAdd the cost of the milkshake: $15 + $4 = $19.\nSubtract from her initial amount: $23 − $19 = $4.\nSo, Olivia has $4 left.\nFinal Answer: 4\u200b\n\nProblem:\nA bakery sells cookies in packs of 6. If a customer buys 9 packs, how many cookies does the customer get in total?\nSolution:\nEach pack contains 6 cookies.\nThe customer buys 9 packs.\nTotal cookies = 6 × 9 = 54.\nFinal Answer: 54\n\nProblem:\nThere are 42 students in a class. One-third of them are boys. How many girls are in the class?\nSolution:\nNumber of boys = 42 ÷ 3 = 14.\nNumber of girls = total students − boys = 42 − 14 = 28.\nFinal Answer: 28\n\nProblem:\nA car travels 60 miles per hour. How many miles does it travel in 2 hours and 30 minutes?\nSolution:\nConvert 2 hours 30 minutes to hours: 2.5 hours.\nDistance = speed × time = 60 × 2.5 = 150 miles.\nFinal Answer: 150\n\nProblem:\nJames has 3 times as many marbles as Lisa. Together, they have 48 marbles. How many marbles does James have?\n\nSolution:\nLet Lisa have x marbles.\nThen James has 3x marbles.\nTogether: x + 3x = 4x = 48.\nSo, x = 48 ÷ 4 = 12.\nJames has 3 × 12 = 36 marbles.\nFinal Answer: 36\n\nJanet’s ducks lay 16 eggs per day. She eats three for breakfast every morning and bakes muffins for her friends every day with four. She sells the remainder at the farmers' market daily for $2 per fresh duck egg. How much in dollars does she make every day at the farmers' market?'''

student_score = []
teacher_score = []

import glob
import re
txt_arr = []

for filename in glob.glob("*_ft_32.txt"):
    if '_ft' in filename:
        txt_arr.append(filename)

for txt in txt_arr:
    print(f'{txt=}')
    with open(f'{txt}', 'r', encoding='utf-8') as f:
        data = f.read()
    if str_to_delete in data:
        data = data.replace(str_to_delete, '')
    splitted = data.split(f'-----'*20)
    # print(splitted)

    comparison_dict = {'student_score': 0,
                       'teacher_score': 0}

    for task in splitted:
        # Correct answer
        correct_answer = task[task.find('Correct answer: ')+len('Correct answer: '):task.find('\n\nStudent solution')]
        correct_answer = correct_answer.replace(',', '').replace('.', '')
        
        # Student parsed answer
        idx = task.find('\n\nTeacher solution')
        if idx == -1:
            student_answer = task[-50:]
        else:
            start = max(0, idx - 50)
            student_answer = task[start:idx]

        student_answer = student_answer.replace(',', '').replace('.', '')

        # Teacher parsed answer
        teacher_answer = task[-50:].replace(',', '').replace('.', '')

        student_answer = re.findall(r'-?\b\d+\b', student_answer)
        teacher_answer = re.findall(r'-?\b\d+\b', teacher_answer)
        
        # print(f"{correct_answer=}")
        # print(f"{student_answer=}")
        # print(f"{teacher_answer=}")

        if correct_answer in student_answer:
            comparison_dict['student_score'] += 1
        if correct_answer in teacher_answer:
            comparison_dict['teacher_score'] += 1
        
        # print('---------------------------------')
    print(f'{comparison_dict=}')
    student_score.append(comparison_dict['student_score'])
    teacher_score.append(comparison_dict['teacher_score'])
    print('---------------------------------')