In [1]:
import pandas as pd
from embeds import multi_inner_align, standardize
from transformers import AutoTokenizer, AutoModel
import torch
import pickle
import sys
from tqdm.notebook import tqdm_notebook as tqdm

sys.path.append('..')
from rca.rca import run_rca

In [2]:
norms_voc = set(
    pd.read_csv('../../data/psychNorms/psychNorms.zip', index_col=0, low_memory=False, compression='zip').index
)
with open('../../data/brain_behav_union.pkl', 'rb') as f:
    brain_behav_union = pickle.load(f)

# Extract intersection of norms and brain_behavior_union
to_extract = list(norms_voc & brain_behav_union)
len(to_extract)

46246

## Extracting representations

In [3]:
torch.random.manual_seed(42)

if torch.cuda.is_available():
    device = torch.device("cuda")
    print('CUDA is available. Using GPU.')
elif torch.backends.mps.is_available():
    device = torch.device("mps")
    print("MPS is available. Using Apple's Metal.")
else:
    device = torch.device("cpu")
    print("No GPU or MPS available. Using CPU.")


model_name = 'meta-llama/Llama-3.2-1B'

# Tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token

# Model
model = AutoModel.from_pretrained(model_name).to(device)
model.eval()

MPS is available. Using Apple's Metal.


LlamaModel(
  (embed_tokens): Embedding(128256, 2048)
  (layers): ModuleList(
    (0-15): 16 x LlamaDecoderLayer(
      (self_attn): LlamaAttention(
        (q_proj): Linear(in_features=2048, out_features=2048, bias=False)
        (k_proj): Linear(in_features=2048, out_features=512, bias=False)
        (v_proj): Linear(in_features=2048, out_features=512, bias=False)
        (o_proj): Linear(in_features=2048, out_features=2048, bias=False)
      )
      (mlp): LlamaMLP(
        (gate_proj): Linear(in_features=2048, out_features=8192, bias=False)
        (up_proj): Linear(in_features=2048, out_features=8192, bias=False)
        (down_proj): Linear(in_features=8192, out_features=2048, bias=False)
        (act_fn): SiLU()
      )
      (input_layernorm): LlamaRMSNorm((2048,), eps=1e-05)
      (post_attention_layernorm): LlamaRMSNorm((2048,), eps=1e-05)
    )
  )
  (norm): LlamaRMSNorm((2048,), eps=1e-05)
  (rotary_emb): LlamaRotaryEmbedding()
)

In [6]:
batch_size = 16
embeds = {}
with torch.no_grad():
    # Loop through the data in chunks of `batch_size`
    for i in tqdm(range(0, len(to_extract), batch_size)):
        # Create a batch from the large list
        batch_words = to_extract[i:i+batch_size]

        # --- The rest of the logic is the same, just applied to `batch_words` ---
        inputs = tokenizer(batch_words, return_tensors='pt', padding=True, truncation=True)
        all_word_ids = [inputs.word_ids(j) for j in range(len(batch_words))]
        inputs = {key: val.to(device) for key, val in inputs.items()}

        outputs = model(**inputs)
        last_hidden_state = outputs.last_hidden_state.cpu()

        for k, word in enumerate(batch_words):
            word_ids = all_word_ids[k]
            word_token_indices = [j for j, wid in enumerate(word_ids) if wid is not None]

            word_hidden_states = last_hidden_state[k, word_token_indices, :]
            averaged_word_representation = torch.mean(word_hidden_states, dim=0)

            embeds[word] = averaged_word_representation

