In [8]:
from hydra import compose, initialize

with initialize(
    version_base=None,
    config_path="../berrrt/conf",
):
    cfg = compose(
        config_name="config",
        overrides=[
            "modules=berrrt_early_exit",
            "modules_name=berrrt_early_exit",
            "mode=full"
        ],
    )
cfg
del cfg.modules.additional_prefix

In [5]:
%cd ../

/root/berrrt


  self.shell.db['dhist'] = compress_dhist(dhist)[-100:]


In [9]:
from berrrt.modules.base import ModulesFactory
from berrrt.torch_utils import get_default_device, set_seed

set_seed(42)
model = ModulesFactory(cfg.modules_name).create_model(**cfg.modules)
device = get_default_device()
model = model.to(device)



In [10]:
from safetensors.torch import load_model

load_model(model, "model_output/berrrt_early_exit-ee_softmax-layer_range_0_11-mrpc-30epochs-LR2em05-adamw_torch-3o7t38fu/tmp-checkpoint-500/model.safetensors")
model

BERRRTEarlyExitModel(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementw

In [12]:
from berrrt.dataset import BERRRTDataset

dataset = BERRRTDataset(cfg)
dataset



Map: 100%|██████████| 3668/3668 [00:00<00:00, 6614.54 examples/s]
Map: 100%|██████████| 408/408 [00:00<00:00, 5748.96 examples/s]
Map: 100%|██████████| 1725/1725 [00:00<00:00, 6871.79 examples/s]


<berrrt.dataset.BERRRTDataset at 0x7fc51de666e0>

In [17]:
from transformers import Trainer, TrainingArguments

training_args = TrainingArguments(
    output_dir=f"./{cfg.run_name.run_name}",
    run_name=cfg.run_name.run_name,
    **cfg.train
)

In [18]:
from berrrt.utils import (
    compute_metrics,
    compute_metrics_multi,
)


trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=dataset.train_encoded,
    eval_dataset=dataset.eval_encoded,
    compute_metrics=compute_metrics
    if cfg.dataset.num_classes == 2
    else compute_metrics_multi,
)

In [19]:
test_results = trainer.predict(test_dataset=dataset.test_encoded)
test_results

