In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
import pickle
import numpy as np
import pandas as pd
from algo.topics import single_count_setup
from algo.utils import index_vocab
from gpt2.shortlist import shortlist_decode

In [5]:
'''
Code to download GPT-2
'''
# from gpt2.utils import download_gpt2_files
# for size in ["124M", "355M", "774M", "1558M"]:
#     download_gpt2_files(size, "./models")

In [6]:
model_dir = "./models"
model_size = "1558M"

graph_dir = "/data/jiapeng/wiki/final_2/10"
single_count_path = '/data/jiapeng/wiki/final_2/10_single.pkl'
hist_file = "/data/jiapeng/wiki/histogram.csv"
wiki_vocab = pickle.load(open("/data/jiapeng/wiki/vocab.pkl", 'rb'))
min_freq = 100
window_size = 10

wiki_vocab2id, wiki_id2vocab = index_vocab(list(wiki_vocab.keys()))
shortlist_decoder = shortlist_decode(wiki_vocab2id, model_size, 
                                          model_dir, extra_vocab=[])

wiki_num_windows, wiki_single_prob = single_count_setup(hist_file, 
                                                        single_count_path,
                                                        window_size, min_freq)

In [8]:
'''
We obtain the logit distribution by activating a neuron singly
In this example, we do it for just one neuron on Hi-C (mode I)
'''

from gpt2.model import gpt2_mod_i
from gpt2.utils import load_encoder_hparams_and_params
from gpt2.shortlist import shortlist_solve, multi_helper

encoder, hparams, params = load_encoder_hparams_and_params(model_size, model_dir)

layer_id = 26
neuron_id = 5201

MLP = params['blocks'][0]['mlp']['c_proj']['w'].shape[0]
batches = np.eye(MLP, MLP)

test = np.array(gpt2_mod_i(batches[neuron_id], wte=params['wte'],blocks=params['blocks'],
                  ln_f=params['ln_f'],n_head=hparams["n_head"],inject=layer_id))

print(test.shape)

(1, 50257)


In [18]:
'''
Shortlisting step to get the top and lowest activations
'''

arg_matrix = arg_matrix = np.argsort(test[0])
tau = 900
neg = {shortlist_decoder[k] for k in arg_matrix[:tau] 
           if k in shortlist_decoder}
pos = {shortlist_decoder[k] for k in arg_matrix[-tau:] 
       if k in shortlist_decoder}

In [34]:
'''
Projecting to Wikipedia Graphs, followed by exact solving using MMEKC
The topics from neg pool produced is as seen from the appendix
'''

from algo.extract import MMEKC
from algo.topics import get_topics

neg_topics = get_topics(neg, graph_dir, wiki_num_windows, 
                                      min_freq, wiki_single_prob)

results, isets = MMEKC(neg_topics, thresholds=(0.1,1), size_limits=(10,24), time_limit_s=120)
for result, iset in zip(results,isets):
    topic_ids, score = result
    print(" ".join([wiki_id2vocab[x] for x in iset]), f"({len(iset)})")
    print(" ".join([wiki_id2vocab[x] for x in topic_ids]), score)
    print()

lutheran oslo danish stockholm viking norse norwegian sweden swedish helsinki copenhagen scandinavian (12)
copenhagen swedish oslo norwegian helsinki stockholm viking danish scandinavian sweden 0.23735375280168527

archbishop pope worcester premiership earthquake leicester norwich cathedral pilgrim bishop canterbury (11)
archbishop bishop pope earthquake pilgrim canterbury norwich worcester cathedral leicester 0.09824852693207832

salvation souls doctrine orthodox resurrection mercy sins conversion sacrament repentance (10)
salvation souls doctrine repentance sins mercy conversion sacrament resurrection orthodox 0.17949054874666126



In [36]:
'''
We repeat the process for pos pool
Again, these topics are as seen from the appendix
'''

pos_topics = get_topics(pos, graph_dir, wiki_num_windows, 
                                      min_freq, wiki_single_prob)

