In [1]:
import numpy as np
import pandas as pd
from tqdm import tqdm
import together
from time import sleep
import re
from sklearn.metrics import mean_absolute_error

In [2]:
def dprint(s, debug):
    if debug:
        print(s)

In [3]:
# TODO: find your API key here
# https://api.together.xyz/settings/api-keys
YOUR_API_KEY = '61face25ac35778c0786a2541cac586ca614b28859377010113538d84adf3f0b'
together.api_key = YOUR_API_KEY

def call_together_api(prompt, student_configs, post_processing, model='meta-llama/Llama-2-7b-chat-hf', debug=False):
    output = together.Complete.create(
    prompt = prompt,
    model = model, 
    **student_configs
    )
    dprint('*****prompt*****', debug)
    dprint(prompt, debug)
    dprint('*****result*****', debug)
    res = output['output']['choices'][0]['text']
    dprint(res, debug)
    dprint('*****output*****', debug)
    numbers_only = post_processing(res)
    dprint(numbers_only, debug)
    dprint('=========', debug)
    return numbers_only

###  Part 1. Zero Shot Addition

In [4]:
def get_addition_pairs(lower_bound, upper_bound, rng):
    int_a = int(np.ceil(rng.uniform(lower_bound, upper_bound)))
    int_b = int(np.ceil(rng.uniform(lower_bound, upper_bound)))
    return int_a, int_b

def test_range(added_prompt, prompt_configs, rng, n_sample=30, 
               lower_bound=1, upper_bound=10, fixed_pairs=None, 
               pre_processing=lambda x:x, post_processing=lambda y:y,
               model='meta-llama/Llama-2-7b-chat-hf', debug=False):
    int_as = []
    int_bs = []
    answers = []
    model_responses = []
    correct = []
    prompts = []
    iterations = range(n_sample) if fixed_pairs is None else fixed_pairs
    for i, v in enumerate(tqdm(iterations)):
        if fixed_pairs is None:
            int_a, int_b = get_addition_pairs(lower_bound=lower_bound, upper_bound=upper_bound, rng=rng)
        else:
            int_a, int_b = v
        fixed_prompt = f'{int_a}+{int_b}'
        fixed_prompt = pre_processing(fixed_prompt)
        prefix, suffix = added_prompt
        prompt = prefix + fixed_prompt + suffix
        model_response = call_together_api(prompt, prompt_configs, post_processing, model=model, debug=debug)
        answer = int_a + int_b
        int_as.append(int_a)
        int_bs.append(int_b)
        prompts.append(prompt)
        answers.append(answer)
        model_responses.append(model_response)
        correct.append((answer == model_response))
        sleep(1) # pause to not trigger DDoS defense
    df = pd.DataFrame({'int_a': int_as, 'int_b': int_bs, 'prompt': prompts, 'answer': answers, 'response': model_responses, 'correct': correct})
    print(df)
    mae = mean_absolute_error(df['answer'], df['response'])
    acc = df.correct.sum()/len(df)
    prompt_length = len(prefix) + len(suffix)
    res = acc * 1/prompt_length * (1-mae/(5*10**6))
    return {'res': res, 'acc': acc, 'mae': mae, 'prompt_length': prompt_length}

In [5]:
model_names = [
    "meta-llama/Llama-2-7b-chat-hf",  #LLaMa-2-7B
    "meta-llama/Llama-2-13b-chat-hf", #LLaMa-2-13B
    "meta-llama/Llama-2-70b-hf" #LLaMa-2-70B
]

**Example: Zero-shot single-digit addition**

In [6]:
added_prompt = ('Question: What is ', '?\nAnswer: ') # Question: What is a+b?\nAnswer:
prompt_config = {'max_tokens': 2,
                'temperature': 0.7,
                'top_k': 50,
                'top_p': 0.6,
                'repetition_penalty': 1,
                'stop': []}

# input_string: 'a+b'
def your_pre_processing(input_string):
    return input_string

# output_string: 
# depending on your prompt, it might look like 'output: number'
def your_post_processing(output_string):
    # using regular expression to find the first consecutive digits in the returned string
    only_digits = re.sub(r"\D", "", output_string)
    try:
        res = int(only_digits)
    except:
        res = 0
    return res

model = 'meta-llama/Llama-2-7b-chat-hf'
print(model)
seed = 0
rng = np.random.default_rng(seed)
res = test_range(added_prompt=added_prompt, prompt_configs=prompt_config, rng=rng, n_sample=10, lower_bound=1, upper_bound=10, fixed_pairs=None, pre_processing=your_pre_processing, post_processing=your_post_processing, model=model, debug=False)
print(res)

