In [1]:
import json
import random
import numpy as np
import torch
from datasets import load_dataset, Dataset
from tqdm import tqdm
from transformers import AutoModel, AutoTokenizer, pipeline


from repe import repe_pipeline_registry
repe_pipeline_registry()

# Configurationsrandom.seed(0)
np.random.seed(0)


2023-10-11 07:12:09.388515: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


[2023-10-11 07:12:12,827] [INFO] [real_accelerator.py:158:get_accelerator] Setting ds_accelerator to cuda (auto detect)


In [3]:
model_name_or_path =  "microsoft/deberta-xxlarge-v2-mnli"
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
model = AutoModel.from_pretrained(model_name_or_path).half().cuda()

rep_pipeline =  pipeline("rep-reading", model=model, tokenizer=tokenizer, device=model.device)

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [9]:
n_difference = 1
hidden_layers = list(range(-1, -model.config.num_hidden_layers, -1))
direction_method = "pca"
direction_finder_kwargs= {"n_components": 1}


## RTE

In [10]:
template_str = "Consider the {concept} of the sentences: Hypothesis: {sent1}\Premise: {sent2}"
rep_token = 3

In [13]:
def samples(ds, test_set=False):
    pairs = []
    for e in ds:
        concepts = ['entailment', 'contradiction']
                
        if e['label'] == 1 and test_set: # flip the concepts in test for easy eval
            concepts = concepts[::-1]
    
        pair = [template_str.format(concept=c, sent1=e['sentence1'], sent2=e['sentence2']) for c in concepts]
        pairs.extend(pair)
    return pairs

dataset = load_dataset("glue", "rte")
max_train_samples = 64
train_dataset, test_dataset = dataset['train'], dataset['validation']
train_dataset =  train_dataset.shuffle(seed=0)
train_dataset = train_dataset.select(range(max_train_samples))

train_data, train_labels = samples(train_dataset), train_dataset['label']
train_labels = [[1, 0] if label == 0 else [0, 1] for label in train_labels]
test_data =  samples(test_dataset, test_set=True)

Found cached dataset glue (/data/long_phan/.cache/huggingface/datasets/glue/rte/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad)


  0%|          | 0/3 [00:00<?, ?it/s]

Loading cached shuffled indices for dataset at /data/long_phan/.cache/huggingface/datasets/glue/rte/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-dddfb666e65c61ad.arrow


In [14]:
batch_size=128
rep_reader = rep_pipeline.get_directions(
    train_data,
    rep_token=rep_token, 
    hidden_layers=hidden_layers, 
    n_difference=n_difference, 
    train_labels=train_labels, 
    direction_method=direction_method,
    direction_finder_kwargs=direction_finder_kwargs,
    batch_size=batch_size,
    
    max_length=2048,
    padding="longest",
)

batch_size=128
results_val = {layer: {} for layer in hidden_layers}

H_tests = rep_pipeline(
                    test_data,
                    rep_token=rep_token, 
                    hidden_layers=hidden_layers, 
                    rep_reader=rep_reader,
                    batch_size=batch_size,
                    max_length=2048,
                    padding="longest")

n_choices=2
for layer in hidden_layers:
    H_test = [H[layer] for H in H_tests] 
    H_test = [H_test[i:i+n_choices] for i in range(0, len(H_test), n_choices)]

    sign = rep_reader.direction_signs[layer]
    eval_func = min if sign == -1 else max
    
    cors = np.mean([eval_func(H) == H[0] for H in H_test])
    results_val[layer] = cors
    
    print(f"{layer}: {cors}")
    print("=====")


-1: 0.8989169675090253
=====
-2: 0.6209386281588448
=====
-3: 0.7725631768953068
=====
-4: 0.8953068592057761
=====
-5: 0.8700361010830325
=====
-6: 0.7581227436823105
=====
-7: 0.7292418772563177
=====
-8: 0.6931407942238267
=====
-9: 0.628158844765343
=====
-10: 0.7364620938628159
=====
-11: 0.7220216606498195
=====
-12: 0.7509025270758123
=====
-13: 0.6967509025270758
=====
-14: 0.7184115523465704
=====
-15: 0.6823104693140795
=====
-16: 0.7509025270758123
=====
-17: 0.5631768953068592
=====
-18: 0.516245487364621
=====
-19: 0.5018050541516246
=====
-20: 0.48375451263537905
=====
-21: 0.5126353790613718
=====
-22: 0.4981949458483754
=====
-23: 0.51985559566787
=====
-24: 0.5054151624548736
=====
-25: 0.47653429602888087
=====
-26: 0.4548736462093863
=====
-27: 0.5054151624548736
=====
-28: 0.516245487364621
=====
-29: 0.4729241877256318
=====
-30: 0.4368231046931408
=====
-31: 0.4584837545126354
=====
-32: 0.4404332129963899
=====
-33: 0.44404332129963897
=====
-34: 0.43682310469314

