In [1]:
import torch
import numpy as np
from transformers import AutoTokenizer, AutoModelForCausalLM
from tqdm import tqdm
import dotenv

model_name = "google/gemma-2b"
g_file_path = "binary/g_gemma-2b.bin"
space_char = "_"

### load model ###
device = torch.device("cuda:0")
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name,
                                             torch_dtype=torch.float32,
                                             device_map="auto")

### load unembdding vectors ###
gamma = model.get_output_embeddings().weight.detach()
W, d = gamma.shape
gamma_bar = torch.mean(gamma, dim = 0)
centered_gamma = gamma - gamma_bar

### compute Cov(gamma) and tranform gamma to g ###
Cov_gamma = centered_gamma.T @ centered_gamma / W
eigenvalues, eigenvectors = torch.linalg.eigh(Cov_gamma)
inv_sqrt_Cov_gamma = eigenvectors @ torch.diag(1/torch.sqrt(eigenvalues)) @ eigenvectors.T
sqrt_Cov_gamma = eigenvectors @ torch.diag(torch.sqrt(eigenvalues)) @ eigenvectors.T
g = centered_gamma @ inv_sqrt_Cov_gamma

import torch
import numpy as np
import json
from transformers import AutoTokenizer
import networkx as nx
import matplotlib.pyplot as plt

from mpl_toolkits.mplot3d import Axes3D
from mpl_toolkits.mplot3d.art3d import Poly3DCollection

import warnings
warnings.filterwarnings('ignore')

vocab_dict = tokenizer.get_vocab() # token: index
vocab_list = [None] * (max(vocab_dict.values()) + 1)
for word, index in vocab_dict.items():
    vocab_list[index] = word

import plotly.express as px
import plotly.graph_objects as go

