In [None]:
# Load model directly
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm, trange

model_name = 'meta-llama/Llama-2-7b-hf'
# model_name = 'princeton-nlp/unsup-simcse-roberta-large'

tokenizer = AutoTokenizer.from_pretrained(model_name, token="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")
model = AutoModelForCausalLM.from_pretrained(model_name, token="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")

device = 'cuda' if torch.cuda.is_available() else 'cpu'

model.to(device)

In [None]:
from datasets import load_dataset
from sklearn.decomposition import PCA
import os

max_data_count = 200
max_length = 512
subjects = ['high_school_european_history', 'business_ethics', 'clinical_knowledge', 'medical_genetics', 'high_school_us_history', 'high_school_physics', 'high_school_world_history', 'virology', 'high_school_microeconomics', 'econometrics', 'college_computer_science', 'high_school_biology', 'abstract_algebra', 'professional_accounting', 'philosophy', 'professional_medicine', 'nutrition', 'global_facts', 'machine_learning', 'security_studies', 'public_relations', 'professional_psychology', 'prehistory', 'anatomy', 'human_sexuality', 'college_medicine', 'high_school_government_and_politics', 'college_chemistry', 'logical_fallacies', 'high_school_geography', 'elementary_mathematics', 'human_aging', 'college_mathematics', 'high_school_psychology', 'formal_logic', 'high_school_statistics', 'international_law', 'high_school_mathematics', 'high_school_computer_science', 'conceptual_physics', 'miscellaneous', 'high_school_chemistry', 'marketing', 'professional_law', 'management', 'college_physics', 'jurisprudence', 'world_religions', 'sociology', 'us_foreign_policy', 'high_school_macroeconomics', 'computer_security', 'moral_scenarios', 'moral_disputes', 'electrical_engineering', 'astronomy', 'college_biology']
dim = 3

pca = PCA(n_components=dim)

def process_datasets(subject):
    dataset = load_dataset("lukaemon/mmlu", subject, split='test')

    template = """<s>[INST] <<SYS>>
    Answer the following question
    <</SYS>>
    {input}
    A. {A}
    B. {B}
    C. {C}
    D. {D}[/INST]
    The answer is:"""

    dataset = dataset.map(lambda x: {'text':template.format(**x)})
    dataset = dataset.map(lambda x: tokenizer(x['text'], return_tensors='pt'))

    return dataset.filter(lambda x: len(x['input_ids'][0]) <= max_length)

def get_hidden_states(dataset):
    data_count = min(max_data_count, len(dataset))
    raw_layers_data = np.zeros((data_count, 33, max_length, 4096), dtype=np.float32)
    tokens_mask = np.zeros((data_count, max_length), dtype=np.int32)

    with torch.no_grad():
        for i, row in tqdm(enumerate(dataset), total=data_count):
            if i >= data_count:
                break
            input_ids = torch.tensor(row['input_ids']).to(device)
            output = model(input_ids=input_ids,output_hidden_states=True)
            hidden_states = torch.stack(output.hidden_states).squeeze(1).cpu().numpy()
            raw_layers_data[i, :, :hidden_states.shape[1], :] = hidden_states
            tokens_mask[i, :hidden_states.shape[1]] = 1

    tokens_length = tokens_mask.sum(axis=1)

    return raw_layers_data, tokens_mask, tokens_length, data_count

def get_3d_data(raw_layers_data, tokens_mask):
    layers_3d = np.zeros((raw_layers_data.shape[0], 33, raw_layers_data.shape[2], dim), dtype=np.float32)
    pca_components = np.zeros((33, dim, 4096), dtype=np.float32)
    variance_ratios = np.zeros((33, dim), dtype=np.float32)

    for i in trange(33):
        layer = raw_layers_data[:, i]
        expanded_tokens_mask = tokens_mask[:, :, np.newaxis]
        expanded_tokens_mask = np.broadcast_to(expanded_tokens_mask, layer.shape)
        flattened_data = layer[expanded_tokens_mask == 1].reshape(-1, 4096)
        flattened_data /= np.linalg.norm(flattened_data, axis=-1, keepdims=True)
        pca.fit(flattened_data)
        pca_components[i] = pca.components_
        if i != 0:
            # make sure the new components are not flipped
            for j in range(dim):
                dot_product = np.dot(pca_components[i, j], pca_components[i-1, j])
                if dot_product < 0:
                    pca_components[i, j] *= -1
        pca.components_ = pca_components[i]
        flattened_data_2d = pca.transform(flattened_data)
        data = np.zeros((layer.shape[0], layer.shape[1], dim), dtype=np.float32)
        expanded_tokens_mask = tokens_mask[:, :, np.newaxis]
        expanded_tokens_mask = np.broadcast_to(expanded_tokens_mask, data.shape)
        data[expanded_tokens_mask == 1] = flattened_data_2d.flatten()
        layers_3d[:, i] = data
        variance_ratios[i] = pca.explained_variance_ratio_

    return layers_3d, pca_components, variance_ratios

def save_file(subject, layers_3d, pca_components, variance_ratios):
    if not os.path.exists(subject):
        os.mkdir(subject)
    np.save(os.path.join(subject, 'layers_3d.npy'), layers_3d)
    np.save(os.path.join(subject, 'pca_components.npy'), pca_components)
    np.save(os.path.join(subject, 'variance_ratios.npy'), variance_ratios)

def process(subject):
    print(f"Processing {subject}")
    print("Loading dataset")
    filtered_dataset = process_datasets(subject)
    print("Passing to decoder")
    raw_layers_data, tokens_mask, tokens_length, data_count = get_hidden_states(filtered_dataset)
    print("Dimensionality reduction")
    layers_3d, pca_components, variance_ratios = get_3d_data(raw_layers_data, tokens_mask)
    print(f"Saving {subject}")
    save_file(layers_3d, pca_components, variance_ratios)
    print(f"Finished {subject}")

# process(subjects[0])


In [None]:
filtered_dataset = process_datasets(subjects[0])
raw_layers_data, tokens_mask, tokens_length, data_count = get_hidden_states(filtered_dataset)

In [None]:
raw_layers_data.shape

In [None]:
from tqdm import trange

similiarities = np.zeros((32, 120, 512))
raw_layers_data /= np.linalg.norm(raw_layers_data, axis=-1, keepdims=True) + 1e-8

for layer in trange(32):
    l0 = raw_layers_data[:, layer]
    l1 = raw_layers_data[:, layer + 1]
    similiarity = np.einsum('ijk,ijk->ij', l0, l1)
    similiarities[layer] = similiarity

similiarities.shape

In [None]:
import plotly.express as px

px.imshow(similiarities, animation_frame=0, zmin=0.5, zmax=1)