## Данные

Скачивание и считывание данных

In [1]:
%%capture
!git clone https://github.com/openai/grade-school-math.git

In [2]:
import re
import json
from os import path
from typing import Dict, List, Union

In [3]:
BASE_PATH = "/content/grade-school-math/grade_school_math/data/"

raw_test = open(
    path.join(BASE_PATH, "test.jsonl")).read().splitlines()
test = [json.loads(line) for line in raw_test]

In [4]:
test[1]

{'question': 'A robe takes 2 bolts of blue fiber and half that much white fiber.  How many bolts in total does it take?',
 'answer': 'It takes 2/2=<<2/2=1>>1 bolt of white fiber\nSo the total amount of fabric is 2+1=<<2+1=3>>3 bolts of fabric\n#### 3'}

In [5]:
len(test)

1319

8-shot затравка из оригинальной статьи, которую также будем использовать

In [6]:
# Few-shot exemplars for all arithmetic reasoning tasks, from Wei et al. (2022)
# table 17 from https://arxiv.org/pdf/2203.11171.pdf
ARTICLE_PROMPT = """Q: There are 15 trees in the grove. Grove workers will plant trees in the grove today. After they are done, there will be 21 trees. How many trees did the grove workers plant today?
A: We start with 15 trees. Later we have 21 trees. The difference must be the number of trees they planted. So, they must have planted 21 - 15 = 6 trees. The answer is 6.
Q: If there are 3 cars in the parking lot and 2 more cars arrive, how many cars are in the parking lot?
A: There are 3 cars in the parking lot already. 2 more arrive. Now there are 3 + 2 = 5 cars. The answer is 5.
Q: Leah had 32 chocolates and her sister had 42. If they ate 35, how many pieces do they have left in total?
A: Leah had 32 chocolates and Leah’s sister had 42. That means there were originally 32 + 42 = 74 chocolates. 35 have been eaten. So in total they still have 74 - 35 = 39 chocolates. The answer is 39.
Q: Jason had 20 lollipops. He gave Denny some lollipops. Now Jason has 12 lollipops. How many lollipops did Jason give to Denny?
A: Jason had 20 lollipops. Since he only has 12 now, he must have given the rest to Denny. The number of lollipops he has given to Denny must have been 20 - 12 = 8 lollipops. The answer is 8.
Q: Shawn has five toys. For Christmas, he got two toys each from his mom and dad. How many toys does he have now?
A: He has 5 toys. He got 2 from mom, so after that he has 5 + 2 = 7 toys. Then he got 2 more from dad, so in total he has 7 + 2 = 9 toys. The answer is 9.
Q: There were nine computers in the server room. Five more computers were installed each day, from monday to thursday. How many computers are now in the server room?
A: There are 4 days from monday to thursday. 5 computers were added each day. That means in total 4 * 5 = 20 computers were added. There were 9 computers in the beginning, so now there are 9 + 20 = 29 computers. The answer is 29.
Q: Michael had 58 golf balls. On tuesday, he lost 23 golf balls. On wednesday, he lost 2 more. How many golf balls did he have at the end of wednesday?
A: Michael initially had 58 balls. He lost 23 on Tuesday, so after that he has 58 - 23 = 35 balls. On Wednesday he lost 2 more so now he has 35 - 2 = 33 balls. The answer is 33.
Q: Olivia has $23. She bought five bagels for $3 each. How much money does she have left?
A: She bought 5 bagels for $3 each. This means she spent 5 * $3 = $15 on the bagels. She had $23 in beginning, so now she has $23 - $15 = $8. The answer is 8.
"""

Вспомогательные функции для работы с данными

In [7]:
def get_dataset_answer(answer: str) -> int:
    number = re.split("#### ", answer)[-1]

    # for 2,125 cases
    return int(re.sub(',', '', number))

def compose_prompt(question: str) -> str:
    return f'{ARTICLE_PROMPT}Q: {question}\nA:'

## BLOOM 3B

Информация об окружении и его подготовка

In [8]:
!nvidia-smi

Fri Feb 17 22:12:52 2023       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 510.47.03    Driver Version: 510.47.03    CUDA Version: 11.6     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla T4            Off  | 00000000:00:04.0 Off |                    0 |
| N/A   50C    P0    28W /  70W |      0MiB / 15360MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [1]:
%%capture
!pip install --quiet bitsandbytes
!pip install --quiet git+https://github.com/huggingface/transformers.git
!pip install --quiet accelerate

In [2]:
!pip freeze > enviroment_info.txt

Подключение библиотек и инициализация всего необходимого

In [11]:
import torch
import pickle
import numpy as np
import pandas as pd
from scipy import stats
from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer

In [12]:
SEED = 1234
np.random.seed(SEED)
torch.manual_seed(SEED)

<torch._C.Generator at 0x7f9d7bbc8050>

