# Блок инициализации ученика/учителя

In [1]:
import re
from datasets import Dataset
from transformers import AutoModelForCausalLM, AutoTokenizer

def define_model(path_to_model):
    tokenizer = AutoTokenizer.from_pretrained(path_to_model)
    model = AutoModelForCausalLM.from_pretrained(path_to_model,
                                                 torch_dtype="auto",
                                                 device_map="auto")
    return tokenizer, model

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
student_tokenizer, student_model = define_model('./student_model/')
# teacher_tokenizer, teacher_model = define_model('./teacher_model/')

`torch_dtype` is deprecated! Use `dtype` instead!
The module name  (originally ) is not a valid Python identifier. Please rename the original module to avoid import issues.


# Блок парсинга правильных ответов от учителя для обучения ученика

In [None]:
import re
from datasets import Dataset

def parse_csqa_txt(filepath):
    with open(filepath, 'r', encoding='utf-8') as f:
        content = f.read()
    blocks = content.split(f"\n{'-'*100}\n")
    
    filtered_examples = []

    for block in blocks:
        if not block.strip():
            continue
        try:
            correct_answer = block[block.find('Correct answer from original table: ')+len('Correct answer from original table: '):block.find('\n\n<|im_start|>user')]
        except Exception:
            continue

        student_text = block[block.find('<|im_start|>assistant\n<think>\n\n</think>\n')+len('<|im_start|>assistant\n<think>\n\n</think>\n'):]

        patterns = [
            r'(?:correct|final|best)?\s*answer\s*[—:-]?\s*\*?([A-E])\*?',
            r'answer\s+is\s*:?\s*\*?([A-E])\*?',
            r'so the (?:best|correct) answer is\s*:?\s*\*?([A-E])',
            r'^\s*\*?([A-E])\s*—',
            r'\b([A-E])\s*—\s*[a-z]',
        ]

        student_answer = None
        for pattern in patterns:
            match = re.search(pattern, student_text, re.IGNORECASE | re.MULTILINE)
            if match:
                student_answer = match.group(1).upper()
                break

        if student_answer is None:
            candidates = re.findall(r'\b([A-E])\b', student_text, re.IGNORECASE)
            if candidates:
                student_answer = candidates[-1].upper()

        if student_answer and correct_answer == student_answer:
            clean_block = block.replace('<think>\n\n</think>\n\n', '')
            filtered_examples.append({'text': clean_block})

    return Dataset.from_list(filtered_examples)

dataset = parse_csqa_txt('train_data_CSQA.txt')
print(f"Осталось {len(dataset)} правильных примеров")

In [None]:
def parse_gsm8k_txt(filepath):
    with open(filepath, "r", encoding="utf-8") as f:
        content = f.read()

    pattern = r'<\|im_start\|>user\s*(.*?)\s*<\|im_end\|>\s*<\|im_start\|>assistant\s*(.*?)\s*<\|im_end\|>'
    matches = re.findall(pattern, content, flags=re.DOTALL)

    examples = []
    for user_content, assistant_content in matches:
        cleaned_assistant = re.sub(r'(<think>.*?</think>)', '', assistant_content, flags=re.DOTALL).strip()
        examples.append({
            "messages": [
                {"role": "user", "content": user_content.strip()},
                {"role": "assistant", "content": cleaned_assistant}
            ]
        })
    return Dataset.from_list(examples)

In [6]:
import re
from datasets import Dataset

def parse_coin_flip(filepath):
    with open(filepath, 'r', encoding='utf-8') as f:
        content = f.read()
    blocks = content.split(f"\n{'-'*100}\n")
    
    filtered_examples = []

    for block in blocks:
        if not block.strip():
            continue
        try:
            correct_answer = block[block.find('Correct answer: ')+len('Correct answer: '):block.find('\n\n<|im_start|>user')]
            # print(f"{correct_answer=}")
        except Exception:
            continue

        teacher_text = block[block.find('<|im_start|>assistant\n<think>\n\n</think>\n')+len('<|im_start|>assistant\n<think>\n\n</think>\n\n'):]

        teacher_text = teacher_text.lower()
        # print(f"{teacher_text=}")

        teacher_answer = teacher_text[teacher_text .find('**answer: ')+len('**answer: '):]
        # print(f"{teacher_answer=}")

        if correct_answer in teacher_answer:
            clean_block = block.replace('<think>\n\n</think>\n\n', '')
            filtered_examples.append({'text': clean_block})

    return Dataset.from_list(filtered_examples)

