# Generation of completions for evaluation

In [None]:
from tqdm import tqdm
import random
import copy
import json

import torch
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

from transformers import GPTNeoForCausalLM, AutoTokenizer
from soft_embedding import SoftEmbedding

random.seed(5678)
torch.manual_seed(5678)

In [None]:
import dill as pickle

ops = pickle.load( open( 'data/ops_raw.pickle', "rb" ) )

with open( 'config/gCONFIG.json', "r") as f:
    CONFIG = json.load(f)

In [None]:
egs = {}

for op in ops.keys():

    if op not in ['suggest-rephrase']:
        
        egs[str(op)] = []
        
        print(op)
        
        # Get random subset of 100 test samples.
        subset = random.sample(ops[op]['test'], 100)
        
        make_prompts = ops[op]['make_prompts']
        
        for s in subset:
            ps = make_prompts(s)
            ps['ex'] = s
            ps['id'] = s['id']
            egs[str(op)].append(ps)

In [None]:
tokenizer = AutoTokenizer.from_pretrained('EleutherAI/gpt-neo-2.7B')
tokenizer.pad_token = '[PAD]'

In [None]:
def generate_completion(model, prompt, CF):

    if 'soft' in CF['checkpoint']:
        
        inputs = tokenizer(prompt["in"], return_tensors="pt")

        # need to pad attention_mask and input_ids to be full seq_len + n_learned_tokens
        # even though it does not matter what you pad input_ids with, it's just to make HF happy
        inputs['input_ids'] = torch.cat([torch.full((1,30), 50256), inputs['input_ids']], 1)
        inputs['attention_mask'] = torch.cat([torch.full((1,30), 1), inputs['attention_mask']], 1)
        
        inputs.to(device)
        
        gen_tokens = model.generate(**inputs, do_sample=True, temperature=0.9, max_new_tokens=150, use_cache=False)
        gen_text = tokenizer.batch_decode(gen_tokens)[0]
        
        return gen_text
    
    else:
        input_ids = tokenizer(prompt['in'], return_tensors="pt").input_ids.to(device)

        # Need to amend input_ids to account for possible soft prompt!

        gen_tokens = model.generate(input_ids, do_sample=True, temperature=0.9, max_new_tokens=150,)
        gen_text = tokenizer.batch_decode(gen_tokens)[0]

        return gen_text

In [None]:
gens = {}

for idx, CF in enumerate(CONFIG):
    if bool(CF['useThis']):
    
        gens[CF['model']] = []

        print(f'Generating for model #{idx}')
        print(f"{CF['op']} / {CF['checkpoint']} / {CF['prompt_type']}")

        checkpoint = CF['checkpoint']

        print(". . . loading")
        if 'soft' in checkpoint:
            model = GPTNeoForCausalLM.from_pretrained("EleutherAI/gpt-neo-2.7B")
            s_wte = SoftEmbedding(model.get_input_embeddings(), n_tokens=30, initialize_from_vocab=True)
            model.set_input_embeddings(s_wte)
            model.load_state_dict(torch.load(f"{CF['checkpoint']}/pytorch_model.bin"), strict=False)
        else: 
            model = GPTNeoForCausalLM.from_pretrained(checkpoint, use_auth_token=True)
        model.to(device)
        model.eval()
        print(". . . model load DONE")

        print(". . . generating")
        for eg in tqdm(egs[CF['op']]):
            p = copy.deepcopy(eg)[CF['prompt_type']]
            gen_text = generate_completion(model, p, CF)
            p['op'] = CF['op']
            p['gen'] = gen_text
            p['model'] = CF['model']
            p['ex'] = copy.deepcopy(eg['ex'])
            p['map_title'] = copy.deepcopy(eg['map_title'])
            p['id'] = copy.deepcopy(eg['id'])
            gens[CF['model']].append(p)

        print(". . . generation DONE")

        with open('gens.json', 'w') as f:
            json.dump(gens, f)
        
        model.to('cpu')

## Tidy

In [None]:
import re

In [None]:
def tidy(gen):

    text = gen['gen']
    
    # Remove <|endoftext|> markers (there will be 30 of them at the start for soft prompts).
    text = text.replace("<|endoftext|>", '')
    
    # Remove prompt so we just have the newly generated text.
    text = text[len(gen['in']):]

    # Remove any parentheses.
    text = re.subn(r"\(.*?\)", '', text)[0]

    # Remove any symbols.
    text = re.subn(r"(\*|\_|`)", '', text)[0]
    
    # Remove leading gibberish.
    text = re.sub(r"^[0-9]\.\s", '', text)
    text = re.sub(r"^[^a-zA-Z0-9\"']*", '', text)
    
    # Remove trailing whitespace.
    text = re.sub(r"\s*$", '', text)
    
    # Truncate to only one sentence.
    lines = text.split('\n')
    sentences = lines[0].split('. ')
    if gen['op'] != 'suggest-intermediary-claims':
        if len(sentences) > 1:
            text = sentences[0] + '.'
        else:
            text = sentences[0]
    else:
        text = lines[0]
        
    # Capitalise first letter.
    if len(text) > 0:
        text = text[0].upper() + text[1:]
    
    gen['gen'] = text
    return gen

for model in tqdm(gens.keys()):
    gens[model] = list(map(tidy, gens[model]))

In [None]:
with open('gens_tidy.json', 'w') as f:
    json.dump(gens, f)

In [None]:
# Make a gens_only copy for ease of interactive inspection.

import copy
gens_only = copy.deepcopy(gens)
for model in tqdm(gens_only.keys()):
    gens_only[model] = [gen['gen'] for gen in gens_only[model]]

In [None]:
# Insert human responses so we have a benchmark.

gens['aH'] = copy.deepcopy(gens['a1'])
gens['bH'] = copy.deepcopy(gens['b1'])
gens['cH'] = copy.deepcopy(gens['c1'])
gens['dH'] = copy.deepcopy(gens['d1'])

for model in ['aH', 'bH', 'cH', 'dH']:
    for gen in gens[model]:
        gen['gen'] = gen['out']

In [None]:
for model in tqdm(gens.keys()):
    if gens[model][0]['op'] == 'suggest-intermediary-claims':
        for gen in gens[model]:
            gen['gen'] = gen['gen'].split(' ~ ')
        gens[model] = [gen for gen in gens[model] if len(gen['gen']) > 2]

In [None]:
# Collapse into single list, randomly (but nicely) ordered.

gens_flat = []
for model in gens.keys():
    gens_flat = gens_flat + gens[model]
    
random.shuffle(gens_flat)

In [None]:
ids = list(set([gen['id'] for gen in gens_flat]))
random.shuffle(ids)
gens_flat = sorted(gens_flat, key=lambda gen: (ids.index(gen['id']), gen['model'][0]))

In [None]:
with open('gens_tidy_official.json', 'w') as f:
    json.dump(gens_flat, f)