In [None]:
import torch
import numpy as np
import json
from transformers import AutoTokenizer
import networkx as nx
import random
import matplotlib.pyplot as plt
from tqdm import tqdm
import hierarchical as hrc
import einops

device = torch.device("cuda:0")
MODEL_NAME = "google/gemma-2b"
g, _, sqrt_Cov_gamma = hrc.get_g(MODEL_NAME, device)
vocab_dict, vocab_list = hrc.get_vocab(MODEL_NAME)

In [2]:
cats, G, sorted_keys = hrc.get_categories('noun', 'gemma')
vec_reps = torch.load('FILE_PATH')

In [12]:
def get_logit_diff(features, vector):
    inds = {cat: [vocab_dict[t] for t in cats[cat]] for cat in features}
    unembed = {cat: g[inds[cat]] for cat in features}

    diff = unembed[features[1]].unsqueeze(1) -  unembed[features[0]].unsqueeze(0)
    logits = diff @ vector
    return logits.flatten()


def show_logit_diff(parents, children):
    vector = vec_reps['original']['non_split'][parents[1]]['lda'] -  vec_reps['original']['non_split'][parents[0]]['lda']
    vector = vector / torch.norm(vector)

    child_diff = get_logit_diff(children, vector)
    parent_diff = get_logit_diff(parents, vector)

    print(f"child: {child_diff.mean().item():.4f} \\pm {child_diff.std().item():.4f}")
    print(f"parent: {parent_diff.mean().item():.4f} \\pm {parent_diff.std().item():.4f}")

In [13]:
parents = ['plant.n.02', 'animal.n.01']
children = ['mammal.n.01', 'reptile.n.01']

show_logit_diff(parents, children)

child: -0.0600 \pm 1.2190
parent: 5.1265 \pm 1.1731


In [14]:
parents = ['fluid.n.02', 'solid.n.01']
children = ['crystal.n.01', 'food.n.02']

show_logit_diff(parents, children)

child: 0.3770 \pm 1.5410
parent: 9.8296 \pm 1.1099


In [15]:
parents = ['scientist.n.01', 'contestant.n.01']
children = ['athlete.n.01', 'player.n.01']

show_logit_diff(parents, children)

child: -0.1545 \pm 1.1426
parent: 14.4222 \pm 0.9458