dataset = parse_coin_flip('train_data_Coin_Flip.txt')
print(f"Осталось {len(dataset)} правильных примеров")

Осталось 1588 правильных примеров


# Блок LoRA-обертки и обучения

In [7]:
from peft import LoraConfig, get_peft_model, TaskType

lora_config = LoraConfig(r=16, 
                         lora_alpha=32,
                         target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
                         lora_dropout=0.05, 
                         bias="none", 
                         task_type=TaskType.CAUSAL_LM)

peft_model = get_peft_model(student_model, lora_config)
peft_model.print_trainable_parameters()

trainable params: 10,092,544 || all params: 606,142,464 || trainable%: 1.6650


In [9]:
from trl import SFTTrainer, SFTConfig
OUTPUT_DIR = "./qwen2_rank_16_Coin_Flip"

sft_config = SFTConfig(
    output_dir=OUTPUT_DIR,
    per_device_train_batch_size=2,
    gradient_accumulation_steps=8,
    num_train_epochs=3,
    learning_rate=1e-4,
    bf16=True,
    logging_steps=10,
    save_strategy="steps",
    save_steps=200,
    save_total_limit=3,
    optim="adamw_torch",
    lr_scheduler_type="cosine",
    warmup_ratio=0.1,
    report_to="none",
    remove_unused_columns=True,    
    neftune_noise_alpha=0.6,
    # max_seq_length=1024,
    packing=False,
    dataset_text_field="text"
)

trainer = SFTTrainer(
    model=peft_model,
    train_dataset=dataset,
    args=sft_config,
    processing_class=student_tokenizer
)

trainer.train()

Adding EOS to train dataset: 100%|██████████| 1588/1588 [00:00<00:00, 46853.18 examples/s]
Tokenizing train dataset: 100%|██████████| 1588/1588 [00:00<00:00, 8406.13 examples/s]
Truncating train dataset: 100%|██████████| 1588/1588 [00:00<00:00, 1502493.74 examples/s]
The model is already on multiple devices. Skipping the move to device specified in `args`.


Step,Training Loss
10,2.3044
20,1.3886
30,0.7924
40,0.5787
50,0.4791
60,0.4589
70,0.4148
80,0.4239
90,0.4016
100,0.3964




TrainOutput(global_step=300, training_loss=0.4986765185991923, metrics={'train_runtime': 996.275, 'train_samples_per_second': 4.782, 'train_steps_per_second': 0.301, 'total_flos': 1713908613120000.0, 'train_loss': 0.4986765185991923, 'epoch': 3.0})

# Блок валидации обученного ученика

In [10]:
# Инициализация дообученного ученика

from peft import PeftModel

supa_model = PeftModel.from_pretrained(student_model, "./qwen2_rank_16_Coin_Flip/checkpoint-300")
supa_tokenizer = AutoTokenizer.from_pretrained("./qwen2_rank_16_Coin_Flip/checkpoint-300")
supa_tokenizer.pad_token = supa_tokenizer.eos_token



In [11]:
# Инициализация few-shot'a

few_shot_GSM8K = '''
Problem:
Olivia has $23. She bought five cupcakes for $3 each and a milkshake for $4. How much money does she have left?
Solution:
First, calculate the total cost of the cupcakes:
5 cupcakes × $3 = $15.
Add the cost of the milkshake: $15 + $4 = $19.
Subtract from her initial amount: $23 − $19 = $4.
So, Olivia has $4 left.
Final Answer: 4​

Problem:
A bakery sells cookies in packs of 6. If a customer buys 9 packs, how many cookies does the customer get in total?
Solution:
Each pack contains 6 cookies.
The customer buys 9 packs.
Total cookies = 6 × 9 = 54.
Final Answer: 54

Problem:
There are 42 students in a class. One-third of them are boys. How many girls are in the class?
Solution:
Number of boys = 42 ÷ 3 = 14.
Number of girls = total students − boys = 42 − 14 = 28.
Final Answer: 28

Problem:
A car travels 60 miles per hour. How many miles does it travel in 2 hours and 30 minutes?
Solution:
Convert 2 hours 30 minutes to hours: 2.5 hours.
Distance = speed × time = 60 × 2.5 = 150 miles.
Final Answer: 150

Problem:
James has 3 times as many marbles as Lisa. Together, they have 48 marbles. How many marbles does James have?

Solution:
Let Lisa have x marbles.
Then James has 3x marbles.
Together: x + 3x = 4x = 48.
So, x = 48 ÷ 4 = 12.
James has 3 × 12 = 36 marbles.
Final Answer: 36
'''

