In [1]:
from metrics.utils import hidden_states_collapse
from metrics.query import DataFrameQuery
from common.tensor_storage import TensorStorage
from metrics.utils import  exact_match
#from sklearn.feature_selection import mutual_info_regression MISSIN?
from dadapy.data import Data

from pathlib  import Path

import numpy as np
import pandas as pd

from common.metadata_db import MetadataDB
from common.utils import *
from pathlib import Path
import pickle


## Functions

In [2]:
def set_dataframes(db) -> pd.DataFrame:
    """
    Aggregate in a dataframe the hidden states of all instances
    ----------
    hidden_states: pd.DataFrame(num_instances, num_layers, model_dim)
    """
    df = pd.read_sql("SELECT * FROM metadata", db.conn)
    df["train_instances"] = df["train_instances"].astype(str)
    df.drop(columns=["id"],inplace = True)
    #import pdb; pdb.set_trace()
    df.drop_duplicates(subset = ["id_instance"],inplace = True, ignore_index = True) # why there are duplicates???
    return df

In [3]:
def tensor_retrieve(dict_query):
    query = DataFrameQuery(dict_query)
    hidden_states,_, hidden_states_df= hidden_states_collapse(metadata_df,query,tensor_storage)
    return hidden_states,hidden_states_df

In [4]:
  def constructing_labels(label: str, hidden_states_df: pd.DataFrame, hidden_states: np.ndarray) -> np.ndarray:
      labels_literals = hidden_states_df[label].unique()
      labels_literals.sort()
      
      map_labels = {class_name: n for n,class_name in enumerate(labels_literals)}
      
      label_per_row = hidden_states_df[label].reset_index(drop=True)
      label_per_row = np.array([map_labels[class_name] for class_name in label_per_row])[:hidden_states.shape[0]]
      
      return label_per_row

In [5]:
_PATH = Path("/orfeo/scratch/dssc/zenocosini/mmlu_result")
result_path = Path(_PATH,"diego")
result_path.mkdir(exist_ok=True,parents=True)
metadata_db = MetadataDB(_PATH / "metadata.db")
metadata_df = set_dataframes(metadata_db)
tensor_storage = TensorStorage(Path(_PATH, "tensor_files"))



## Tensor Retrieval

### Base model

In [15]:
datasets = list(metadata_df["dataset"].unique())
dict_query = {"dataset":datasets, 
              "method":"last",
              "model_name":"meta-llama/Llama-2-7b-hf",
              "train_instances": 5}
hidden_states, hidden_states_df = tensor_retrieve(dict_query)

 Tensor retrieval took: 156.5567717552185



In [55]:
! ls /orfeo/scratch/dssc/zenocosini/mmlu_result/diego/

llama-7b-base-5shot.npy  subjects.npy


In [53]:
np.save(Path(result_path,"llama-7b-base-5shot"), hidden_states)

In [64]:
labels = constructing_labels("dataset",hidden_states_df, hidden_states)
np.save(Path(result_path,"subjects-base"),labels)
labels = constructing_labels("only_ref_pred",hidden_states_df, hidden_states)
np.save(Path(result_path,"letter-base"),labels)

In [57]:
dict_nn_matrix = {}
for layer in [6,15,18,29,31]:
    data = Data(hidden_states[:,layer,:])
    data.compute_distances(maxk=150)
    dict_nn_matrix[layer] = (data.distances,data.dist_indices)


In [61]:
with open(Path(result_path,"llama-7b-base-5shot-dist-matrix.pkl"),"wb") as f:
    pickle.dump(dict_nn_matrix,f)

### Chat model

In [65]:
datasets = list(metadata_df["dataset"].unique())
dict_query = {"dataset":datasets, 
              "method":"last",
              "model_name":"meta-llama/Llama-2-7b-chat-hf",
              "train_instances": 5}
hidden_states, hidden_states_df = tensor_retrieve(dict_query)

 Tensor retrieval took: 161.6436107158661