`config.hidden_act` is ignored, you should use `config.hidden_activation` instead.
Gemma's activation function will be set to `gelu_pytorch_tanh`. Please, use
`config.hidden_activation` if you want to override this behaviour.
See https://github.com/huggingface/transformers/pull/29402 for more details.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [6]:
data = {
    'joy': ['mirth', 'thrill', 'bliss', 'relief', 'admiration', 'affection', 'serenity', 'inspiration', 'gladness', 'adoration', 'delight', 'love', 'hilarity', 'buoyancy', 'gaiety', 'zeal', 'vibrancy', 'compassion', 'fulfillment', 'exhilaration', 'happiness', 'gratitude', 'pride', 'triumph', 'tenderness', 'zest', 'rapture', 'euphoria', 'glee', 'blissfulness', 'enthusiasm', 'cheerfulness', 'pleasure', 'excitement', 'hopefulness', 'joviality', 'carefree', 'satisfaction', 'elation', 'lightheartedness', 'comfort', 'contentment', 'ecstasy', 'warmth', 'awe', 'tranquility', 'radiance', 'jubilation', 'playfulness', 'optimism', 'wonder', 'vivacity', 'fondness', 'amusement'], 
    'sadness': ['dejection', 'anguish', 'nostalgia', 'melancholy', 'despondency', 'desperation', 'grieving', 'guilt', 'miserableness', 'gloom', 'loneliness', 'discouragement', 'defeatism', 'hollowness', 'listlessness', 'mourning', 'desolation', 'woe', 'pessimism', 'tearfulness', 'apathy', 'homesickness', 'abandonment', 'numbness', 'sulkiness', 'bleakness', 'yearning', 'pining', 'heartbreak', 'resentment', 'alienation', 'grief', 'regret', 'wistfulness', 'disillusionment', 'sorrow', 'lethargy', 'resignation', 'heartache', 'world-weariness', 'disappointment', 'emptiness', 'depression', 'despair', 'hopelessness', 'isolation', 'longing', 'weariness', 'remorse', 'shame', 'forlornness', 'bitterness', 'misery', 'blues'], 
    'anger': ['displeasure', 'spite', 'irritation', 'disdain', 'disgruntlement', 'rage', 'ire', 'wrath', 'competitiveness', 'hostility', 'fury', 'chagrin', 'petulance', 'indignation', 'exasperation', 'malice', 'contempt', 'venom', 'irascibility', 'ferocity', 'scorn', 'enmity', 'antagonism', 'rivalry', 'irksomeness', 'crabbiness', 'jealousy', 'disgust', 'resentment', 'frustration', 'grumpiness', 'vindictiveness', 'vengefulness', 'animosity', 'antipathy', 'violence', 'loathing', 'outrage', 'cantankerousness', 'envy', 'hatred', 'temperament', 'vexation', 'combativeness', 'aggression', 'pique', 'agitation', 'grudge', 'bitterness', 'huffiness', 'belligerence', 'annoyance'], 
    'fear': ['nervousness', 'paranoia', 'discomfort', 'helplessness', 'restlessness', 'aversion', 'alertness', 'alarm', 'uncertainty', 'foreboding', 'fearfulness', 'panic', 'confusion', 'shock', 'worry', 'insecurity', 'disorientation', 'stress', 'trepidation', 'shyness', 'overwhelm', 'dread', 'startlement', 'revulsion', 'fright', 'self-doubt', 'terror', 'squeamishness', 'phobia', 'distress', 'inadequacy', 'reluctance', 'mistrust', 'anxiety', 'wariness', 'timidity', 'indecision', 'angst', 'hesitation', 'unease', 'apprehension', 'tension', 'vulnerability', 'horror', 'jitters', 'agitation', 'presentiment', 'dismay', 'caution', 'bewilderment', 'suspicion'], 
    'surprise': ['enthrallment', 'unexpectedness', 'revitalization', 'inquisitiveness', 'rejuvenation', 'discovery', 'unpredictability', 'stimulation', 'confusion', 'marvel', 'shock', 'disbelief', 'jolting', 'skepticism', 'dubiety', 'epiphany', 'incredulity', 'puzzlement', 'stupor', 'revelation', 'novelty', 'disorientation', 'amazement', 'realization', 'eureka', 'refreshment', 'stunned', 'awe-struck', 'dumbfounded', 'astonied', 'astonishment', 'startlement', 'excitement', 'fascination', 'animation', 'engrossment', 'intrigue', 'awe', 'flabbergasted', 'captivation', 'vivification', 'startled', 'wonderment', 'staggered', 'wonder', 'eye-opening', 'bewilderment', 'unfamiliarity', 'arousal', 'stupefaction', 'perplexity', 'curiosity', 'invigoration'], 
    'disgust': ['detestation', 'displeasure', 'prudishness', 'disdain', 'aversion', 'bothered', 'perturbed', 'sneering', 'pretentiousness', 'sanctimoniousness', 'mockery', 'contempt', 'scorn', 'dislike', 'arrogance', 'fussiness', 'repulsion', 'snobbishness', 'offensiveness', 'nausea', 'derision', 'abhorrence', 'disapprobation', 'sarcasm', 'disapproval', 'condescension', 'queasiness', 'offense', 'pickiness', 'alienation', 'superciliousness', 'revulsion', 'repugnance', 'finickiness', 'squeamishness', 'self-righteousness', 'antipathy', 'fastidiousness', 'repulsiveness', 'loathing', 'rejection', 'sickness', 'outrage', 'haughtiness', 'primness', 'obnoxiousness', 'pomposity', 'ridicule', 'egotism', 'affectation', 'superiority', 'repellence', 'nauseousness', 'vanity', 'distaste', 'cynicism']
    }

categories = list(data.keys())

for category in categories:
    print(category, len(data[category]))

joy 54
sadness 54
anger 52
fear 51
surprise 53
disgust 56


In [7]:
import inflect
p = inflect.engine()

def noun_to_gemma_vocab_elements(word, vocab_set, space_char: str = "▁"):
    word = word.lower()
    plural = p.plural(word)
    add_cap_and_plural = [word, word.capitalize(), plural, plural.capitalize()]
    add_space = [space_char + w for w in add_cap_and_plural]
    return vocab_set.intersection(add_space)