PredictionOutput(predictions=(array([[-5.886644 ,  5.678859 ],
       [-5.749083 ,  5.659796 ],
       [-4.455859 ,  4.3362823],
       ...,
       [-3.5405414,  3.3606794],
       [-5.279708 ,  5.1101136],
       [-4.3429074,  4.223866 ]], dtype=float32), [array([[-0.3454004 ,  0.30549675],
       [-0.3454004 ,  0.30549675],
       [-0.3454004 ,  0.30549675],
       ...,
       [-0.3454004 ,  0.30549675],
       [-0.3454004 ,  0.30549675],
       [-0.3454004 ,  0.30549675]], dtype=float32), array([[-0.4152403 ,  0.4761602 ],
       [-0.41384202,  0.51524574],
       [-0.42469728,  0.483114  ],
       ...,
       [-0.37960333,  0.45543227],
       [-0.42341247,  0.5256171 ],
       [-0.445907  ,  0.5334038 ]], dtype=float32), array([[-0.26779693,  0.27924356],
       [-0.7137309 ,  0.56444055],
       [-0.4179542 ,  0.15583484],
       ...,
       [-0.22090133, -0.07902987],
       [-1.0041832 ,  1.1339313 ],
       [-1.1885852 ,  1.1498477 ]], dtype=float32), array([[-0.541353  ,  0.4

In [23]:
test_results.metrics

{'test_loss': 12.51516056060791,
 'test_accuracy': 0.7884057971014493,
 'test_f1': 0.8494845360824742,
 'test_precision': 0.8059467918622848,
 'test_recall': 0.8979947689625108,
 'test_final_logits_last_sample': [-4.342907428741455, 4.223865985870361],
 'test_final_labels': 1,
 'test_all_accs': [0.664927536231884,
  0.664927536231884,
  0.696231884057971,
  0.6944927536231884,
  0.7101449275362319,
  0.735072463768116,
  0.7669565217391304,
  0.7814492753623189,
  0.7855072463768116,
  0.784927536231884,
  0.7860869565217391,
  0.7866666666666666],
 'test_all_logits_table': <wandb.data_types.Table at 0x7fc46c783e50>,
 'test_runtime': 6.7027,
 'test_samples_per_second': 257.36,
 'test_steps_per_second': 8.056}

In [39]:
import torch
import random

def summarize_predictions(logits, labels, layer_logits, correct_samples=3, incorrect_samples=2, layers_taken=12):
    # Ensure inputs are torch tensors
    logits = torch.tensor(logits)
    labels = torch.tensor(labels)
    layer_logits = [torch.tensor(layer) for layer in layer_logits]

    # Calculate the predicted class from main logits
    predicted_classes = torch.argmax(logits, dim=1)

    # Initialize lists for the summaries
    full_summary = []
    correct_summary = []
    incorrect_summary = []

    # Generate a summary for each data point
    for i in range(len(labels)):
        # Main summary for each data point
        summary = {
            'Index': i,
            'Main Predicted Class': predicted_classes[i].item(),
            'Main Logits': logits[i].tolist(),
            'Actual Label': labels[i].item(),
            'Layer Logits': {}
        }

        # Adding layer logits and predicted classes for each layer up to layers_taken
        for layer_index in range(min(layers_taken, len(layer_logits))):
            layer_pred_class = torch.argmax(layer_logits[layer_index][i], dim=0).item()
            summary['Layer Logits'][f'Layer {layer_index + 1}'] = {
                'Predicted Class': layer_pred_class,
                'Logits': layer_logits[layer_index][i].tolist()
            }

        full_summary.append(summary)

        # Separate correct and incorrect predictions based on main logits
        if predicted_classes[i] == labels[i]:
            correct_summary.append(summary)
        else:
            incorrect_summary.append(summary)

    # Randomly select the requested number of correct and incorrect samples
    random_correct_samples = random.sample(correct_summary, min(correct_samples, len(correct_summary)))
    random_incorrect_samples = random.sample(incorrect_summary, min(incorrect_samples, len(incorrect_summary)))

    return full_summary, random_correct_samples, random_incorrect_samples

In [45]:
test_results.predictions[1][0]

array([[-0.3454004 ,  0.30549675],
       [-0.3454004 ,  0.30549675],
       [-0.3454004 ,  0.30549675],
       ...,
       [-0.3454004 ,  0.30549675],
       [-0.3454004 ,  0.30549675],
       [-0.3454004 ,  0.30549675]], dtype=float32)

In [71]:
# ensemble
for i in range(12):
    test_results_tensor = torch.tensor(test_results.predictions[1])[i:, :, :]

    all_classes = test_results_tensor.argmax(-1)

    all_classes_voted = torch.mode(all_classes, 0).values
    result_acc = (all_classes_voted == torch.tensor(test_results.label_ids)).sum() / all_classes_voted.shape[0]
    print(f"{i}: {result_acc}, Greater than MoE : {result_acc > test_results.metrics['test_accuracy']}")

0: 0.7744927406311035, Greater than MoE : False
1: 0.7744927406311035, Greater than MoE : False
2: 0.782608687877655, Greater than MoE : False
3: 0.7831884026527405, Greater than MoE : False
4: 0.7837681174278259, Greater than MoE : False
5: 0.7866666913032532, Greater than MoE : False
6: 0.7866666913032532, Greater than MoE : False
7: 0.7837681174278259, Greater than MoE : False
8: 0.7849275469779968, Greater than MoE : False
9: 0.7855072617530823, Greater than MoE : False
10: 0.7872464060783386, Greater than MoE : False
11: 0.7866666913032532, Greater than MoE : False


In [None]:
summary, correct_samples, incorrect_samples = summarize_predictions(test_results.predictions[0], test_results.label_ids, test_results.predictions[1])

In [42]:
from pprint import pprint
pprint(correct_samples)

[{'Actual Label': 1,
  'Index': 571,
  'Layer Logits': {'Layer 1': {'Logits': [-0.3454003930091858,
                                          0.3054967522621155],
                               'Predicted Class': 1},
                   'Layer 10': {'Logits': [-4.8850836753845215,
                                           4.675403118133545],
                                'Predicted Class': 1},
                   'Layer 11': {'Logits': [-4.3408122062683105,
                                           4.275904178619385],
                                'Predicted Class': 1},
                   'Layer 12': {'Logits': [-4.808088302612305,
                                           4.175318241119385],
                                'Predicted Class': 1},
                   'Layer 2': {'Logits': [-0.3772399127483368,
                                          0.4976996183395386],
                               'Predicted Class': 1},
                   'Layer 3': {'Logits': [-0.7099351882934