In [15]:
# Функция генерации ответа на запрос

def invoke_llm(type_of_model, task, type):
    if type_of_model == 'teacher':
        tokenizer = teacher_tokenizer
        model = teacher_model
    # Когда обучим ученика
    elif type_of_model == 'student':
        tokenizer = supa_tokenizer
        model = supa_model
    # elif type_of_model == 'student':
    #         tokenizer = student_tokenizer
    #         model = student_model
    else:
        return None
    
    if type == 'GSM8K':
        few_shot = few_shot_GSM8K
    # elif type == 'Coin Flip':
    #     few_shot = few_shot_Coin_Flip
    # elif type == 'CSQA':
    #     few_shot = few_shot_CSQA

    if type  == 'GSM8K':
        content = few_shot + '\n' + task
    elif type in ['CSQA', 'Coin Flip']:
        content = task
    
    messages = [{"role": "user", 
                "content": content}]
        
    text = tokenizer.apply_chat_template(messages,
                                         tokenize=False,
                                         add_generation_prompt=True,
                                         enable_thinking=False)
    
    model_inputs = tokenizer([text], return_tensors="pt").to(model.device)

    generated_ids = model.generate(
        **model_inputs,
        max_new_tokens=32768
    )

    output_ids = generated_ids[0][len(model_inputs.input_ids[0]):].tolist()
    output_text = tokenizer.decode(output_ids, skip_special_tokens=True).strip("\n")
    return output_text

In [16]:
# Импорт данных

import polars as pl

# GSM8K
splits_GSM8K = {'train': 'main/train-00000-of-00001.parquet'}
df_train_GSM8K = pl.read_parquet("hf://datasets/openai/gsm8k/" + splits_GSM8K["train"])
questions_GSM8K = df_train_GSM8K['question'].to_list()
answers_GSM8K = df_train_GSM8K['answer'].to_list()

# CSQA
splits_CSQA = {'train': 'data/train-00000-of-00001.parquet', 'validation': 'data/validation-00000-of-00001.parquet', 'test': 'data/test-00000-of-00001.parquet'}
df_train_CSQA = pl.read_parquet('hf://datasets/tau/commonsense_qa/' + splits_CSQA['train'])
questions_CSQA = df_train_CSQA['question'].to_list()
choices_CSQA = df_train_CSQA['choices'].to_list()
answers_CSQA = df_train_CSQA['answerKey'].to_list()

# Coin Flip
from datasets import load_dataset
dataset = load_dataset("skrishna/coin_flip")
df_train = dataset['train'].to_pandas()  # или .to_dict(), если удобнее
df_train_coinf_flip = pl.from_pandas(df_train)
question_coin_flip_test = df_train_coinf_flip['inputs'].to_list()
answer_coin_flip_test = df_train_coinf_flip['targets'].to_list()

In [17]:
# Предобработка coin flip (установка заглавных букв для имен)

def find_all_indices(text, substring):
    indices = []
    start_index = 0
    while True:
        index = text.find(substring, start_index)
        if index == -1:
            break
        indices.append(index+len(substring))
        start_index = index + 1
    return indices