def get_emotion_category(data, categories, vocab_dict, g, space_char: str = "▁"):
    vocab_set = set(vocab_dict.keys())

    emotions = {}
    emotions_ind = {}
    emotions_g = {}
    emotions_token = {}

    for category in categories:
        emotions[category] = []
        emotions_ind[category] = []
        emotions_token[category] = []
        emotions_g[category] = []

    for category in categories:
        lemmas = data[category]
        for w in lemmas:
            emotions[category].extend(noun_to_gemma_vocab_elements(w, vocab_set, space_char=space_char))
        
        for word in emotions[category]:
            emotions_ind[category].append(vocab_dict[word])
            emotions_token[category].append(word)
            emotions_g[category] = g[emotions_ind[category]]
    return emotions_token, emotions_ind, emotions_g

In [8]:
emotions_token, emotions_ind, emotions_g = get_emotion_category(data, categories,  vocab_dict, g)

In [9]:
emotions_g['anger'].shape

torch.Size([47, 2048])

In [10]:
from sklearn.covariance import ledoit_wolf

def category_to_indices(category, vocab_dict):
    return [vocab_dict[w] for w in category]

def get_words_sim_to_vec(query: torch.tensor, unembed, vocab_list, k=300):
    similar_indices = torch.topk(unembed @ query, k, largest=True).indices.cpu().numpy()
    return [vocab_list[idx] for idx in similar_indices]

def estimate_single_dir_from_embeddings(category_embeddings):
    category_mean = category_embeddings.mean(dim=0)
    # print(f"category_mean shape: {category_mean.shape}")

    cov = ledoit_wolf(category_embeddings.cpu().numpy())
    cov = torch.tensor(cov[0], device = category_embeddings.device)
    # print(f"cov shape: {cov.shape}")
    pseudo_inv = torch.linalg.pinv(cov)
    lda_dir = pseudo_inv @ category_mean
    lda_dir = lda_dir / torch.norm(lda_dir)
    # print(f"lda_dir shape: {lda_dir.shape}")
    lda_dir = (category_mean @ lda_dir) * lda_dir

    return lda_dir, category_mean

def estimate_cat_dir(category_lemmas, unembed, vocab_dict):
    category_embeddings = unembed[category_to_indices(category_lemmas, vocab_dict)]
    # print(category_embeddings.shape)
    lda_dir, category_mean = estimate_single_dir_from_embeddings(category_embeddings)
    
    return {'lda': lda_dir, 'mean': category_mean}

In [11]:
dirs = {k: estimate_cat_dir(v, g, vocab_dict) for k, v in emotions_token.items()}

In [13]:
all_emotions_tokens = [a for k, v in emotions_token.items() for a in v]
dirs.update({'emotion': estimate_cat_dir(all_emotions_tokens, g, vocab_dict)})
emotions_token.update({'emotion': all_emotions_tokens})

In [14]:
for k, v in dirs.items():
    print(k, v['lda'].shape)

joy torch.Size([2048])
sadness torch.Size([2048])
anger torch.Size([2048])
fear torch.Size([2048])
surprise torch.Size([2048])
disgust torch.Size([2048])
emotion torch.Size([2048])


In [15]:
categories = categories + ['emotion', 'all']

In [16]:
def get_cat_lemma_dirs(category_lemmas, unembed, vocab_dict):
    category_embeddings = unembed[category_to_indices(category_lemmas, vocab_dict)]
    
    return category_embeddings

In [26]:
import random
# randomly sample 10000 tokens from vocab_list
emotions_token['all'] = random.sample(vocab_list, 10000)

In [108]:
indices = ["emotion", "sadness"]

In [109]:
inds0 = {idx: category_to_indices(emotions_token[idx], vocab_dict) for idx in indices}

dir1, dir2 = (dirs[idx]["lda"] for idx in indices)

In [110]:
cat_lemma_dirs = {
    category: get_cat_lemma_dirs(emotions_token[category], g, vocab_dict)
    for category in categories
}

In [111]:
def process_dirs(dir1, dir2):
    # normalize and orthogonalize
    dir1 = dir1 / torch.norm(dir1)
    dir2 = dir2 - (dir2 @ dir1) * dir1
    dir2 = dir2 / torch.norm(dir2)
    return dir1, dir2

In [112]:
dir1, dir2 = process_dirs(dir1, dir2)

In [113]:
categories