meta-llama/Llama-2-7b-chat-hf


100%|██████████| 10/10 [00:13<00:00,  1.33s/it]

   int_a  int_b                             prompt  answer  response  correct
0      7      4   Question: What is 7+4?\nAnswer:       11        11     True
1      2      2   Question: What is 2+2?\nAnswer:        4         4     True
2      9     10  Question: What is 9+10?\nAnswer:       19        19     True
3      7      8   Question: What is 7+8?\nAnswer:       15        15     True
4      6     10  Question: What is 6+10?\nAnswer:       16         6    False
5      9      2   Question: What is 9+2?\nAnswer:       11        11     True
6      9      2   Question: What is 9+2?\nAnswer:       11        11     True
7      8      3   Question: What is 8+3?\nAnswer:       11         8    False
8      9      6   Question: What is 9+6?\nAnswer:       15        15     True
9      4      5   Question: What is 4+5?\nAnswer:        9         9     True
{'res': 0.028571421142857146, 'acc': 0.8, 'mae': 1.3, 'prompt_length': 28}





**Example: Zero-shot 7-digit addition**

In [7]:
sleep(1) # wait a little bit to prevent api call error
prompt_config['max_tokens'] = 8
rng = np.random.default_rng(seed)
res = test_range(added_prompt=added_prompt, prompt_configs=prompt_config, rng=rng, n_sample=10, lower_bound=1000000, upper_bound=9999999, fixed_pairs=None, pre_processing=your_pre_processing, post_processing=your_post_processing, model=model, debug=False)
print(res)

100%|██████████| 10/10 [00:13<00:00,  1.35s/it]

     int_a    int_b                                        prompt    answer  \
0  6732655  3428081  Question: What is 6732655+3428081?\nAnswer:   10160736   
1  1368762  1148749  Question: What is 1368762+1148749?\nAnswer:    2517511   
2  8319432  9214800  Question: What is 8319432+9214800?\nAnswer:   17534232   
3  6459722  7565469  Question: What is 6459722+7565469?\nAnswer:   14025191   
4  5892625  9415651  Question: What is 5892625+9415651?\nAnswer:   15308276   
5  8342682  1024647  Question: What is 8342682+1024647?\nAnswer:    9367329   
6  8716638  1302271  Question: What is 8716638+1302271?\nAnswer:   10018909   
7  7566899  2580901  Question: What is 7566899+2580901?\nAnswer:   10147800   
8  8768610  5873151  Question: What is 8768610+5873151?\nAnswer:   14641761   
9  3697407  4804185  Question: What is 3697407+4804185?\nAnswer:    8501592   

   response  correct  
0   6732655    False  
1   2517591    False  
2   8319432    False  
3   6459722    False  
4   5892625    




-----------

**Q1a.** In your opinion, what are some factors that cause language model performance to deteriorate from 1 digit to 7 digits?

Answer: 

-----------

**Q1b**. Play around with the config parameters ('max_tokens','temperature','top_k','top_p','repetition_penalty') in together.ai's [web UI](https://api.together.xyz/playground/language/togethercomputer/llama-2-7b). 
* What does each parameter represent?
* How does increasing each parameter change the generation?

Answer: 

-----------

**Q1c**. Do 7-digit addition with 70B parameter llama model. 
* How does the performance change?
* What are some factors that cause this change?

Answer: 

In [8]:
sleep(1) # wait a little bit to prevent api call error
rng = np.random.default_rng(seed)
res = test_range(added_prompt=added_prompt, prompt_configs=prompt_config, rng=rng, n_sample=10, lower_bound=1000000, upper_bound=9999999, fixed_pairs=None, pre_processing=your_pre_processing, post_processing=your_post_processing, model='meta-llama/Llama-2-70b-hf', debug=False)
print(res)

100%|██████████| 10/10 [00:16<00:00,  1.64s/it]

     int_a    int_b                                        prompt    answer  \
0  6732655  3428081  Question: What is 6732655+3428081?\nAnswer:   10160736   
1  1368762  1148749  Question: What is 1368762+1148749?\nAnswer:    2517511   
2  8319432  9214800  Question: What is 8319432+9214800?\nAnswer:   17534232   
3  6459722  7565469  Question: What is 6459722+7565469?\nAnswer:   14025191   
4  5892625  9415651  Question: What is 5892625+9415651?\nAnswer:   15308276   
5  8342682  1024647  Question: What is 8342682+1024647?\nAnswer:    9367329   
6  8716638  1302271  Question: What is 8716638+1302271?\nAnswer:   10018909   
7  7566899  2580901  Question: What is 7566899+2580901?\nAnswer:   10147800   
8  8768610  5873151  Question: What is 8768610+5873151?\nAnswer:   14641761   
9  3697407  4804185  Question: What is 3697407+4804185?\nAnswer:    8501592   

   response  correct  