embeds = pd.DataFrame(embeds).T.astype(float)
embeds

  0%|          | 0/2891 [00:00<?, ?it/s]

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,2038,2039,2040,2041,2042,2043,2044,2045,2046,2047
baddie,0.617632,3.677328,2.646231,1.066725,1.290422,-0.615204,0.670158,1.840440,-1.621486,1.008968,...,-0.722906,2.831227,-2.334162,-0.939548,1.865410,-2.632090,4.180321,-2.140123,-1.898184,-1.185073
unblushingly,-0.780811,4.523852,3.454993,0.153665,0.672691,-3.011597,1.730513,2.812676,-0.904837,2.997128,...,-1.004646,3.110042,-2.377111,-1.208966,0.706215,-3.520344,1.515211,-3.186903,-2.715932,-2.195801
insensibly,-0.447130,4.485464,1.775921,0.256286,1.148121,-1.091852,0.998185,3.691133,0.363665,1.720003,...,-3.294988,5.081615,-1.825445,0.847631,-0.042128,-2.111433,0.839492,-3.386798,-4.152276,0.360679
parenthesize,-1.802119,3.565008,4.939508,0.087262,1.035947,-2.990864,0.296555,2.672847,-1.835698,3.704694,...,-1.855984,1.303200,-1.128216,-1.122571,2.355582,-0.972989,-0.019990,-2.025034,-1.709530,-1.170313
insatiate,-0.485329,3.673390,2.670759,0.127345,0.541745,-2.871100,1.666120,2.899798,0.152642,1.901984,...,-1.974949,2.901847,-1.563233,0.150040,-0.821032,-1.307971,2.909365,-1.170126,-4.494101,-0.393104
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
popularly,1.089802,4.918963,2.048598,-0.072903,1.971993,-2.231852,2.272493,2.100999,0.033533,3.741592,...,-2.732045,3.377771,-3.337066,0.894094,1.947781,-2.443411,0.744422,-2.858145,-3.858637,-0.828052
unable,0.111263,6.062789,3.109928,1.477149,0.907181,-3.777168,3.525682,2.249653,0.286584,2.333573,...,-1.833538,5.179575,-3.770607,-0.577586,-0.271112,-1.049749,1.283552,-2.703760,-4.763577,0.064372
wench,1.319315,3.614148,3.027516,-0.368783,0.113268,-1.586929,2.954784,3.827778,-0.045522,1.795469,...,-0.864732,2.597534,-1.394352,-2.037887,0.309999,-3.030828,4.684218,-3.293120,-3.587358,-1.720351
fanaticize,-0.334161,4.565280,2.965562,0.823084,0.807213,-0.316923,0.539709,2.069281,-0.749703,3.892582,...,-2.658002,2.438370,-1.762910,0.149773,-0.147724,-1.470575,1.400764,-2.487698,-2.895077,-1.743315


In [7]:
ft_baseline = pd.read_csv('../../data/embeds/fastText_CommonCrawl.csv', index_col=0)

# Comparing
to_compare = {
    'ft_baseline': ft_baseline,
    'Llama_3.2_1B': embeds
}

# Aligning for fair comparison
to_compare = dict(zip(to_compare.keys(), multi_inner_align(to_compare.values())))

# Standardizing
to_compare = {name: standardize(embed) for name, embed in to_compare.items()}

# Loading norm data
norms = pd.read_csv('../../data/psychNorms/psychNorms_processed.zip', index_col=0, low_memory=False, compression='zip')
norms_meta = pd.read_csv('../../data/psychNorms/psychNorms_metadata_processed.csv', index_col='norm')
norms

Unnamed: 0_level_0,frequency_lund,frequency_kucera,frequency_subtlexus,frequency_subtlexuk,frequency_blog_gimenes,frequency_twitter_gimenes,frequency_news_gimenes,frequency_written_cobuild,frequency_spoken_cobuild,context_diversity_subtlexus,...,person_vanarsdall,goals_vanarsdall,movement_vanarsdall,concreteness_vanarsdall,familiarity_vanarsdall,imageability_vanarsdall,familiarity_fear,aoa_fear,imageability_fear,sensory_experience_juhasz2013
word,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
'em,0.0,,,,,,,1.3617,1.9138,,...,,,,,,,,,,
'neath,0.0,,,,,,,0.0000,0.0000,,...,,,,,,,,,,
're,0.0,,,,,,,0.9031,1.6335,,...,,,,,,,,,,
'shun,0.0,,,,,,,0.0000,0.0000,,...,,,,,,,,,,
'tis,0.0,,,,,,,0.4771,0.6021,,...,,,,,,,,,,
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
shrick,,,,,,,,,,,...,,,,,,,2.62,4.38,2.93,
post office,,,,,,,,,,,...,,,,,,,3.79,3.07,5.29,
fishing rod,,,,,,,,,,,...,,,,,,,2.29,3.38,5.64,
March,,,,,,,,,,,...,,,,,,,3.43,2.76,3.50,


In [None]:
results = run_rca(to_compare, norms, norms_meta, n_jobs=10)
results

  0%|          | 0/2 [00:00<?, ?it/s]

ft_baseline:   0%|          | 0/291 [00:00<?, ?it/s]

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
TOKENIZERS_PARALLELISM=(true | false)
TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Av

In [None]:
# Adding norm category
results['norm_category'] = (
    results['norm']
    .apply(lambda norm: norms_meta.loc[norm]['category'])
    .replace({'_': ' '}, regex=True)
)

results_avg = (
    results[['norm_category', 'embed', 'r2_mean']]
    .groupby(['norm_category', 'embed'], as_index=False).median()
    .dropna()
)

results_avg_piv = results_avg.pivot(columns='embed', index='norm_category', values='r2_mean')
results_avg_piv.round(2)

In [None]:
# Finding the top-performing fmri_text_denoise
sorted_overall = results_avg_piv.mean().sort_values(ascending=False)
sorted_overall

## Saving

In [None]:
top_performer = None
top_performer.to_csv('../../data/embeds/Llama_X_XB.csv')