['joy', 'sadness', 'anger', 'fear', 'surprise', 'disgust', 'emotion', 'all']

In [114]:
# give appropriate colors to each category
colors_hex = {
    "joy": "#FFFF00",        # yellow
    "sadness": "#0000FF",    # blue
    "anger": "#FF0000",      # red
    "fear": "#800080",       # purple
    "surprise": "#008000",   # green
    "disgust": "#FFA500",    # orange
    "emotion": "#000000",    # black
    "all": "#808080"         # gray
}

In [115]:
categories_to_plot = ["emotion", "fear", "all", "anger", "disgust", "joy", "sadness", "surprise"]
# categories_to_plot = ["emotion", "fear", "all"]

In [116]:
x = {category: (cat_lemma_dirs[category] @ dir1).cpu().numpy() for category in categories_to_plot}
y = {category: (cat_lemma_dirs[category] @ dir2).cpu().numpy() for category in categories_to_plot}

dir1 = dir1.cpu().numpy()
dir2 = dir2.cpu().numpy()

In [117]:
fig = go.Figure()

for category in categories_to_plot:
    fig.add_trace(go.Scatter(x=x[category], y=y[category], mode='markers', name=category,
                             marker=dict(size=5 if category != "all" else 1, opacity=1 if category != "all" else 0.3, color=colors_hex[category])))

fig.update_layout(plot_bgcolor='white', xaxis=dict(showgrid=True, gridcolor='black', gridwidth=0.1), yaxis=dict(showgrid=True, gridcolor='black', gridwidth=0.1))

fig.update_xaxes(dtick=1)
fig.update_yaxes(dtick=1)

fig.update_xaxes(zeroline=True, zerolinecolor='black', zerolinewidth=3)
fig.update_yaxes(zeroline=True, zerolinecolor='black', zerolinewidth=3)

fig.update_xaxes(title_text=indices[0])
fig.update_yaxes(title_text=indices[1])

lim = 13
fig.update_xaxes(range=[-lim, lim])
fig.update_yaxes(range=[-lim, lim])

fig.update_layout(width=800, height=800)

fig.show()

## Random Data

In [119]:
data = {
    "moon": ["toaster", "penguin", "jelly", "cactus", "submarine", "marble", "accordion", "cheese", "volcano", "pepper", "trombone", "zebra", "bicycle", "napkin", "iguana", "waterfall", "spatula", "pyramid", "cloud", "lantern", "shoebox", "squid", "rubber", "iceberg", "chocolate", "boomerang", "spaghetti", "eraser", "ladle", "tsunami", "windmill", "comb", "raven", "alarm", "cucumber", "lightning", "quartz", "flip-flop", "parachute", "snowman", "bottle", "puzzle"],
    "elephant": ["sandwich", "yo-yo", "plank", "rainbow", "monocle", "snowflake", "ladder", "coconut", "spaceship", "umbrella", "pebble", "keyboard", "straw", "pineapple", "screw", "mountain", "seashell", "whisk", "blender", "radio", "crater", "toothpaste", "zeppelin", "otter", "paintbrush", "sundial", "mirrorball", "bonsai", "handkerchief", "muffin", "telescope", "pothole", "skillet", "salmon", "glue", "jungle", "elevator", "compass", "butter", "drumstick", "satellite"],
    "pillow": ["kiwi", "tornado", "chopstick", "helicopter", "sunflower", "giraffe", "whistle", "bookcase", "crayon", "dragon", "pogo-stick", "barnacle", "treasure", "firefly", "dinosaur", "shoelace", "lava", "violin", "mailbox", "saxophone", "parrot", "paperclip", "tinsel", "whirlpool", "cookie", "speedboat", "broomstick", "potato", "scarecrow", "goblin", "clownfish", "icecream", "fossil", "flagpole", "firework", "basket", "thimble", "beacon", "spoon", "ostrich", "canoe"],
    "donut": ["ocean", "microscope", "tiger", "pasta", "umbrella", "helicopter", "mirror", "popcorn", "dolphin", "soap", "sailboat", "tangerine", "grapefruit", "pencil", "chandelier", "toothbrush", "igloo", "skateboard", "canoe", "lava-lamp", "mango", "suitcase", "spider", "bubble", "sprinkler", "yo-yo", "helmet", "carousel", "sandpaper", "hurricane", "microphone", "pogo-stick", "flamingo", "matchstick", "grape", "shoelace", "hedgehog", "rocket", "geyser", "yo-yo", "quicksand"],
    "fridge": ["banjo", "skyscraper", "avocado", "sphinx", "teacup", "hammock", "kite", "gorilla", "pinecone", "marshmallow", "carpet", "lighthouse", "moose", "cookie", "train", "firetruck", "jigsaw", "tulip", "skeleton", "bobsled", "coconut", "lamppost", "kangaroo", "pear", "raccoon", "toolbox", "bubblegum", "picnic", "pottery", "spoon", "toadstool", "velvet", "cloud", "sunscreen", "bathtub", "apple", "lawnmower", "bat", "harp", "airplane", "pickle"]
}


