In [None]:
from transformers import AutoTokenizer, GPTNeoXForCausalLM
import os
import json

In [None]:
def load_json(name: str):
    with open(name, 'r') as f:
        return json.load(f)

In [None]:
ROOT_300M_MODELS = '/fsx/proj-chemnlp/experiments/checkpoints/finetuned/300M-tokenised-gridsearch-v1'

In [None]:
models = {}
NAME = '1B_fine_tune_1' # xyz model
COLLECT_ALL = True # can be expensive (30 seconds per 1B param model)

if COLLECT_ALL:
    # traverse the directory for all models
    for name in os.listdir(ROOT_300M_MODELS):
        model_path = f"{ROOT_300M_MODELS}/{name}"
        if not name.endswith('.json') and 'checkpoint-final' in os.listdir(model_path):
            models[name] = {
                'model': GPTNeoXForCausalLM.from_pretrained(pretrained_model_name_or_path=f"{model_path}/checkpoint-final"),
                'configs': load_json(f"{model_path}_global_0_local_0_rank_overrides.json")
            }
else:
    model_path = f"{ROOT_300M_MODELS}/{NAME}"
    models[NAME] = {
        'model': GPTNeoXForCausalLM.from_pretrained(pretrained_model_name_or_path=f"{model_path}/checkpoint-final"),
        'configs': load_json(f"{model_path}_global_0_local_0_rank_overrides.json")
    }

In [None]:
tokeniser = AutoTokenizer.from_pretrained("EleutherAI/pythia-1b")
tokeniser.add_special_tokens({"pad_token": "<|padding|>"})

In [None]:
TEST_CASES = [
    'I enjoy walking with my cute dog',
    'The heaviest element in the periodic table is',
    'C 6.39 2.84 -1.46 O 6.12 1.57 -0.86 P 5.14 1.10 0.31',
    'The element carbon is denoted with the following symbol',
]
OUTPUT_LEN = 20

In [None]:
for i, test_case in enumerate(TEST_CASES): 
    print(f'TEST CASE {i+1} -> {test_case}')
    input_ids = tokeniser.encode(test_case, return_tensors='pt')

    for model_configs in models.values():
        greedy_output = model_configs['model'].generate(
            input_ids, 
            max_length=input_ids.shape[-1]+OUTPUT_LEN, 
            pad_token_id=tokeniser.eos_token_id
        )

        checkpoint_dir = list(model_configs['configs'].keys())[0]
        dataset_name = model_configs['configs'][checkpoint_dir]['data']['path'].split('/')[-1]

        print(f"\nOutput for model trained on 300M {dataset_name}\n" + 100 * '-')
        print(tokeniser.decode(greedy_output[0], skip_special_tokens=True))
        print( )
    print( )