In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import torch
import numpy as np
import json
from tqdm.auto import tqdm
import random
import transformers

import os
import sys
sys.path.append('..')

from src.relations import estimate
from src.util import model_utils
from baukit import nethook
from operator import itemgetter

In [None]:
MODEL_NAME = "EleutherAI/gpt-j-6B"  # gpt2-{medium,large,xl} or EleutherAI/gpt-j-6B
mt = model_utils.ModelAndTokenizer(MODEL_NAME, low_cpu_mem_usage=True, torch_dtype=torch.float32)

model = mt.model
tokenizer = mt.tokenizer
tokenizer.pad_token = tokenizer.eos_token

print(f"{MODEL_NAME} ==> device: {model.device}, memory: {model.get_memory_footprint()}")

In [None]:
#################################################
relation_id = "P17"
precision_at = 3
#################################################

with open("../data/counterfact.json") as f:
    counterfact = json.load(f)

objects = [c['requested_rewrite'] for c in counterfact if c["requested_rewrite"]['relation_id'] == relation_id]
objects = [" "+ o['target_true']['str'] for o in objects]
objects = list(set(objects))
print("unique objects: ", len(objects), objects[0:5])

In [None]:
from src.relations.corner import CornerEstimator
corner_estimator = CornerEstimator(model=model, tokenizer=tokenizer)

In [None]:
simple_corner = corner_estimator.estimate_simple_corner(objects, scale_up=70)
print(simple_corner.norm().item(), corner_estimator.get_vocab_representation(simple_corner))

In [None]:
lin_inv_corner = corner_estimator.estimate_lin_inv_corner(objects, target_logit_value=50)
print(lin_inv_corner.norm().item(), corner_estimator.get_vocab_representation(lin_inv_corner))

In [None]:
grad_dsc_corner = corner_estimator.estimate_corner_with_gradient_descent(objects, target_logit_value=50, verbose=True)
print(grad_dsc_corner.norm().item())

In [None]:
avg_corner = corner_estimator.estimate_average_corner_with_gradient_descent(objects, average_on=5, target_logit_value=50, verbose=False)
print(avg_corner.norm().item(), corner_estimator.get_vocab_representation(avg_corner))

In [None]:
def P17__check_with_test_cases(relation_operator):

    test_cases = [
        ("The Great Wall", -1, "China"),
        ("Niagara Falls", -2, "Canada"),
        ("Valdemarsvik", -1, "Sweden"),
        ("Kyoto University", -2, "Japan"),
        ("Hattfjelldal", -1, "Norway"),
        ("Ginza", -1, "Japan"),
        ("Sydney Hospital", -2, "Australia"),
        ("Mahalangur Himal", -1, "Nepal"),
        ("Higashikagawa", -1, "Japan"),
        ("Trento", -1, "Italy"),
        ("Taj Mahal", -1, "India"),
        ("Hagia Sophia", -1, "Turkey"),
        ("Colosseum", -1, "Italy"),
        ("Mount Everest", -1, "Nepal"),
        ("Valencia", -1, "Spain"),
        ("Lake Baikal", -1, "Russia"),
        ("Merlion Park", -1, "Singapore"),
        ("Cologne Cathedral", -1, "Germany"),
        ("Buda Castle", -1, "Hungary")
    ]

    for subject, subject_token_index, target in test_cases:
        objects = relation_operator(
            subject,
            subject_token_index=subject_token_index,
            device=model.device,
            return_top_k=5,
        )
        print(f"{subject}, target: {target}   ==>   predicted: {objects}")

In [None]:
relation = estimate.RelationOperator(
    model = model,
    tokenizer = tokenizer,
    relation = '{} is located in the country of',
    layer = 15,
    weight = torch.eye(model.config.n_embd).to(model.dtype).to(model.device),
    bias = simple_corner
)
P17__check_with_test_cases(relation)

In [None]:
relation = estimate.RelationOperator(
    model = model,
    tokenizer = tokenizer,
    relation = '{} is located in the country of',
    layer = 15,
    weight = torch.eye(model.config.n_embd).to(model.dtype).to(model.device),
    bias = lin_inv_corner
)
P17__check_with_test_cases(relation)

In [None]:
relation = estimate.RelationOperator(
    model = model,
    tokenizer = tokenizer,
    relation = '{} is located in the country of',
    layer = 15,
    weight = torch.eye(model.config.n_embd).to(model.dtype).to(model.device),
    bias = grad_dsc_corner
)
P17__check_with_test_cases(relation)

In [None]:
relation = estimate.RelationOperator(
    model = model,
    tokenizer = tokenizer,
    relation = '{} is located in the country of',
    layer = 15,
    weight = torch.eye(model.config.n_embd).to(model.dtype).to(model.device),
    bias = avg_corner
)
P17__check_with_test_cases(relation)