In [44]:
from torch.utils.data import DataLoader


from refactor.utils.data import FilePaths, load_antibiotic_data
from refactor.utils.hooking import get_activations as get_activations_new
from refactor.utils.compatibility import ModelConfig
from refactor.probes import model_setup
from utils.probe_confidence_intervals import bootstrap

"""This function runs an entire pipeline that bootstraps, trains and creates confidence intervals showing
    The probes f1 score on different labels and across layers
    
    We bootstrap 10 times
    Results are saved in this folder: results/data/probe_confidence_intervals/*model_name*_reg_lambda_*reg_lambda*

Args:
    model_name (_type_): _description_
    reg_lambdas (_type_): _description_
"""

model_name = "downloaded_models/gpt_gptsw3_en_is_da_356m_gbs1024"


# loads model
print("Load model")
model, tokenizer, device = model_setup(model_name)


# loads data
print("Load data")
ds = load_antibiotic_data(
    file_paths=FilePaths.antibiotic,
    file_extension='txt'
)
loader = DataLoader(ds, batch_size=32, shuffle=True)



# sets training parameters
meta_data = {}
meta_data["hidden_size"] = ModelConfig.hidden_size(model)
meta_data["hidden_layers"] = ModelConfig.hidden_layers(model)
meta_data["model_name"] = model_name.split("/")[0]
meta_data["learning_rate"] = 0.001
meta_data["reg_lambda"] = 10
meta_data["amount_epochs"] = 1


# extracts activation from forward passes on data
# We use hooks to extract the different layer activations that will be used to train our probes

print("Extract activations")
activations = get_activations_new(
    loader=loader, 
    model=model,
    tokenizer=tokenizer,
    hook_addresses=None,
    layers=None,
    max_batches=2,
    sampling_prob=0.1
)

Load model
found device: cpu
Load data
Extract activations


  2%|▏         | 3/130 [01:02<44:16, 20.92s/it]


In [46]:
positions = []
for index, (key, val) in enumerate(activations.items()):
    if index == 6: break
    positions.append(key.replace("layer.0.",""))

In [None]:
d = {}
for pos in positions:
    acts_ds_by_layer = {}
    for layer in range(meta_data["hidden_layers"]):
        pos_key = f"layer.{layer}.{pos}"
        acts_ds_by_layer[layer] = activations[pos_key]
    d[pos] = acts_ds_by_layer



In [None]:
# extracts activation from forward passes on data
# We use hooks to extract the different layer activations that will be used to train our probes
from utils.probe_confidence_intervals import get_activations

print("Extract activations")
#activation_ds_by_layer = get_activations(meta_data,loader, tokenizer, device, model)



Extract activations


  5%|▍         | 6/130 [00:01<00:37,  3.33it/s]


In [50]:
import numpy as np

In [51]:
s = set()
for i in range(meta_data["hidden_layers"]):
    unique_labels = set(np.array(acts_ds_by_layer[i].labels))
    [s.add(x) for x in unique_labels]
number_labels = len(s)
meta_data["number_labels"] = number_labels

In [52]:
boot = bootstrap(10, meta_data, acts_ds_by_layer, device)