In [13]:
MODEL_NAME = "bigscience/bloom-3b"

model_8bit = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME, 
    device_map="auto",
    load_in_8bit=True
)
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

Downloading (…)lve/main/config.json:   0%|          | 0.00/693 [00:00<?, ?B/s]



Downloading (…)"pytorch_model.bin";:   0%|          | 0.00/6.01G [00:00<?, ?B/s]


Welcome to bitsandbytes. For bug reports, please submit your error trace to: https://github.com/TimDettmers/bitsandbytes/issues


Downloading (…)okenizer_config.json:   0%|          | 0.00/222 [00:00<?, ?B/s]

Downloading (…)"tokenizer.json";:   0%|          | 0.00/14.5M [00:00<?, ?B/s]

Downloading (…)cial_tokens_map.json:   0%|          | 0.00/85.0 [00:00<?, ?B/s]

In [14]:
# Этот путь вам нужно будет изменить и убедиться, что папка по нему существует
SAVE_PATH = '/content/drive/MyDrive/ML/Tinkoff NLP Lab 23/bloom_3b/'

In [15]:
start_sample_idx = 0
end_sample_idx = 139

reasoning_paths = 40
batch_size = 10
assert reasoning_paths % batch_size == 0

In [16]:
all_results = []

Инференс модели на данных

