In [86]:
# Imports
import torch
from transformers import AutoTokenizer, GPT2LMHeadModel, GPT2Model
import numpy as np
import itertools

In [87]:
# Initialize Tokenizer and Model
tokenizer = AutoTokenizer.from_pretrained("gpt2")
model = GPT2LMHeadModel.from_pretrained("gpt2")

In [89]:
# Define Embedding Matrix
embeddings = model.transformer.wte.weight
emb_size = embeddings.size(dim=1)

In [90]:
# Define Word Bank
word_bank = ["academy", "advance", "aircraft", "ally", "ammo", "ammunition", "armor", "arms", "army", "arrow", "arsenal", "artillery", "attack", "attention", "ballistic", "barracks", "base", "battalion", "battery", "battle", "battlefield", "bomb", "bombard", "bombardment", "brig", "brigade", "bullet", "camouflage", "camp", "cannon", "captain", "capture", "carrier", "casualty", "catapult", "cavalry", "colonel", "combat", "command", "commander", "commission", "company", "conflict", "conquest", "convoy", "corps", "covert", "crew", "decode", "defeat", "defend", "defense", "destroyer", "division", "draft", "encode", "enemy", "engage", "enlist", "evacuate", "explosive", "fight", "fire", "fleet", "force", "formation", "fort", "front", "garrison", "general", "grenade", "grunt", "guerrilla", "gun", "headquarters", "helmet", "honor", "hospital", "infantry", "injury", "intelligence", "invade", "invasion", "jet", "kill", "leave", "lieutenant", "major", "maneuver", "marines", "MIA", "mid", "military", "mine", "missile", "mortar", "navy", "neutral", "offense", "officer", "ordinance", "parachute", "peace", "plane", "platoon", "private", "radar", "rank", "recruit", "regiment", "rescue", "reserves", "retreat", "ribbon", "sabotage", "sailor", "salute", "section", "sergeant", "service", "shell", "shoot", "shot", "siege", "sniper", "soldier", "spear", "specialist", "squad", "squadron", "staff", "submarine", "surrender", "tactical", "tactics", "tank", "torpedo", "troops", "truce", "uniform", "unit", "veteran", "volley", "war", "warfare", "warrior", "weapon", "win", "wound"]

In [146]:
word_bank

In [149]:
# Get respective tokens 

tokens = tokenizer(word_bank)["input_ids"]
tokens = list(itertools.chain.from_iterable(tokens))
len(tokenizer.tokenize(test))

151

In [121]:
# Get embeddings for word bank
wb_embeddings = embeddings[tokens]
wb_embeddings.size()

torch.Size([268, 768])

In [122]:
# Define Softmax function
# (Taken from class assignment 2)
def softmax(x):
    orig_shape = x.shape

    if len(x.shape) > 1:
        # Matrix
        tmp = np.max(x, axis=1)
        x -= tmp.reshape((x.shape[0], 1))
        x = np.exp(x)
        tmp = np.sum(x, axis=1)
        x /= tmp.reshape((x.shape[0], 1))
    else:
        # Vector
        tmp = np.max(x)
        x -= tmp
        x = np.exp(x)
        tmp = np.sum(x)
        x /= tmp

    assert x.shape == orig_shape
    return x

In [141]:
word = "hi"
token = tokenizer(word)["input_ids"]
emb = embeddings[token[0]]
print(torch.matmul(emb, wb_embeddings[5]).sum())
print(torch.linalg.norm(wb_embeddings[5] - emb, ord=1))

tensor(3.8325, grad_fn=<SumBackward0>)
tensor(106.4792, grad_fn=<LinalgVectorNormBackward0>)


In [104]:
prompt = "In the garden, I usually"

embeddings = model.transformer.wte.weight   # replace this matrix with the matrix we make

# currently only supports cluster size of 1 (no looping/averaging)
cluster = "hate"
cluster_tokens = tokenizer(cluster, return_tensors="pt")['input_ids'][0]
cluster_embedding = embeddings[cluster_tokens]# if we had multiple members, would have to average here

top_k_val = 5  # use top-p instead

inputs = tokenizer(prompt, return_tensors="pt")
outputs = model(**inputs, labels=inputs["input_ids"])
loss = outputs.loss
logits = outputs.logits
#hidden_states = outputs.last_hidden_state
#print(hidden_state)
next_token_scores = logits[:, -1, :].softmax(dim=-1)
# change values in next_token_scores[0] and then torch.sort to get indices
#print(next_token_scores.shape)



sorted_vals, indices = torch.sort(next_token_scores[0])
# generate embedding for each index position
# have 3 vectors, sorted_vals, indices, and embeddings. 
# turn embedding column into distance to cluster column
# merge sorted_vals and embeddings columns (re-weight)
# re-sort indices according to sorted_vals weights
sorted_vals = sorted_vals[-top_k_val:]
indices = indices[-top_k_val:]
#print(indices)
top_embeddings = embeddings[indices]
#print(top_embeddings[0])
#print(cluster_embedding)

dist_score = [torch.linalg.norm(embed-cluster_embedding) for embed in top_embeddings]

hyper_weight = .5

checkpoint = sorted_vals.detach().clone()
for i in range(len(sorted_vals)):
    sorted_vals[i] += (1/dist_score[i])*hyper_weight

sort_indices = torch.argsort(sorted_vals)
final_ranked_indices = [indices[s] for s in sort_indices]
# best result is at the back of indices



#print(sorted_vals)
#print(sort_indices)
print("Original: ")
for idx in range(1, top_k_val+1):
    #print(tokenizer.decode(indices[-idx]))
    print(f'{checkpoint[-idx]:5f} | {tokenizer.decode(indices[-idx]):8s}')

print()
print("After Weighting: ")
s_vals = sorted_vals[sort_indices]
for idx in range(1, len(final_ranked_indices)+1):
    #print(s_vals)
    #print(idx)
    print(f'{s_vals[-idx]:5f} | {tokenizer.decode(final_ranked_indices[-idx]):8s}')
print()
print([tokenizer.decode(word) for word in final_ranked_indices])
print(s_vals)

#next_token = next_token_scores.argmax().unsqueeze(0).unsqueeze(0)
#print(tokenizer.decode(next_token[0]))

Original: 
0.076533 |  have   
0.035735 |  use    
0.035064 |  get    
0.030945 |  take   
0.029046 |  find   

After Weighting: 
0.187508 |  have   
0.145201 |  get    
0.139643 |  take   
0.139212 |  use    
0.138318 |  find   

[' find', ' use', ' take', ' get', ' have']
tensor([0.1383, 0.1392, 0.1396, 0.1452, 0.1875], grad_fn=<IndexBackward0>)


In [81]:
"""
word = "he left"
inputs = tokenizer(word, return_tensors="pt")

gpt2_model = GPT2Model.from_pretrained("gpt2")
outputs = gpt2_model(**inputs)

#word_embedding = outputs.last_hidden_state[:, -1, :]
print(inputs['input_ids'])


#print(embeddings[inputs['input_ids'][0][0]].shape)
#print(word_embedding[0].shape)
embeddings = []
for i in range(len(inputs['input_ids'][0])):
    word_embedding = outputs.last_hidden_state[i][0]
    embeddings.append(word_embedding)

print(embeddings)
    
#print(word_embedding[0])
#print(embeddings[inputs['input_ids'][0][0], :])
print(inputs['input_ids'])
print(tokenizer.decode(inputs['input_ids'][0][0]))
"""

AttributeError: 'list' object has no attribute 'size'