# Align Codex tokens to FlanT5 tokens using dynamic time wrapping

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import sys 
sys.path.append('..')

In [3]:
import time
import pickle
import random
import editdistance

from tqdm import tqdm
from transformers import T5Tokenizer, GPT2Tokenizer
from src.utils import (parse_codex_outputs, 
                       parse_flan_t5_outputs, 
                       vis_prob_flow, 
                       vis_heatmap, 
                       test_acc, 
                       majority_vote_acc,
                       ClosestToken,
                       transform_codex_token_to_t5_token,
                       print_transformed_probs,
                       dtw
                      )

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
codex_questions = pickle.load(open('../processed_data/codex_questions.pkl', 'rb'))
codex_answers = pickle.load(open('../processed_data/codex_answers.pkl', 'rb'))
codex_predictions = pickle.load(open('../processed_data/codex_predictions.pkl', 'rb'))
codex_per_step_probs = pickle.load(open('../processed_data/codex_per_step_probs.pkl', 'rb'))
codex_prediction_labels = pickle.load(open('../processed_data/codex_prediction_labels.pkl', 'rb'))

In [5]:
tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-xxl")

In [6]:
codex_predictions[0][0]

'Model output 0: \nIf Natalia sold 48 clips in April, she sold 48/2 = 24 clips in May.\nIn total, Natalia sold 48 + 24 = 72 clips in April and May.\nThe answer is 72\n'

In [28]:
pred = ''.join(codex_predictions[1][0].split(': ')[1:]).strip()

In [29]:
pred

'Weng earns $12 for every hour of babysitting.\nShe babysat for 50 minutes.\n50 minutes is less than an hour, so she earned less than $12.\nShe earned $12/60 * 50 = $10.\nThe answer is 10'

In [30]:
tokens = tokenizer.convert_ids_to_tokens(tokenizer(pred)['input_ids'])[:-1]

In [31]:
tokens

['▁We',
 'ng',
 '▁earn',
 's',
 '▁$12',
 '▁for',
 '▁every',
 '▁hour',
 '▁of',
 '▁baby',
 's',
 'i',
 'tting',
 '.',
 '▁She',
 '▁baby',
 's',
 'at',
 '▁for',
 '▁50',
 '▁minutes',
 '.',
 '▁50',
 '▁minutes',
 '▁is',
 '▁less',
 '▁than',
 '▁an',
 '▁hour',
 ',',
 '▁so',
 '▁she',
 '▁earned',
 '▁less',
 '▁than',
 '▁$12',
 '.',
 '▁She',
 '▁earned',
 '▁$12',
 '/',
 '60',
 '▁*',
 '▁50',
 '▁=',
 '▁$10',
 '.',
 '▁The',
 '▁answer',
 '▁is',
 '▁10']

In [34]:
steps = [s[0][0] for s in codex_per_step_probs[1][0][2:]]

idx = -1
while(steps[idx] == '\n'): idx -= 1
steps = steps[:idx + 1]

In [35]:
steps

['W',
 'eng',
 ' earns',
 ' $',
 '12',
 ' for',
 ' every',
 ' hour',
 ' of',
 ' babys',
 'itting',
 '.',
 '\n',
 'She',
 ' babys',
 'at',
 ' for',
 ' 50',
 ' minutes',
 '.',
 '\n',
 '50',
 ' minutes',
 ' is',
 ' less',
 ' than',
 ' an',
 ' hour',
 ',',
 ' so',
 ' she',
 ' earned',
 ' less',
 ' than',
 ' $',
 '12',
 '.',
 '\n',
 'She',
 ' earned',
 ' $',
 '12',
 '/',
 '60',
 ' *',
 ' 50',
 ' =',
 ' $',
 '10',
 '.',
 '\n',
 'The',
 ' answer',
 ' is',
 ' 10']

In [12]:
def dist_fn(codex_a, flan_b):
    a = codex_a.replace(' ', '')
    b = flan_b.replace('▁', '')
    dist = editdistance.eval(a, b)
    return dist

In [33]:
steps[1]

' Nat'

In [34]:
tokens[1]

'▁Nat'

In [35]:
dist_fn(steps[1], tokens[1])

0

In [21]:
len(steps), len(tokens)

(45, 40)

In [39]:
codex_per_step_probs[1][0][2:]

