In [191]:
from micro_config import MetaConfig, deep_replace, parse_args
from base_configs import project_root
from data import NatInstSeq2SeqConfig, NatInstSeq2SeqGeneratorConfig
from models.t5_config import T5ModelConfig
from core import TKInferenceConfig
from tkinstruct_eval_inference import TKInstructEvaluationConfig, tk_instruct_evaluate

model = T5ModelConfig(
    # model_str="google/t5-v1_1-xl", 
    # model_str="t5-3b", 
    #model_str="t5-small",
    # model_str="google/ul2", 
    model_str="allenai/tk-instruct-large-def-pos",
    #model_str="google/t5-large-lm-adapt",
    # model_str="allenai/tk-instruct-11b-def-pos-neg-expl", 
    # checkpoint_path='outputs/T5_11B_random_nat_inst_finetune_test1/model_18854/', 
    # checkpoint_path='outputs/tk_model_full/',
    checkpoint_path = 'outputs/T5_large_gpt_dist_finetune_test1/model_4903',
    #checkpoint_path=None,
    from_pretrained=False, 
    use_fp16=True,
    gradient_checkpoint=True, 
)

eval_dataset = NatInstSeq2SeqConfig(
    #tsv_path='data/nat_inst/text2text/defintion_pos_2/test.tsv', 
    tsv_path='data/gpt_dist/text2text/io/test.tsv',
    enc_len=256, 
    dec_len=1024,
    add_ar_sentinal=False, 
    target_prepend_pad=True, 
    model_tokenizer=model, 
)

inference = TKInferenceConfig(
    model=model, 
    pjit=True, 
    verbose=True, 
)

In [192]:
from data import dataloader
from base_configs import project_root
from tqdm import tqdm

In [193]:
metaconfig = MetaConfig(
    project_root=project_root, 
    verbose=False, 
)

In [194]:
import jax
rng = jax.random.PRNGKey(0)

generation_kwargs={
    'max_length': 1024, 
    'do_sample': False,
    'num_beams': 1, 
}

In [195]:
ds = eval_dataset.unroll(metaconfig)

tcmalloc: large alloc 3132399616 bytes == 0x5620adcf8000 @  0x7f7a7a0c5680 0x7f7a7a0e6824 0x561d624c153b 0x561d625020ba 0x561d625d8a58 0x561d6253448d 0x561d6240e328 0x561d625ee66d 0x561d62534825 0x561d624922da 0x561d6252abe4 0x561d62491088 0x561d6252abe4 0x561d62491088 0x561d6252abe4 0x561d62491088 0x561d62529fe3 0x561d6252acb4 0x561d624922da 0x561d6252abe4 0x561d62491088 0x561d62529fe3 0x561d6252acb4 0x561d624922da 0x561d62529fe3 0x561d625d6a7c 0x561d6252adbb 0x561d6260d33e 0x561d62534571 0x561d62491088 0x561d6251f7cb
Token indices sequence length is longer than the specified maximum sequence length for this model (513 > 512). Running this sequence through the model will result in indexing errors


In [196]:
inference, _, mesh = inference.unroll(metaconfig)

unmatches keys: set()
using mesh shape: (1, 8)
full mesh: [[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0)
  TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1)
  TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0)
  TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1)
  TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0)
  TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1)
  TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0)
  TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]]


In [197]:
inference.tokenizer.decode(ds[153][0]['input_ids'], skip_special_tokens=True)

'Input: What are some fun and unique ways to exercise? Output:'

In [198]:
import jax.numpy as jnp

In [199]:
len(ds) / 51

195.0

In [200]:
len(ds)

9945

In [201]:
inputs, predictions = [], []
with mesh:
    for i in tqdm(range(0, int(len(ds)/51), 2)):
        s=i*51
        e=(i+1)*51
        if e == len(ds):
            e = s

        newInput = jnp.stack((ds[s][0]['input_ids'], ds[e][0]['input_ids']), axis=0)
        rng, new_rng = jax.random.split(rng)
        model_outputs = inference.generate_from_tokens(newInput, new_rng, **generation_kwargs)
        
        inputs.extend(inference.tokenizer.batch_decode(newInput, skip_special_tokens=True))
        predictions.extend(inference.tokenizer.batch_decode(model_outputs, skip_special_tokens=True))

  scopes, treedef = jax.tree_flatten(scope_tree)
  jax.tree_leaves(tree)))
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 98/98 [10:43<00:00,  6.56s/it]


In [202]:
with open('../data/gpt_dist/splits/default/test_tasks.txt', 'r') as file:
    test_tasks = file.readlines()
    test_tasks = [line.rstrip() for line in test_tasks]

In [203]:
test_tasks

