<a href="https://colab.research.google.com/github/M1croZavr/CoTResearch/blob/master/CoT_SelfConsistency_research.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Установка и импортирование необходимых библиотек и git clone репозитория с необходимым кодом и данными. В случае запуска с petals, необходимо раскомментировать соответствующие ячейки.

In [None]:
# %pip install -q petals

In [None]:
!git clone https://github.com/M1croZavr/CoTResearch.git

In [None]:
!python --version

In [None]:
from google.colab import drive


drive.mount('./drive')

In [None]:
import torch
import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
import time
import json
import shutil
import requests
from pathlib import Path
from tqdm.auto import tqdm
# from transformers import BloomTokenizerFast, set_seed
# from petals import DistributedBloomForCausalLM
from CoTResearch.data_preprocessing import FormattedPrompts, FormattedInputs
from CoTResearch.data_postprocessing import AnswersList

In [None]:
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'Running device: {DEVICE}')

Загрузка модели BLOOM из petals

In [None]:
# MODEL_NAME = "bigscience/bloom-petals"
# tokenizer = BloomTokenizerFast.from_pretrained(MODEL_NAME)
# model = DistributedBloomForCausalLM.from_pretrained(MODEL_NAME)
# model = model.to(DEVICE)

Пример одного prompt c 2 CoT

In [None]:
example_prompts = FormattedPrompts(
    data_path=Path('CoTResearch/GSM8K_data/train_data.jsonl'),
    n_exemplars=2, 
    random_seed=123
    )
example_prompts.sample_prompts()
example_inputs = FormattedInputs(example_prompts)


with open(Path('CoTResearch/GSM8K_data/test_data.jsonl')) as file:
    example_prompt = example_inputs.sample_input(file.readline())
print(example_prompt)

Инициализируем объект форматированных prompts и делаю сэмплинг из тренировочного набора. Для проведения экспериментов устанавливаю некоторый seed для формирования экземпляров prompts и выбора тестовых вопросов из GSM8K

In [None]:
PROMPTS_SET_SEED = 14799
TEST_SPLIT_SEED = 77777
prompts = FormattedPrompts(
    data_path=Path('CoTResearch/GSM8K_data/train_data.jsonl'),
    n_exemplars=5,
    random_seed=PROMPTS_SET_SEED
    )
prompts.sample_prompts()
inputs = FormattedInputs(prompts)


# Build few-shot prompting subsample dataset
# Number of data points = 100, number of paths ensembled = 10
N_DATA_POINTS = 100
N_PATHS = 10
with open(Path('CoTResearch/GSM8K_data/test_data.jsonl')) as file:
    lines = file.readlines()
    np.random.seed(TEST_SPLIT_SEED)
    data_points_indices = np.random.randint(0, len(lines), size=(N_DATA_POINTS, ))
    for data_point_index in data_points_indices:
        inputs.sample_input(lines[data_point_index])

In [None]:
print(inputs.inputs[0])

Объект answers_list хранит отформатированные ответы модели и истинные ответы

In [None]:
answers_list = AnswersList()

В Hugging Face Inference API использую аналогичную модель BLOOM 176B и свой токен для использования http API моделей 

In [None]:
MODEL_NAME = "bloom"
API_URL = f"https://api-inference.huggingface.co/models/bigscience/{MODEL_NAME}"
HEADERS = {"Authorization": "Bearer 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"}


def query(payload):
    response = requests.post(API_URL, headers=HEADERS, json=payload)
    return response.json()

Цикл получения генераций по всему тестовому набору данных при помощи сформированных входов. Для каждого входа генерирую несколько вариантов, чтобы в дальнейшем агрегировать.

In [None]:
# self-consistency chain of though prompting
for i in tqdm(range(N_DATA_POINTS)):
    prompt = inputs.inputs[i]
    gt_answer = inputs.ground_truths[i]
    predictions = []
    paths_completed = 0
    while paths_completed < N_PATHS:
        time.sleep(5)
        try:
            output = query(
                payload={
                    "inputs": prompt.strip(),
                    "parameters": {
                        "top_k": 30,
                        "top_p": None,
                        "temperature": 0.2,
                        "repetition_penalty": None,
                        "max_new_tokens": 249,
                        "max_time": None,
                        "return_full_text": False,
                        "num_return_sequences": 1,
                        "do_sample": True,
                        "stop": ["Q:", "\n\n"]
                    },
                    "options": {
                        "use_cache": False,
                        "wait_for_model": True
                    }
                }
            )
            predictions.append(output[0]["generated_text"])
            print(output)
        except Exception as e:
            print(f'Exception occured on iteration {i}/{[paths_completed]}...{e}')
            continue
        else:
            paths_completed += 1
    answers_list.add_answer(predictions, gt_answer)
    if i % 10 == 0:
        answers_list.write_to_file(f'./drive/MyDrive/{PROMPTS_SET_SEED}_ensemble10_30_02.jsonl')
    print('\n')

