# Align Codex tokens to FlanT5 tokens using dynamic time wrapping

In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
import sys 
sys.path.append('..')

In [34]:
import time
import pickle
import random
import editdistance

from collections import OrderedDict
from tqdm import tqdm
from transformers import T5Tokenizer, GPT2Tokenizer
from src.utils import (parse_codex_outputs, 
                       parse_flan_t5_outputs, 
                       vis_prob_flow, 
                       vis_heatmap, 
                       test_acc, 
                       majority_vote_acc,
                       ClosestToken,
                       transform_codex_token_to_t5_token,
                       print_transformed_probs,
                       dtw
                      )

In [143]:
codex_questions = pickle.load(open('../processed_data/codex_questions.pkl', 'rb'))
codex_answers = pickle.load(open('../processed_data/codex_answers.pkl', 'rb'))
codex_predictions = pickle.load(open('../processed_data/codex_predictions.pkl', 'rb'))
codex_per_step_probs = pickle.load(open('../processed_data/codex_per_step_probs.pkl', 'rb'))
codex_prediction_labels = pickle.load(open('../processed_data/codex_prediction_labels.pkl', 'rb'))

In [144]:
len(codex_questions)

7473

In [122]:
codex_questions[3324]

'Question 3324: Monica and Sheila are twins. Their mother gave them $50 and told them to buy some toilet paper and spend the remainder on groceries. The toilet paper cost $12. They bought apples, butter, eggs, and a large ham for twice the cost of the toilet paper. Since they still had some leftover money, they called their mother and she gave them permission to buy whatever they wanted for themselves as long as they shared the money evenly. They saw some boots they really liked, but a pair of boots costs 3 times the amount they had left. How much more would Monica and Sheila each have to add of their own money to buy two pairs of boots?\n'

In [123]:
codex_questions[3325]

'Question 3324: Monica and Sheila are twins. Their mother gave them $50 and told them to buy some toilet paper and spend the remainder on groceries. The toilet paper cost $12. They bought apples, butter, eggs, and a large ham for twice the cost of the toilet paper. Since they still had some leftover money, they called their mother and she gave them permission to buy whatever they wanted for themselves as long as they shared the money evenly. They saw some boots they really liked, but a pair of boots costs 3 times the amount they had left. How much more would Monica and Sheila each have to add of their own money to buy two pairs of boots?\n'

In [126]:
codex_questions[3326]

"Question 3325: Billy's family likes to keep their bicycles stored in the garage when they're not being used.  They own a total of 4 bicycles.  Each bicycle wheel has 10 spokes.  How many spokes are inside the garage?\n"

In [131]:
codex_questions_ = list(codex_questions[:3325] + codex_questions[3326:])

In [140]:
codex_questions_[3326]

'Question 3326: The largest animal to have ever lived on earth is the blue whale.  The tongue of an adult blue whale can weigh 6000 pounds.  If one ton is 2000 pounds, how many tons can the tongue of an adult blue whale weigh?\n'

In [125]:
codex_answers[3325]

'Answer: Each bicycle has 2 wheels, so there are a total of 4*2=<<4*2=8>>8 wheels in the garage as there are 4 bicycles.\nSince each wheel has 10 spokes, this means there are 8*10=<<8*10=80>>80 spokes in total.\n#### 80\n'

In [129]:
codex_predictions[3325][1]

'Model output 1: \nEach bicycle has two wheels, so there are a total of 4*2 = 8 wheels.\nEach wheel has 10 spokes, so there are a total of 8*10 = 80 spokes in total.\nThe answer is 80\n'

In [78]:
len(codex_prediction_labels)

7473

In [6]:
tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-xxl")

In [7]:
codex_predictions[0][0]

'Model output 0: \nIf Natalia sold 48 clips in April, she sold 48/2 = 24 clips in May.\nIn total, Natalia sold 48 + 24 = 72 clips in April and May.\nThe answer is 72\n'

In [28]:
pred = ''.join(codex_predictions[1][0].split(': ')[1:]).strip()

In [29]:
pred

'Weng earns $12 for every hour of babysitting.\nShe babysat for 50 minutes.\n50 minutes is less than an hour, so she earned less than $12.\nShe earned $12/60 * 50 = $10.\nThe answer is 10'

In [30]:
tokens = tokenizer.convert_ids_to_tokens(tokenizer(pred)['input_ids'])[:-1]

In [31]:
tokens

