# Align Codex tokens to FlanT5 tokens

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import sys 
sys.path.append('..')

In [3]:
import time
import pickle
import random
import editdistance

from tqdm import tqdm
from transformers import T5Tokenizer, GPT2Tokenizer
from src.utils import (parse_codex_outputs, 
                       parse_flan_t5_outputs, 
                       vis_prob_flow, 
                       vis_heatmap, 
                       test_acc, 
                       majority_vote_acc,
                       ClosestToken,
                       transform_codex_token_to_t5_token,
                       print_transformed_probs
                      )

  from .autonotebook import tqdm as notebook_tqdm


# Read Codex and Flan T5 predictions

In [4]:
codex_questions = pickle.load(open('codex_questions.pkl', 'rb'))
codex_answers = pickle.load(open('codex_answers.pkl', 'rb'))
codex_predictions = pickle.load(open('codex_predictions.pkl', 'rb'))
codex_per_step_probs = pickle.load(open('codex_per_step_probs.pkl', 'rb'))
codex_prediction_labels = pickle.load(open('codex_prediction_labels.pkl', 'rb'))

In [72]:
# pickle.dump(codex_prediction_labels, open('codex_prediction_labels.pkl', 'wb'))

In [75]:
# codex_prediction_labels = pickle.load(open('codex_prediction_labels.pkl', 'rb'))

In [65]:
_, codex_prediction_labels = majority_vote_acc(codex_predictions, codex_answers)

total 7473, pred 6262, acc 0.8379


In [71]:
len(codex_prediction_labels)

7473

In [28]:
len(codex_per_step_probs)

7473

In [5]:
flan_questions = pickle.load(open('flan_questions.pkl', 'rb'))
flan_answers = pickle.load(open('flan_answers.pkl', 'rb'))
flan_predictions = pickle.load(open('flan_predictions.pkl', 'rb'))
flan_per_step_probs = pickle.load(open('flan_per_step_probs.pkl', 'rb'))
flan_prediction_labels = pickle.load(open('flan_prediction_labels.pkl', 'rb'))

In [76]:
# _, flan_prediction_labels = majority_vote_acc(flan_predictions, flan_answers)

total 7473, pred 1512, acc 0.2023


In [77]:
# pickle.dump(flan_prediction_labels, open('flan_prediction_labels.pkl', 'wb'))

In [27]:
len(flan_per_step_probs)

7473

In [5]:
codex_predictions[0][0]

'Model output 0: \nIf Natalia sold 48 clips in April, she sold 48/2 = 24 clips in May.\nIn total, Natalia sold 48 + 24 = 72 clips in April and May.\nThe answer is 72\n'

In [32]:
print([tk[0][0] for tk in codex_per_step_probs[0][0]])

[' step', '\n', 'If', ' Nat', 'alia', ' sold', ' 48', ' clips', ' in', ' April', ',', ' she', ' sold', ' 48', '/', '2', ' =', ' 24', ' clips', ' in', ' May', '.', '\n', 'In', ' total', ',', ' Nat', 'alia', ' sold', ' 48', ' +', ' 24', ' =', ' 72', ' clips', ' in', ' April', ' and', ' May', '.', '\n', 'The', ' answer', ' is', ' 72', '\n', '\n']


In [7]:
flan_per_step_probs[0][0][0]

[('▁In', 0.2814),
 ('▁Nat', 0.5805),
 ('▁In', 0.2814),
 ('▁She', 0.036),
 ('▁First', 0.0167),
 ('▁If', 0.0123)]

In [6]:
tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-xxl")

In [15]:
tokenizer.special_tokens_map