['brainstorming_42',
 'brainstorming_77',
 'brainstorming_70',
 'brainstorming_68',
 'brainstorming_101',
 'brainstorming_39',
 'brainstorming_121',
 'brainstorming_3',
 'brainstorming_72',
 'brainstorming_61',
 'brainstorming_1',
 'brainstorming_28',
 'brainstorming_106',
 'brainstorming_80',
 'brainstorming_97',
 'brainstorming_69',
 'brainstorming_66',
 'brainstorming_96',
 'brainstorming_118',
 'brainstorming_114',
 'brainstorming_12',
 'brainstorming_99',
 'brainstorming_94',
 'brainstorming_51',
 'brainstorming_20',
 'chat_11',
 'chat_28',
 'chat_68',
 'chat_67',
 'chat_66',
 'chat_79',
 'chat_30',
 'chat_19',
 'chat_21',
 'chat_54',
 'chat_18',
 'chat_78',
 'chat_40',
 'chat_10',
 'chat_37',
 'chat_13',
 'chat_9',
 'closedqa_23',
 'closedqa_2',
 'closedqa_15',
 'closedqa_20',
 'closedqa_35',
 'closedqa_6',
 'closedqa_1',
 'closedqa_11',
 'closedqa_3',
 'extract_6',
 'extract_24',
 'extract_11',
 'extract_10',
 'extract_18',
 'extract_8',
 'extract_29',
 'extract_32',
 'generatio

### get GPT likelihoods for each io pair

In [204]:
import openai
openai.api_key="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"

def get_response(input_prompt, num_responses=1, temp=0):
    response = openai.Completion.create(
        model='text-davinci-001',
        prompt=input_prompt,
        max_tokens=0,
        echo=True,
        temperature=temp,
        logprobs=1,
        n=num_responses
    )
    return response

In [205]:
def get_summed_logprob(start_token_idx, dict_ex):
    try:
        stop_token_idx = dict_ex['tokens'].index("<|endoftext|>") + 1
    except:
        stop_token_idx = len(dict_ex['tokens'])
    logprobs_sum = 0
    for i in range(start_token_idx, stop_token_idx):
        if not dict_ex['token_logprobs'][i]:
            continue
        logprobs_sum += dict_ex['token_logprobs'][i]
    
    return logprobs_sum, (stop_token_idx-start_token_idx)

In [206]:
gpt_logprobs = []
token_lengths = []
for query, pred in tqdm(zip(inputs, predictions)):
    prompt = query + ' ' + pred
    res = get_response(prompt)
    start_idx = res['choices'][0]['logprobs']['tokens'].index(' Output') + 2
    logprobs, length = get_summed_logprob(start_idx, res['choices'][0]['logprobs'])
    gpt_logprobs.append(logprobs)
    token_lengths.append(length)

196it [00:43,  4.51it/s]


### Write results to file

In [3]:
import csv
import pandas as pd

In [207]:
with open('finetuned_tk_outputs.csv', 'w') as f:
    writer = csv.writer(f, delimiter =",")
    writer.writerow(['task', 'prompt', 't5 output', 'output tokens', 'gpt likelihood'])
    
for task, query, pred, num_tokens, logprob in zip(test_tasks, inputs, predictions, token_lengths, gpt_logprobs):
    with open('finetuned_tk_outputs.csv', 'a') as f:
        writer = csv.writer(f, delimiter =",")
        writer.writerow([task, query, pred, num_tokens, logprob])

In [4]:
with open('finetuned_tk_outputs.csv', 'r') as f:
    fine_tk_df = pd.read_csv(f)

with open('baseline_tk_outputs.csv', 'r') as f:
    base_tk_df = pd.read_csv(f)

In [5]:
fine_tk_df = fine_tk_df.set_index('task')
base_tk_df = base_tk_df.set_index('task')

In [6]:
idx = 'extract_29'

prompt = base_tk_df.loc[idx]['prompt']
print(prompt, '\n')

base_out = base_tk_df.loc[idx]['t5 output']
fine_out = fine_tk_df.loc[idx]['t5 output']

print("Baseline:", base_out)
print(base_tk_df.loc[idx]['gpt likelihood'], base_tk_df.loc[idx]['output tokens'], base_tk_df.loc[idx]['gpt likelihood'] / base_tk_df.loc[idx]['output tokens'], '\n')

print("Finetuned:", fine_out)
print(fine_tk_df.loc[idx]['gpt likelihood'], fine_tk_df.loc[idx]['output tokens'], fine_tk_df.loc[idx]['gpt likelihood'] / fine_tk_df.loc[idx]['output tokens'], '\n')

Input: Given this set of instructions, follow them to make a cake. Output: 

Baseline: Give this set of instructions to make a cake.
-22.0567200893 10 -2.20567200893 

Finetuned: The cake should be made of a cake mix and butter. The cake should be made of a cake mix and butter.
-52.086296759016 24 -2.170262364959 

