In [1]:
from metrics.utils import hidden_states_collapse
from metrics.query import DataFrameQuery
from common.tensor_storage import TensorStorage
from metrics.utils import  exact_match
#from sklearn.feature_selection import mutual_info_regression MISSIN?
from dadapy.data import Data

from pathlib  import Path

import numpy as np
import pandas as pd

from common.metadata_db import MetadataDB
from common.utils import *
from pathlib import Path
import pickle


## Functions

In [2]:
def set_dataframes(db) -> pd.DataFrame:
    """
    Aggregate in a dataframe the hidden states of all instances
    ----------
    hidden_states: pd.DataFrame(num_instances, num_layers, model_dim)
    """
    df = pd.read_sql("SELECT * FROM metadata", db.conn)
    df["train_instances"] = df["train_instances"].astype(str)
    df.drop(columns=["id"],inplace = True)
    #import pdb; pdb.set_trace()
    df.drop_duplicates(subset = ["id_instance"],inplace = True, ignore_index = True) # why there are duplicates???
    return df

In [3]:
def tensor_retrieve(dict_query):
    query = DataFrameQuery(dict_query)
    hidden_states,logits, hidden_states_df= hidden_states_collapse(metadata_df,query,tensor_storage)
    return hidden_states,logits,hidden_states_df

In [4]:
def constructing_labels(label: str, hidden_states_df: pd.DataFrame, hidden_states: np.ndarray) -> np.ndarray:
    labels_literals = hidden_states_df[label].unique()
    labels_literals.sort()
    
    map_labels = {class_name: n for n,class_name in enumerate(labels_literals)}
    
    label_per_row = hidden_states_df[label].reset_index(drop=True)
    label_per_row = np.array([map_labels[class_name] for class_name in label_per_row])[:hidden_states.shape[0]]
    
    return label_per_row, map_labels

In [5]:
_PATH = Path("/orfeo/scratch/dssc/zenocosini/mmlu_result")
result_path = Path(_PATH,"diego")
result_path.mkdir(exist_ok=True,parents=True)
metadata_db = MetadataDB(_PATH / "metadata.db")
metadata_df = set_dataframes(metadata_db)
tensor_storage = TensorStorage(Path(_PATH, "tensor_files"))



## Tensor Retrieval

### Base model

In [28]:
shot = 0

In [9]:
!ls /orfeo/scratch/dssc/zenocosini/mmlu_result/transposed_dataset/llama-7b-base/0

distances-0.npy   distances-31.npy	dist_indices-22.npy
distances-10.npy  distances-32.npy	dist_indices-23.npy
distances-11.npy  distances-3.npy	dist_indices-24.npy
distances-12.npy  distances-4.npy	dist_indices-25.npy
distances-13.npy  distances-5.npy	dist_indices-26.npy
distances-14.npy  distances-6.npy	dist_indices-27.npy
distances-15.npy  distances-7.npy	dist_indices-28.npy
distances-16.npy  distances-8.npy	dist_indices-29.npy
distances-17.npy  distances-9.npy	dist_indices-2.npy
distances-18.npy  distances-logits.npy	dist_indices-30.npy
distances-19.npy  dist_indices-0.npy	dist_indices-31.npy
distances-1.npy   dist_indices-10.npy	dist_indices-32.npy
distances-20.npy  dist_indices-11.npy	dist_indices-3.npy
distances-21.npy  dist_indices-12.npy	dist_indices-4.npy
distances-22.npy  dist_indices-13.npy	dist_indices-5.npy
distances-23.npy  dist_indices-14.npy	dist_indices-6.npy
distances-24.npy  dist_indices-15.npy	dist_indices-7.npy
distances-25.npy  dist_indices-16.npy	dist_indices-8.n

In [29]:
datasets = list(metadata_df["dataset"].unique())
dict_query = {"dataset":datasets, 
              "method":"last",
              "model_name":"meta-llama/Llama-2-7b-chat-hf",
              "train_instances": shot}
hidden_states,logits, hidden_states_df = tensor_retrieve(dict_query)

 Tensor retrieval took: 197.59465193748474



In [8]:
! mkdir -p /orfeo/scratch/dssc/zenocosini/mmlu_result/transposed_dataset/llama-7b-base

In [11]:
! ls /orfeo/scratch/dssc/zenocosini/mmlu_result/transposed_dataset/llama-7b-base

0


In [10]:
! mkdir -p /orfeo/scratch/dssc/zenocosini/mmlu_result/transposed_dataset/llama-7b-base/0

In [30]:
path = Path(_PATH,"transposed_dataset","llama-7b-chat",str(shot))
path.mkdir(exist_ok=True,parents=True)

In [31]:
labels, map_dict = constructing_labels("dataset",hidden_states_df, hidden_states)
np.save(Path(path,"subjects-labels.pkl"),labels)
with open(Path(path,"subjects-map"),"wb") as f:
    pickle.dump(map_dict,f)
labels, map_dict = constructing_labels("only_ref_pred",hidden_states_df, hidden_states)
with open(Path(path,"letter-map.pkl"),"wb") as f:
    pickle.dump(map_dict,f)
np.save(Path(result_path,"letter-base"),labels)

In [32]:
dict_nn_matrix = {}
dict_nn_matrix_l = {}