## Boolq

In [16]:
template_str = "Consider the {concept} of answering Yes to the question:\nQuestion: {question}?\nContext: {context}"
rep_token = 3

In [17]:
def samples(ds, test_set=False):
    pairs = []
    for e in ds:
        concepts = ['correctness', 'incorrectness']
        question = e['question']
        context = e['passage']
                
        if e['answer'] == False and test_set: # flip the concepts in test for easy eval
            concepts = concepts[::-1]

        pair = [template_str.format(concept=c, question=question, context=context) for c in concepts]
        pairs.extend(pair)
    return pairs

dataset = load_dataset("boolq")
max_train_samples = 64
train_dataset, test_dataset = dataset['train'], dataset['validation']
train_dataset =  train_dataset.shuffle(seed=0)
train_dataset = train_dataset.select(range(max_train_samples))
test_dataset = test_dataset.select(range(500)) # comment out for full set

train_data = samples(train_dataset)
train_labels = [[1,0] if l else [0,1] for l in train_dataset['answer']] # we get the index of true answer to decide the signs here, not for train
test_data =  samples(test_dataset, test_set=True)

Found cached dataset boolq (/data/long_phan/.cache/huggingface/datasets/boolq/default/0.1.0/bf0dd57da941c50de94ae3ce3cef7fea48c08f337a4b7aac484e9dddc5aa24e5)


  0%|          | 0/2 [00:00<?, ?it/s]

Loading cached shuffled indices for dataset at /data/long_phan/.cache/huggingface/datasets/boolq/default/0.1.0/bf0dd57da941c50de94ae3ce3cef7fea48c08f337a4b7aac484e9dddc5aa24e5/cache-deec76edf279b5b0.arrow


In [18]:
batch_size=128
rep_reader = rep_pipeline.get_directions(
    train_data,
    rep_token=rep_token, 
    hidden_layers=hidden_layers, 
    n_difference=n_difference, 
    train_labels=train_labels, 
    direction_method=direction_method,
    direction_finder_kwargs=direction_finder_kwargs,
    batch_size=batch_size,
    
    max_length=2048,
    padding="longest",
)

batch_size=128
results_val = {layer: {} for layer in hidden_layers}

H_tests = rep_pipeline(
                    test_data,
                    rep_token=rep_token, 
                    hidden_layers=hidden_layers, 
                    rep_reader=rep_reader,
                    batch_size=batch_size,
                    max_length=2048,
                    padding="longest")

n_choices=2
for layer in hidden_layers:
    H_test = [H[layer] for H in H_tests] 
    H_test = [H_test[i:i+n_choices] for i in range(0, len(H_test), n_choices)]

    sign = rep_reader.direction_signs[layer]
    eval_func = min if sign == -1 else max
    
    cors = np.mean([eval_func(H) == H[0] for H in H_test])
    results_val[layer] = cors
    
    print(f"{layer}: {cors}")
    print("=====")


-1: 0.788
=====
-2: 0.626
=====
-3: 0.794
=====
-4: 0.788
=====
-5: 0.776
=====
-6: 0.766
=====
-7: 0.744
=====
-8: 0.748
=====
-9: 0.748
=====
-10: 0.752
=====
-11: 0.762
=====
-12: 0.768
=====
-13: 0.756
=====
-14: 0.772
=====
-15: 0.78
=====
-16: 0.766
=====
-17: 0.678
=====
-18: 0.638
=====
-19: 0.62
=====
-20: 0.616
=====
-21: 0.552
=====
-22: 0.444
=====
-23: 0.63
=====
-24: 0.474
=====
-25: 0.61
=====
-26: 0.588
=====
-27: 0.564
=====
-28: 0.44
=====
-29: 0.56
=====
-30: 0.622
=====
-31: 0.594
=====
-32: 0.41
=====
-33: 0.588
=====
-34: 0.416
=====
-35: 0.576
=====
-36: 0.49
=====
-37: 0.472
=====
-38: 0.486
=====
-39: 0.558
=====
-40: 0.566
=====
-41: 0.564
=====
-42: 0.564
=====
-43: 0.546
=====
-44: 0.49
=====
-45: 0.494
=====
-46: 0.484
=====
-47: 0.472
=====