['▁We',
 'ng',
 '▁earn',
 's',
 '▁$12',
 '▁for',
 '▁every',
 '▁hour',
 '▁of',
 '▁baby',
 's',
 'i',
 'tting',
 '.',
 '▁She',
 '▁baby',
 's',
 'at',
 '▁for',
 '▁50',
 '▁minutes',
 '.',
 '▁50',
 '▁minutes',
 '▁is',
 '▁less',
 '▁than',
 '▁an',
 '▁hour',
 ',',
 '▁so',
 '▁she',
 '▁earned',
 '▁less',
 '▁than',
 '▁$12',
 '.',
 '▁She',
 '▁earned',
 '▁$12',
 '/',
 '60',
 '▁*',
 '▁50',
 '▁=',
 '▁$10',
 '.',
 '▁The',
 '▁answer',
 '▁is',
 '▁10']

In [34]:
steps = [s[0][0] for s in codex_per_step_probs[1][0][2:]]

idx = -1
while(steps[idx] == '\n'): idx -= 1
steps = steps[:idx + 1]

In [35]:
steps

['W',
 'eng',
 ' earns',
 ' $',
 '12',
 ' for',
 ' every',
 ' hour',
 ' of',
 ' babys',
 'itting',
 '.',
 '\n',
 'She',
 ' babys',
 'at',
 ' for',
 ' 50',
 ' minutes',
 '.',
 '\n',
 '50',
 ' minutes',
 ' is',
 ' less',
 ' than',
 ' an',
 ' hour',
 ',',
 ' so',
 ' she',
 ' earned',
 ' less',
 ' than',
 ' $',
 '12',
 '.',
 '\n',
 'She',
 ' earned',
 ' $',
 '12',
 '/',
 '60',
 ' *',
 ' 50',
 ' =',
 ' $',
 '10',
 '.',
 '\n',
 'The',
 ' answer',
 ' is',
 ' 10']

In [8]:
def dist_fn(codex_a, flan_b):
    a = codex_a.replace(' ', '')
    b = flan_b.replace('▁', '')
    dist = editdistance.eval(a, b)
    return dist

In [33]:
steps[1]

' Nat'

In [34]:
tokens[1]

'▁Nat'

In [35]:
dist_fn(steps[1], tokens[1])

0

In [21]:
len(steps), len(tokens)

(45, 40)

In [10]:
codex_per_step_probs[1][0][:]

