# Different Models

The first aim here is to find tasks that some models can do but others can't (so that we know something interesting is happening, rather than just e.g. skip trigrams). There are a variety of models to do this with, within [EasyTransformer](https://github.com/neelnanda-io/Easy-Transformer).

## Imports

In [None]:
!pip install git+https://github.com/neelnanda-io/Easy-Transformer ipython

# Clear output
from IPython.display import clear_output
clear_output()

In [38]:
from easy_transformer import EasyTransformer
import torch
from IPython.core.display import HTML, clear_output
from typing import List
import pandas as pd
import copy

  from IPython.core.display import HTML, clear_output


## Models available

In [3]:
from easy_transformer import loading_from_pretrained
", ".join(loading_from_pretrained.OFFICIAL_MODEL_NAMES)

'gpt2, gpt2-medium, gpt2-large, gpt2-xl, distilgpt2, facebook/opt-125m, facebook/opt-1.3b, facebook/opt-2.7b, facebook/opt-6.7b, facebook/opt-13b, facebook/opt-30b, facebook/opt-66b, EleutherAI/gpt-neo-125M, EleutherAI/gpt-neo-1.3B, EleutherAI/gpt-neo-2.7B, EleutherAI/gpt-j-6B, EleutherAI/gpt-neox-20b, stanford-crfm/alias-gpt2-small-x21, stanford-crfm/battlestar-gpt2-small-x49, stanford-crfm/caprica-gpt2-small-x81, stanford-crfm/darkmatter-gpt2-small-x343, stanford-crfm/expanse-gpt2-small-x777, stanford-crfm/arwen-gpt2-medium-x21, stanford-crfm/beren-gpt2-medium-x49, stanford-crfm/celebrimbor-gpt2-medium-x81, stanford-crfm/durin-gpt2-medium-x343, stanford-crfm/eowyn-gpt2-medium-x777, EleutherAI/pythia-19m, EleutherAI/pythia-125m, EleutherAI/pythia-350m, EleutherAI/pythia-800m, EleutherAI/pythia-1.3b, EleutherAI/pythia-6.7b, EleutherAI/pythia-13b, EleutherAI/pythia-125m-deduped, EleutherAI/pythia-800m-deduped, EleutherAI/pythia-1.3b-deduped, EleutherAI/pythia-6.7b-deduped, NeelNanda/SoL

## Introduction

Code completion in language models is pretty sophisticated. For example GitHub Copilot successfully solves problem prompts with solutions, like this:

### Python prompt

```python
# Two Sum

# Given an array of integers nums and an integer target, return indices of the two numbers such that they add up to target.

# You may assume that each input would have exactly one solution, and you may not use the same element twice.

# You can return the answer in any order.

class Solution(object):
    def twoSum(self, nums, target):
```

### Solution

```python
"""
:type nums1: List[int]
:type nums2: List[int]
:rtype: float
"""
nums1.extend(nums2)
nums1.sort()
if len(nums1)%2 == 0:
    return (nums1[len(nums1)//2] + nums1[len(nums1)//2 - 1])/2
else:
    return nums1[len(nums1)//2]
```

Some interesting observations of learnt skills here include:

- Deep stuff: understanding the problem from the comments, and formulating a high level solution.
- Keeping track of if/else blocks & brackets
- Using previously defined variables (perhaps similar to induction heads)
- Knowing variable types (both primitives and library types)
- Knowing the methods within classes
- Understanding underlying maths (e.g. what modulus and floor divide actually do)
- Understanding broadly how a function works (e.g. it typically returns something)
- Known line spacing

## Finding interesting problems for different models

In [34]:
# Tasks in the form [name, prompt, expected_output]
tasks: List[List[str]] = [
    [
        "Convert types correctly", 
        """
            a = "The number is"
            b = 12.145
            concat = a""",
        """ + " " + str(b)"""
    ]
]