## QNLI

In [19]:
template_str = 'Consider the {concept} of the answer to the question:\nQuestion: {question}\nAnswer: {sentence}'
rep_token = 3

In [20]:
def sample(ds, eval_set=False):
    pairs = []
    for e in ds:
        concepts = ['plausibility', 'implausibility']
        sentence  = e['sentence']
        question = e['question']
                
        if e['label'] == 1 and eval_set: # flip the concepts in test for easy eval
            concepts = concepts[::-1]
        pair = [template_str.format(concept=c, question=question, sentence=sentence) for c in concepts]
        pairs.extend(pair)
    return pairs

max_train_samples = 128

dataset = load_dataset("glue", "qnli")
train_dataset, test_dataset = dataset['train'], dataset['validation']
train_dataset =  train_dataset.shuffle(seed=0)
train_dataset = train_dataset.select(range(max_train_samples))

train_data, train_labels = sample(train_dataset), train_dataset['label']
train_labels = [[1, 0] if label == 0 else [0, 1] for label in train_labels]

test_data = sample(test_dataset, eval_set=True)


Found cached dataset glue (/data/long_phan/.cache/huggingface/datasets/glue/qnli/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad)


  0%|          | 0/3 [00:00<?, ?it/s]

Loading cached shuffled indices for dataset at /data/long_phan/.cache/huggingface/datasets/glue/qnli/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-b7167382490e7670.arrow


In [21]:
rep_reader = rep_pipeline.get_directions(
    train_data,
    rep_token=rep_token, 
    hidden_layers=hidden_layers, 
    n_difference=n_difference, 
    train_labels=train_labels, 
    direction_method=direction_method,
    direction_finder_kwargs=direction_finder_kwargs,
    batch_size=8,
    
    max_length=2048,
    padding="longest",
)

In [22]:
batch_size=128
results_val = {layer: {} for layer in hidden_layers}

H_tests = rep_pipeline(test_data, 
                    rep_token=rep_token, 
                    hidden_layers=hidden_layers, 
                    rep_reader=rep_reader,
                    batch_size=batch_size,
                    max_length=2048,
                    padding="longest")


In [23]:
n_choices=2
for layer in hidden_layers:
    H_test = [H[layer] for H in H_tests] 
    H_test = [H_test[i:i+n_choices] for i in range(0, len(H_test), n_choices)]

    sign = rep_reader.direction_signs[layer]
    eval_func = min if sign == -1 else max
    
    cors = np.mean([eval_func(H) == H[0] for H in H_test])
    results_val[layer] = cors
    
    print(f"{layer}: {cors}")
    print("=====")

-1: 0.6976020501555922
=====
-2: 0.4946000366099213
=====
-3: 0.5773384587223137
=====
-4: 0.6686802123375435
=====
-5: 0.6132161815852096
=====
-6: 0.5464030752333883
=====
-7: 0.5374336445176643
=====
-8: 0.5405454878272011
=====
-9: 0.5251693208859601
=====
-10: 0.5174812374153396
=====
-11: 0.5370675453047776
=====
-12: 0.5291964122277137
=====
-13: 0.5403624382207578
=====
-14: 0.5648910854841662
=====
-15: 0.5606809445359693
=====
-16: 0.556836902800659
=====
-17: 0.528281164195497
=====
-18: 0.5207761303313198
=====
-19: 0.49679663188724144
=====
-20: 0.48288486179754714
=====
-21: 0.47995606809445357
=====
-22: 0.4885593995972909
=====
-23: 0.48929159802306427
=====
-24: 0.4995423759838916
=====
-25: 0.5143693941058026
=====
-26: 0.5022881200805418
=====
-27: 0.5088779059125023
=====
-28: 0.5081457074867289
=====
-29: 0.4993593263774483
=====
-30: 0.5015559216547685
=====
-31: 0.5046677649643053
=====
-32: 0.48416620904265056
=====
-33: 0.4925864909390445
=====
-34: 0.491488193

## PIQA

In [28]:
template_str = "Consider the amount of plausible reasoning in the scenario:\n{goal} {sol}"
rep_token = 5