categories = list(data.keys())

for category in categories:
    print(category, len(data[category]))

nonsenses_token, nonsenses_ind, nonsenses_g = get_emotion_category(data, categories,  vocab_dict, g)
dirs = {k: estimate_cat_dir(v, g, vocab_dict) for k, v in nonsenses_token.items()}
all_nonsenses_tokens = [a for k, v in nonsenses_token.items() for a in v]
dirs.update({'nonsense': estimate_cat_dir(all_nonsenses_tokens, g, vocab_dict)})
nonsenses_token.update({'nonsense': all_nonsenses_tokens})
categories = categories + ['nonsense', 'all']
nonsenses_token['all'] = random.sample(vocab_list, 10000)

moon 42
elephant 41
pillow 41
donut 41
fridge 41


In [120]:
indices = ["nonsense", "moon"]

In [121]:
inds0 = {idx: category_to_indices(nonsenses_token[idx], vocab_dict) for idx in indices}

dir1, dir2 = (dirs[idx]["lda"] for idx in indices)

In [122]:
cat_lemma_dirs = {
    category: get_cat_lemma_dirs(nonsenses_token[category], g, vocab_dict)
    for category in categories
}

In [123]:
dir1, dir2 = process_dirs(dir1, dir2)

In [128]:
colors_hex = {
    "moon": "#FFFF00",        # yellow
    "elephant": "#0000FF",    # blue
    "pillow": "#FF0000",      # red
    "donut": "#800080",       # purple
    "fridge": "#008000",      # green
    "nonsense": "#000000",    # black
    "all": "#808080"          # gray
}

In [124]:
categories_to_plot = ["nonsense", "moon", "all", "elephant", "pillow", "donut", "fridge"]
# categories_to_plot = ["emotion", "fear", "all"]

In [125]:
x = {category: (cat_lemma_dirs[category] @ dir1).cpu().numpy() for category in categories_to_plot}
y = {category: (cat_lemma_dirs[category] @ dir2).cpu().numpy() for category in categories_to_plot}

dir1 = dir1.cpu().numpy()
dir2 = dir2.cpu().numpy()

In [131]:
fig = go.Figure()

for category in categories_to_plot:
    fig.add_trace(go.Scatter(x=x[category], y=y[category], mode='markers', name=category,
                             marker=dict(size=5 if category != "all" else 1, opacity=1 if category != "all" else 0.8, color=colors_hex[category])))

fig.update_layout(plot_bgcolor='white', xaxis=dict(showgrid=True, gridcolor='black', gridwidth=0.1), yaxis=dict(showgrid=True, gridcolor='black', gridwidth=0.1))

fig.update_xaxes(dtick=1)
fig.update_yaxes(dtick=1)

fig.update_xaxes(zeroline=True, zerolinecolor='black', zerolinewidth=3)
fig.update_yaxes(zeroline=True, zerolinecolor='black', zerolinewidth=3)

fig.update_xaxes(title_text=indices[0])
fig.update_yaxes(title_text=indices[1])

lim = 8
fig.update_xaxes(range=[-lim, lim])
fig.update_yaxes(range=[-lim, lim])

fig.update_layout(width=800, height=800)

fig.show()