[[(' step', 1.0),
  ('!', 0.0),
  (' stepped', 0.0),
  (' stepping', 0.0),
  (' step', 1.0),
  (' steps', 0.0)],
 [('\n', 0.9878),
  ('.', 0.0067),
  ('\n', 0.9878),
  (' ', 0.0011),
  (':', 0.0027),
  ('\n\n', 0.0003)],
 [('W', 0.3172),
  ('We', 0.0438),
  ('If', 0.1324),
  ('She', 0.0625),
  ('W', 0.3172),
  ('Since', 0.0693)],
 [('eng', 0.9903),
  ('eng', 0.9903),
  ('en', 0.0034),
  ('end', 0.0035),
  ('ong', 0.0004),
  ('ang', 0.0013)],
 [(' earns', 0.7358),
  (' makes', 0.0283),
  (' earns', 0.7358),
  (' gets', 0.0301),
  (' is', 0.0231),
  (' earned', 0.0544)],
 [(' $', 0.8506),
  (' 12', 0.1252),
  (' money', 0.0031),
  (' a', 0.0031),
  (' 1', 0.0033),
  (' $', 0.8506)],
 [('12', 0.9898),
  ('12', 0.9898),
  (' 12', 0.0016),
  ('0', 0.0007),
  ('1', 0.0053),
  ('24', 0.0004)],
 [(' for', 0.1161),
  ('/', 0.0625),
  (' an', 0.4909),
  (' for', 0.1161),
  (' per', 0.2627),
  (' every', 0.0137)],
 [(' every', 0.3063),
  (' each', 0.2902),
  (' an', 0.1321),
  (' 1', 0.0732),
  (

In [36]:
matches, matrix_, mappings_series_1, mappings_series_2, matrix = dtw(steps, tokens, norm_func=dist_fn)

In [37]:
for i, mapped in enumerate(mappings_series_2):
    print(repr(tokens[i]), end=' | ')
    for j in mapped: print(repr(steps[j]), end = " ")
    print()

'▁We' | 'W' 
'ng' | 'eng' 
'▁earn' | ' earns' 
's' | ' $' 
'▁$12' | '12' 
'▁for' | ' for' 
'▁every' | ' every' 
'▁hour' | ' hour' 
'▁of' | ' of' 
'▁baby' | ' babys' 
's' | ' babys' 
'i' | ' babys' 
'tting' | 'itting' 
'.' | '.' '\n' 
'▁She' | 'She' 
'▁baby' | ' babys' 
's' | 'at' 
'at' | 'at' 
'▁for' | ' for' 
'▁50' | ' 50' 
'▁minutes' | ' minutes' 
'.' | '.' '\n' 
'▁50' | '50' 
'▁minutes' | ' minutes' 
'▁is' | ' is' 
'▁less' | ' less' 
'▁than' | ' than' 
'▁an' | ' an' 
'▁hour' | ' hour' 
',' | ',' 
'▁so' | ' so' 
'▁she' | ' she' 
'▁earned' | ' earned' 
'▁less' | ' less' 
'▁than' | ' than' 
'▁$12' | ' $' '12' 
'.' | '.' '\n' 
'▁She' | 'She' 
'▁earned' | ' earned' 
'▁$12' | ' $' '12' 
'/' | '/' 
'60' | '60' 
'▁*' | ' *' 
'▁50' | ' 50' 
'▁=' | ' =' ' $' 
'▁$10' | '10' 
'.' | '.' '\n' 
'▁The' | 'The' 
'▁answer' | ' answer' 
'▁is' | ' is' 
'▁10' | ' 10' 


In [20]:
matrix

array([[  0.,   3.,   7., ..., 112., 114., 116.],
       [  3.,   0.,   4., ..., 112., 115., 117.],
       [  7.,   4.,   0., ..., 111., 114., 118.],
       ...,
       [122., 123., 124., ...,  15.,   6.,   4.],
       [124., 125., 127., ...,  21.,   8.,   6.],
       [126., 127., 129., ...,  24.,  10.,   8.]])

In [9]:
codex_per_step_probs[0][0]

[[(' step', 1.0),
  ('!', 0.0),
  (' stepped', 0.0),
  (' stepping', 0.0),
  (' step', 1.0),
  (' steps', 0.0)],
 [('\n', 0.9862),
  ('!', 0.0002),
  ('.', 0.009),
  ('\n', 0.9862),
  (' ', 0.001),
  (':', 0.0025)],
 [('If', 0.1009),
  ('If', 0.1009),
  ('She', 0.0416),
  ('The', 0.0394),
  ('Nat', 0.3286),
  ('In', 0.2685)],
 [(' Nat', 0.6018),
  (' Nat', 0.6018),
  (' in', 0.0261),
  (' we', 0.0125),
  (' 48', 0.0335),
  (' she', 0.2588)],
 [('alia', 0.9968),
  ('ale', 0.0001),
  ('ilia', 0.0001),
  ('lia', 0.0002),
  (' sold', 0.0013),
  ('alia', 0.9968)],
 [(' sold', 0.9573),
  (' sells', 0.0217),
  (' sold', 0.9573),
  (' is', 0.0022),
  ("'s", 0.0028),
  (' had', 0.0024)],
 [(' 48', 0.4455),
  (' clips', 0.3353),
  (' half', 0.1318),
  (' to', 0.0351),
  (' 48', 0.4455),
  (' twice', 0.0058)],
 [(' clips', 0.9248),
  (' clips', 0.9248),
  (' friends', 0.0062),
  (' to', 0.0058),
  (' of', 0.0109),
  (' in', 0.0345)],
 [(' in', 0.7597),
  (',', 0.0206),
  (' to', 0.2018),
  (' in'

In [147]:
def transform_step_probs(qi, aj, vocab, pred, per_step_prob):
    # remove "Model output i:"
    pred = ''.join(pred.split(': ')[1:]).strip()
    flan_tokens = tokenizer.convert_ids_to_tokens(tokenizer(pred)['input_ids'])[:-1]
    
    if(per_step_prob[0][0][0] != ' step'): 
        # print('q %d a %d debug 1' % (qi, aj))
        return -1, None
    if(per_step_prob[1][0][0] != '\n'): 
        # print('q %d a %d debug 2' % (qi, aj))
        return -1, None
    if(per_step_prob[-2][0][0] not in ['\n', '\n\n']): 
        # print('q %d a %d debug 3' % (qi, aj))
        return -1, None
    if(per_step_prob[-1][0][0] != '\n'): 
        # print('q %d a %d debug 4' % (qi, aj))
        return -1, None
    codex_tokens = [s[0][0] for s in per_step_prob[2:-2]]
    per_step_prob = per_step_prob[2:-2]
    
    _, _, _, flan2codex, _ = dtw(codex_tokens, flan_tokens, norm_func=dist_fn)
    transformed_step_probs = []
    for i, codex_idx in enumerate(flan2codex):
        if(len(codex_idx) == 1): # one flan token map to on codex token
            j = codex_idx[0]
            flan_token = flan_tokens[i]
            codex_token = codex_tokens[j].replace(' ', '▁')
            
            # import ipdb; ipdb.set_trace()
            if(flan_token == codex_token):
                probs = OrderedDict()
                for t, p in per_step_prob[j]:
                    flan_t = t.replace(' ', '▁')
                    if(flan_t in vocab):
                        if(flan_t not in probs):
                            probs[flan_t] = p
            else:
                probs = {flan_token: 1.0}
        else: # one flan token map to multiple codex token, in this case only fit flan token
            flan_token = flan_tokens[i]
            probs = {flan_token: 1.0}
        transformed_step_probs.append(probs)
    if(i != len(flan_tokens) - 1): print('q %d a %d debug 5')
    return 1, transformed_step_probs

In [12]:
vocab = tokenizer.get_vocab()

In [38]:
transformed_step_probs = transform_step_probs(0, 0, vocab, codex_predictions[0][0], codex_per_step_probs[0][0])

In [33]:
codex_per_step_probs[0][0]

[[(' step', 1.0),
  ('!', 0.0),
  (' stepped', 0.0),
  (' stepping', 0.0),
  (' step', 1.0),
  (' steps', 0.0)],
 [('\n', 0.9862),
  ('!', 0.0002),
  ('.', 0.009),
  ('\n', 0.9862),
  (' ', 0.001),
  (':', 0.0025)],
 [('If', 0.1009),
  ('If', 0.1009),
  ('She', 0.0416),
  ('The', 0.0394),
  ('Nat', 0.3286),
  ('In', 0.2685)],
 [(' Nat', 0.6018),
  (' Nat', 0.6018),
  (' in', 0.0261),
  (' we', 0.0125),
  (' 48', 0.0335),
  (' she', 0.2588)],
 [('alia', 0.9968),
  ('ale', 0.0001),
  ('ilia', 0.0001),
  ('lia', 0.0002),
  (' sold', 0.0013),
  ('alia', 0.9968)],
 [(' sold', 0.9573),
  (' sells', 0.0217),
  (' sold', 0.9573),
  (' is', 0.0022),
  ("'s", 0.0028),
  (' had', 0.0024)],
 [(' 48', 0.4455),
  (' clips', 0.3353),
  (' half', 0.1318),
  (' to', 0.0351),
  (' 48', 0.4455),
  (' twice', 0.0058)],
 [(' clips', 0.9248),
  (' clips', 0.9248),
  (' friends', 0.0062),
  (' to', 0.0058),
  (' of', 0.0109),
  (' in', 0.0345)],
 [(' in', 0.7597),
  (',', 0.0206),
  (' to', 0.2018),
  (' in'

In [28]:
codex_predictions[0][0]

'Model output 0: \nIf Natalia sold 48 clips in April, she sold 48/2 = 24 clips in May.\nIn total, Natalia sold 48 + 24 = 72 clips in April and May.\nThe answer is 72\n'

In [39]:
transformed_step_probs

[{'▁If': 1.0},
 OrderedDict([('▁Nat', 0.6018),
              ('▁in', 0.0261),
              ('▁we', 0.0125),
              ('▁48', 0.0335),
              ('▁she', 0.2588)]),
 OrderedDict([('alia', 0.9968), ('ilia', 0.0001), ('▁sold', 0.0013)]),
 OrderedDict([('▁sold', 0.9573), ('▁is', 0.0022), ('▁had', 0.0024)]),
 OrderedDict([('▁48', 0.4455),
              ('▁clips', 0.3353),
              ('▁half', 0.1318),
              ('▁to', 0.0351),
              ('▁twice', 0.0058)]),
 OrderedDict([('▁clips', 0.9248),
              ('▁friends', 0.0062),
              ('▁to', 0.0058),
              ('▁of', 0.0109),
              ('▁in', 0.0345)]),
 OrderedDict([('▁in', 0.7597),
              (',', 0.0206),
              ('▁to', 0.2018),
              ('▁and', 0.0031),
              ('▁then', 0.0028)]),
 OrderedDict([('▁April', 0.9759),
              ('▁May', 0.0039),
              ('▁total', 0.0025),
              ('▁the', 0.0067)]),
 OrderedDict([(',', 0.7709),
              ('▁to', 0.0033),
   

In [43]:
transformed_step_probs[1]['▁Nat']

0.6018

In [52]:
idx = 20
for k in transformed_step_probs[idx]: print(k, transformed_step_probs[idx][k])

▁total 0.8519
▁both 0.0057
▁May 0.0097
▁the 0.0058
▁April 0.1015


In [73]:
len(codex_predictions)

7473

In [74]:
len(codex_questions)

7474

In [148]:
codex_dtw_transformed_step_probs = []
codex_updated_labels = []
total_case = 0
total_correct = 0
modified = 0
for qi, q in tqdm(enumerate(codex_questions), total=len(codex_questions)):
    transformed = []
    updated_labels = []
    for ai, (pred, prob, label) in enumerate(zip(codex_predictions[qi], codex_per_step_probs[qi], codex_prediction_labels[qi])):
        total_case += 1
        if(label == 1):
            total_correct += 1
            ret_code, ret = transform_step_probs(qi, ai, vocab, pred, prob)
            if(ret_code == -1):
                updated_labels.append(0)
                transformed.append(None)
                modified += 1
            else:
                transformed.append(ret)
                updated_labels.append(1)
        else: 
            transformed.append(None)
            updated_labels.append(0)
    codex_updated_labels.append(updated_labels)
    codex_dtw_transformed_step_probs.append(transformed)

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 7473/7473 [58:01<00:00,  2.15it/s]


In [149]:
print('total prediction %d, original labeled corred %d, modified %d' % (total_case, total_correct, modified))

total prediction 323810, original labeled corred 207619, modified 484


In [150]:
pickle.dump(codex_dtw_transformed_step_probs, open('../processed_data/codex_dtw_transformed_step_probs.pkl', 'wb'))
pickle.dump(codex_updated_labels, open('../processed_data/codex_updated_labels.pkl', 'wb'))

In [57]:
codex_predictions[1][1]

'Model output 1: \nJasper buys 2 pounds of cheddar cheese for $10, meaning that 1 pound of cheddar cheese costs $10/2 = $5\nJasper buys a pound of cream cheese that costs half the price of cheddar cheese, meaning that 1 pound of cream cheese costs $5/2 = $2.50\nJasper buys a pack of cold cuts that cost twice the price of cheddar cheese, meaning that 1 pack of cold cuts costs $5*2 = $10\nJasper buys 2 pounds of cheddar cheese, 1 pound of cream cheese, and 1 pack of cold cuts.\nTogether, he spends 2*$5 + $2.50 + $10 = $22.50\nThe answer is 22.5\n'

In [64]:
codex_per_step_probs[32][44]

[[(' step', 1.0),
  ('!', 0.0),
  (' stepped', 0.0),
  (' stepping', 0.0),
  (' step', 1.0),
  (' steps', 0.0)],
 [('\n', 0.9877),
  ('.', 0.0066),
  ('\n', 0.9877),
  (' ', 0.0007),
  (':', 0.0027),
  ('\n\n', 0.0007)],
 [('After', 0.2227),
  ('We', 0.0602),
  ('If', 0.0648),
  ('After', 0.2227),
  ('The', 0.2784),
  ('Let', 0.0527)],
 [(' the', 0.8436),
  (' each', 0.0086),
  (' traveling', 0.01),
  (' turn', 0.0138),
  (' the', 0.8436),
  (' 1', 0.049)],
 [(' first', 0.4074),
  (' third', 0.0225),
  (' 1', 0.4435),
  (' 2', 0.0232),
  (' 3', 0.0537),
  (' first', 0.4074)],
 [(' turn', 0.8979),
  (' three', 0.0075),
  (' turn', 0.8979),
  (' and', 0.0113),
  (' two', 0.0159),
  (' right', 0.0446)],
 [(',', 0.8022),
  (',', 0.8022),
  (' the', 0.1396),
  (' of', 0.0051),
  (' it', 0.0252),
  (' we', 0.0072)],
 [(' the', 0.8069),
  (' the', 0.8069),
  (' it', 0.1321),
  (' we', 0.0212),
  (' there', 0.0043),
  (' 5', 0.0093)],
 [(' car', 0.9658),
  (' car', 0.9658),
  (' total', 0.006)