In [29]:
def samples(ds, test_set=False):
    pairs = []
    for e in ds:
        solutions  = [e['sol1'], e['sol2']]
        goal = e['goal']
                
        if e['label'] == 1 and test_set: # flip the true sol in test for easy eval
            solutions = solutions[::-1]

        pair = [template_str.format(goal=goal, sol=sol) for sol in solutions]
        pairs.extend(pair)
    return pairs

max_train_samples = 64

dataset = load_dataset("piqa")
train_dataset, test_dataset = dataset['train'], dataset['validation']
train_dataset =  train_dataset.shuffle(seed=0)
train_dataset = train_dataset.select(range(max_train_samples))

train_data, train_labels = samples(train_dataset), train_dataset['label']
train_labels = [[1, 0] if label == 0 else [0, 1] for label in train_labels]

test_data = samples(test_dataset, test_set=True)

Found cached dataset piqa (/data/long_phan/.cache/huggingface/datasets/piqa/plain_text/1.1.0/6c611c1a9bf220943c4174e117d3b660859665baf1d43156230116185312d011)


  0%|          | 0/3 [00:00<?, ?it/s]

Loading cached shuffled indices for dataset at /data/long_phan/.cache/huggingface/datasets/piqa/plain_text/1.1.0/6c611c1a9bf220943c4174e117d3b660859665baf1d43156230116185312d011/cache-ed09e117332c4dbb.arrow


In [30]:
batch_size=128
rep_reader = rep_pipeline.get_directions(
    train_data,
    rep_token=rep_token, 
    hidden_layers=hidden_layers, 
    n_difference=n_difference, 
    train_labels=train_labels, 
    direction_method=direction_method,
    direction_finder_kwargs=direction_finder_kwargs,
    batch_size=batch_size,
    
    max_length=2048,
    padding="longest",
)

batch_size=128
results_val = {layer: {} for layer in hidden_layers}

H_tests = rep_pipeline(
                    test_data,
                    rep_token=rep_token, 
                    hidden_layers=hidden_layers, 
                    rep_reader=rep_reader,
                    batch_size=batch_size,
                    max_length=2048,
                    padding="longest")

n_choices=2
for layer in hidden_layers:
    H_test = [H[layer] for H in H_tests] 
    H_test = [H_test[i:i+n_choices] for i in range(0, len(H_test), n_choices)]

    sign = rep_reader.direction_signs[layer]
    eval_func = min if sign == -1 else max
    
    cors = np.mean([eval_func(H) == H[0] for H in H_test])
    results_val[layer] = cors
    
    print(f"{layer}: {cors}")
    print("=====")


-1: 0.6996735582154516
=====
-2: 0.5712731229597389
=====
-3: 0.6964091403699674
=====
-4: 0.6991294885745375
=====
-5: 0.6936887921653971
=====
-6: 0.6866158868335147
=====
-7: 0.6877040261153428
=====
-8: 0.6942328618063112
=====
-9: 0.6893362350380848
=====
-10: 0.6877040261153428
=====
-11: 0.6871599564744287
=====
-12: 0.6964091403699674
=====
-13: 0.6947769314472253
=====
-14: 0.6708378672470077
=====
-15: 0.6942328618063112
=====
-16: 0.6822633297062024
=====
-17: 0.6343852013057671
=====
-18: 0.6066376496191512
=====
-19: 0.5990206746463548
=====
-20: 0.5794341675734495
=====
-21: 0.5413492927094669
=====
-22: 0.47551686615886835
=====
-23: 0.4591947769314472
=====
-24: 0.4733405875952122
=====
-25: 0.4820457018498368
=====
-26: 0.4619151251360174
=====
-27: 0.49347116430903154
=====
-28: 0.4923830250272035
=====
-29: 0.4836779107725789
=====
-30: 0.48639825897714906
=====
-31: 0.514145810663765
=====
-32: 0.5457018498367792
=====
-33: 0.45212187159956474
=====
-34: 0.460282916

## COPA

In [51]:
template_str = "Consider the amount of plausible reasoning in the scenario:\n{premise} {hook} {alternative}"
rep_token = 5

In [52]:
copa_hook = {
    'cause': 'because',
    'effect': 'then',
}
def samples(ds, test_set=False):
    pairs = []
    for e in ds:
        alternatives  = [e['choice1'], e['choice2']]
        premise = e['premise']
        hook = copa_hook[e['question']]
                
        if e['label'] == 1 and test_set: # flip the true alt in test for easy eval
            alternatives = alternatives[::-1]
        pair = [template_str.format(premise=premise, hook=hook, alternative=alternative) for alternative in alternatives]
        pairs.extend(pair)
    return pairs