0  10160636    False  
1   2517511     True  
2  17534232     True  
3  14025191     True  
4  15308276    




-----------

**Q1d.** Here we're giving our language model the prior that the sum of two 7-digit numbers must have a maximum of 8 digits. (by setting max_token=8). What if we remove this prior by increasing the max_token to 20? 
* Does the model still perform well?
* What are some reasons why?

Answer: 

In [9]:
sleep(1) # wait a little bit to prevent api call error
added_prompt = ('Question: What is ', '?\nAnswer: ') # Question: What is a+b?\nAnswer:
prompt_config = {'max_tokens': 20,
                'temperature': 0.7,
                'top_k': 50,
                'top_p': 0.6,
                'repetition_penalty': 1,
                'stop': []}

# input_string: 'a+b'
def your_pre_processing(input_string):
    return input_string

def your_post_processing(output_string):
    first_line = output_string.splitlines()[0]
    only_digits = re.sub(r"\D", "", first_line)
    try:
        res = int(only_digits)
    except:
        res = 0
    return res


model = 'meta-llama/Llama-2-7b-chat-hf'
print(model)
seed = 0
rng = np.random.default_rng(seed)
res = test_range(added_prompt=added_prompt, prompt_configs=prompt_config, rng=rng, n_sample=10, lower_bound=1000000, upper_bound=9999999, fixed_pairs=None, pre_processing=your_pre_processing, post_processing=your_post_processing, model=model, debug=False)
print(res)

meta-llama/Llama-2-7b-chat-hf


100%|██████████| 10/10 [00:14<00:00,  1.42s/it]

     int_a    int_b                                        prompt    answer  \
0  6732655  3428081  Question: What is 6732655+3428081?\nAnswer:   10160736   
1  1368762  1148749  Question: What is 1368762+1148749?\nAnswer:    2517511   
2  8319432  9214800  Question: What is 8319432+9214800?\nAnswer:   17534232   
3  6459722  7565469  Question: What is 6459722+7565469?\nAnswer:   14025191   
4  5892625  9415651  Question: What is 5892625+9415651?\nAnswer:   15308276   
5  8342682  1024647  Question: What is 8342682+1024647?\nAnswer:    9367329   
6  8716638  1302271  Question: What is 8716638+1302271?\nAnswer:   10018909   
7  7566899  2580901  Question: What is 7566899+2580901?\nAnswer:   10147800   
8  8768610  5873151  Question: What is 8768610+5873151?\nAnswer:   14641761   
9  3697407  4804185  Question: What is 3697407+4804185?\nAnswer:    8501592   

           response  correct  
0  6732655342808110    False  
1           2517591    False  
2          71302322    False  
3  645




### Part 2. In Context Learning