# Models to use
models = [
    "NeelNanda/SoLU_1L512W_C4_Code", 
    "NeelNanda/SoLU_2L512W_C4_Code", 
    # "NeelNanda/SoLU_3L512W_C4_Code", 
    # "NeelNanda/SoLU_4L512W_C4_Code", 
    # "NeelNanda/SoLU_6L768W_C4_Code", 
    # "NeelNanda/SoLU_8L1024W_C4_Code", 
    # "NeelNanda/SoLU_10L1280W_C4_Code", 
    # "NeelNanda/SoLU_12L1536W_C4_Code", 
    # "NeelNanda/GELU_1L512W_C4_Code", 
    # "NeelNanda/GELU_2L512W_C4_Code", 
    # "NeelNanda/GELU_3L512W_C4_Code", 
    # "NeelNanda/GELU_4L512W_C4_Code", 
    "NeelNanda/Attn_Only_1L512W_C4_Code", 
    "NeelNanda/Attn_Only_2L512W_C4_Code", 
    # "NeelNanda/Attn_Only_3L512W_C4_Code", 
    # "NeelNanda/Attn_Only_4L512W_C4_Code
]

In [39]:
def get_model_output(
    model: EasyTransformer, 
    prompt: str, 
    number_new_tokens: int) -> str:
    """Get the model output for a given prompt

    Args:
        model (EasyTransformer): Model
        prompt (str): Prompt
        number_new_tokens (int): Number of output tokens to get (by recursively running the model)

    Returns:
        str: Output tokens as a concatenated string
    """
    next_tokens: List[str] = []
    
    for i in range(number_new_tokens):
        logits = model(prompt + "".join(next_tokens))
        predictions = torch.argmax(logits, 2)
        prediction = int(predictions[0][-1].item())
        next_token = model.tokenizer.decode(prediction)
        next_tokens.append(next_token)
        
    return "".join(next_tokens)

tasks_with_model_outputs = copy.deepcopy(tasks)

for model_idx, model_name in enumerate(models):
    # Clear PyTorch GPU memory
    torch.cuda.empty_cache()
    
    # Load up the model
    model = EasyTransformer.from_pretrained(model_name)
    
    # Loop through tasks
    for task_idx, task in enumerate(tasks):
        
        # Destructure the tasks
        print(task)
        print(len(task))

        [name, prompt, expected_output] = task
    
        # Get the number of tokens in the expected output
        expected_tokens = model.tokenizer.encode(expected_output)
        expected_tokens_length = len(expected_tokens)
        
        # Get the model output
        model_output = get_model_output(model, prompt, expected_tokens_length)
        tasks_with_model_outputs[task_idx].append(model_output)
        

# Display the results
results = pd.DataFrame(tasks_with_model_outputs, columns=["Task", "Prompt", "Expected Output", *models])
clear_output()
results

Unnamed: 0,Task,Prompt,Expected Output,NeelNanda/SoLU_1L512W_C4_Code,NeelNanda/SoLU_2L512W_C4_Code,NeelNanda/Attn_Only_1L512W_C4_Code,NeelNanda/Attn_Only_2L512W_C4_Code
0,Convert types correctly,"\n a = ""The number is""\n ...","+ "" "" + str(b)",.get_at_at(a,".concat(a, b)\n",\n concat = concat\n conc,"+ "" "" + b + "" """


### save


In [None]:
# Load the model
model_name = "NeelNanda/SoLU_12L1536W_C4_Code"
model = EasyTransformer.from_pretrained(model_name)

# Complete a LeetCode test
prompt = """# Two Sum

# Given an array of integers nums and an integer target, return indices of the two numbers such that they add up to target.

# You may assume that each input would have exactly one solution, and you may not use the same element twice.

# You can return the answer in any order.

class Solution(object):
    def twoSum(self, nums, target):"""

next_tokens = []
for i in range(10):
    logits = model(prompt + "".join(next_tokens))
    predictions = torch.argmax(logits, 2)
    prediction = int(predictions[0][-1].item())
    next_token = model.tokenizer.decode(prediction)
    next_tokens.append(next_token)
    
HTML("".join(next_tokens).replace("\n", "<br/>"))


        # """
        # :type nums1: List[int]
        # :type nums2: List[int]
        # :rtype: float
        # """
        # nums1.extend(nums2)
        # nums1.sort()
        # if len(nums1)%2 == 0:
        #     return (nums1[len(nums1)//2] + nums1[len(nums1)//2 - 1])/2
        # else:
        #     return nums1[len(nums1)//2]