max_train_samples = 64
dataset = load_dataset("pkavumba/balanced-copa")
# Dataset is from https://github.com/Balanced-COPA/Balanced-COPA
# Need to filter out mirrored samples to get original COPA
dataset['train'] =  dataset['train'].filter(lambda e: not e['mirrored'])

train_dataset, test_dataset= dataset['train'], dataset['test']
train_dataset =  train_dataset.shuffle(seed=0)
train_dataset = train_dataset.select(range(max_train_samples))

train_data, train_labels = samples(train_dataset), train_dataset['label']
train_labels = [[1, 0] if label == 0 else [0, 1] for label in train_labels]

test_data = samples(test_dataset, test_set=True)

Found cached dataset csv (/data/long_phan/.cache/huggingface/datasets/pkavumba___csv/pkavumba--balanced-copa-81fb5dd3c6eeb4d7/0.0.0/eea64c71ca8b46dd3f537ed218fc9bf495d5707789152eb2764f5c78fa66d59d)


  0%|          | 0/2 [00:00<?, ?it/s]

Loading cached processed dataset at /data/long_phan/.cache/huggingface/datasets/pkavumba___csv/pkavumba--balanced-copa-81fb5dd3c6eeb4d7/0.0.0/eea64c71ca8b46dd3f537ed218fc9bf495d5707789152eb2764f5c78fa66d59d/cache-f6b83c6a27e35715.arrow
Loading cached shuffled indices for dataset at /data/long_phan/.cache/huggingface/datasets/pkavumba___csv/pkavumba--balanced-copa-81fb5dd3c6eeb4d7/0.0.0/eea64c71ca8b46dd3f537ed218fc9bf495d5707789152eb2764f5c78fa66d59d/cache-b4e8b29247fbf014.arrow


In [53]:
batch_size=128
rep_reader = rep_pipeline.get_directions(
    train_data,
    rep_token=rep_token, 
    hidden_layers=hidden_layers, 
    n_difference=n_difference, 
    train_labels=train_labels, 
    direction_method=direction_method,
    direction_finder_kwargs=direction_finder_kwargs,
    batch_size=batch_size,
    
    max_length=2048,
    padding="longest",
)

batch_size=128
results_val = {layer: {} for layer in hidden_layers}

H_tests = rep_pipeline(
                    test_data,
                    rep_token=rep_token, 
                    hidden_layers=hidden_layers, 
                    rep_reader=rep_reader,
                    batch_size=batch_size,
                    max_length=2048,
                    padding="longest")

n_choices=2
for layer in hidden_layers:
    H_test = [H[layer] for H in H_tests] 
    H_test = [H_test[i:i+n_choices] for i in range(0, len(H_test), n_choices)]

    sign = rep_reader.direction_signs[layer]
    eval_func = min if sign == -1 else max
    
    cors = np.mean([eval_func(H) == H[0] for H in H_test])
    results_val[layer] = cors
    
    print(f"{layer}: {cors}")
    print("=====")


-1: 0.906
=====
-2: 0.722
=====
-3: 0.896
=====
-4: 0.894
=====
-5: 0.892
=====
-6: 0.886
=====
-7: 0.868
=====
-8: 0.872
=====
-9: 0.866
=====
-10: 0.872
=====
-11: 0.87
=====
-12: 0.862
=====
-13: 0.868
=====
-14: 0.856
=====
-15: 0.86
=====
-16: 0.85
=====
-17: 0.786
=====
-18: 0.746
=====
-19: 0.748
=====
-20: 0.732
=====
-21: 0.732
=====
-22: 0.724
=====
-23: 0.736
=====
-24: 0.668
=====
-25: 0.682
=====
-26: 0.702
=====
-27: 0.702
=====
-28: 0.682
=====
-29: 0.644
=====
-30: 0.614
=====
-31: 0.562
=====
-32: 0.544
=====
-33: 0.546
=====
-34: 0.532
=====
-35: 0.514
=====
-36: 0.512
=====
-37: 0.518
=====
-38: 0.544
=====
-39: 0.59
=====
-40: 0.568
=====
-41: 0.57
=====
-42: 0.564
=====
-43: 0.538
=====
-44: 0.54
=====
-45: 0.546
=====
-46: 0.566
=====
-47: 0.542
=====


## Story Cloze

In [47]:
template_str = "Consider the plausibility of the scenario:\n{scenario}"
rep_token = 3

