# Setup

In [8]:
!git clone https://github.com/jmerullo/lm_vector_arithmetic.git

Cloning into 'lm_vector_arithmetic'...
remote: Enumerating objects: 35, done.[K
remote: Counting objects: 100% (35/35), done.[K
remote: Compressing objects: 100% (30/30), done.[K
remote: Total 35 (delta 10), reused 21 (delta 4), pack-reused 0[K
Receiving objects: 100% (35/35), 236.07 KiB | 3.93 MiB/s, done.
Resolving deltas: 100% (10/10), done.


In [9]:
cd lm_vector_arithmetic

/content/lm_vector_arithmetic


In [None]:
%%capture
!pip install -r requirements.txt

In [5]:
# !pip install transformers

# import torch
# import transformers
# from transformers import AutoTokenizer, AutoModelForCausalLM
# import matplotlib.pyplot as plt

# def load_gpt2(version):
#     device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
#     tokenizer = AutoTokenizer.from_pretrained(version)
#     model = AutoModelForCausalLM.from_pretrained(version, torch_dtype=torch.float16).to(device)
#     return model, tokenizer

In [10]:
from modeling import *  #GPT2Wrapper, ModelWrapper

In [11]:
model, tokenizer = load_gpt2('gpt2')
model = model.float()
wrapper = GPT2Wrapper(model, tokenizer)

# Functions

In [12]:
def tokenize(text):
    inp_ids = wrapper.tokenize(text)
    str_toks = wrapper.list_decode(inp_ids[0])
    return inp_ids, str_toks

def get_tokIDS(test_text):
    test_ids, test_toks = tokenize(test_text)
    logits = wrapper.get_layers(test_ids)
    wrapper.print_top(logits[1:]) #skip the first embedding layer b/c it just reverses to input
    return test_ids

In [13]:
def create_layer_list(X, Y):
    layer_list = []
    for layer in range(X, Y+1):
        layer_list.append("Layer {}".format(layer))
    return layer_list

from tabulate import tabulate
def create_table(list1, list2):
    table_data = zip(list1, list2)
    headers = ["Layer", "Orig Top Token"]
    table = tabulate(table_data, headers, tablefmt="grid")
    return table

import torch.nn.functional as F
def get_decoded(logits, k=10):
    output_list = []
    for i,layer in enumerate(logits):
        output_list.append( wrapper.tokenizer.decode(F.softmax(layer,dim=-1).argsort(descending=True)[:k]) )
    return output_list

def format_to_table(tok_ids):
    out = wrapper.model(input_ids = tok_ids, output_hidden_states=True)
    logits = out.logits
    hidden_states = out.hidden_states
    hidden_states = list(hidden_states)[1:]

    orig_layer_logits = wrapper.layer_decode(hidden_states)
    orig_layer_logits = torch.stack(orig_layer_logits).squeeze(-1)

    orig_decoded = get_decoded(orig_layer_logits)
    table_output = create_table(create_layer_list(0, 23), orig_decoded)
    print(table_output)

# Test Template

In [14]:
test_text = """1 2 3 4"""
test_ids = get_tokIDS(test_text)

0 thteenthx 4 3teen3454 239
1 thteenth54x3934teen384174
2 thx3454GHz388639teen74
3 3454thGHz86ts743979x
4 tsms34thths54 -iel86 4
5 ts 4 - 3 5ths 6 +ms34
6  4★ 3 5 6 2 Tycoon >>> +ts
7  4 5 6 3 0 2 Tycoon 1ths 8
8  4 5 3 6 1 2 8 7★ 9
9  5 4 6 3 75 1 9 0★
10  5 6 4 3 7 1 05 9/
11  5 4 1 6 3 0/ 7
5


In [15]:
format_to_table(test_ids)

+----------+---------------------------+
| Layer    | Orig Top Token            |
| Layer 0  | thteenthx 4 3teen3454 239 |
+----------+---------------------------+
| Layer 1  | thteenth54x3934teen384174 |
+----------+---------------------------+
| Layer 2  | thx3454GHz388639teen74    |
+----------+---------------------------+
| Layer 3  | 3454thGHz86ts743979x      |
+----------+---------------------------+
| Layer 4  | tsms34thths54 -iel86 4    |
+----------+---------------------------+
| Layer 5  | ts 4 - 3 5ths 6 +ms34     |
+----------+---------------------------+
| Layer 6  | 4★ 3 5 6 2 Tycoon >>> +ts |
+----------+---------------------------+
| Layer 7  | 4 5 6 3 0 2 Tycoon 1ths 8 |
+----------+---------------------------+
| Layer 8  | 4 5 3 6 1 2 8 7★ 9        |
+----------+---------------------------+
| Layer 9  | 5 4 6 3 75 1 9 0★         |
+----------+---------------------------+
| Layer 10 | 5 6 4 3 7 1 05 9/         |
+----------+---------------------------+
| Layer 11 | 5 4

# Test Template

In [16]:
test_text = """January February March April"""
test_ids = get_tokIDS(test_text)

0  2015 May 2014 April June Aug 2013 July 2017 2016
1  2015 2014 2017 2013 2016 29 2018 27 May 24
2  2015 2014 29 2017 28 27 30 2013 2018 2016
3  29 2015 27 28 30 24 2014 25 26 23
4  24 29 2014 2015 2018 2017 27 30 28 26
5  29 24 30 28 27 4 3 26 2014 6
6  29 24 30 2014 28 26 2018 31 27 2015
7  2014 29 April 24 2015 2018 2019 2013 4 28
8  April 2014 May March 2015 June September 2017 October Apr
9  May April June 2015 5 July 2005 September March November
10  May June July AprilMay November September 5 March October
11  May June July 5
 April - MarchMay November


In [17]:
format_to_table(test_ids)

+----------+-----------------------------------------------------------+
| Layer    | Orig Top Token                                            |
| Layer 0  | 2015 May 2014 April June Aug 2013 July 2017 2016          |
+----------+-----------------------------------------------------------+
| Layer 1  | 2015 2014 2017 2013 2016 29 2018 27 May 24                |
+----------+-----------------------------------------------------------+
| Layer 2  | 2015 2014 29 2017 28 27 30 2013 2018 2016                 |
+----------+-----------------------------------------------------------+
| Layer 3  | 29 2015 27 28 30 24 2014 25 26 23                         |
+----------+-----------------------------------------------------------+
| Layer 4  | 24 29 2014 2015 2018 2017 27 30 28 26                     |
+----------+-----------------------------------------------------------+
| Layer 5  | 29 24 30 28 27 4 3 26 2014 6                              |
+----------+---------------------------------------