[[('W', 0.3172),
  ('We', 0.0438),
  ('If', 0.1324),
  ('She', 0.0625),
  ('W', 0.3172),
  ('Since', 0.0693)],
 [('eng', 0.9903),
  ('eng', 0.9903),
  ('en', 0.0034),
  ('end', 0.0035),
  ('ong', 0.0004),
  ('ang', 0.0013)],
 [(' earns', 0.7358),
  (' makes', 0.0283),
  (' earns', 0.7358),
  (' gets', 0.0301),
  (' is', 0.0231),
  (' earned', 0.0544)],
 [(' $', 0.8506),
  (' 12', 0.1252),
  (' money', 0.0031),
  (' a', 0.0031),
  (' 1', 0.0033),
  (' $', 0.8506)],
 [('12', 0.9898),
  ('12', 0.9898),
  (' 12', 0.0016),
  ('0', 0.0007),
  ('1', 0.0053),
  ('24', 0.0004)],
 [(' for', 0.1161),
  ('/', 0.0625),
  (' an', 0.4909),
  (' for', 0.1161),
  (' per', 0.2627),
  (' every', 0.0137)],
 [(' every', 0.3063),
  (' each', 0.2902),
  (' an', 0.1321),
  (' 1', 0.0732),
  (' babys', 0.0787),
  (' every', 0.3063)],
 [(' hour', 0.8344),
  (' full', 0.0238),
  (' hour', 0.8344),
  (' 60', 0.0744),
  (' 1', 0.0349),
  (' one', 0.0137)],
 [(' of', 0.3404),
  (',', 0.0729),
  ('.', 0.1042),
  (' 

In [36]:
matches, matrix_, mappings_series_1, mappings_series_2, matrix = dtw(steps, tokens, norm_func=dist_fn)

In [37]:
for i, mapped in enumerate(mappings_series_2):
    print(repr(tokens[i]), end=' | ')
    for j in mapped: print(repr(steps[j]), end = " ")
    print()

'▁We' | 'W' 
'ng' | 'eng' 
'▁earn' | ' earns' 
's' | ' $' 
'▁$12' | '12' 
'▁for' | ' for' 
'▁every' | ' every' 
'▁hour' | ' hour' 
'▁of' | ' of' 
'▁baby' | ' babys' 
's' | ' babys' 
'i' | ' babys' 
'tting' | 'itting' 
'.' | '.' '\n' 
'▁She' | 'She' 
'▁baby' | ' babys' 
's' | 'at' 
'at' | 'at' 
'▁for' | ' for' 
'▁50' | ' 50' 
'▁minutes' | ' minutes' 
'.' | '.' '\n' 
'▁50' | '50' 
'▁minutes' | ' minutes' 
'▁is' | ' is' 
'▁less' | ' less' 
'▁than' | ' than' 
'▁an' | ' an' 
'▁hour' | ' hour' 
',' | ',' 
'▁so' | ' so' 
'▁she' | ' she' 
'▁earned' | ' earned' 
'▁less' | ' less' 
'▁than' | ' than' 
'▁$12' | ' $' '12' 
'.' | '.' '\n' 
'▁She' | 'She' 
'▁earned' | ' earned' 
'▁$12' | ' $' '12' 
'/' | '/' 
'60' | '60' 
'▁*' | ' *' 
'▁50' | ' 50' 
'▁=' | ' =' ' $' 
'▁$10' | '10' 
'.' | '.' '\n' 
'▁The' | 'The' 
'▁answer' | ' answer' 
'▁is' | ' is' 
'▁10' | ' 10' 


In [20]:
matrix

array([[  0.,   3.,   7., ..., 112., 114., 116.],
       [  3.,   0.,   4., ..., 112., 115., 117.],
       [  7.,   4.,   0., ..., 111., 114., 118.],
       ...,
       [122., 123., 124., ...,  15.,   6.,   4.],
       [124., 125., 127., ...,  21.,   8.,   6.],
       [126., 127., 129., ...,  24.,  10.,   8.]])

In [9]:
codex_per_step_probs[0][0]

[[(' step', 1.0),
  ('!', 0.0),
  (' stepped', 0.0),
  (' stepping', 0.0),
  (' step', 1.0),
  (' steps', 0.0)],
 [('\n', 0.9862),
  ('!', 0.0002),
  ('.', 0.009),
  ('\n', 0.9862),
  (' ', 0.001),
  (':', 0.0025)],
 [('If', 0.1009),
  ('If', 0.1009),
  ('She', 0.0416),
  ('The', 0.0394),
  ('Nat', 0.3286),
  ('In', 0.2685)],
 [(' Nat', 0.6018),
  (' Nat', 0.6018),
  (' in', 0.0261),
  (' we', 0.0125),
  (' 48', 0.0335),
  (' she', 0.2588)],
 [('alia', 0.9968),
  ('ale', 0.0001),
  ('ilia', 0.0001),
  ('lia', 0.0002),
  (' sold', 0.0013),
  ('alia', 0.9968)],
 [(' sold', 0.9573),
  (' sells', 0.0217),
  (' sold', 0.9573),
  (' is', 0.0022),
  ("'s", 0.0028),
  (' had', 0.0024)],
 [(' 48', 0.4455),
  (' clips', 0.3353),
  (' half', 0.1318),
  (' to', 0.0351),
  (' 48', 0.4455),
  (' twice', 0.0058)],
 [(' clips', 0.9248),
  (' clips', 0.9248),
  (' friends', 0.0062),
  (' to', 0.0058),
  (' of', 0.0109),
  (' in', 0.0345)],
 [(' in', 0.7597),
  (',', 0.0206),
  (' to', 0.2018),
  (' in'