In [48]:
def samples(ds, test_set=False):
    pairs = []
    for e in ds:
        concepts = ['plausibility']
        sentences = ' '.join([e[f'input_sentence_{i}'] for i in range(1, 5)])
        quiz_sentences = [e['sentence_quiz1'], e['sentence_quiz2']]
                
        if e['answer_right_ending'] == 2 and test_set: # flip to have true sentence on first in test set for easy evaluation 
            quiz_sentences = quiz_sentences[::-1]

        pair = [template_str.format(scenario=' '.join([sentences, s])) for s in quiz_sentences]
        pairs.extend(pair)
    return pairs

max_train_samples = 64

data_dir="" # See https://huggingface.co/datasets/story_cloze to download data manually
data_dir='/data/private_models/cais_models/misc/'
dataset = load_dataset("story_cloze", '2016', data_dir=data_dir)

train_dataset, test_dataset = dataset['validation'], dataset['test']
train_dataset =  train_dataset.shuffle(seed=0)
train_dataset = train_dataset.select(range(max_train_samples))

train_data = samples(train_dataset)
train_labels = [[1, 0] if label == 1 else [0, 1] for label in train_dataset['answer_right_ending']]

test_data = samples(test_dataset, test_set=True)

Found cached dataset story_cloze (/data/long_phan/.cache/huggingface/datasets/story_cloze/2016-4349b0206f129d71/0.0.0/45cead0538c3deb72d731a7990e60835c2c9c5d5d5b1e95a7dd47ccf593671e4)


  0%|          | 0/2 [00:00<?, ?it/s]

Loading cached shuffled indices for dataset at /data/long_phan/.cache/huggingface/datasets/story_cloze/2016-4349b0206f129d71/0.0.0/45cead0538c3deb72d731a7990e60835c2c9c5d5d5b1e95a7dd47ccf593671e4/cache-cd5f729fe91c7970.arrow


In [49]:
batch_size=128
rep_reader = rep_pipeline.get_directions(
    train_data,
    rep_token=rep_token, 
    hidden_layers=hidden_layers, 
    n_difference=n_difference, 
    train_labels=train_labels, 
    direction_method=direction_method,
    direction_finder_kwargs=direction_finder_kwargs,
    batch_size=batch_size,
    
    max_length=2048,
    padding="longest",
)

batch_size=128
results_val = {layer: {} for layer in hidden_layers}

H_tests = rep_pipeline(
                    test_data,
                    rep_token=rep_token, 
                    hidden_layers=hidden_layers, 
                    rep_reader=rep_reader,
                    batch_size=batch_size,
                    max_length=2048,
                    padding="longest")

n_choices=2
for layer in hidden_layers:
    H_test = [H[layer] for H in H_tests] 
    H_test = [H_test[i:i+n_choices] for i in range(0, len(H_test), n_choices)]

    sign = rep_reader.direction_signs[layer]
    eval_func = min if sign == -1 else max
    
    cors = np.mean([eval_func(H) == H[0] for H in H_test])
    results_val[layer] = cors
    
    print(f"{layer}: {cors}")
    print("=====")


-1: 0.9716729021913415
=====
-2: 0.9294494922501336
=====
-3: 0.9652592196686264
=====
-4: 0.9700694815606627
=====
-5: 0.9679315873864244
=====
-6: 0.965793693212186
=====
-7: 0.9529663281667557
=====
-8: 0.9497594869053981
=====
-9: 0.9492250133618386
=====
-10: 0.9428113308391235
=====
-11: 0.943345804382683
=====
-12: 0.9363976483164084
=====
-13: 0.9390700160342063
=====
-14: 0.928915018706574
=====
-15: 0.9208979155531801
=====
-16: 0.8727952966328166
=====
-17: 0.8551576696953501
=====
-18: 0.8268305718866916
=====
-19: 0.8001068947087119
=====
-20: 0.8295029396044896
=====
-21: 0.8134687332977018
=====
-22: 0.8348476750400855
=====
-23: 0.8107963655799038
=====
-24: 0.7771245323356494
=====
-25: 0.7648316408337787
=====
-26: 0.772314270443613
=====
-27: 0.7509353287012293
=====
-28: 0.7482629609834314
=====
-29: 0.7274184927846071
=====
-30: 0.7541421699625869
=====
-31: 0.7509353287012293
=====
-32: 0.7338321753073223
=====
-33: 0.7274184927846071
=====
-34: 0.6932121859967931