Notebook for testing Emergent Abilities

In [1]:
import torch
from torch.nn.utils import prune

from tqdm import tqdm

from transformers import AutoTokenizer, OPTForCausalLM, pipeline
from datasets import load_dataset

from calculate_mask import calculate_mask
from inverse_hessian import calc_inverse_hessian
from input_prehooks import put_input_hooks
from testing_module import calculate_perp

  from .autonotebook import tqdm as notebook_tqdm


In [11]:
#DEVICE
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# model_name = "facebook/opt-125m"
model_name = "facebook/opt-1.3b"

#Load dataset
dataset = load_dataset('c4', 'en', streaming=True)

# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side='left', max_new_tokens=30, max_length=100)

# Load model with pre-trained head
model = OPTForCausalLM.from_pretrained(model_name, output_attentions=True, output_hidden_states=True).to(device=device) # type: ignore

# Load generator
generator = pipeline('text-generation', model=model_name, tokenizer=tokenizer)

In [12]:
generator('Hello, my name is', temperature=0.7)

[{'generated_text': 'Hello, my name is John. I am a professional writer and I have been writing for over 10'}]

In [9]:
generator('5 + 5 = 10, 13 + 16 = 29, 10 + 15 =', temperature=0.7)

[{'generated_text': '5 + 5 = 10, 13 + 16 = 29, 10 + 15 = -5*z'}]

In [5]:
# digit addition few shot generator
import random
def generate_addition_few_shot(num_dig, num_examples):
    few_shot_str = ""
    for ex in range(num_examples):
        max_num = 10**num_dig-1
        num1 = random.randint(1, max_num)
        num2 = random.randint(1, max_num)
        sum = num1+num2
        few_shot_str += f'{num1} + {num2} = {sum}, '
    return few_shot_str

In [10]:
inp_str = generate_addition_few_shot(2, 10) + " 10 + 20 = "
generator(inp_str, temperature=.8)

Input length of input_ids is 67, but `max_length` is set to 21. This can lead to unexpected behavior. You should consider increasing `max_new_tokens`.


[{'generated_text': '25 + 6 = 31, 68 + 45 = 113, 94 + 74 = 168, 49 + 9 = 58, 10 + 39 = 49, 49 + 41 = 90, 47 + 24 = 71, 76 + 8 = 84, 27 + 62 = 89, 67 + 68 = 135,  10 + 20 =  '}]