answers_list.write_to_file(f'./drive/MyDrive/{PROMPTS_SET_SEED}_ensemble10_30_02.jsonl')

Цикл получения генераций по всему тестовому набору данных при помощи сформированных prompts и сэмплирования с агрегацией. Использование petals distributed

In [None]:
for i in tqdm(range(N_DATA_POINTS)):
    prompt = inputs.inputs[i]
    gt_answer = inputs.ground_truths[i]
    tokenized_prompt = tokenizer(prompt, return_tensors="pt")["input_ids"].to(DEVICE)
    predictions = []
    for j in range(N_PATHS):
        outputs = model.generate(
            tokenized_prompt,
            max_new_tokens=128,
            return_full_text=False,
            stop=['\n\n', 'Q:'],
            # num_return_sequences=1  # number of paths for ansembling
        )
        predicted_answer = tokenizer.decode(
            outputs[0],
            # truncate_before_pattern=[r'\n\n', r'Q:']
        )
        predictions.append(predicted_answer)
    answers_list.add_answer(predictions, gt_answer)

# Анализ полученных результатов

In [None]:
!ls './CoTResearch/experiments/ensemble_results'

In [None]:
def extract_result(filepath: str):
    with open(filepath) as f:
        return [json.loads(line) for line in f.readlines()]


ensemble_result1 = AnswersList(extract_result('./CoTResearch/experiments/ensemble_results/12345_ensemble5_50_06.jsonl'))
ensemble_result2 = AnswersList(extract_result('./CoTResearch/experiments/ensemble_results/14799_ensemble10_30_02.jsonl'))
ensemble_result3 = AnswersList(extract_result('./CoTResearch/experiments/ensemble_results/77777_ensemble10_40_015.jsonl'))
ensemble_result4 = AnswersList(extract_result('./CoTResearch/experiments/ensemble_results/77777_ensemble5_40_085.jsonl'))
ensemble_results_list = [ensemble_result1, ensemble_result2, ensemble_result3,
                         ensemble_result4]

In [None]:
ensemble_results_acc = list(map(lambda x: x.calculate_accuracy(), ensemble_results_list))
ensemble_results_mean = np.mean(ensemble_results_acc)
ensemble_results_std = np.std(ensemble_results_acc)
print('Среднее значение accuracy:', ensemble_results_mean)
print('Стандартное отклонение accuracy:', ensemble_results_std)
print(ensemble_results_acc)

In [None]:
plt.figure(figsize=(12, 9))
plt.bar(0, ensemble_results_acc[0], label='5 paths | top_k=50 | T=0.6', color='forestgreen')
plt.bar(1, ensemble_results_acc[1], label='10 paths | top_k=30 | T=0.2', color='limegreen')
plt.bar(2, ensemble_results_acc[2], label='10 paths | top_k=40 | T=0.15', color='turquoise')
plt.bar(3, ensemble_results_acc[3], label='5 paths | top_k=40 | T=0.85', color='mediumseagreen')
plt.xticks(range(len(ensemble_results_acc)), labels=['12345', '14799', '77777', '77777'])
plt.ylim((0, 0.25))
plt.xlabel('random seed')
plt.ylabel('Test accuracy')
plt.legend();

In [None]:
def atleast_one(answers_list_object: AnswersList) -> int:
    '''Хотя бы одно вхождение правильного ответа в ансамблевый выход'''
    correct = 0
    for item in answers_list_object:
        gt = item['ground_truth']
        preds = item['predicted']
        if gt in preds:
            correct += 1
        else:
            try:
                gt = float(gt)
                for pred in preds:
                    pred = float(pred)
                    if abs(gt - pred) < 0.01:
                        correct += 1
                        break
            except:
                continue   
    return correct

In [None]:
errors_analysis_df = pd.DataFrame(
    data=[[14, 5],
          [17, 10],
          [19, 10],
          [15, 5]],
    index=pd.Series(['12345', '14799', '77777', '77777'], name='random seed'),
    columns=['Корректные', 'Количество сгенерированных \'размышлений\' для ансамблирования']
)
atleast_array = np.array([atleast_one(ensemble_result1), atleast_one(ensemble_result2), atleast_one(ensemble_result3), atleast_one(ensemble_result4)])
errors_analysis_df['Количество CoT в одном prompt'] = 5
errors_analysis_df['Хотя бы один правильный'] = atleast_array
errors_analysis_df['Всего'] = 100
errors_analysis_df