In [66]:
np.save(Path(result_path,"llama-7b-chat-5shot"), hidden_states)

In [67]:
labels = constructing_labels("dataset",hidden_states_df, hidden_states)
np.save(Path(result_path,"subjects-chat"),labels)
labels = constructing_labels("only_ref_pred",hidden_states_df, hidden_states)
np.save(Path(result_path,"letter-chat"),labels)

In [68]:
dict_nn_matrix = {}
for layer in [6,15,18,29,31]:
    data = Data(hidden_states[:,layer,:])
    data.compute_distances(maxk=150)
    dict_nn_matrix[layer] = (data.distances,data.dist_indices)

In [69]:
with open(Path(result_path,"llama-7b-chat-5shot-dist-matrix.pkl"),"wb") as f:
    pickle.dump(dict_nn_matrix,f)

## Other

In [151]:
exact_matches = metadata_df.apply(lambda r: exact_match(r["std_pred"], r["letter_gold"]), axis=1)
metadata_df["exact_match"] = exact_matches
metadata_df_correct = metadata_df[metadata_df["method"]=="last"].copy()

#WHATCH OUT: some instance in medicine realated subject are repeated.
metadata_df_correct["id_instance"] = metadata_df_correct.apply(lambda r: r["id_instance"][:92]+r["id_instance"][-1], axis=1)
metadata_df_correct.drop_duplicates(subset=["id_instance"],inplace = True)
metadata_df_correct["id_instance"] = metadata_df_correct.apply(lambda r: r["id_instance"][:64], axis=1)

metadata_df_correct.reset_index(inplace=True)
metadata_df_correct.drop(columns=["index"],inplace=True)

#creating pivot table
metadata_df_correct['match_comb'] = metadata_df_correct.apply(lambda row: f"{row['model_name']}_{row['train_instances']}", axis=1)
pivot_df = metadata_df_correct.pivot(index='id_instance', columns='match_comb', values='exact_match')
pivot_df.reset_index(inplace=True)

In [152]:
pivot_df.columns

Index(['id_instance', 'meta-llama/Llama-2-7b-chat-hf_0',
       'meta-llama/Llama-2-7b-chat-hf_2', 'meta-llama/Llama-2-7b-chat-hf_5',
       'meta-llama/Llama-2-7b-hf_0', 'meta-llama/Llama-2-7b-hf_2',
       'meta-llama/Llama-2-7b-hf_5'],
      dtype='object', name='match_comb')

In [187]:
from tabulate import tabulate
cols = ["meta-llama/Llama-2-7b-hf_0","meta-llama/Llama-2-7b-hf_5","meta-llama/Llama-2-7b-chat-hf_0","meta-llama/Llama-2-7b-chat-hf_5"]
matrix = []
for col1 in cols:
    for col2 in cols:
        if col1==col2:
            perc = 1
        else:
            perc = (pivot_df[[col1,col2]].apply(lambda r: r[col1]==r[col2], axis = 1).sum()/len(pivot_df))*100
        matrix.append(perc)
matrix = np.array(matrix).reshape([4,4])
print(tabulate(matrix, headers=cols,showindex=cols))

                                   meta-llama/Llama-2-7b-hf_0    meta-llama/Llama-2-7b-hf_5    meta-llama/Llama-2-7b-chat-hf_0    meta-llama/Llama-2-7b-chat-hf_5
-------------------------------  ----------------------------  ----------------------------  ---------------------------------  ---------------------------------
meta-llama/Llama-2-7b-hf_0                             1                            74.5641                            67.2813                            66.3199
meta-llama/Llama-2-7b-hf_5                            74.5641                        1                                 72.2537                            75.8269
meta-llama/Llama-2-7b-chat-hf_0                       67.2813                       72.2537                             1                                 84.7169
meta-llama/Llama-2-7b-chat-hf_5                       66.3199                       75.8269                            84.7169                             1