for layer in range(hidden_states.shape[1]):
    data = Data(hidden_states[:,layer,:])
    data.compute_distances(maxk=150)
    
    np.save(Path(path,f"distances-{layer}"), data.distances)
    np.save(Path(path,f"dist_indices-{layer}"), data.dist_indices)
    
 


This can cause problems in various routines.
We suggest to either perform smearing of distances using
remove_zero_dists()
or remove identical points using
remove_identical_points()).


In [33]:
data_l = Data(logits[:,0,:])
data_l.compute_distances(maxk=150)
np.save(Path(path,f"distances-logits"), data_l.distances)
np.save(Path(path,f"dist_indices-logits"), data_l.dist_indices)

In [23]:
data.distances.shape

(14015, 151)

In [24]:
logits.shape

(14015, 1, 32000)

In [61]:
with open(Path(result_path,"llama-7b-base-5shot-dist-matrix.pkl"),"wb") as f:
    pickle.dump(dict_nn_matrix,f)

### Chat model

In [65]:
datasets = list(metadata_df["dataset"].unique())
dict_query = {"dataset":datasets, 
              "method":"last",
              "model_name":"meta-llama/Llama-2-7b-chat-hf",
              "train_instances": 5}
hidden_states, hidden_states_df = tensor_retrieve(dict_query)

 Tensor retrieval took: 161.6436107158661



In [66]:
np.save(Path(result_path,"llama-7b-chat-5shot"), hidden_states)

In [67]:
labels, map_labels = constructing_labels("dataset",hidden_states_df, hidden_states)
np.save(Path(result_path,"subjects-chat"),labels)
labels = constructing_labels("only_ref_pred",hidden_states_df, hidden_states)
np.save(Path(result_path,"letter-chat"),labels)

In [68]:
dict_nn_matrix = {}
for layer in [6,15,18,29,31]:
    data = Data(hidden_states[:,layer,:])
    data.compute_distances(maxk=150)
    dict_nn_matrix[layer] = (data.distances,data.dist_indices)

In [69]:
with open(Path(result_path,"llama-7b-chat-5shot-dist-matrix.pkl"),"wb") as f:
    pickle.dump(dict_nn_matrix,f)

## Other

In [151]:
exact_matches = metadata_df.apply(lambda r: exact_match(r["std_pred"], r["letter_gold"]), axis=1)
metadata_df["exact_match"] = exact_matches
metadata_df_correct = metadata_df[metadata_df["method"]=="last"].copy()

#WHATCH OUT: some instance in medicine realated subject are repeated.
metadata_df_correct["id_instance"] = metadata_df_correct.apply(lambda r: r["id_instance"][:92]+r["id_instance"][-1], axis=1)
metadata_df_correct.drop_duplicates(subset=["id_instance"],inplace = True)
metadata_df_correct["id_instance"] = metadata_df_correct.apply(lambda r: r["id_instance"][:64], axis=1)

metadata_df_correct.reset_index(inplace=True)
metadata_df_correct.drop(columns=["index"],inplace=True)

#creating pivot table
metadata_df_correct['match_comb'] = metadata_df_correct.apply(lambda row: f"{row['model_name']}_{row['train_instances']}", axis=1)
pivot_df = metadata_df_correct.pivot(index='id_instance', columns='match_comb', values='exact_match')
pivot_df.reset_index(inplace=True)

In [152]:
pivot_df.columns

Index(['id_instance', 'meta-llama/Llama-2-7b-chat-hf_0',
       'meta-llama/Llama-2-7b-chat-hf_2', 'meta-llama/Llama-2-7b-chat-hf_5',
       'meta-llama/Llama-2-7b-hf_0', 'meta-llama/Llama-2-7b-hf_2',
       'meta-llama/Llama-2-7b-hf_5'],
      dtype='object', name='match_comb')

In [187]:
from tabulate import tabulate
cols = ["meta-llama/Llama-2-7b-hf_0","meta-llama/Llama-2-7b-hf_5","meta-llama/Llama-2-7b-chat-hf_0","meta-llama/Llama-2-7b-chat-hf_5"]
matrix = []
for col1 in cols:
    for col2 in cols:
        if col1==col2:
            perc = 1
        else:
            perc = (pivot_df[[col1,col2]].apply(lambda r: r[col1]==r[col2], axis = 1).sum()/len(pivot_df))*100
        matrix.append(perc)
matrix = np.array(matrix).reshape([4,4])
print(tabulate(matrix, headers=cols,showindex=cols))

                                   meta-llama/Llama-2-7b-hf_0    meta-llama/Llama-2-7b-hf_5    meta-llama/Llama-2-7b-chat-hf_0    meta-llama/Llama-2-7b-chat-hf_5
-------------------------------  ----------------------------  ----------------------------  ---------------------------------  ---------------------------------
meta-llama/Llama-2-7b-hf_0                             1                            74.5641                            67.2813                            66.3199
meta-llama/Llama-2-7b-hf_5                            74.5641                        1                                 72.2537                            75.8269
meta-llama/Llama-2-7b-chat-hf_0                       67.2813                       72.2537                             1                                 84.7169
meta-llama/Llama-2-7b-chat-hf_5                       66.3199                       75.8269                            84.7169                             1