{'eos_token': '</s>',
 'unk_token': '<unk>',
 'pad_token': '<pad>',
 'additional_special_tokens': ['<extra_id_0>',
  '<extra_id_1>',
  '<extra_id_2>',
  '<extra_id_3>',
  '<extra_id_4>',
  '<extra_id_5>',
  '<extra_id_6>',
  '<extra_id_7>',
  '<extra_id_8>',
  '<extra_id_9>',
  '<extra_id_10>',
  '<extra_id_11>',
  '<extra_id_12>',
  '<extra_id_13>',
  '<extra_id_14>',
  '<extra_id_15>',
  '<extra_id_16>',
  '<extra_id_17>',
  '<extra_id_18>',
  '<extra_id_19>',
  '<extra_id_20>',
  '<extra_id_21>',
  '<extra_id_22>',
  '<extra_id_23>',
  '<extra_id_24>',
  '<extra_id_25>',
  '<extra_id_26>',
  '<extra_id_27>',
  '<extra_id_28>',
  '<extra_id_29>',
  '<extra_id_30>',
  '<extra_id_31>',
  '<extra_id_32>',
  '<extra_id_33>',
  '<extra_id_34>',
  '<extra_id_35>',
  '<extra_id_36>',
  '<extra_id_37>',
  '<extra_id_38>',
  '<extra_id_39>',
  '<extra_id_40>',
  '<extra_id_41>',
  '<extra_id_42>',
  '<extra_id_43>',
  '<extra_id_44>',
  '<extra_id_45>',
  '<extra_id_46>',
  '<extra_id_47>',
 

In [38]:
codex_predictions[0][0]

'Model output 0: \nIf Natalia sold 48 clips in April, she sold 48/2 = 24 clips in May.\nIn total, Natalia sold 48 + 24 = 72 clips in April and May.\nThe answer is 72\n'

In [13]:
codex_per_step_probs[0][0]

[[(' step', 1.0),
  ('!', 0.0),
  (' stepped', 0.0),
  (' stepping', 0.0),
  (' step', 1.0),
  (' steps', 0.0)],
 [('\n', 0.9862),
  ('!', 0.0002),
  ('.', 0.009),
  ('\n', 0.9862),
  (' ', 0.001),
  (':', 0.0025)],
 [('If', 0.1009),
  ('If', 0.1009),
  ('She', 0.0416),
  ('The', 0.0394),
  ('Nat', 0.3286),
  ('In', 0.2685)],
 [(' Nat', 0.6018),
  (' Nat', 0.6018),
  (' in', 0.0261),
  (' we', 0.0125),
  (' 48', 0.0335),
  (' she', 0.2588)],
 [('alia', 0.9968),
  ('ale', 0.0001),
  ('ilia', 0.0001),
  ('lia', 0.0002),
  (' sold', 0.0013),
  ('alia', 0.9968)],
 [(' sold', 0.9573),
  (' sells', 0.0217),
  (' sold', 0.9573),
  (' is', 0.0022),
  ("'s", 0.0028),
  (' had', 0.0024)],
 [(' 48', 0.4455),
  (' clips', 0.3353),
  (' half', 0.1318),
  (' to', 0.0351),
  (' 48', 0.4455),
  (' twice', 0.0058)],
 [(' clips', 0.9248),
  (' clips', 0.9248),
  (' friends', 0.0062),
  (' to', 0.0058),
  (' of', 0.0109),
  (' in', 0.0345)],
 [(' in', 0.7597),
  (',', 0.0206),
  (' to', 0.2018),
  (' in'

In [9]:
closest_token = ClosestToken(tokenizer.get_vocab())

In [109]:
codex_per_step_probs[0][0]

[[(' step', 1.0),
  ('!', 0.0),
  (' stepped', 0.0),
  (' stepping', 0.0),
  (' step', 1.0),
  (' steps', 0.0)],
 [('\n', 0.9862),
  ('!', 0.0002),
  ('.', 0.009),
  ('\n', 0.9862),
  (' ', 0.001),
  (':', 0.0025)],
 [('If', 0.1009),
  ('If', 0.1009),
  ('She', 0.0416),
  ('The', 0.0394),
  ('Nat', 0.3286),
  ('In', 0.2685)],
 [(' Nat', 0.6018),
  (' Nat', 0.6018),
  (' in', 0.0261),
  (' we', 0.0125),
  (' 48', 0.0335),
  (' she', 0.2588)],
 [('alia', 0.9968),
  ('ale', 0.0001),
  ('ilia', 0.0001),
  ('lia', 0.0002),
  (' sold', 0.0013),
  ('alia', 0.9968)],
 [(' sold', 0.9573),
  (' sells', 0.0217),
  (' sold', 0.9573),
  (' is', 0.0022),
  ("'s", 0.0028),
  (' had', 0.0024)],
 [(' 48', 0.4455),
  (' clips', 0.3353),
  (' half', 0.1318),
  (' to', 0.0351),
  (' 48', 0.4455),
  (' twice', 0.0058)],
 [(' clips', 0.9248),
  (' clips', 0.9248),
  (' friends', 0.0062),
  (' to', 0.0058),
  (' of', 0.0109),
  (' in', 0.0345)],
 [(' in', 0.7597),
  (',', 0.0206),
  (' to', 0.2018),
  (' in'

In [39]:
transferred_per_step_probs, mask, transform_result = transform_codex_token_to_t5_token(codex_per_step_probs[0][0], closest_token)

In [9]:
transform_result

{'blank_before_number': 6,
 'blank_after_number': 3,
 'blank_step_before_number': 0,
 'blank_step_after_number': 1}

In [10]:
mask

[1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 0,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1]

In [17]:
transferred_per_step_probs

[[('If', 0.1009),
  ('If', 0.1009),
  ('She', 0.0416),
  ('The', 0.0394),
  ('▁at', 0.3286),
  ('In', 0.2685)],
 [('▁Nat', 0.6018),
  ('▁Nat', 0.6018),
  ('▁in', 0.0261),
  ('▁we', 0.0125),
  ('48', 0.0335),
  ('▁she', 0.2588)],
 [('alia', 0.9968),
  ('▁le', 0.0001),
  ('ilia', 0.0001),
  ('la', 0.0002),
  ('▁sold', 0.0013),
  ('alia', 0.9968)],
 [('▁sold', 0.9573),
  ('▁sell', 0.0217),
  ('▁sold', 0.9573),
  ('▁is', 0.0022),
  ('s', 0.0028),
  ('▁had', 0.0024)],
 [('▁48', 0.4455),
  ('▁clips', 0.3353),
  ('▁half', 0.1318),
  ('▁to', 0.0351),
  ('48', 0.4455),
  ('▁twice', 0.0058)],
 [('▁clips', 0.9248),
  ('▁clips', 0.9248),
  ('▁friends', 0.0062),
  ('▁to', 0.0058),
  ('▁of', 0.0109),
  ('▁in', 0.0345)],
 [('▁in', 0.7597),
  (',', 0.0206),
  ('▁to', 0.2018),
  ('▁in', 0.7597),
  ('▁and', 0.0031),
  ('▁then', 0.0028)],
 [('▁April', 0.9759),
  ('▁May', 0.0039),
  ('▁total', 0.0025),
  ('▁the', 0.0067),
  ('▁April', 0.9759),
  ('ar', 0.0026)],
 [(',', 0.7709),
  (',', 0.7709),
  ('▁to',

In [88]:
list(t[0][0] for t in transferred_per_step_probs)

['▁step',
 '▁',
 'If',
 '▁Nat',
 'alia',
 '▁sold',
 '48',
 '▁clips',
 '▁in',
 '▁April',
 ',',
 '▁she',
 '▁sold',
 '48',
 '/',
 '2',
 '=',
 '▁24',
 '▁clips',
 '▁in',
 '▁May',
 '.',
 '▁',
 'In',
 '▁total',
 ',',
 '▁Nat',
 'alia',
 '▁sold',
 '48',
 '+',
 '▁24',
 '=',
 '72',
 '▁clips',
 '▁in',
 '▁April',
 '▁and',
 '▁May',
 '.',
 '▁',
 'The',
 '▁answer',
 '▁is',
 '72',
 '▁',
 '▁']

In [40]:
tokenizer.decode(tokenizer.convert_tokens_to_ids(list(t[0][0] for t in transferred_per_step_probs)))

'If Natalia sold 48 clips in April, she sold 48 / 2 = 24 clips in May. In total, Natalia sold 48 + 24 = 72 clips in April and May. The answer is 72</s>'

In [19]:
x = 'abcdef'

In [20]:
x[:-2]

'abcd'

# Transform the codex decoded tokens to be flan tokens

In [7]:
transformed_codex_per_step_probs = []
transformed_mask = []

In [80]:
len(transformed_codex_per_step_probs)

672

In [10]:
for qi, q in tqdm(enumerate(codex_per_step_probs), total=len(codex_per_step_probs)):
    transformed_q = []
    transformed_m = []
    for aid, ai in enumerate(q):
        probs, mask, _ = transform_codex_token_to_t5_token(qi, aid, ai, closest_token)
        transformed_q.append(probs)
        transformed_m.append(mask)
    transformed_codex_per_step_probs.append(transformed_q)
    transformed_mask.append(transformed_m)

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 7473/7473 [20:18<00:00,  6.13it/s]


In [64]:
print_transformed_probs(tokenizer, transferred_per_step_probs)

If Tin works 10 hours, she will work 8 hours at her regular pay of $ 18.00 and 2 hours of overtime. The overtime pay is calculated by adding the regular pay plus half the regular pay. So, she will earn $ 18.00 + 1 / 2 ( 18.00 )= $ 18.00 + $ 9.00 = $ 27.00 per hour for her overtime pay. She will earn $ 18.00 x 8 = $ 144.00 for her regular pay and $ 27.00 x 2 = $ 54.00 for her overtime pay. In total, she will earn $ 144.00 + $ 54.00 = $ 198.00 for the day. If she works 10 hours every day for 5 days, she will earn $ 198.00 x 5 = $ 990.00 The answer is 990</s>


In [22]:
pickle.dump(transformed_codex_per_step_probs, open('codex_transformed_per_step_probs.pkl', 'wb'))

In [44]:
pickle.dump(transformed_mask, open('codex_mask_after_transform.pkl', 'wb'))

# Use tokenizer to transform everything to index

In [29]:
questions = []
for q in codex_questions:
    q_ = q.split(':')[1:]
    questions.append(':'.join(q_).strip())

In [96]:
questions[0]

'Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May?'

In [30]:
codex_questions_idx = tokenizer(questions[:-1], return_attention_mask=False)['input_ids']

In [31]:
len(codex_questions_idx)

7473

In [32]:
pickle.dump(codex_questions_idx, open('codex_questions_idx.pkl', 'wb'))

In [102]:
tokenizer.decode(codex_questions_idx[0])

'Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May?</s>'

In [104]:
codex_answers[0]

'Answer: Natalia sold 48/2 = <<48/2=24>>24 clips in May.\nNatalia sold 48+24 = <<48+24=72>>72 clips altogether in April and May.\n#### 72\n'

In [33]:
answers = []
for a in codex_answers:
    a_ = a.split(':')[1:]
    answers.append(':'.join(a_).strip())

In [34]:
codex_answers_idx = tokenizer(answers[:-1], return_attention_mask=False)['input_ids']

In [35]:
pickle.dump(codex_answers_idx, open('codex_answers_idx.pkl', 'wb'))

In [108]:
tokenizer.decode(codex_answers_idx[0])

'Natalia sold 48/2 = <unk> 48/2=24>>24 clips in May. Natalia sold 48+24 = <unk> 48+24=72>>72 clips altogether in April and May. #### 72</s>'

In [24]:
len(transformed_codex_per_step_probs)

7473

In [19]:
codex_per_step_probs_idx = []
vocab = tokenizer.get_vocab()
for q in tqdm(transformed_codex_per_step_probs):
    q_ = []
    for a in q:
        a_ = []
        for at in a:
            at_ = []
            for tk, p in at:
                idx = vocab[tk]
                at_.append((idx, p))
            a_.append(at_)
        q_.append(a_)
    codex_per_step_probs_idx.append(q_)

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 7473/7473 [03:39<00:00, 34.11it/s]


In [36]:
list(tp[0][0] for tp in transformed_codex_per_step_probs[7472][39])

['▁',
 't',
 '▁30',
 '▁',
 ',',
 '▁An',
 'ika',
 '▁is',
 '▁4',
 '▁',
 '/',
 '▁3',
 '▁the',
 'age',
 '▁of',
 '▁add',
 'e',
 '.',
 '▁',
 'So',
 ',',
 '▁add',
 'e',
 '▁is',
 '▁30',
 '▁',
 'x',
 '▁3',
 '▁',
 '/',
 '▁4',
 '▁=',
 '▁22',
 '.',
 '5',
 '▁years',
 '▁old',
 '.',
 '▁',
 'In',
 '▁15',
 '▁years',
 ',',
 '▁An',
 'ika',
 '▁would',
 '▁be',
 '▁30',
 '▁+',
 '▁15',
 '▁=',
 '▁45',
 '▁years',
 '▁old',
 '.',
 '▁',
 'And',
 '▁add',
 'e',
 '▁would',
 '▁be',
 '▁22',
 '.',
 '5',
 '▁+',
 '▁15',
 '▁=',
 '▁37',
 '.',
 '5',
 '▁years',
 '▁old',
 '.',
 '▁',
 '▁Their',
 '▁average',
 'age',
 '▁in',
 '▁15',
 '▁years',
 '▁would',
 '▁be',
 '▁(',
 '▁45',
 '▁+',
 '▁37',
 '.',
 '5',
 '▁',
 ')',
 '/',
 '▁2',
 '▁=',
 '▁41',
 '.',
 '25',
 '▁',
 'The',
 '▁answer',
 '▁is',
 '▁41',
 '.',
 '25',
 '▁',
 '</s>']

In [17]:
tokenizer.convert_ids_to_tokens(tp[0][0] for tp in codex_per_step_probs_idx[0][0])

['▁',
 'If',
 '▁Nat',
 'alia',
 '▁sold',
 '▁48',
 '▁clips',
 '▁in',
 '▁April',
 ',',
 '▁she',
 '▁sold',
 '▁48',
 '▁',
 '/',
 '▁2',
 '▁=',
 '▁24',
 '▁clips',
 '▁in',
 '▁May',
 '.',
 '▁',
 'In',
 '▁total',
 ',',
 '▁Nat',
 'alia',
 '▁sold',
 '▁48',
 '▁+',
 '▁24',
 '▁=',
 '▁72',
 '▁clips',
 '▁in',
 '▁April',
 '▁and',
 '▁May',
 '.',
 '▁',
 'The',
 '▁answer',
 '▁is',
 '▁72',
 '▁',
 '</s>']

In [37]:
tokenizer.decode(tp[0][0] for tp in codex_per_step_probs_idx[7472][39])

't 30, Anika is 4 / 3 theage of adde. So, adde is 30 x 3 / 4 = 22.5 years old. In 15 years, Anika would be 30 + 15 = 45 years old. And adde would be 22.5 + 15 = 37.5 years old.  Their averageage in 15 years would be ( 45 + 37.5 )/ 2 = 41.25 The answer is 41.25 </s>'

In [39]:
''.join(tp[0][0] for tp in codex_per_step_probs[7472][39])

' step\nAt 30, Anika is 4/3 the age of Maddie.\nSo, Maddie is 30 x 3/4 = 22.5 years old.\nIn 15 years, Anika would be 30 + 15 = 45 years old.\nAnd Maddie would be 22.5 + 15 = 37.5 years old.\nTheir average age in 15 years would be (45 + 37.5) / 2 = 41.25\nThe answer is 41.25\n\n'

In [43]:
pickle.dump(codex_per_step_probs_idx, open('codex_per_step_probs_idx.pkl', 'wb'))

# TODO: use tokenizer to tokenize FlanT5 generation

In [45]:
len(tokenizer.get_vocab())

32100

# Generation index permutation

In [25]:
idx = list(range(len(transformed_codex_per_step_probs)))
random.shuffle(idx)

In [27]:
pickle.dump(idx, open('permuted_idx.pkl', 'wb'))