In [2]:
!pip install -q petals datasets

/bin/bash: /opt/conda/lib/libtinfo.so.6: no version information available (required by /bin/bash)
[0m

In [3]:
# Базовые импорты
import os

import torch
from transformers import BloomTokenizerFast
from petals import DistributedBloomForCausalLM
from tqdm import tqdm
import json
from collections import Counter

## Попробуем сгенерировать текст при помощи модели

In [4]:
MODEL_NAME = "bigscience/bloom-petals"
TUNING_MODE = 'ptune'
NUM_PREFIX_TOKENS = 16
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
SEED = 42
MODEL_MAX_LENGTH = 256

In [5]:
tokenizer = BloomTokenizerFast.from_pretrained(MODEL_NAME)
tokenizer.padding_side = 'right'
tokenizer.model_max_length = MODEL_MAX_LENGTH
model = DistributedBloomForCausalLM.from_pretrained(
    MODEL_NAME,
    pre_seq_len=NUM_PREFIX_TOKENS, 
    tuning_mode=TUNING_MODE,
    request_timeout=1000
).to(DEVICE)

Downloading:   0%|          | 0.00/263 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/14.5M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/96.0 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/641 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/7.19G [00:00<?, ?B/s]

In [6]:
TOP_K = 100
TEMPERATURE = 0.6
user_phrase = 'Привет, как твои дела?\n'
inputs = tokenizer(user_phrase, return_tensors='pt')['input_ids'].to(DEVICE)
outputs = model.generate(
                inputs,
                temperature=TEMPERATURE,
                do_sample=True,
                top_k=TOP_K,
                max_new_tokens=8,
            )
bloom_answer_token = tokenizer.decode(outputs[0])
print(bloom_answer_token)

Feb 27 17:16:36.183 [[1m[38;5;208mWARN[0m] [[1m/opt/conda/lib/python3.7/site-packages/petals/client/inference_session.py.step:312[0m] Caught exception when running inference from block 0 (retry in 0 sec): ControlFailure('Connect failed. msg=failed to find peers: routing: not found')
Feb 27 17:16:36.515 [[1m[38;5;208mWARN[0m] [[1m/opt/conda/lib/python3.7/site-packages/petals/client/inference_session.py.step:312[0m] Caught exception when running inference from block 0 (retry in 1 sec): ControlFailure('Connect failed. msg=failed to find peers: routing: not found')


Привет, как твои дела?
Здравствуй, как


Отлично, не без проблем, но с моделью установлен контакт, можно переходить к следующему этапу.

## Теперь посмотрим на датасет, с которым предлагается поэкспериментировать

In [7]:
with open('/kaggle/input/chain-of-thoughts/test.jsonl') as file:
    data = [json.loads(line) for line in file.readlines() if line]
data[:3]

[{'question': "Janet’s ducks lay 16 eggs per day. She eats three for breakfast every morning and bakes muffins for her friends every day with four. She sells the remainder at the farmers' market daily for $2 per fresh duck egg. How much in dollars does she make every day at the farmers' market?",
  'answer': 'Janet sells 16 - 3 - 4 = <<16-3-4=9>>9 duck eggs a day.\nShe makes 9 * 2 = $<<9*2=18>>18 every day at the farmer’s market.\n#### 18'},
 {'question': 'A robe takes 2 bolts of blue fiber and half that much white fiber.  How many bolts in total does it take?',
  'answer': 'It takes 2/2=<<2/2=1>>1 bolt of white fiber\nSo the total amount of fabric is 2+1=<<2+1=3>>3 bolts of fabric\n#### 3'},
 {'question': 'Josh decides to try flipping a house.  He buys a house for $80,000 and then puts in $50,000 in repairs.  This increased the value of the house by 150%.  How much profit did he make?',
  'answer': 'The cost of the house and repairs came out to 80,000+50,000=$<<80000+50000=130000>>130

Как мы видим, предложенный датасет состоит из списка словарей с вопросом и развернутым ответом. Отметим, что в ответах численные выражения отдельно выделены в <<>>, а также правильный ответ дополнительно помещен в конец строки после ####. Насколько я понял, ответы на все вопросы в GSM8K представляют собой целые числа, что немного облегачает задачу.

## Перейдем к эксперименту, сперва поймем что необходимо делать и что мы ожидаем

Отметим, что в статье https://arxiv.org/pdf/2203.11171.pdf (SELF-CONSISTENCY IMPROVES CHAIN OF THOUGHT REASONING IN LANGUAGE MODELS) есть аналогичные эксперименты с моделью GPT-3 c 175B параметров, что близко к рассматриваемой нами BLOOM-176B. Причем есть табличные результаты и подтверждающие их json-файлы на том же датасете GSM8K, так что нам есть на что ориентироваться по качеству. 

Основная рассматриваемая метрика - это Accuracy, то есть доля правильных ответов.

Также возьмем параметры сэмплирования из статьи, температура T=0.7, top-k с k=40. (Замечу, что для GPT-3 в работе отсечение по top-k не использовали совсем, но для BLOOM его лучше осуществлять). Далее данные параметры можно будет потюнить, чтобы попробовать получить эффект робастности из статьи.

Теперь натравим языковую модель на арифметические задачи из GSM8K и измерим accuracy в случае жадного CoT и self-consistency CoT.

In [8]:
PROMT="""Q: There are 15 trees in the grove. Grove workers will plant trees in the grove today. After they are done, there will be 21 trees. How many trees did the grove workers plant today?
A: There are 15 trees originally. Then there were 21 trees after some more were planted. So there must have been 21 - 15 = 6. The answer is 6.
\n
Q: If there are 3 cars in the parking lot and 2 more cars arrive, how many cars are in the parking lot?
A: There are originally 3 cars. 2 more cars arrive. 3 + 2 = 5. The answer is 5.
\n
Q: Leah had 32 chocolates and her sister had 42. If they ate 35, how many pieces do they have left in total?
A: Originally, Leah had 32 chocolates. Her sister had 42. So in total they had 32 + 42 = 74. After eating 35, they had 74 - 35 = 39. The answer is 39.
\n
Q: Jason had 20 lollipops. He gave Denny some lollipops. Now Jason has 12 lollipops. How many lollipops did Jason give to Denny?
A: Jason started with 20 lollipops. Then he had 12 after giving some to Denny. So he gave Denny 20 - 12 = 8. The answer is 8.
\n
Q: Shawn has five toys. For Christmas, he got two toys each from his mom and dad. How many toys does he have now?
A: Shawn started with 5 toys. If he got 2 toys each from his mom and dad, then that is 4 more toys. 5 + 4 = 9. The answer is 9.
\n
Q: There were nine computers in the server room. Five more computers were installed each day, from monday to thursday. How many computers are now in the server room?
A: There were originally 9 computers. For each of 4 days, 5 more computers were added. So 5 * 4 = 20 computers were added. 9 + 20 is 29. The answer is 29.
\n
Q: Michael had 58 golf balls. On tuesday, he lost 23 golf balls. On wednesday, he lost 2 more. How many golf balls did he have at the end of wednesday?
A: Michael started with 58 golf balls. After losing 23 on tuesday, he had 58 - 23 = 35. After losing 2 more, he had 35 - 2 = 33 golf balls. The answer is 33.
\n
Q: Olivia has $23. She bought five bagels for $3 each. How much money does she have left?
A: Olivia had 23 dollars. 5 bagels for 3 dollars each will be 5 x 3 = 15 dollars. So she has 23 - 15 dollars left. 23 - 15 is 8. The answer is 8.
\n
"""

В качестве промта воспользуемся предложенными авторами https://arxiv.org/pdf/2201.11903.pdf (Chain-of-Thought Prompting Elicits Reasoning in Large Language Models) для арифметических задач. К тому же их же использовали в статье про self-consistency, так что сравнению с полученными там результатами будет более корректно. В конце каждого ответа в промте есть фраза "The answer is ...", что должно помочь языковой модели выдавать ответ.

Посмотрим на то, как справится модель с одним вопросом из датасета.

In [13]:
q = "Janet’s ducks lay 16 eggs per day. She eats three for breakfast every morning and bakes muffins for her friends every day with four. She sells the remainder at the farmers' market daily for $2 per fresh duck egg. How much in dollars does she make every day at the farmers' market?"
text = PROMT + 'Q: ' + q + '\nA: '
inputs = tokenizer(text, return_tensors='pt')['input_ids'].to(DEVICE)
stop_token_ids = tokenizer(tokenizer.eos_token)['input_ids']
outputs = model.generate(inputs, max_new_tokens=100, temperature=0.7, top_k=40, do_sample=True, stop_token_ids=stop_token_ids)

Feb 27 17:27:08.733 [[1m[38;5;208mWARN[0m] [[1m/opt/conda/lib/python3.7/site-packages/petals/client/inference_session.py.step:312[0m] Caught exception when running inference from block 0 (retry in 0 sec): ControlFailure('Connect failed. msg=failed to find peers: routing: not found')


In [14]:
tokenizer.decode(outputs[0])

"Q: There are 15 trees in the grove. Grove workers will plant trees in the grove today. After they are done, there will be 21 trees. How many trees did the grove workers plant today?\nA: There are 15 trees originally. Then there were 21 trees after some more were planted. So there must have been 21 - 15 = 6. The answer is 6.\n\n\nQ: If there are 3 cars in the parking lot and 2 more cars arrive, how many cars are in the parking lot?\nA: There are originally 3 cars. 2 more cars arrive. 3 + 2 = 5. The answer is 5.\n\n\nQ: Leah had 32 chocolates and her sister had 42. If they ate 35, how many pieces do they have left in total?\nA: Originally, Leah had 32 chocolates. Her sister had 42. So in total they had 32 + 42 = 74. After eating 35, they had 74 - 35 = 39. The answer is 39.\n\n\nQ: Jason had 20 lollipops. He gave Denny some lollipops. Now Jason has 12 lollipops. How many lollipops did Jason give to Denny?\nA: Jason started with 20 lollipops. Then he had 12 after giving some to Denny. So 

Стоит отметить, что ответа пришлось ждать порядка 10 минут. Так что придется работать с небольшим количеством примеров.

# Вспомогательные функции
Используются для того, чтобы получить ответ из предсказания нейросети и для агрегации нескольких ответов.

In [15]:
import re
NUMBER_SET = [str(num) for num in range(0, 10)]

def _is_float(s):
    try:
        float(s)
        return True
    except:
        return False

FINAL_ANS = 'answer is '
def clean_ans(ans):
    index = ans.find('.')
    if index >= 0:
        end_index = index + 1
        while end_index < len(ans) and ans[end_index] in NUMBER_SET:
            end_index += 1
        ans = ans[:end_index]
    while ans and ans.endswith('.'):
        ans = ans[:-1]
  
    ans = ans.split('=')[-1].strip()
    for c in ['$', ',', '%', '€', '"']:
        ans = ans.replace(c, '')
    parts = ans.split(' ')
    for part in parts:
        if _is_float(part):
            return part
  
    ans = parts[0]  # default
    for part in parts:
        if not part.isalpha():  # take the 1st non-alpha token
            ans = part
            break
    while ans and ans[-1].isalpha():
        ans = ans[:-1]
    return ans.strip()
    
def get_ans(pred):
    text = pred.split('Q:')[0].split('[eot]')[0].replace('\n', '').strip()
    if text.rfind(FINAL_ANS) >= 0:
        pred_ans = text[text.rfind(FINAL_ANS) + len(FINAL_ANS):len(text)].strip()
        return clean_ans(pred_ans)
    else:
        return ''


from collections import Counter
def get_maj(ans_list):
    is_all_float = True
    float_list = []
    for ans in ans_list:
        if _is_float(ans):
            float_list.append(float(ans))
        else:
            is_all_float = False
            break
    if is_all_float:
        f = Counter(float_list)
        return f.most_common()[0][0]
    else:
        c = Counter(ans_list)
        return c.most_common()[0][0]

def get_str_ans(pred):
    text = pred.split('Q:')[0].split('[eot]')[0].replace('\n', '').strip()
    if text.rfind(FINAL_ANS) >= 0:
        pred_ans = text[text.rfind(FINAL_ANS) + len(FINAL_ANS):len(text)].strip()
        if pred_ans.endswith('.'):
            pred_ans = pred_ans[:-1]
        return pred_ans
    else:
        return ''

Так как к сожалению доступ по публичному API к языковой модели занимает очень много времени - придется в эксперименте сильно порезать количество решаемых задач. Чтобы исключить элемент случайности, насэмплируем случайные из датасета.

In [18]:
import random

num_tasks = 20

inds = random.sample(range(0, len(data)), num_tasks)
data_test = [data[i] for i in inds]

Здесь представлены основные две функции, для получения предсказания и подсчета accuracy на данном предсказании.

In [19]:
def get_predictions(data_test, PROMT, examples=False, max_tokens=100, self_cons=False, num_samples=10, temp=0.7, top_k=40):
    result = []
    for elem in tqdm(data_test):
        q = elem['question']
        a = elem['answer']
        
        text = PROMT + '\nQ: ' + q + '\nA: '
        inputs = tokenizer(text, return_tensors='pt')['input_ids'].to(DEVICE)
        stop_token_ids = tokenizer(tokenizer.eos_token)['input_ids']
        
        ans = []
        if self_cons:
            for _ in tqdm(range(num_samples)):
                outputs = model.generate(inputs, max_new_tokens=max_tokens, temperature=temp, top_k=top_k, do_samples=True, stop_token_ids=stop_token_ids)
                ans.append(tokenizer.decode(outputs[0]))
        else:
            outputs = model.generate(inputs, max_new_tokens=max_tokens)
            ans.append(tokenizer.decode(outputs[0]))
        
        result.append({'true': a, 'predicted': ans})
        
        if examples:
            print('Q: ' + q + '\nTrue A: ' + a + '\nPredicted A: ' + ans[0] + '\n\n')
    
    return result


def get_accuracy(result):
    correct = 0
    
    for res in result:
        true = res['true'].split(' ')
        true = true[-1]
        
        preds = res['predicted']
        
        pred_list = []
        for pred in preds:
            ans = get_ans(pred)
            if ans:
                pred_list.append(ans)
        if not pred_list:
            continue
        maj_ans = get_maj(pred_list)
        if _is_float(true) and _is_float(maj_ans):
            if abs(float(true) - float(maj_ans)) <= 1e-5:
                correct += 1
        elif str(true) == str(maj_ans):
            correct += 1
    
    total = len(result)
    return correct, total, correct/total

Получим результаты для жадного сэмплирования

In [None]:
result_default = get_predictions(data_test, PROMT, examples=True)
correct, total, acc = get_accuracy(result_default)
print("Correct tasks:", correct, "Total tasks:", total, "Accuracy:", acc)

  0%|          | 0/20 [00:00<?, ?it/s]Feb 27 17:42:18.194 [[1m[38;5;208mWARN[0m] [[1m/opt/conda/lib/python3.7/site-packages/petals/client/inference_session.py.step:312[0m] Caught exception when running inference from block 0 (retry in 0 sec): ControlFailure('Connect failed. msg=routing: not found')
Feb 27 17:42:18.592 [[1m[38;5;208mWARN[0m] [[1m/opt/conda/lib/python3.7/site-packages/petals/client/inference_session.py.step:312[0m] Caught exception when running inference from block 0 (retry in 1 sec): ControlFailure('Connect failed. msg=failed to find peers: routing: not found')
Feb 27 17:42:19.993 [[1m[38;5;208mWARN[0m] [[1m/opt/conda/lib/python3.7/site-packages/petals/client/inference_session.py.step:312[0m] Caught exception when running inference from block 0 (retry in 2 sec): ControlFailure('Connect failed. msg=failed to find peers: routing: not found')
Feb 27 17:42:22.394 [[1m[38;5;208mWARN[0m] [[1m/opt/conda/lib/python3.7/site-packages/petals/client/inference_ses

Теперь результаты для Self-consistency метода

In [None]:
result_cons = get_predictions(data_test, PROMT, examples=True, max_tokens=100, 
                              self_cons=True, num_samples=10, temp=0.7, top_k=40)
correct, total, acc = get_accuracy(result_cons)
print("Correct tasks:", correct, "Total tasks:", total, "Accuracy:", acc)

In [None]:
# Повторение экспериментов жадного сэмплирования для снижения влияния рандома
correct_def = []
total_def = []
acc_def = []
for _ in range(5):
    num_tasks = 20
    inds = random.sample(range(0, len(data)), num_tasks)
    data_test = [data[i] for i in inds]
    
    result_default = get_predictions(data_test, PROMT, examples=True)
    correct, total, acc = get_accuracy(result_default)
    
    correct_def.append(correct)
    total_def.append(total)
    acc_def.append(acc)

In [None]:
print("Correct tasks:", correct_def.mean(), "Total tasks:", total_def.mean(), "Accuracy:", acc_def.mean())

In [None]:
# Повторение экспериментов self-consistency сэмплирования для снижения влияния рандома
correct_cons = []
total_cons = []
acc_cons = []
for _ in range(5):
    num_tasks = 20
    inds = random.sample(range(0, len(data)), num_tasks)
    data_test = [data[i] for i in inds]
    
    result_cons = get_predictions(data_test, PROMT, examples=True, max_tokens=100, 
                              self_cons=True, num_samples=10, temp=0.7, top_k=40)
    correct, total, acc = get_accuracy(result_cons)
    
    correct_cons.append(correct)
    total_cons.append(total)
    acc_cons.append(acc)

In [None]:
print("Correct tasks:", correct_cons.mean(), "Total tasks:", total_cons.mean(), "Accuracy:", acc_cons.mean())

## Здесь предполагается вывод по качеству полученных ответов и сравнение с результатами в статье

Теперь можно потюнить температуру и топ-к и посмотреть как будет меняться качество. Также желательно потворить несколько раз на случайном наборе задач из датасета.

In [None]:
temp_list = [0.5, 0.7, 0.9]
top_k_list = [20, 40, 60]
result_list = []

for temp in temp_list:
    for top_k in top_k_list:
        correct_cons = []
        total_cons = []
        acc_cons = []
        for _ in range(5):
            num_tasks = 20
            inds = random.sample(range(0, len(data)), num_tasks)
            data_test = [data[i] for i in inds]

            result_cons = get_predictions(data_test, PROMT, examples=True, max_tokens=100, 
                                      self_cons=True, num_samples=10, temp=temp, top_k=top_k)
            correct, total, acc = get_accuracy(result_cons)

            correct_cons.append(correct)
            total_cons.append(total)
            acc_cons.append(acc)
        result_list.append((acc_cons.mean(), temp, top_k))
        print("Correct tasks:", correct_cons.mean(), "Total tasks:", total_cons.mean(), "Accuracy:", acc_cons.mean())

In [None]:
sorted_results = sorted(result_list)
print("Best parameters:")
print("Mean accuracy:", sorted_results[-1][0], "Best temp:", sorted_results[-1][1], "Best top_k:", sorted_results[-1][2])

Ещё было бы интересно попробовать повторить зависимость: больше разных решений аггрегируется - больше вероятность получить верный ответ. Во второй статье был отмечен значительный рост качества, если агрегируется от 10 вариантов. Также тут можно попробовать более сложный метод агрегации, например Дэвида-Скина.

# Предложения для дальнейших экспериментов и улучшений качества:
## 1. Суда по результатам из статьи полезно взять модель, у которой ещё больше параметров, чем у рассмотренной BLOOM-176B
## 2. Попробовать поиграться с промтами: поискать более качественные, сэмплить случайные из датасета, создать свои
## 3. Усложнить сэмплирование ответов, избавиться от приближенно хорошего Majority Vote.