# Install the prerequisites

In [None]:
!sudo apt install zstd -y
!pip install transformers datasets 'torch>=1.9.1'

# Load the dataset

In [None]:
from datasets import list_datasets, load_dataset
datasets_list = list_datasets()

src_path = 'glue'
src_name = 'mrpc'
src = src_path + "-" + src_name

dataset_glue = load_dataset(src_path, src_name, split='train')

dataset = []
idx = 0
for record in dataset_glue:
    if 'sentence1' in record:
        dataset.append({'text': record['sentence1'], 'source': src, 'index': "{}.s1".format(idx)})
    if 'sentence2' in record:
        dataset.append({'text': record['sentence2'], 'source': src, 'index': "{}.s2".format(idx)})

dataset_name = src
print('Dataset has', len(dataset), 'examples.')

# Load the model

In [None]:
import torch

from transformers import AutoTokenizer, GPTJForCausalLM

layers = list(range(28))
device_ids = list(range(2))
assert len(device_ids) % len(device_ids) == 0
num_layers_per_device = len(layers) // len(device_ids)
device_map = { i: layers[i * num_layers_per_device : (i + 1) * num_layers_per_device] for i in device_ids }

tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6B")
model = GPTJForCausalLM.from_pretrained("EleutherAI/gpt-j-6B", low_cpu_mem_usage=True)
model.parallelize(device_map)

print("Model Loaded..!")

# Implementation

In [None]:
# register hooks

import torch
import gc
from tqdm import tqdm
from functools import partial


# make sure we clear up any existing hooks if this cell is re-run
if 'handles' in globals():
    print('Removing existing hooks...')
    for handle in handles:
        handle.remove()

module_names_to_track_for_activations = \
    [f'transformer.h.{i}.mlp.fc_out' for i in range(model.config.n_layer)]
module_names_to_track_for_logits = \
    [f'transformer.h.{i}' for i in range(model.config.n_layer)]

hidden_states = dict()
def save_hidden_states(name, module, input, output):
    # we care about input for activations, and output for logit lens
    hidden_states[name] = { 'input': input[0], 'output': output }

def clear_hidden_states():
    hidden_states = dict()

handles = []
cnt = 0
for name, m in model.named_modules():
    if name in module_names_to_track_for_activations or name in module_names_to_track_for_logits:
        cnt += 1
        handle = m.register_forward_hook(partial(save_hidden_states, name))
        handles.append(handle)

In [None]:
# some settings

num_tokens = 10000
top_k = 5
activation_threshold = 3
dense_activations = False

# stop after this many records in the dataset have been processed
dataset_limit = 10

# how many bytes to pickle into a single file before compressing it and starting a new file
file_size_goal = 128 * 1024 * 1024

main_device = torch.device('cuda:0')

In [None]:
!rm -rf output
!mkdir output

In [None]:
from datetime import datetime
import os
import pickle
import subprocess
import time
import torch.nn.functional as F

class batch_pickler:
    file_name_base = None
    file_bytes_goal = 0
    save_file_func = None

    next_file_idx = 0

    file = None
    cur_file_name = None

    def __init__(self, file_name_base, file_bytes_goal, save_file_func):
        self.file_name_base = file_name_base
        self.file_bytes_goal = file_bytes_goal
        self.save_file_func = save_file_func

    def dump(self, record):
        # if the current batch is full, close it
        if self.file is not None and self.file.tell() > self.file_bytes_goal:
            self.close()

        # if there is no open batch, create a new one
        if self.file is None:
            self.cur_file_name = "{}.{}".format(self.file_name_base, self.next_file_idx)
            self.file = open(self.cur_file_name, 'wb')
            self.next_file_idx = self.next_file_idx + 1

        # pickle the record into the current batch
        pickle.dump(record, self.file)

    def close(self):
        if self.file != None:
            # close the file
            self.file.close()
            self.file = None

            # compress and upload the file
            self.save_file_func(self.cur_file_name)

            # remove the file
            os.remove(self.cur_file_name)

            # mark the batch as finished
            self.cur_file_name = None