for p in range(len(question_coin_flip_test)):
    question_coin_flip_test[p] = question_coin_flip_test[p].replace(' Q: ', '')
    question_coin_flip_test[p] = question_coin_flip_test[p].replace('  Is the', ' Is the') + 'Answer yes or no.'

    new_massive = []
    a_coin = 'A coin is heads up. '
    flip = 'flip the coin. '
    flips = 'flips the coin. '

    idx1 = find_all_indices(question_coin_flip_test[p], a_coin)[0]
    idx2 = find_all_indices(question_coin_flip_test[p], flip)
    idx3 = find_all_indices(question_coin_flip_test[p], flips)

    question_coin_flip_test[p] = question_coin_flip_test[p][:idx1] + question_coin_flip_test[p][idx1].upper() + question_coin_flip_test[p][idx1+1:]

    # Для flip the coin. 
    for fl in [idx2, idx3]:
        if isinstance(fl, list):
            if len(fl) != 0:
                for elem in fl:
                    question_coin_flip_test[p] = question_coin_flip_test[p][:elem] + question_coin_flip_test[p][elem].upper() + question_coin_flip_test[p][elem+1:]
        elif isinstance(fl, int):
            question_coin_flip_test[p] = question_coin_flip_test[p][:fl] + question_coin_flip_test[p][fl].upper() + question_coin_flip_test[p][fl+1:]

In [18]:
# Генерация обученным учеником Coin Flip

for p in range(1, 11):
    thousand_counter = 0
    with open(f'./student_outputs_Coin_Flip_posttrain/{p}_16.txt', 'w', encoding='utf-8', buffering=1) as f:
        for q, a in zip(question_coin_flip_test, answer_coin_flip_test):
            student_response = invoke_llm('student', q, 'Coin Flip')
            correct_answer = a
            if thousand_counter == 100:
                break
            if '**Answer: ' in student_response:
                f.write(f'Correct answer: {correct_answer}\n\n')
                f.write(f'Student solution: {student_response}\n\n')
                f.write('-'*100 + '\n\n')
                f.flush()
                print(f'Correct answer: {correct_answer}\n\nStudent solution: {student_response}')
                print(f'-----'*20)
                thousand_counter += 1

print("Готово!")

Correct answer: no

Student solution: - The coin starts **heads up**.
- **Sager does **not** flip** the coin.
- **Zyheir flips** the coin.

Since Zyheir flips the coin, it changes from **heads** to **tails**.

**Answer: No.**
----------------------------------------------------------------------------------------------------
Correct answer: yes

Student solution: No.  

The coin is **heads up** initially.  
- **Mailey does not flip** the coin, so it remains heads up.  
- **Maurisa does not flip** the coin, so it remains heads up.  

So, the answer is: **No**, the coin is **not** still heads up.  

**Answer: No.**
----------------------------------------------------------------------------------------------------
Correct answer: no

Student solution: - The coin starts **heads up**.
- **Murraylee does **not** flip** the coin.
- **Meilich flips** the coin.

Since Meilich flips the coin, the coin will **change** from heads to tails.

**Answer: No.**
----------------------------------------

In [None]:
# Генерация обученным учеником CSQA

for p in range(1, 11):
    with open(f'./student_outputs_CSQA_posttrain/{p}_16.txt', 'w', encoding='utf-8', buffering=1) as f:
        for q, c, a in zip(questions_CSQA[:100], choices_CSQA[:100], answers_CSQA[:100]):
            
            labels = c['label']
            text_labels = c['text']
            variants = ''

            for glyph, txt in zip(labels, text_labels):
                if glyph == labels[-1]:
                    variants = variants + glyph + ' — ' + txt
                else:
                    variants = variants + glyph + ' — ' + txt + ', '

            print(f'{q + ' Here are the answer options: ' + variants}')

            student_response = invoke_llm('student', q + '\nHere are the answer options:\n' + variants, 'CSQA')
            correct_answer = a

            f.write(f'Correct answer: {correct_answer}\n\n')
            f.write(f'Student solution: {student_response}\n\n')
            f.write('-'*100 + '\n\n')
            f.flush()
            print(f'Correct answer: {correct_answer}\n\nStudent solution: {student_response}')
            print(f'-----'*20)

print("Готово!")

In [None]:
# Подсчет числа правильных ответов

