A script to load two models and perform various modes of experimental analysis on them, ensuring this is done ROBUSTLY! In particular, we are interested in:
- Generation from unprompted, trained model.
- Generation from prompted, untrained model.
- Calculation of logits given a question and answer for:
    - unprompted, untrained model.
    - prompted, trained model.
    - unprompted, trained model.
- Create plots for deviance, correlation, and other possibly interesting measurements.

In [14]:
import torch
import torch.nn as nn
import sys 
import os
from tqdm import tqdm
sys.path.append('../')
from generate_data import format_prompt
import gc

In [15]:
# Unbatched
def calculate_logits(model,
                     tokenizer,
                     x0,
                     question,
                     answer,
                     use_system=False):
    '''This function calculates the logits for a given question and answer pair.
        Inputs:
            model: The model to use for computing logits.
            tokenizer: The tokenizer to use for computing logits
            x0: The system prompt (string)
            question: The question to ask (string)
            answer: The answer to the question (string)
            use_system: Whether to use the system prompt or not
        Outputs:
            logits: The logits for the question and answer pair
            answer_mask: The mask for the answer.'''

    prompt_q_str = format_prompt(x0, question, use_system=use_system)
    prompt_q_ids = tokenizer.encode(prompt_q_str, return_tensors='pt').to(model.device)['input_ids']
    
    answer_ids = tokenizer.encode(answer, return_tensors='pt').to(model.device)['input_ids']

    input_ids = torch.cat([prompt_q_ids, answer_ids], dim=0)
    answer_mask = torch.ones_like(answer_ids)
    answer_mask = torch.cat([torch.zeros_like(prompt_q_ids), answer_mask], dim=0)
    answer_mask = answer_mask == 1
    
    logits = model(input_ids, return_dict=True).logits
    
    return logits, answer_mask

In [16]:
def generate_from_model(model,
                        tokenizer,
                        x0,
                        question,
                        min_length,
                        max_new_tokens,
                        temperature,
                        use_system=False):
    '''This function generates text from the model.
        Inputs:
            model: The model to use for generating text
            tokenizer: The tokenizer to use for generating text
            x0: The system prompt (string)
            question: The question to ask (string)
            min_length: The minimum length for generation.
            max_new_tokens: The maximum number of tokens to generate
            temperature: The temperature to use for sampling
            use_system: Whether to use the system prompt or not
        Outputs:
            output: The generated text (token ids)
    '''
    prompt_q_str = format_prompt(x0, question, use_system=use_system)
    prompt_q_ids = tokenizer.encode(prompt_q_str, return_tensors='pt').to(model.device)['input_ids']

    output = model.generate(
                    prompt_q_ids, 
                    attention_mask = None,
                    do_sample = True, 
                    max_new_tokens = max_new_tokens,
                    min_length = min_length,
                    temperature = temperature,
                    pad_token_id = tokenizer.eos_token_id
                )
    return output

Now we're going to sketch out the pseudocode for generating the deviance plots

In [None]:
# Load untrained model
model_untrained = 
# Load tokenizer
tokenizer = 
# open an question, answer, and system prompt file
x0 = 
question = 
answer = 

# Calculate logits for untrained, unprompted
logits_untrained_unprompted, mask_untrained_unprompted = calculate_logits(model_untrained, tokenizer, x0, question, answer, use_system=False)

# Calculate logits for untrained, prompted
logits_untrained_prompted, mask_untrained_prompted = calculate_logits(model_untrained, tokenizer, x0, question, answer, use_system=True)

# flush model_untrained from GPU memory
del model_untrained
gc.collect()
torch.cuda.empty_cache()

# Load trained model
model_trained = 

# Calculate logits for trained, unprompted
logits_trained_unprompted, mask_trained_unprompted = calculate_logits(model_trained, tokenizer, x0, question, answer, use_system=False)

# Calculate logits for trained, prompted
logits_trained_prompted, mask_trained_prompted = calculate_logits(model_trained, tokenizer, x0, question, answer, use_system=True)

# Write plotting code here