results, isets = MMEKC(pos_topics, thresholds=(0.1,1), size_limits=(10,24), time_limit_s=120)
for result, iset in zip(results,isets):
    topic_ids, score = result
    print(" ".join([wiki_id2vocab[x] for x in iset]), f"({len(iset)})")
    print(" ".join([wiki_id2vocab[x] for x in topic_ids]), score)
    print()

wednesday saturday tuesday monday month morning february good daily day night thursday sunday tonight friday (15)
sunday morning monday wednesday night friday saturday thursday day tuesday 0.33148646093748896

correctly crowd drive got decided pass flip bat throw throws tie beat tossed bet toss (15)
bat throws drive tossed got crowd pass tie toss throw 0.10929527706674895

modules boot array client command install installed tracking control smart function functions module (13)
module modules client smart control functions install tracking installed array 0.11403861688252051

don let guys exactly going tell feel ought understand need answer getting folks better come pretty like certainly people sort feeling fun little know (24)
let pretty tell going like fun don feel know guys 0.17885697409360674



In [47]:
'''
We compare to random, to ensure that what we got was meaningful
While a single example does not reflect the entire random distribution
We can observe that very good topics can be found
Setting a very large $tau$ increases the advantage of the random baseline
'''
import random
idx = list(range(50257))
random.shuffle(idx)
rand_pool = {shortlist_decoder[k] for k in idx[-tau:] if k in shortlist_decoder}
rand_topics = get_topics(rand_pool, graph_dir, wiki_num_windows, min_freq, wiki_single_prob)
results, isets = MMEKC(rand_topics, thresholds=(0.1,1), size_limits=(10,24), time_limit_s=120)
for result, iset in zip(results,isets):
    topic_ids, score = result
    print(" ".join([wiki_id2vocab[x] for x in iset]), f"({len(iset)})")
    print(" ".join([wiki_id2vocab[x] for x in topic_ids]), score)
    print()

diverse diversity reef dolphins prey eggs shrimp food rays fishes (10)
diversity fishes eggs food reef rays dolphins shrimp diverse prey 0.08609391579800717

miz tara hogan team lost win html url demolition successfully hardy tag (12)
hogan win miz lost tag team demolition tara hardy successfully 0.03405867872378255

target undergoing fatigue administered stem respond aids protocols proven complications medical node physicians healthy treatments chemotherapy (16)
undergoing administered complications aids medical healthy treatments physicians chemotherapy respond 0.12153163523996449

fusion tap artery rehabilitation activation wound fiber medial ankle function spinal (11)
medial fusion function fiber rehabilitation ankle activation spinal wound artery -0.011409897917609442

fuzzy county dots crush dutch narrower alert shirt fruits orange (10)
shirt narrower county fruits dutch alert fuzzy crush orange dots -0.18895207988312487



In [52]:
"""
The following code does the above on a larger-scale
"""
from gpt2.model import gpt2_crawl

model_dir = "./models"
model_size = "124M"
layer_outputs = gpt2_crawl(model_dir, model_size, 3, layers=[0,1], batch_size=32)
print(len(layer_outputs), type(layer_outputs))
print(layer_outputs[0].shape)

100%|██████████| 24/24 [00:25<00:00,  1.07s/it]
100%|██████████| 24/24 [00:23<00:00,  1.02it/s]

2 <class 'dict'>
(768, 50257)





In [59]:
'''
shortlist_solve uses Pools to call multiple MMEKC
optional to pass in names
'''
from gpt2.shortlist import shortlist_solve
test = layer_outputs[0][0:8,:]
shortlist_solve(test, dest_dir='./outputs', tau=900, thresholds=(0.1,1), size_limits=(10,24),
               graph_dir=graph_dir, num_windows=wiki_num_windows, 
                min_freq=min_freq, single_prob=wiki_single_prob, 
                shortlist_decoder=shortlist_decoder, verbose=1)

processing: 8
STOP : ./outputs/1 : 15s
STOP : ./outputs/2 : 16s
STOP : ./outputs/0 : 25s
STOP : ./outputs/3 : 28s
STOP : ./outputs/4 : 20s
STOP : ./outputs/7 : 10s
STOP : ./outputs/5 : 23s
STOP : ./outputs/6 : 15s