str_to_delete = '''prompt="\nProblem:\nOlivia has $23. She bought five cupcakes for $3 each and a milkshake for $4. How much money does she have left?\nSolution:\nFirst, calculate the total cost of the cupcakes:\n5 cupcakes × $3 = $15.\nAdd the cost of the milkshake: $15 + $4 = $19.\nSubtract from her initial amount: $23 − $19 = $4.\nSo, Olivia has $4 left.\nFinal Answer: 4\u200b\n\nProblem:\nA bakery sells cookies in packs of 6. If a customer buys 9 packs, how many cookies does the customer get in total?\nSolution:\nEach pack contains 6 cookies.\nThe customer buys 9 packs.\nTotal cookies = 6 × 9 = 54.\nFinal Answer: 54\n\nProblem:\nThere are 42 students in a class. One-third of them are boys. How many girls are in the class?\nSolution:\nNumber of boys = 42 ÷ 3 = 14.\nNumber of girls = total students − boys = 42 − 14 = 28.\nFinal Answer: 28\n\nProblem:\nA car travels 60 miles per hour. How many miles does it travel in 2 hours and 30 minutes?\nSolution:\nConvert 2 hours 30 minutes to hours: 2.5 hours.\nDistance = speed × time = 60 × 2.5 = 150 miles.\nFinal Answer: 150\n\nProblem:\nJames has 3 times as many marbles as Lisa. Together, they have 48 marbles. How many marbles does James have?\n\nSolution:\nLet Lisa have x marbles.\nThen James has 3x marbles.\nTogether: x + 3x = 4x = 48.\nSo, x = 48 ÷ 4 = 12.\nJames has 3 × 12 = 36 marbles.\nFinal Answer: 36\n\nJanet’s ducks lay 16 eggs per day. She eats three for breakfast every morning and bakes muffins for her friends every day with four. She sells the remainder at the farmers' market daily for $2 per fresh duck egg. How much in dollars does she make every day at the farmers' market?'''

student_score = []
teacher_score = []

import glob
import re
txt_arr = []

for filename in glob.glob("*_ft_32.txt"):
    if '_ft' in filename:
        txt_arr.append(filename)

for txt in txt_arr:
    print(f'{txt=}')
    with open(f'{txt}', 'r', encoding='utf-8') as f:
        data = f.read()
    if str_to_delete in data:
        data = data.replace(str_to_delete, '')
    splitted = data.split(f'-----'*20)
    # print(splitted)

    comparison_dict = {'student_score': 0,
                       'teacher_score': 0}

    for task in splitted:
        # Correct answer
        correct_answer = task[task.find('Correct answer: ')+len('Correct answer: '):task.find('\n\nStudent solution')]
        correct_answer = correct_answer.replace(',', '').replace('.', '')
        
        # Student parsed answer
        idx = task.find('\n\nTeacher solution')
        if idx == -1:
            student_answer = task[-50:]
        else:
            start = max(0, idx - 50)
            student_answer = task[start:idx]

        student_answer = student_answer.replace(',', '').replace('.', '')

        # Teacher parsed answer
        teacher_answer = task[-50:].replace(',', '').replace('.', '')

        student_answer = re.findall(r'-?\b\d+\b', student_answer)
        teacher_answer = re.findall(r'-?\b\d+\b', teacher_answer)
        
        # print(f"{correct_answer=}")
        # print(f"{student_answer=}")
        # print(f"{teacher_answer=}")

        if correct_answer in student_answer:
            comparison_dict['student_score'] += 1
        if correct_answer in teacher_answer:
            comparison_dict['teacher_score'] += 1
        
        # print('---------------------------------')
    print(f'{comparison_dict=}')
    student_score.append(comparison_dict['student_score'])
    teacher_score.append(comparison_dict['teacher_score'])
    print('---------------------------------')

In [None]:
import re
from datasets import Dataset

def parse_gsm8k_txt(filepath):
    with open(filepath, "r", encoding="utf-8") as f:
        content = f.read()

    # Находим все пары user/assistant
    pattern = r'<\|im_start\|>user\s*(.*?)\s*<\|im_end\|>\s*<\|im_start\|>assistant\s*(.*?)\s*<\|im_end\|>'
    matches = re.findall(pattern, content, flags=re.DOTALL)

    examples = []
    for user_content, assistant_content in matches:
        # Убираем возможные <think>...</think>, если они есть (по вашему примеру их нет, но на всякий случай)
        cleaned_assistant = re.sub(r'<think>\s*</think>', '', assistant_content, flags=re.DOTALL).strip()
        
        examples.append({
            "prompt": [{"role": "user", "content": user_content.strip()}],
            "completion": [{"role": "assistant", "content": cleaned_assistant}]
        })

    return Dataset.from_list(examples)

# Пример использования
dataset = parse_gsm8k_txt('train_data_GSM8K.txt')
print(f"Полных примеров: {len(dataset)}")
print(next(iter(dataset)))