def extract_neuron_values(key, threshold, dense=True):
    """
    For each MLP, determine which neurons fire at any point during the entire sequence,
    unless it only fires on the first token (which we will just assume is noise).

    The output is a list of dicts resembling individual neurons with fields:
        l: the layer of the neuron
        f: the index of the neuron in the feature dimension
        a: a list of activations equal to the length of the sequence

    """

    neurons = []
    for name in module_names_to_track_for_activations:
        h = hidden_states[name][key]
        neurons.append(h[0].to(main_device))
    neurons = torch.stack(neurons)

    if dense:
        return neurons.tolist()

    high_activations = (neurons > threshold).nonzero()
    values = []
    uniq = set()
    i = 0
    st = time.time()
    torch.index_select(neurons, 0, high_activations[:, 0])
    return None
    for layer_idx, _, feature_idx in high_activations:
        layer_idx = layer_idx.item()
        feature_idx = feature_idx.item()
        if (layer_idx, feature_idx) in uniq:
            # we already have it!
            continue
        uniq.add((layer_idx, feature_idx))
        values.append({
            'l': layer_idx,
            'f': feature_idx,
            'a': neurons[layer_idx, :, feature_idx].reshape([neurons.shape[1]]).tolist(),
        })
        i += 1
        if i % 10000 == 0:
            pass
#             print(i, len(uniq))
    en = time.time()
    print('here:', en-st)
    return values

def extract_logit_lens(k=10):
    """
    Extract the output logits for each layer (including the final layer)

    Returns a nested list structure of shape [n_layers, n_seq, k]
    where each element is a dict containing:
        tok: the predicted token
        prob: the probability given to this token (from softmax of logits)

    Note: The sum of the final dimension probabilities will be very close to 1.
    """

    per_layer_tokens = []
    for name in module_names_to_track_for_logits:
        h2 = hidden_states[name]['output'][0]  # x, present
        with torch.no_grad():
            layer_logits = model.lm_head(h2.to(main_device)).detach()[0]
        seq = layer_logits.shape[0]
        values, indices = torch.topk(layer_logits, k=k)
        norm_values = F.softmax(values, dim=-1)
        indices = indices.cpu()
        norm_values = norm_values.cpu()
        top_in_sequence = []
        for i in range(seq):
            top_tokens = []
            for tok, prob in zip(indices[i], norm_values[i]):
                tok = tok.item()
                prob = prob.item()
                top_tokens.append({
                    'tok': tokenizer.decode([tok]),
                    'prob': prob,
                })
            top_in_sequence.append(top_tokens)
        per_layer_tokens.append(top_in_sequence)
    return per_layer_tokens

def compress_upload(file_name):
    file_name_zst = file_name + ".zst"

    # compress the file with zstd
    runres = subprocess.run(["zstd", file_name, "-f", "-o", file_name_zst])
    if runres.returncode != 0:
        raise Exception("zstd compression failed. ec={}".format(
            runres.returncode))

model.eval()

total_records = 0
total_tokens = 0

try:
    pickler = batch_pickler("output/neurons.pickle", file_size_goal, compress_upload)

    with torch.no_grad():
        for row in tqdm(dataset):
            inputs = tokenizer(row['text'], return_tensors="pt")
            context = inputs["input_ids"][0][:num_tokens]
            context = context.to(main_device)

            total_records = total_records + 1
            total_tokens = total_tokens + len(context)

            clear_hidden_states()

            output = model(context, return_dict=True, output_attentions=True)

            st = time.time()
            activations_in = extract_neuron_values('input', activation_threshold, dense_activations)
            en = time.time()
            print('process in:', en-st)
            st = time.time()
            activations_out = extract_neuron_values('output', activation_threshold, dense_activations)
            en = time.time()
            print('process out:', en-st)
            st = time.time()
            logits = extract_logit_lens(k=top_k)  # [48, seq, 5]
            en = time.time()
            print('logit lens:', en-st)
            st = time.time()
            attentions = torch.stack([a.to(main_device) for a in output.attentions]).tolist()
            en = time.time()
            print('attn:', en-st)

            record = {
                'text': row["text"],
                'source': row['source'],
                'tokens': context,
                'activationsIn': activations_in,
                'activationsOut': activations_out,
                'logits': logits,
                'attentions': attentions,
            }

            pickler.dump(record)
            if total_records == dataset_limit:
                break
finally:
    pickler.close()

print("Records: ", total_records)
print("Processed tokens: ", total_tokens)