In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import torch
import numpy as np
import json
from tqdm.auto import tqdm
import random
import transformers

import os
import sys
sys.path.append('..')

from relations import estimate
from util import model_utils
from baukit import nethook
from operator import itemgetter
from relations.evaluate import evaluate
from relations.corner import CornerEstimator

In [None]:
# counterfact = CounterFactDataset("../data/")

In [None]:
MODEL_NAME = "EleutherAI/gpt-j-6B"  # gpt2-{medium,large,xl} or EleutherAI/gpt-j-6B
mt = model_utils.ModelAndTokenizer(MODEL_NAME, low_cpu_mem_usage=True, torch_dtype=torch.float32)

model = mt.model
tokenizer = mt.tokenizer
tokenizer.pad_token = tokenizer.eos_token

print(f"{MODEL_NAME} ==> device: {model.device}, memory: {model.get_memory_footprint()}")

In [None]:
#################################################
relation_id = "P101"
precision_at = 3
#################################################

with open("../data/counterfact.json") as f:
    counterfact = json.load(f)

objects = [c['requested_rewrite'] for c in counterfact if c["requested_rewrite"]['relation_id'] == relation_id]
objects = [" "+ o['target_true']['str'] for o in objects]
objects = list(set(objects))
print("unique objects: ", len(objects), objects[0:5])

In [None]:
corner_estimator = CornerEstimator(model=model, tokenizer=tokenizer)

In [None]:
simple_corner = corner_estimator.estimate_simple_corner(objects, scale_up=70)
print(simple_corner.norm().item(), corner_estimator.get_vocab_representation(simple_corner))

relation_operator = estimate.RelationOperator(
    model = model,
    tokenizer = tokenizer,
    relation = '{} works in the field of',
    layer = 15,
    weight = torch.eye(model.config.n_embd).to(model.dtype).to(model.device),
    bias = simple_corner
)

In [None]:
precision, ret_dict = evaluate(
    relation_id= relation_id,
    relation_operator= relation_operator,
    precision_at=3
)

In [None]:
precision

In [None]:
lin_inv_corner = corner_estimator.estimate_lin_inv_corner(objects, target_logit_value=50)
print(lin_inv_corner.norm().item(), corner_estimator.get_vocab_representation(lin_inv_corner))

relation_lin_inv = estimate.RelationOperator(
    model = model,
    tokenizer = tokenizer,
    relation = '{} works in the field of',
    layer = 15,
    weight = torch.eye(model.config.n_embd).to(model.dtype).to(model.device),
    bias = lin_inv_corner
)

In [None]:
precision, ret_dict_2 = evaluate(
    relation_id="P17",
    relation_operator= relation_lin_inv,
    precision_at=3,
    validation_set= ret_dict["validation_set"]
)

precision