We will try to improve the performance of 7-digit addition via in-context learning.
For cost-control purposes (you only have $25 free credits), we will use [llama-2-7b](https://api.together.xyz/playground/language/togethercomputer/llama-2-7b). Below is a simple example.

In [10]:
sleep(1) # wait a little bit to prevent api call error
added_prompt = ('Question: What is 3+7?\nAnswer: 10\n Question: What is ', '?\nAnswer: ') # Question: What is a+b?\nAnswer:
prompt_config = {'max_tokens': 8,
                'temperature': 0.7,
                'top_k': 50,
                'top_p': 0.6,
                'repetition_penalty': 1,
                'stop': []}
rng = np.random.default_rng(seed)
res = test_range(added_prompt=added_prompt, prompt_configs=prompt_config, rng=rng, n_sample=10, lower_bound=1000000, upper_bound=9999999, fixed_pairs=None, pre_processing=your_pre_processing, post_processing=your_post_processing, model='meta-llama/Llama-2-7b-chat-hf', debug=False)
print(res)

100%|██████████| 10/10 [00:15<00:00,  1.51s/it]

     int_a    int_b                                             prompt  \
0  6732655  3428081  Question: What is 3+7?\nAnswer: 10\n Question:...   
1  1368762  1148749  Question: What is 3+7?\nAnswer: 10\n Question:...   
2  8319432  9214800  Question: What is 3+7?\nAnswer: 10\n Question:...   
3  6459722  7565469  Question: What is 3+7?\nAnswer: 10\n Question:...   
4  5892625  9415651  Question: What is 3+7?\nAnswer: 10\n Question:...   
5  8342682  1024647  Question: What is 3+7?\nAnswer: 10\n Question:...   
6  8716638  1302271  Question: What is 3+7?\nAnswer: 10\n Question:...   
7  7566899  2580901  Question: What is 3+7?\nAnswer: 10\n Question:...   
8  8768610  5873151  Question: What is 3+7?\nAnswer: 10\n Question:...   
9  3697407  4804185  Question: What is 3+7?\nAnswer: 10\n Question:...   

     answer  response  correct  
0  10160736  10154639    False  
1   2517511   2517601    False  
2  17534232  17533200    False  
3  14025191   7324688    False  
4  15308276  1030776




**Q2a**.
* How does the performance change with the baseline in-context learning prompt? (compare with "Example: Zero-shot 7-digit addition" in Q1)
* What are some factors that cause this change?

Answer: 

-----------

Now we will remove the prior on output length and re-evaluate the performance of our baseline one-shot learning prompt. We need to modify our post processing function to extract the answer from the output sequence. In this case, it is the number in the first line that starts with "Answer: ".

**Q2b**.
* How does the performance change when we relax the output length constraint? (compare with Q2a)
* What are some factors that cause this change?

Answer: 

In [11]:
sleep(1) # wait a little bit to prevent api call error

prompt_config['max_tokens'] = 50 # changed from 8, assuming we don't know the output length
                
rng = np.random.default_rng(seed)
res = test_range(added_prompt=added_prompt, prompt_configs=prompt_config, rng=rng, n_sample=10, lower_bound=1000000, upper_bound=9999999, fixed_pairs=None, pre_processing=your_pre_processing, post_processing=your_post_processing, model='meta-llama/Llama-2-7b-chat-hf', debug=False)
print(res)

100%|██████████| 10/10 [00:17<00:00,  1.80s/it]

     int_a    int_b                                             prompt  \
0  6732655  3428081  Question: What is 3+7?\nAnswer: 10\n Question:...   
1  1368762  1148749  Question: What is 3+7?\nAnswer: 10\n Question:...   
2  8319432  9214800  Question: What is 3+7?\nAnswer: 10\n Question:...   
3  6459722  7565469  Question: What is 3+7?\nAnswer: 10\n Question:...   
4  5892625  9415651  Question: What is 3+7?\nAnswer: 10\n Question:...   
5  8342682  1024647  Question: What is 3+7?\nAnswer: 10\n Question:...   
6  8716638  1302271  Question: What is 3+7?\nAnswer: 10\n Question:...   
7  7566899  2580901  Question: What is 3+7?\nAnswer: 10\n Question:...   
8  8768610  5873151  Question: What is 3+7?\nAnswer: 10\n Question:...   
9  3697407  4804185  Question: What is 3+7?\nAnswer: 10\n Question:...   

     answer  response  correct  
0  10160736  10160892    False  
1   2517511   2517511     True  
2  17534232  17533200    False  
3  14025191   7324601    False  
4  15308276  1030776




-----------

**Q2c.** Let's change our one-shot learning example to something more "in-distribution". Previously we were using 1-digit addition as an example. Let's change it to 7-digit addition (1234567+1234567=2469134). 
* Evaluate the performance with max_tokens = 8.
* Evaluate the performance with max_tokens = 50.
* How does the performance change from 1-digit example to 7-digit example?

Answer: 

In [12]:
sleep(1) # wait a little bit to prevent api call error
prompt_config['max_tokens'] = 8 
added_prompt = ('Question: What is 1234567+123457?\nAnswer: 2469134\nQuestion: What is ', '?\nAnswer: ') # Question: What is a+b?\nAnswer:
test_range(added_prompt=added_prompt, prompt_configs=prompt_config, rng=rng, n_sample=10, lower_bound=1000000, upper_bound=9999999, fixed_pairs=None, pre_processing=your_pre_processing, post_processing=your_post_processing, model='meta-llama/Llama-2-7b-chat-hf', debug=False)

100%|██████████| 10/10 [00:14<00:00,  1.46s/it]

     int_a    int_b                                             prompt  \
0  1254878  2118550  Question: What is 1234567+123457?\nAnswer: 246...   
1  7035620  6824705  Question: What is 1234567+123457?\nAnswer: 246...   
2  6538466  4453098  Question: What is 1234567+123457?\nAnswer: 246...   
3  9974889  9827518  Question: What is 1234567+123457?\nAnswer: 246...   
4  7169878  6854133  Question: What is 1234567+123457?\nAnswer: 246...   
5  7196020  4500293  Question: What is 1234567+123457?\nAnswer: 246...   
6  2215869  7493395  Question: What is 1234567+123457?\nAnswer: 246...   
7  5728189  3792177  Question: What is 1234567+123457?\nAnswer: 246...   
8  5372518  9005390  Question: What is 1234567+123457?\nAnswer: 246...   
9  9406391  4220157  Question: What is 1234567+123457?\nAnswer: 246...   

     answer  response  correct  
0   3373428   3373428     True  
1  13860325  13858925    False  
2  10991564  10988464    False  
3  19802407  19802697    False  
4  14024011  1392400




{'res': 0.0021433928205128205,
 'acc': 0.2,
 'mae': 820384.0,
 'prompt_length': 78}

In [13]:
sleep(1) # wait a little bit to prevent api call error
prompt_config['max_tokens'] = 50 
test_range(added_prompt=added_prompt, prompt_configs=prompt_config, rng=rng, n_sample=10, lower_bound=1000000, upper_bound=9999999, fixed_pairs=None, pre_processing=your_pre_processing, post_processing=your_post_processing, model='meta-llama/Llama-2-7b-chat-hf', debug=False)

100%|██████████| 10/10 [00:18<00:00,  1.82s/it]

     int_a    int_b                                             prompt  \
0  6143768  3896825  Question: What is 1234567+123457?\nAnswer: 246...   
1  6348700  4041201  Question: What is 1234567+123457?\nAnswer: 246...   
2  4524571  9012469  Question: What is 1234567+123457?\nAnswer: 246...   
3  3044419  6608684  Question: What is 1234567+123457?\nAnswer: 246...   
4  1756139  8493797  Question: What is 1234567+123457?\nAnswer: 246...   
5  8083884  3154325  Question: What is 1234567+123457?\nAnswer: 246...   
6  8888358  1527113  Question: What is 1234567+123457?\nAnswer: 246...   
7  4025054  2352516  Question: What is 1234567+123457?\nAnswer: 246...   
8  5053054  8166918  Question: What is 1234567+123457?\nAnswer: 246...   
9  3075780  1468192  Question: What is 1234567+123457?\nAnswer: 246...   

     answer  response  correct  
0  10040593  10034553    False  
1  10389901  10389801    False  
2  13537040  54369857    False  
3   9653103   9653093    False  
4  10249936  1025717




{'res': 0.00023123223076923082,
 'acc': 0.1,
 'mae': 4098194.3,
 'prompt_length': 78}

-----------

**Q2d.** Let's look at a specific example with large absolute error. 
* Run the cell at least 5 times. Does the error change each time? Why?
* Can you think of a prompt to reduce the error?
* Why do you think it would work?
* Does it work in practice? Why or why not?

Answer:

In [14]:
added_prompt = """Question: What is 1234567+1354634 ?
Answer: Start from the rightmost digit and remember to carry over when necessary = 2589201
Question: What is """, " ?\nAnswer: "
test_range(added_prompt=added_prompt, prompt_configs=prompt_config, rng=rng, fixed_pairs=[(9090909,1010101)], pre_processing=your_pre_processing, post_processing=your_post_processing, model='meta-llama/Llama-2-7b-chat-hf', debug=True)

  0%|          | 0/1 [00:00<?, ?it/s]

*****prompt*****
Question: What is 1234567+1354634 ?
Answer: Start from the rightmost digit and remember to carry over when necessary = 2589201
Question: What is 9090909+1010101 ?
Answer: 
*****result*****
 Start from the rightmost digit and remember to carry over when necessary = 10102001
Question: What is 2345678-1234567 ?
Answer: Subt
*****output*****
10102001


100%|██████████| 1/1 [00:01<00:00,  1.63s/it]

     int_a    int_b                                             prompt  \
0  9090909  1010101  Question: What is 1234567+1354634 ?\nAnswer: S...   

     answer  response  correct  
0  10101010  10102001    False  





{'res': 0.0, 'acc': 0.0, 'mae': 991.0, 'prompt_length': 156}

### Part 3: Prompt-a-thon (autograder & leaderboard)


Compete with your classmates to see who's best at teach llama to add 7-digit numbers reliably! Submit your ```submission.py``` to enter the leader board!

Note: while you can use prompt.txt for debugging and local testing, for the final autograder submission, please use a string (not a file), because autograder cannot find prompt.txt in the testing environment. Sorry about the inconvenience!

What you can change:
* your_api_key
* your_prompt
* your_config
* your_pre_processing
* your_post_processing