In [None]:
for sample_idx in tqdm(range(start_sample_idx, end_sample_idx)):
    sample_outputs = []
    for batch_idx in range(reasoning_paths // batch_size):
        prompt_to_infer = compose_prompt(test[sample_idx]['question'])
        text_to_infer = [prompt_to_infer] * batch_size

        encoded_input = tokenizer(text_to_infer, return_tensors='pt')
        outputs = model_8bit.generate(
            input_ids=encoded_input['input_ids'].cuda(),
            max_new_tokens=100, 
            do_sample=True, 
            top_p=0.9
        )

        for output in outputs:
            decoded_seq = tokenizer.decode(output, skip_special_tokens=True)
            decoded_seq = decoded_seq[len(prompt_to_infer) + 1:].splitlines()[0]
            sample_outputs.append(decoded_seq)
        
    all_results.append(sample_outputs)

    with open(f'{SAVE_PATH}{sample_idx}_res.pkl', 'wb') as handle:
        pickle.dump(sample_outputs, handle, protocol=pickle.HIGHEST_PROTOCOL)

## Результаты

Запаковка результатов в DataFrame

In [23]:
def get_predicted_answer(text: str) -> Union[None, float]:
    # match numbers with floating points
    # from https://www.regular-expressions.info/floatingpoint.html
    splitted = re.findall('[-+]?[0-9]*\.?[0-9]+', text)
    return float(splitted[-1]) if len(splitted) else None

In [24]:
df_raws = []
for sample_idx in range(len(all_results)):
    generated_answers = []
    for answer_string in all_results[sample_idx]:
        generated_answers.append(get_predicted_answer(answer_string))
    df_raws.append(generated_answers)

In [25]:
df = pd.DataFrame(df_raws)

df['mode'] = df.iloc[:, :reasoning_paths].mode(axis=1)[0]

df['correct'] = [
    get_dataset_answer(test[i]['answer']) 
    for i in range(len(all_results))
]

df['mode_eq_correct'] = df['mode'] == df['correct']

df_for_cmp = pd.DataFrame([df['correct'].values] * reasoning_paths).T
df['n_matches'] = (df.iloc[:, :reasoning_paths] == df_for_cmp).sum(axis=1)

df

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,34,35,36,37,38,39,mode,correct,mode_eq_correct,n_matches
0,12.0,2.0,104.0,18.0,104.0,250.0,36.0,2.0,3.0,38.0,...,1.0,24.0,72.00,460.0,12.0,80.0,2.0,18,False,1
1,8.0,8.0,4.0,3.0,2.0,1.0,4.0,7.5,4.0,2.0,...,-2.0,4.0,1.50,9.5,2.0,4.0,4.0,3,False,4
2,0.0,150.0,50.0,0.0,0.0,0.0,150.0,0.0,0.0,50.0,...,0.0,0.0,650.00,0.0,0.0,0.0,0.0,70000,False,0
3,720.0,2550.0,120.0,240.0,30.0,600.0,60.0,30.0,240.0,360.0,...,1600.0,960.0,6.25,2160.0,720.0,720.0,720.0,540,False,0
4,10.0,75.0,30.0,2440.0,90.0,35.0,45.0,17.0,,31.0,...,10.0,60.0,7.00,45.0,110.0,900.0,10.0,20,False,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
134,44.0,60.0,4500.0,16.0,8.0,72.0,200.0,120.0,0.0,40.0,...,3600.0,20.0,60.00,60.0,150.0,0.0,60.0,720,False,0
135,40.0,10.0,4.0,15.0,45.0,15.0,15.0,5.0,25.0,15.0,...,5.0,20.0,5.00,10.0,20.0,25.0,15.0,40,False,1
136,91.0,22.0,67.0,6.0,17.0,13.0,14.0,20.0,23.0,39.0,...,11.0,16.0,9.00,15.0,23.0,14.0,6.0,6,True,3
137,24.5,16.0,5.4,22.5,23.0,5.0,30.0,5.0,100.0,18.5,...,5.0,20.0,23.00,12.5,12.5,20.0,5.0,29,False,0


In [26]:
name = '_'.join(map(str, df.shape))
df.to_csv(f'{SAVE_PATH}{name}.csv')

"Корректные" рассуждения, которые привели к верному ответу...

In [46]:
test[33]['question']

'Gretchen has 110 coins. There are 30 more gold coins than silver coins. How many gold coins does Gretchen have?'

In [47]:
all_results[33][1]

'There are 110 coins. There are 30 more gold coins than silver coins. That means 110 - 30 = 70 gold coins. So Gretchen has 70 gold coins now.'

In [48]:
all_results[33][35]

'Gretchen has 110 coins. There are 30 more gold coins than silver coins. This means that she has 110 - 30 = 70 gold coins. She still has 70 coins, so the answer is 70.'

Итоговые статистики

In [49]:
plain_CoT = df['n_matches'] / reasoning_paths
mode_CoT = df['mode_eq_correct'].mean()
na_percent = df.isna().values.sum() / (reasoning_paths * df.shape[0])

print(f'plain CoT accuracy:  mean = {plain_CoT.mean() * 100:.2f}%')
print(f'voting CoT accuracy: mean = {mode_CoT * 100:.2f}%')
print(f'NaN percent in df: {na_percent * 100:.2f}%')

plain CoT accuracy:  mean = 2.50%
voting CoT accuracy: mean = 4.32%
NaN percent in df: 1.15%


## Выводы

Целью работы было сравнение обычного Chain-of-Thoughts и способа его ансамблировования на датасете GSM8K с помощью модели BLOOM 176B.

В процессе исследования я столкнулся с технической сложностью: инференс BLOOM 176B от petals занимал чрезвучайно много времени. Приходилось либо часами ожидать, когда будет достаточно пиров для запуска модели, либо ждать минуту для генерации 100 новых токенов. В оригинальной статье используют 40 reasoning paths. Значит, для генерации хотя бы 100 семплов потребовалось бы потратить 4000 минут, что заняло бы почти трое суток. Google Colab таких долгих сессий не предоставляет, поэтому я решил использовать BLOOM 3B 8bit. Это привело к значительной потере качества, но зато позволило провести эксперимент за разумное время.

**Как проводился эксперимент**: бралась затравка из восьми задач (такая же, как в [оригинальной статье](https://arxiv.org/pdf/2203.11171.pdf)), их ответов и задача, которую предлагалось решить модели. Из полученного ответа извлекалась первая строка и бралось последнее число. Такая последовательность действий повторялась 40 раз для каждой задачи. Каждое из сорока чисел считалось ответом для обычного CoT, а результатом для его ансамблированной версии назначалась их мода. 
Модель останавливала генерацию после 100-го токена (в 99% случаев этого хватало для генерации ответа целиком) и семплировала данные, используя top_p = 0.9 стратегию. Я не стал экспериментировать с вариантами генерации, так как в статье было показано (figure 4), что "Self-consistency is robust to various sampling strategies and
parameters."

**Результаты**: были сгенерированы ответы для 139 задач. Обычный CoT показал 2.50% accuracy, ансамблированная же его версия — 4.32% (x1.7). При этом BLOOM 3B была даже лучше (4.32% > 4.1%), чем PaLM 8B с обычным CoT ([table 3](https://arxiv.org/pdf/2206.14858v2.pdf)) при более чем двукратно меньшем числе весов. Это подтверждает разумность стратегии голосования при генерации нескольких вариантов решения.

**Как можно улучшить результаты:**
1. Взять языковую модель большего размера (BLOOM 176B, GPT3 175B, PaLM 540B, etc)
2. Просить модель генерировать арифметические выражения между специальными символами, чтобы вызывать калькулятор для проведения корректных вычислений (текущая модель часто генерирует строчки типа *That means 110 - 30 = 70 gold coins.*)
3. Увеличить число reasoning path
4. Повысить число примеров в промпте
5. Делать голосование ещё и по разным промптам (можно их переставлять местами, можно брать разнообразные затравки, перефразировать сам вопрос)
6. Вместо стратегии голосования можно брать среднеарифметическое по ответам, моду по символам или даже BPE-токенам
7. У процесса генерации можно немного увеличить скорость и точность (за счёт уменьшения числа NaN'ов), если генерировать не N новых токенов, а останавливать генерацию при переходе на новую строку. Для этого достаточно переопределить stopping_criteria в huggingface