In [81]:
from algo.utils import read_topics, read_isets
for i in range(8):
    print('Example', i)
    pos_isets = read_isets(f"./outputs/{i}_pos_isets.csv", wiki_id2vocab)
    neg_isets = read_isets(f"./outputs/{i}_neg_isets.csv", wiki_id2vocab)
    pos_topics = read_topics(f"./outputs/{i}_pos_topics.csv", wiki_id2vocab)
    neg_topics = read_topics(f"./outputs/{i}_neg_topics.csv", wiki_id2vocab)
    print(f'NEG ({len(neg_isets)})')
    if len(neg_isets) == 0: print(' EMPTY')
    for iset, topics in zip(neg_isets,neg_topics):
        print(" ".join(iset), f"({len(iset)})")
        print(' ', " ".join(topics[0]), topics[1])
        print()
    print(f'POS ({len(pos_isets)})')
    if len(pos_isets) == 0: print(' EMPTY')
    for iset, topics in zip(pos_isets,pos_topics):
        print(" ".join(iset), f"({len(iset)})")
        print(' ', " ".join(topics[0]), topics[1])
        print()
    print('\n','='*20)
    

Example 0
NEG (0)
 EMPTY
POS (21)
high children science lunch sports girls secondary curriculum closed prep grade college program education math law school (17)
  math high girls secondary curriculum grade college school education prep 0.1940944472231845

items cut pair price kick sharp thread hand tool piece running tail scissors (13)
  thread sharp tool tail hand piece running cut scissors pair 0.0893952992701257

childhood mother girl family daughter young sex sexual laptop inquiry life birth child (13)
  daughter young sex child mother childhood girl life family birth 0.1536742521138479

talent miss reality charm truth love art earth natural nature products shop perfect innocence sleeping beauty (16)
  miss charm perfect reality nature love innocence beauty truth talent 0.1281501138763905

sorted poll house ign built stock table list sale listed (10)
  list listed poll stock table ign built sorted sale house -0.0084403733247289

drive code adventure file default land inside skip sl

In [86]:
'''
counting activations in the model for ZSTM

use extra_vocab to narrow the vocabulary further:
    shortlist_decoder = shortlist_decode(wiki_vocab2id, model_size, 
                                              model_dir, extra_vocab=[])

NPMI values of other corpus might not guarantee a score that correlates with human judgment
In most cases, using a subset of common vocabulary would work for estimation
'''
from gpt2.model import gpt2_get_votes

corpus = ['The flying fox jumps over the lazy dog',
         'Buffalo buffalo Buffalo buffalo buffalo buffalo Buffalo buffalo',
         'Pack my box with five dozen liquor jugs']

top_activated = gpt2_get_votes(model_dir, model_size, corpus, votes=5)

100%|██████████| 3/3 [00:00<00:00,  4.54it/s]


In [92]:
print(type(top_activated[1]), top_activated[1].shape)
top_activated[1]

<class 'numpy.ndarray'> (12, 9, 5)


array([[[300, 670, 326, 447, 138],
        [101, 288, 326, 266, 447],
        [ 99, 326, 288, 266, 447],
        [101, 326, 288, 266, 447],
        [ 99, 326, 288, 266, 447],
        [ 99, 326, 288, 266, 447],
        [ 99, 326, 288, 266, 447],
        [608, 326, 288, 266, 447],
        [ 99, 326, 288, 266, 447]],

       [[300, 670, 326, 447, 138],
        [566, 288, 326, 266, 447],
        [480, 326, 288, 266, 447],
        [480, 326, 288, 266, 447],
        [480, 326, 288, 266, 447],
        [480, 326, 288, 266, 447],
        [480, 326, 288, 266, 447],
        [480, 288, 326, 447, 266],
        [480, 326, 288, 447, 266]],

       [[300, 670, 326, 138, 447],
        [101, 288, 326, 266, 447],
        [480, 326, 288, 266, 447],
        [480, 326, 288, 266, 447],
        [480, 326, 288, 266, 447],
        [480, 326, 288, 266, 447],
        [480, 326, 288, 266, 447],
        [480, 326, 288, 266, 447],
        [480, 326, 288, 266, 447]],

       [[300, 326, 670, 138, 447],
        [566, 