## Loading Packages

In [1]:
!pip install -q transformers
!pip install -q torchinfo
!pip install -U -q datasets fsspec huggingface_hub
!pip install -q evaluate

In [2]:
import numpy as np
import torch
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import accuracy_score

import transformers
import evaluate

from datasets import load_dataset, concatenate_datasets
from torchinfo import summary

from transformers import AutoTokenizer, AutoModel, AutoModelForSequenceClassification, AutoConfig
from transformers import TrainingArguments, Trainer, pipeline

from typing import Dict, Any
import copy
from pathlib import Path
import json
import os

## Dataset preparation

In [3]:
imdb_dataset = load_dataset("imdb")

train_tiny = imdb_dataset['train'].shuffle().select(range(100))
train_small = imdb_dataset['train'].shuffle().select(range(1000))
train_medium = imdb_dataset['train'].shuffle().select(range(2500))
train_large = imdb_dataset['train'].shuffle().select(range(10000))


test_tiny = imdb_dataset['test'].shuffle().select(range(100))
test_small = imdb_dataset['test'].shuffle().select(range(1000))
test_medium = imdb_dataset['test'].shuffle().select(range(5000))
test_large = imdb_dataset['test'].shuffle().select(range(10000))

README.md: 0.00B [00:00, ?B/s]

train-00000-of-00001.parquet:   0%|          | 0.00/21.0M [00:00<?, ?B/s]

test-00000-of-00001.parquet:   0%|          | 0.00/20.5M [00:00<?, ?B/s]

unsupervised-00000-of-00001.parquet:   0%|          | 0.00/42.0M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/25000 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/25000 [00:00<?, ? examples/s]

Generating unsupervised split:   0%|          | 0/50000 [00:00<?, ? examples/s]

In [4]:
MAX_SEQUENCE_LENGTH = 100

def preprocess_imdb(data, tokenizer):
    review_text = data['text']

    encoded = tokenizer.batch_encode_plus(
            review_text,
            max_length=MAX_SEQUENCE_LENGTH,
            padding='max_length',
            truncation=True,
            return_attention_mask=True,
            return_token_type_ids=True,
            return_tensors="pt"
        )

    return encoded

In [5]:
rte_dataset = load_dataset("glue", "rte")

README.md: 0.00B [00:00, ?B/s]

train-00000-of-00001.parquet:   0%|          | 0.00/584k [00:00<?, ?B/s]

validation-00000-of-00001.parquet:   0%|          | 0.00/69.0k [00:00<?, ?B/s]

test-00000-of-00001.parquet:   0%|          | 0.00/621k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/2490 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/277 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/3000 [00:00<?, ? examples/s]

In [6]:
qqp_dataset = load_dataset("glue", "qqp")

qqp_train = qqp_dataset["train"].shuffle().select(range(2500))
qqp_test = qqp_dataset["validation"].shuffle().select(range(4000))

train-00000-of-00001.parquet:   0%|          | 0.00/33.6M [00:00<?, ?B/s]

validation-00000-of-00001.parquet:   0%|          | 0.00/3.73M [00:00<?, ?B/s]

test-00000-of-00001.parquet:   0%|          | 0.00/36.7M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/363846 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/40430 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/390965 [00:00<?, ? examples/s]

In [7]:
qnli_dataset = load_dataset("glue", "qnli")
qnli_train = qnli_dataset["train"].shuffle().select(range(2500))
qnli_test = qnli_dataset["validation"].shuffle().select(range(4000))

train-00000-of-00001.parquet:   0%|          | 0.00/17.5M [00:00<?, ?B/s]

validation-00000-of-00001.parquet:   0%|          | 0.00/872k [00:00<?, ?B/s]

test-00000-of-00001.parquet:   0%|          | 0.00/877k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/104743 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/5463 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/5463 [00:00<?, ? examples/s]

In [8]:
mnli_dataset = load_dataset("glue", "mnli")
validation_all = concatenate_datasets([
    mnli_dataset["validation_matched"],
    mnli_dataset["validation_mismatched"]
])

mnli_train = mnli_dataset["train"].shuffle().select(range(2500))
mnli_test = validation_all.shuffle().select(range(4000))

train-00000-of-00001.parquet:   0%|          | 0.00/52.2M [00:00<?, ?B/s]

(…)alidation_matched-00000-of-00001.parquet:   0%|          | 0.00/1.21M [00:00<?, ?B/s]

(…)dation_mismatched-00000-of-00001.parquet:   0%|          | 0.00/1.25M [00:00<?, ?B/s]

test_matched-00000-of-00001.parquet:   0%|          | 0.00/1.22M [00:00<?, ?B/s]

test_mismatched-00000-of-00001.parquet:   0%|          | 0.00/1.26M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/392702 [00:00<?, ? examples/s]

Generating validation_matched split:   0%|          | 0/9815 [00:00<?, ? examples/s]

Generating validation_mismatched split:   0%|          | 0/9832 [00:00<?, ? examples/s]

Generating test_matched split:   0%|          | 0/9796 [00:00<?, ? examples/s]

Generating test_mismatched split:   0%|          | 0/9847 [00:00<?, ? examples/s]

In [9]:
wnli_dataset = load_dataset("glue", "wnli")

train-00000-of-00001.parquet:   0%|          | 0.00/38.8k [00:00<?, ?B/s]

validation-00000-of-00001.parquet:   0%|          | 0.00/11.1k [00:00<?, ?B/s]

test-00000-of-00001.parquet:   0%|          | 0.00/13.6k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/635 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/71 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/146 [00:00<?, ? examples/s]

In [10]:
def preprocess_rte(examples, tokenizer):
    return tokenizer(
        text=examples["sentence1"],
        text_pair=examples["sentence2"],
        padding="max_length",
        truncation=True,
        max_length=256,
        return_tensors='pt' #important, don't get cuda error.
    )

def preprocess_qqp(examples, tokenizer):
    return tokenizer(
        text=examples["question1"],
        text_pair=examples["question2"],
        padding="max_length",
        truncation=True,
        max_length=256,
        return_tensors='pt'
    )

def preprocess_qnli(examples, tokenizer):
    return tokenizer(
        text=examples["question"],
        text_pair=examples["sentence"],
        padding="max_length",
        truncation=True,
        max_length=256,
        return_tensors='pt'
    )

def preprocess_mnli(examples, tokenizer):
    return tokenizer(
        text=examples["premise"],
        text_pair=examples["hypothesis"],
        padding="max_length",
        truncation=True,
        max_length=256,
        return_tensors='pt'
    )

def preprocess_wnli(examples, tokenizer):
    return tokenizer(
        text=examples["sentence1"],
        text_pair=examples["sentence2"],
        padding="max_length",
        truncation=True,
        max_length=256,
        return_tensors='pt'
    )


## Loading Functions

In [11]:
metric = evaluate.load('accuracy')

def compute_metrics(p):
    predictions, labels = p
    predictions = np.argmax(predictions, axis=-1)
    return metric.compute(predictions=predictions, references=labels)

Downloading builder script: 0.00B [00:00, ?B/s]

In [12]:
def fine_tune_classification_model(model, tokenizer, train_data, dev_data,
                                   num_epochs=1, output_dir="temp",
                                   batch_size=16,
                                   data_preprocessor=preprocess_imdb,
                                   save_path=None, resume_from=None):
    """
    Fine tune model on given dataset.
    """
    # Preprocess data
    train_data_processed = train_data.map(data_preprocessor, batched=True,
                                         fn_kwargs={'tokenizer': tokenizer})
    dev_data_processed = dev_data.map(data_preprocessor, batched=True,
                                     fn_kwargs={'tokenizer': tokenizer})

    # Training configuration
    training_args = TrainingArguments(
        output_dir=output_dir,
        per_device_train_batch_size=batch_size,
        per_device_eval_batch_size=batch_size,
        num_train_epochs=num_epochs,
        eval_strategy="epoch",
        save_strategy="epoch",
        save_total_limit=2,
        save_on_each_node=True,
        report_to='none',
        load_best_model_at_end=False,
    )

    # Create trainer
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_data_processed,
        eval_dataset=dev_data_processed,
        compute_metrics=compute_metrics
    )

    # Train with optional resume
    if resume_from:
        trainer.train(resume_from_checkpoint=resume_from)
    else:
        trainer.train()

    # Save final model state
    if save_path:
        # Save FULL state instead of just model
        trainer.save_model(save_path)

    return trainer

In [13]:
def remove_svs_decile(linear_layer, decile_to_remove):
    """
    Remove a specific decile of singular values from a linear layer.
    :param linear_layer: nn.Linear module
    :param decile_to_remove: int in [0, 9], where 0 = top 10%, 9 = bottom 10%
    """
    with torch.no_grad():
        W = linear_layer.weight.data
        U, s, Vh = torch.linalg.svd(W.float(), full_matrices=False)

        n = len(s)
        decile_size = n // 10
        start = decile_to_remove * decile_size
        end = (decile_to_remove + 1) * decile_size if decile_to_remove < 9 else n

        s_trunc = s.clone()
        s_trunc[start:end] = 0

        W_trunc = U @ torch.diag(s_trunc) @ Vh
        linear_layer.weight.data = W_trunc.to(W.dtype)



def remove_svs_from_bert_decile(
                                model,
                                decile_to_remove,
                                layer_type="all",
                                layers="all",  #specify layers to target.
                            ):
    """
    Remove a decile of SVs from specific layers/matrices.

    :param layers: "all", list of indices (e.g., [0,1,2]), or "late" (last 4)
    :param layer_type: "all", "query", "key", "value", "output", "ffn"
    """
    if layers == "late":
        layers = list(range(len(model.bert.encoder.layer) - 4, len(model.bert.encoder.layer)))
    elif layers == "all":
        layers = range(len(model.bert.encoder.layer))

    for layer_idx in layers:
        layer = model.bert.encoder.layer[layer_idx]
        if layer_type in ["all", "query"]:
            remove_svs_decile(layer.attention.self.query, decile_to_remove)
        if layer_type in ["all", "key"]:
            remove_svs_decile(layer.attention.self.key, decile_to_remove)
        if layer_type in ["all", "value"]:
            remove_svs_decile(layer.attention.self.value, decile_to_remove)
        if layer_type in ["all", "output"]:
            remove_svs_decile(layer.attention.output.dense, decile_to_remove)
        if layer_type in ["all", "ffn"]:
            remove_svs_decile(layer.intermediate.dense, decile_to_remove)
            remove_svs_decile(layer.output.dense, decile_to_remove)

def get_minimal_eval_args(model):
    """Create minimal TrainingArguments for evaluation"""
    return TrainingArguments(
        output_dir="./temp_eval",  # Temporary directory
        per_device_eval_batch_size=16,
        report_to='none',  # Disable logging
        no_cuda=False if torch.cuda.is_available() else True,
        fp16=model.config.torch_dtype == torch.float16,
    )

def evaluate_decile_removal(
        control_model,
        tokenizer,
        compute_metrics,
        test_data,
        layer_type="all",
        layers="all"
    ):
    """
    Evaluate validation accuracy after removing each decile of SVs.
    Returns a list of accuracies, one per decile.
    """
    trainer_args = TrainingArguments(
                                      output_dir="./tmp_eval",      # required, but can be temp
                                      per_device_eval_batch_size=32,  # or whatever you used
                                      do_train=False,
                                      do_eval=True,
                                      logging_strategy="no",  # suppress extra logs
                                      save_strategy="no",
                                      report_to="none"
                                      )
    accuracies = []

    for decile in range(10):
        print(f"\n=== Evaluating decile {decile} removal ===")
        # Clone model in memory instead of reloading from disk
        treatment_model = copy.deepcopy(control_model)

        # Apply decile-specific SV removal
        remove_svs_from_bert_decile(
            treatment_model,
            decile_to_remove=decile,
            layer_type=layer_type,
            layers=layers
        )

        treatment_trainer = Trainer(
            model=treatment_model,
            args=trainer_args,
            compute_metrics=compute_metrics,
            eval_dataset=test_data
        )
        results = treatment_trainer.evaluate()
        accuracies.append(results["eval_accuracy"])

    return accuracies

## Training Models

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
#loading base stuff.
bert_tokenizer = AutoTokenizer.from_pretrained('bert-base-cased')
bert_model = AutoModel.from_pretrained('bert-base-cased')

#assign model name to variable for convenience
model_checkpoint_name = "bert-base-cased"

#initialize tokenizer using model checkpoint name
bert_tokenizer = AutoTokenizer.from_pretrained(model_checkpoint_name)

tokenizer_config.json:   0%|          | 0.00/49.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/213k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/436k [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/436M [00:00<?, ?B/s]

### IMDb (easy)

In [None]:
test_processed = test_large.map(preprocess_imdb, batched=True, fn_kwargs={'tokenizer': bert_tokenizer})

Initial training (1 epoch)

In [None]:
model_imdb = AutoModelForSequenceClassification.from_pretrained('bert-base-cased')
trainer_imdb = fine_tune_classification_model(
    model_imdb,
    bert_tokenizer,
    train_large,
    test_large,
    num_epochs=1,
    output_dir="temp_imdb_1",
    save_path="drive/MyDrive/Summer2025/w266/Project/models/imdb_epoch1"
)

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-cased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Epoch,Training Loss,Validation Loss,Accuracy
1,0.4448,0.345577,0.8481


Continue training (1 -> 2 epochs)

In [None]:
#find latest checkpoint
checkpoint_dir = max(
    [f for f in Path("temp_imdb_1").iterdir() if f.is_dir() and "checkpoint" in f.name],
    key=os.path.getmtime
)

#load initial model
model_epoch1 = AutoModelForSequenceClassification.from_pretrained(
    "drive/MyDrive/Summer2025/w266/Project/models/imdb_epoch1"
)

# Train additional epoch
trainer2 = fine_tune_classification_model(
    model_epoch1,
    bert_tokenizer,
    train_large,
    test_large,
    num_epochs=2, #resume
    output_dir="temp_imdb_2",  # New temp dir
    save_path="drive/MyDrive/Summer2025/w266/Project/models/imdb_epoch2",
    resume_from=str(checkpoint_dir)  # Resume from checkpoint
)

Epoch,Training Loss,Validation Loss,Accuracy
2,0.2666,0.391251,0.8542


Continue training (2 --> 3 epochs)

In [None]:
#find latest checkpoint
checkpoint_dir = max(
    [f for f in Path("temp_imdb_2").iterdir() if f.is_dir() and "checkpoint" in f.name],
    key=os.path.getmtime
)

#load initial model
model_epoch2 = AutoModelForSequenceClassification.from_pretrained(
    "drive/MyDrive/Summer2025/w266/Project/models/imdb_epoch2"
)

# Train additional epoch
trainer3 = fine_tune_classification_model(
    model_epoch2,
    bert_tokenizer,
    train_large,
    test_large,
    num_epochs=3,
    output_dir="temp_imdb_3",  # New temp dir
    save_path="drive/MyDrive/Summer2025/w266/Project/models/imdb_epoch3",
    resume_from=str(checkpoint_dir)  # Resume from checkpoint
)

Epoch,Training Loss,Validation Loss,Accuracy
3,0.1707,0.497933,0.859


In [None]:
#find latest checkpoint
checkpoint_dir = max(
    [f for f in Path("temp_imdb_3").iterdir() if f.is_dir() and "checkpoint" in f.name],
    key=os.path.getmtime
)

#load initial model
model_epoch3 = AutoModelForSequenceClassification.from_pretrained(
    "drive/MyDrive/Summer2025/w266/Project/models/imdb_epoch3"
)

trainer10 = fine_tune_classification_model(
    model_epoch3,
    bert_tokenizer,
    train_large,
    test_large,
    num_epochs=10, # already trained 3 epochs --> 3 + 7 = 10 total.
    output_dir="temp_imdb_10",  # New temp dir
    save_path="drive/MyDrive/Summer2025/w266/Project/models/imdb_epoch10",
    resume_from=str(checkpoint_dir)  # Resume from checkpoint
)

Epoch,Training Loss,Validation Loss,Accuracy
4,0.2036,0.457572,0.8402
5,0.1211,0.679779,0.8347
6,0.0886,0.815664,0.8427
7,0.0518,1.015398,0.8458
8,0.0178,1.155892,0.8482
9,0.0053,1.273104,0.8475
10,0.0026,1.308429,0.8483


### RTE (medium)

In [None]:
model_checkpoint = "bert-base-cased"
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)

Initial training (1 epoch)

In [None]:
model_rte = AutoModelForSequenceClassification.from_pretrained('bert-base-cased')
trainer_rte = fine_tune_classification_model(
    model_rte,
    bert_tokenizer,
    train_data=rte_dataset["train"],
    dev_data=rte_dataset["validation"],
    data_preprocessor=preprocess_rte,
    num_epochs=1,
    output_dir="temp_rte_1",
    save_path="drive/MyDrive/Summer2025/w266/Project/models/rte_epoch1"
)

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-cased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Map:   0%|          | 0/2490 [00:00<?, ? examples/s]

Epoch,Training Loss,Validation Loss,Accuracy
1,No log,0.656533,0.631769


Continue training (1 --> 2 epochs)

In [None]:
#find latest checkpoint
checkpoint_dir = max(
    [f for f in Path("temp_rte_1").iterdir() if f.is_dir() and "checkpoint" in f.name],
    key=os.path.getmtime
)

#load initial model
model_rte1 = AutoModelForSequenceClassification.from_pretrained(
    "drive/MyDrive/Summer2025/w266/Project/models/rte_epoch1"
)

# Train additional epoch
trainer_rte2 = fine_tune_classification_model(
    model_rte1,
    bert_tokenizer,
    train_data=rte_dataset["train"],
    dev_data=rte_dataset["validation"],
    data_preprocessor=preprocess_rte,
    num_epochs=2, #resume
    output_dir="temp_rte_2",  # New temp dir
    save_path="drive/MyDrive/Summer2025/w266/Project/models/rte_epoch2",
    resume_from=str(checkpoint_dir)  # Resume from checkpoint
)

Epoch,Training Loss,Validation Loss,Accuracy
2,No log,0.626228,0.65704


Continue training (2 --> 3 epochs)

In [None]:
#find latest checkpoint
checkpoint_dir = max(
    [f for f in Path("temp_rte_2").iterdir() if f.is_dir() and "checkpoint" in f.name],
    key=os.path.getmtime
)

#load initial model
model_rte2 = AutoModelForSequenceClassification.from_pretrained(
    "drive/MyDrive/Summer2025/w266/Project/models/rte_epoch2"
)

trainer_rte3 = fine_tune_classification_model(
    model_rte2,
    bert_tokenizer,
    train_data=rte_dataset["train"],
    dev_data=rte_dataset["validation"],
    data_preprocessor=preprocess_rte,
    num_epochs=3, #resume
    output_dir="temp_rte_3",  # New temp dir
    save_path="drive/MyDrive/Summer2025/w266/Project/models/rte_epoch3",
    resume_from=str(checkpoint_dir)  # Resume from checkpoint
)

Epoch,Training Loss,Validation Loss,Accuracy
3,No log,0.772841,0.66426


Continue training (3 --> 10 epochs)

In [None]:
#find latest checkpoint
checkpoint_dir = max(
    [f for f in Path("temp_rte_3").iterdir() if f.is_dir() and "checkpoint" in f.name],
    key=os.path.getmtime
)

#load initial model
model_rte3 = AutoModelForSequenceClassification.from_pretrained(
    "drive/MyDrive/Summer2025/w266/Project/models/rte_epoch3"
)

# Train additional epoch
trainer_rte10 = fine_tune_classification_model(
    model_rte3,
    bert_tokenizer,
    train_data=rte_dataset["train"],
    dev_data=rte_dataset["validation"],
    data_preprocessor=preprocess_rte,
    num_epochs=10, #resume
    output_dir="temp_rte_10",  # New temp dir
    save_path="drive/MyDrive/Summer2025/w266/Project/models/rte_epoch10",
    resume_from=str(checkpoint_dir)  # Resume from checkpoint
)

Epoch,Training Loss,Validation Loss,Accuracy
4,0.3546,0.881907,0.628159
5,0.3546,1.273966,0.65704
6,0.3546,1.446766,0.66787
7,0.2196,2.033015,0.635379
8,0.2196,2.485233,0.628159
9,0.2196,2.550749,0.638989
10,0.0469,2.556892,0.642599


### QQP (medium)

In [None]:
#we downsample for time reasons.
model_qqp = AutoModelForSequenceClassification.from_pretrained('bert-base-cased')
trainer_qqp = fine_tune_classification_model(
    model_qqp,
    bert_tokenizer,
    train_data=qqp_train,
    dev_data=qqp_test,
    data_preprocessor=preprocess_qqp,
    num_epochs=1,
    batch_size=32,
    output_dir="temp_qqp_1",
    save_path="drive/MyDrive/Summer2025/w266/Project/models/qqp_epoch1"
)

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-cased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Map:   0%|          | 0/2500 [00:00<?, ? examples/s]

Map:   0%|          | 0/4000 [00:00<?, ? examples/s]

Epoch,Training Loss,Validation Loss,Accuracy
1,No log,0.457042,0.77025


In [None]:
checkpoint_dir = max(
    [f for f in Path("temp_qqp_1").iterdir() if f.is_dir() and "checkpoint" in f.name],
    key=os.path.getmtime
)

#load initial model
model_qqp1 = AutoModelForSequenceClassification.from_pretrained(
    "drive/MyDrive/Summer2025/w266/Project/models/qqp_epoch1"
)

# Train additional epoch
trainer_qqp2 = fine_tune_classification_model(
    model_qqp1,
    bert_tokenizer,
    train_data=qqp_train,
    dev_data=qqp_test,
    data_preprocessor=preprocess_qqp,
    num_epochs=2, #resume,
    batch_size=32,
    output_dir="temp_qqp_2",  # New temp dir
    save_path="drive/MyDrive/Summer2025/w266/Project/models/qqp_epoch2",
    resume_from=str(checkpoint_dir)  # Resume from checkpoint
)

Epoch,Training Loss,Validation Loss,Accuracy
2,No log,0.449391,0.79325


In [None]:
checkpoint_dir = max(
    [f for f in Path("temp_qqp_2").iterdir() if f.is_dir() and "checkpoint" in f.name],
    key=os.path.getmtime
)

#load initial model
model_qqp2 = AutoModelForSequenceClassification.from_pretrained(
    "drive/MyDrive/Summer2025/w266/Project/models/qqp_epoch2"
)

# Train additional epoch
trainer_qqp3 = fine_tune_classification_model(
    model_qqp2,
    bert_tokenizer,
    train_data=qqp_train,
    dev_data=qqp_test,
    data_preprocessor=preprocess_qqp,
    num_epochs=3, #resume,
    batch_size=32,
    output_dir="temp_qqp_3",  # New temp dir
    save_path="drive/MyDrive/Summer2025/w266/Project/models/qqp_epoch3",
    resume_from=str(checkpoint_dir)  # Resume from checkpoint
)

Epoch,Training Loss,Validation Loss,Accuracy
3,No log,0.4958,0.79475


In [None]:
checkpoint_dir = max(
    [f for f in Path("temp_qqp_3").iterdir() if f.is_dir() and "checkpoint" in f.name],
    key=os.path.getmtime
)

#load initial model
model_qqp3 = AutoModelForSequenceClassification.from_pretrained(
    "drive/MyDrive/Summer2025/w266/Project/models/qqp_epoch3"
)

# Train additional epoch
trainer_qqp10 = fine_tune_classification_model(
    model_qqp3,
    bert_tokenizer,
    train_data=qqp_train,
    dev_data=qqp_test,
    data_preprocessor=preprocess_qqp,
    num_epochs=10, #resume,
    batch_size=32,
    output_dir="temp_qqp_10",  # New temp dir
    save_path="drive/MyDrive/Summer2025/w266/Project/models/qqp_epoch10",
    resume_from=str(checkpoint_dir)  # Resume from checkpoint
)

Epoch,Training Loss,Validation Loss,Accuracy
4,No log,0.623229,0.76975
5,No log,0.727512,0.78825
6,No log,0.875514,0.7935
7,0.123400,0.953322,0.78325
8,0.123400,0.991377,0.783
9,0.123400,1.009651,0.79325
10,0.123400,1.003528,0.78925


### QNLI (hard)

In [None]:
#we downsample for time reasons.
model_qnli = AutoModelForSequenceClassification.from_pretrained('bert-base-cased')
trainer_qnli = fine_tune_classification_model(
    model_qnli,
    bert_tokenizer,
    train_data=qnli_train,
    dev_data=qnli_test,
    data_preprocessor=preprocess_qnli,
    num_epochs=1,
    batch_size=32,
    output_dir="temp_qnli_1",
    save_path="drive/MyDrive/Summer2025/w266/Project/models/qnli_epoch1"
)

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-cased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Map:   0%|          | 0/2500 [00:00<?, ? examples/s]

Map:   0%|          | 0/4000 [00:00<?, ? examples/s]

Epoch,Training Loss,Validation Loss,Accuracy
1,No log,0.406567,0.824


In [None]:
checkpoint_dir = max(
    [f for f in Path("temp_qnli_1").iterdir() if f.is_dir() and "checkpoint" in f.name],
    key=os.path.getmtime
)

#load initial model
model_qnli1 = AutoModelForSequenceClassification.from_pretrained(
    "drive/MyDrive/Summer2025/w266/Project/models/qnli_epoch1"
)

# Train additional epoch
trainer_qnli2 = fine_tune_classification_model(
    model_qnli1,
    bert_tokenizer,
    train_data=qnli_train,
    dev_data=qnli_test,
    data_preprocessor=preprocess_qnli,
    num_epochs=2, #resume,
    batch_size=32,
    output_dir="temp_qnli_2",  # New temp dir
    save_path="drive/MyDrive/Summer2025/w266/Project/models/qnli_epoch2",
    resume_from=str(checkpoint_dir)  # Resume from checkpoint
)

Map:   0%|          | 0/2500 [00:00<?, ? examples/s]

Epoch,Training Loss,Validation Loss,Accuracy
2,No log,0.411307,0.82475


In [None]:
checkpoint_dir = max(
    [f for f in Path("temp_qnli_2").iterdir() if f.is_dir() and "checkpoint" in f.name],
    key=os.path.getmtime
)

#load initial model
model_qnli2 = AutoModelForSequenceClassification.from_pretrained(
    "drive/MyDrive/Summer2025/w266/Project/models/qnli_epoch2"
)

# Train additional epoch
trainer_qnli3 = fine_tune_classification_model(
    model_qnli2,
    bert_tokenizer,
    train_data=qnli_train,
    dev_data=qnli_test,
    data_preprocessor=preprocess_qnli,
    num_epochs=3, #resume,
    batch_size=32,
    output_dir="temp_qnli_3",  # New temp dir
    save_path="drive/MyDrive/Summer2025/w266/Project/models/qnli_epoch3",
    resume_from=str(checkpoint_dir)  # Resume from checkpoint
)

Epoch,Training Loss,Validation Loss,Accuracy
3,No log,0.479552,0.82175


In [None]:
checkpoint_dir = max(
    [f for f in Path("temp_qnli_3").iterdir() if f.is_dir() and "checkpoint" in f.name],
    key=os.path.getmtime
)

#load initial model
model_qnli3 = AutoModelForSequenceClassification.from_pretrained(
    "drive/MyDrive/Summer2025/w266/Project/models/qnli_epoch3"
)

# Train additional epoch
trainer_qnli10 = fine_tune_classification_model(
    model_qnli3,
    bert_tokenizer,
    train_data=qnli_train,
    dev_data=qnli_test,
    data_preprocessor=preprocess_qnli,
    num_epochs=10, #resume,
    batch_size=32,
    output_dir="temp_qnli_10",  # New temp dir
    save_path="drive/MyDrive/Summer2025/w266/Project/models/qnli_epoch10",
    resume_from=str(checkpoint_dir)  # Resume from checkpoint
)

Epoch,Training Loss,Validation Loss,Accuracy
4,No log,0.622264,0.819
5,No log,0.747037,0.81375
6,No log,0.791892,0.8175
7,0.089400,0.760011,0.82825
8,0.089400,0.845041,0.82225
9,0.089400,0.865367,0.829
10,0.089400,0.867959,0.82825


### MNLI (hard)

In [None]:
#we downsample for time reasons.
model_mnli = AutoModelForSequenceClassification.from_pretrained('bert-base-cased',
                                                                num_labels=3)
trainer_mnli = fine_tune_classification_model(
    model_mnli,
    bert_tokenizer,
    train_data=mnli_train,
    dev_data=mnli_test,
    data_preprocessor=preprocess_mnli,
    num_epochs=1,
    batch_size=32,
    output_dir="temp_mnli_1",
    save_path="drive/MyDrive/Summer2025/w266/Project/models/mnli_epoch1"
)

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-cased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Map:   0%|          | 0/2500 [00:00<?, ? examples/s]

Map:   0%|          | 0/4000 [00:00<?, ? examples/s]

Epoch,Training Loss,Validation Loss,Accuracy
1,No log,0.895806,0.611


In [None]:
checkpoint_dir = max(
    [f for f in Path("temp_mnli_1").iterdir() if f.is_dir() and "checkpoint" in f.name],
    key=os.path.getmtime
)

#load initial model
model_mnli1 = AutoModelForSequenceClassification.from_pretrained(
    "drive/MyDrive/Summer2025/w266/Project/models/mnli_epoch1"
)

# Train additional epoch
trainer_mnli2 = fine_tune_classification_model(
    model_mnli1,
    bert_tokenizer,
    train_data=mnli_train,
    dev_data=mnli_test,
    data_preprocessor=preprocess_mnli,
    num_epochs=2, #resume,
    batch_size=32,
    output_dir="temp_mnli_2",  # New temp dir
    save_path="drive/MyDrive/Summer2025/w266/Project/models/mnli_epoch2",
    resume_from=str(checkpoint_dir)  # Resume from checkpoint
)

Epoch,Training Loss,Validation Loss,Accuracy
2,No log,0.821615,0.644


In [None]:
checkpoint_dir = max(
    [f for f in Path("temp_mnli_2").iterdir() if f.is_dir() and "checkpoint" in f.name],
    key=os.path.getmtime
)

#load initial model
model_mnli2 = AutoModelForSequenceClassification.from_pretrained(
    "drive/MyDrive/Summer2025/w266/Project/models/mnli_epoch2"
)

# Train additional epoch
trainer_mnli3 = fine_tune_classification_model(
    model_mnli2,
    bert_tokenizer,
    train_data=mnli_train,
    dev_data=mnli_test,
    data_preprocessor=preprocess_mnli,
    num_epochs=3, #resume,
    batch_size=32,
    output_dir="temp_mnli_3",  # New temp dir
    save_path="drive/MyDrive/Summer2025/w266/Project/models/mnli_epoch3",
    resume_from=str(checkpoint_dir)  # Resume from checkpoint
)

Epoch,Training Loss,Validation Loss,Accuracy
3,No log,0.820207,0.66125


In [None]:
checkpoint_dir = max(
    [f for f in Path("temp_mnli_3").iterdir() if f.is_dir() and "checkpoint" in f.name],
    key=os.path.getmtime
)

#load initial model
model_mnli3 = AutoModelForSequenceClassification.from_pretrained(
    "drive/MyDrive/Summer2025/w266/Project/models/mnli_epoch3"
)

# Train additional epoch
trainer_mnli10 = fine_tune_classification_model(
    model_mnli3,
    bert_tokenizer,
    train_data=mnli_train,
    dev_data=mnli_test,
    data_preprocessor=preprocess_mnli,
    num_epochs=10, #resume,
    batch_size=32,
    output_dir="temp_mnli_10",  # New temp dir
    save_path="drive/MyDrive/Summer2025/w266/Project/models/mnli_epoch10",
    resume_from=str(checkpoint_dir)  # Resume from checkpoint
)

Epoch,Training Loss,Validation Loss,Accuracy
4,No log,0.884828,0.66675
5,No log,0.953281,0.68475
6,No log,1.072345,0.68075
7,0.291400,1.257914,0.67775
8,0.291400,1.370241,0.677
9,0.291400,1.453277,0.6805
10,0.291400,1.443409,0.68125


### WNLI (very hard, be careful because there might be some bugs with there being devset  == test set in early versions??) also smaller gradient updates

In [None]:
model_wnli = AutoModelForSequenceClassification.from_pretrained('bert-base-cased')
wnli_train = wnli_dataset["train"]
wnli_test  = wnli_dataset["validation"]


trainer_wnli = fine_tune_classification_model(
    model_wnli,
    bert_tokenizer,
    train_data=wnli_train,
    dev_data=wnli_test,
    data_preprocessor=preprocess_wnli,
    num_epochs=1,
    batch_size=8,
    output_dir="temp_wnli_1",
    save_path="drive/MyDrive/Summer2025/w266/Project/models/wnli_epoch1"
)

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-cased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Epoch,Training Loss,Validation Loss,Accuracy
1,No log,0.686209,0.56338


In [None]:
checkpoint_dir = max(
    [f for f in Path("temp_wnli_1").iterdir() if f.is_dir() and "checkpoint" in f.name],
    key=os.path.getmtime
)

#load initial model
model_wnli1 = AutoModelForSequenceClassification.from_pretrained(
    "drive/MyDrive/Summer2025/w266/Project/models/wnli_epoch1"
)

# Train additional epoch
trainer_wnli2 = fine_tune_classification_model(
    model_wnli1,
    bert_tokenizer,
    train_data=wnli_train,
    dev_data=wnli_test,
    data_preprocessor=preprocess_wnli,
    num_epochs=2, #resume,
    batch_size=8,
    output_dir="temp_wnli_2",  # New temp dir
    save_path="drive/MyDrive/Summer2025/w266/Project/models/wnli_epoch2",
    resume_from=str(checkpoint_dir)  # Resume from checkpoint
)

Epoch,Training Loss,Validation Loss,Accuracy
2,No log,0.691785,0.56338


In [None]:
checkpoint_dir = max(
    [f for f in Path("temp_wnli_2").iterdir() if f.is_dir() and "checkpoint" in f.name],
    key=os.path.getmtime
)

#load initial model
model_wnli2 = AutoModelForSequenceClassification.from_pretrained(
    "drive/MyDrive/Summer2025/w266/Project/models/wnli_epoch2"
)

# Train additional epoch
trainer_wnli3 = fine_tune_classification_model(
    model_wnli2,
    bert_tokenizer,
    train_data=wnli_train,
    dev_data=wnli_test,
    data_preprocessor=preprocess_wnli,
    num_epochs=3, #resume,
    batch_size=8,
    output_dir="temp_wnli_3",  # New temp dir
    save_path="drive/MyDrive/Summer2025/w266/Project/models/wnli_epoch3",
    resume_from=str(checkpoint_dir)  # Resume from checkpoint
)

Epoch,Training Loss,Validation Loss,Accuracy
3,No log,0.695584,0.352113


In [None]:
checkpoint_dir = max(
    [f for f in Path("temp_wnli_3").iterdir() if f.is_dir() and "checkpoint" in f.name],
    key=os.path.getmtime
)

#load initial model
model_wnli3 = AutoModelForSequenceClassification.from_pretrained(
    "drive/MyDrive/Summer2025/w266/Project/models/wnli_epoch3"
)

# Train additional epoch
trainer_wnli10 = fine_tune_classification_model(
    model_wnli3,
    bert_tokenizer,
    train_data=wnli_train,
    dev_data=wnli_test,
    data_preprocessor=preprocess_wnli,
    num_epochs=10, #resume,
    batch_size=8,
    output_dir="temp_wnli_10",  # New temp dir
    save_path="drive/MyDrive/Summer2025/w266/Project/models/wnli_epoch10",
    resume_from=str(checkpoint_dir)  # Resume from checkpoint
)

Epoch,Training Loss,Validation Loss,Accuracy
4,No log,0.712666,0.43662
5,No log,0.699393,0.43662
6,No log,0.709221,0.43662
7,0.702300,0.690531,0.56338
8,0.702300,0.709695,0.43662
9,0.702300,0.704419,0.338028
10,0.702300,0.711611,0.28169


## SV Experiments


### Functions & Drive Mount


In [14]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [15]:
def evaluate_model(model, test_data):
    """Compute accuracy of unmodified model"""
    trainer_args = TrainingArguments(
                                      output_dir="./tmp_eval",      # required, but can be temp
                                      per_device_eval_batch_size=32,  # or whatever you used
                                      do_train=False,
                                      do_eval=True,
                                      logging_strategy="no",  # suppress extra logs
                                      save_strategy="no",
                                      report_to="none"
                                      )

    treatment_trainer = Trainer(
                model=model,
                args=trainer_args,
                compute_metrics=compute_metrics,
                eval_dataset=test_data
            )

    res = treatment_trainer.evaluate()
    return res["eval_accuracy"]

In [16]:
def run_experiment(dataset_name: str,
                   test_data,
                   tokenizer):
    EPOCHS = [1, 2, 3, 10]
    LAYER_TYPES = ["all", "query", "key", "value", "output", "ffn"]
    BASE_PATH = "drive/MyDrive/Summer2025/w266/Project"

    results = {}

    print(f"\n=== STARTING {dataset_name.upper()} EXPERIMENTS ===")
    for epoch in EPOCHS:
        epoch_key = f"epoch{epoch}"
        results[epoch_key] = {}
        print(f"\n=== EPOCH {epoch} ===")
        try:
            # 1. Load model and tokenizer
            model_path = f"{BASE_PATH}/models/{dataset_name}_{epoch_key}"
            model = AutoModelForSequenceClassification.from_pretrained(model_path)

            # 2. Evaluate control model (no ablation)
            control_acc = evaluate_model(model, test_data)
            results[epoch_key]["control"] = control_acc
            print(f"  {epoch_key} | Control: {control_acc:.4f}")

            # 3. Run ablation for each layer type
            ablation_results = {}
            for layer_type in LAYER_TYPES:
                accuracies = evaluate_decile_removal(
                    control_model=model,
                    tokenizer=tokenizer,
                    test_data=test_data,
                    compute_metrics=compute_metrics,
                    layer_type=layer_type
                )

                ablation_results[layer_type] = accuracies
                print(f"    {layer_type}: {accuracies}")

            results[epoch_key]["ablation"] = ablation_results

            # Free GPU memory
            del model
            torch.cuda.empty_cache()

        except Exception as e:
            print(f"ERROR: {dataset_name}_{epoch_key} failed - {str(e)}")
            results[epoch_key]["error"] = str(e)

    # 4. Save results (moved outside of the loop)
    output_path = f"{BASE_PATH}/results/{dataset_name}_results.json"
    with open(output_path, "w") as f:
        json.dump(results, f, indent=2)

    print(f"=== SAVED RESULTS TO {output_path} ===")
    return results

In [17]:
def run_ffn_experiment(dataset_name: str,
                   test_data,
                   tokenizer):
    EPOCHS = [1, 2, 3, 10]
    LAYER_GROUPS = {
        "0-3": [0, 1, 2, 3],
        "4-7": [4, 5, 6, 7],
        "8-11": [8, 9, 10, 11]
    }
    BASE_PATH = "drive/MyDrive/Summer2025/w266/Project"

    results = {}

    print(f"\n=== STARTING {dataset_name.upper()} EXPERIMENTS ===")
    for epoch in EPOCHS:
        epoch_key = f"epoch{epoch}"
        results[epoch_key] = {}
        print(f"\n=== EPOCH {epoch} ===")
        try:
            # 1. Load model and tokenizer
            model_path = f"{BASE_PATH}/models/{dataset_name}_{epoch_key}"
            model = AutoModelForSequenceClassification.from_pretrained(model_path)

            # 2. Evaluate control model (no ablation)
            control_acc = evaluate_model(model, test_data)
            results[epoch_key]["control"] = control_acc
            print(f"  {epoch_key} | Control: {control_acc:.4f}")

            # 3. Run ablation for each layer group
            ablation_results = {}
            for group_name, layer_indices in LAYER_GROUPS.items():
                accuracies = evaluate_decile_removal(
                    control_model=model,
                    tokenizer=tokenizer,
                    test_data=test_data,
                    compute_metrics=compute_metrics,
                    layer_type="ffn",  # Fixed to FFN layers
                    layers=layer_indices  # Use current layer group
                )

                ablation_results[group_name] = accuracies
                print(f"    Group {group_name}: {accuracies}")

            results[epoch_key]["ablation"] = ablation_results

            # Free GPU memory
            del model
            torch.cuda.empty_cache()

        except Exception as e:
            print(f"ERROR: {dataset_name}_{epoch_key} failed - {str(e)}")
            results[epoch_key]["error"] = str(e)

    # 4. Save results with new filename
    output_path = f"{BASE_PATH}/results/{dataset_name}_layer_results.json"
    with open(output_path, "w") as f:
        json.dump(results, f, indent=2)

    print(f"=== SAVED RESULTS TO {output_path} ===")
    return results

### Ablation Experiments

In [None]:
bert_tokenizer = AutoTokenizer.from_pretrained('bert-base-cased')

tokenizer_config.json:   0%|          | 0.00/49.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/213k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/436k [00:00<?, ?B/s]

below took 49 minutes


In [None]:
imdb_test = test_medium.map(preprocess_imdb, batched=True, fn_kwargs={'tokenizer': bert_tokenizer})
imdb_res = run_experiment(dataset_name="imdb",
                          test_data=imdb_test,
                          tokenizer=bert_tokenizer)

below took 12 minutes


In [None]:
rte_test = rte_dataset["validation"].map(preprocess_rte, batched=True, fn_kwargs={'tokenizer': bert_tokenizer})
rte_res = run_experiment(dataset_name="rte",
                          test_data=rte_test,
                          tokenizer=bert_tokenizer)

Map:   0%|          | 0/277 [00:00<?, ? examples/s]


=== STARTING RTE EXPERIMENTS ===

=== EPOCH 1 ===


  return forward_call(*args, **kwargs)


  epoch1 | Control: 0.6318

=== Evaluating decile 0 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 1 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 2 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 3 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 4 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 5 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 6 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 7 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 8 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 9 removal ===


  return forward_call(*args, **kwargs)


    all: [0.5270758122743683, 0.4729241877256318, 0.48014440433212996, 0.5523465703971119, 0.5812274368231047, 0.592057761732852, 0.628158844765343, 0.5848375451263538, 0.6028880866425993, 0.48375451263537905]

=== Evaluating decile 0 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 1 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 2 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 3 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 4 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 5 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 6 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 7 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 8 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 9 removal ===


  return forward_call(*args, **kwargs)


    query: [0.47653429602888087, 0.516245487364621, 0.5703971119133574, 0.6137184115523465, 0.628158844765343, 0.631768953068592, 0.6425992779783394, 0.631768953068592, 0.628158844765343, 0.628158844765343]

=== Evaluating decile 0 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 1 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 2 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 3 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 4 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 5 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 6 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 7 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 8 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 9 removal ===


  return forward_call(*args, **kwargs)


    key: [0.5126353790613718, 0.4981949458483754, 0.5740072202166066, 0.6173285198555957, 0.6245487364620939, 0.6353790613718412, 0.6209386281588448, 0.631768953068592, 0.628158844765343, 0.631768953068592]

=== Evaluating decile 0 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 1 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 2 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 3 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 4 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 5 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 6 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 7 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 8 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 9 removal ===


  return forward_call(*args, **kwargs)


    value: [0.4729241877256318, 0.6173285198555957, 0.6389891696750902, 0.6245487364620939, 0.628158844765343, 0.6209386281588448, 0.6209386281588448, 0.628158844765343, 0.628158844765343, 0.6173285198555957]

=== Evaluating decile 0 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 1 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 2 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 3 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 4 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 5 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 6 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 7 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 8 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 9 removal ===


  return forward_call(*args, **kwargs)


    output: [0.5379061371841155, 0.6064981949458483, 0.5848375451263538, 0.6028880866425993, 0.628158844765343, 0.6209386281588448, 0.6245487364620939, 0.628158844765343, 0.628158844765343, 0.628158844765343]

=== Evaluating decile 0 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 1 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 2 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 3 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 4 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 5 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 6 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 7 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 8 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 9 removal ===


  return forward_call(*args, **kwargs)


    ffn: [0.49097472924187724, 0.5884476534296029, 0.5342960288808665, 0.5595667870036101, 0.6028880866425993, 0.5776173285198556, 0.5992779783393501, 0.5776173285198556, 0.5740072202166066, 0.48375451263537905]

=== EPOCH 2 ===


  return forward_call(*args, **kwargs)


  epoch2 | Control: 0.6570

=== Evaluating decile 0 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 1 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 2 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 3 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 4 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 5 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 6 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 7 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 8 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 9 removal ===


  return forward_call(*args, **kwargs)


    all: [0.5270758122743683, 0.4729241877256318, 0.4729241877256318, 0.6353790613718412, 0.6534296028880866, 0.6570397111913358, 0.6678700361010831, 0.6425992779783394, 0.6425992779783394, 0.5234657039711191]

=== Evaluating decile 0 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 1 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 2 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 3 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 4 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 5 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 6 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 7 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 8 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 9 removal ===


  return forward_call(*args, **kwargs)


    query: [0.5342960288808665, 0.5703971119133574, 0.631768953068592, 0.6425992779783394, 0.6642599277978339, 0.6570397111913358, 0.6642599277978339, 0.6570397111913358, 0.6606498194945848, 0.6570397111913358]

=== Evaluating decile 0 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 1 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 2 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 3 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 4 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 5 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 6 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 7 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 8 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 9 removal ===


  return forward_call(*args, **kwargs)


    key: [0.5126353790613718, 0.5234657039711191, 0.6137184115523465, 0.6570397111913358, 0.6534296028880866, 0.6606498194945848, 0.6534296028880866, 0.6642599277978339, 0.6570397111913358, 0.6570397111913358]

=== Evaluating decile 0 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 1 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 2 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 3 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 4 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 5 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 6 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 7 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 8 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 9 removal ===


  return forward_call(*args, **kwargs)


    value: [0.4729241877256318, 0.6534296028880866, 0.6425992779783394, 0.6425992779783394, 0.6570397111913358, 0.6425992779783394, 0.6570397111913358, 0.6642599277978339, 0.6570397111913358, 0.6642599277978339]

=== Evaluating decile 0 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 1 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 2 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 3 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 4 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 5 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 6 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 7 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 8 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 9 removal ===


  return forward_call(*args, **kwargs)


    output: [0.5740072202166066, 0.6353790613718412, 0.6425992779783394, 0.6498194945848376, 0.6425992779783394, 0.6425992779783394, 0.6534296028880866, 0.6534296028880866, 0.6534296028880866, 0.6570397111913358]

=== Evaluating decile 0 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 1 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 2 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 3 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 4 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 5 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 6 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 7 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 8 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 9 removal ===


  return forward_call(*args, **kwargs)


    ffn: [0.4729241877256318, 0.6498194945848376, 0.6137184115523465, 0.631768953068592, 0.6498194945848376, 0.6425992779783394, 0.6462093862815884, 0.6353790613718412, 0.6389891696750902, 0.5126353790613718]

=== EPOCH 3 ===


  return forward_call(*args, **kwargs)


  epoch3 | Control: 0.6643

=== Evaluating decile 0 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 1 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 2 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 3 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 4 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 5 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 6 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 7 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 8 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 9 removal ===


  return forward_call(*args, **kwargs)


    all: [0.48736462093862815, 0.4729241877256318, 0.49097472924187724, 0.6534296028880866, 0.6425992779783394, 0.6750902527075813, 0.6462093862815884, 0.6389891696750902, 0.6498194945848376, 0.5956678700361011]

=== Evaluating decile 0 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 1 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 2 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 3 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 4 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 5 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 6 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 7 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 8 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 9 removal ===


  return forward_call(*args, **kwargs)


    query: [0.5379061371841155, 0.5740072202166066, 0.6714801444043321, 0.6353790613718412, 0.6606498194945848, 0.6425992779783394, 0.6606498194945848, 0.6606498194945848, 0.6642599277978339, 0.6606498194945848]

=== Evaluating decile 0 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 1 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 2 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 3 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 4 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 5 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 6 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 7 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 8 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 9 removal ===


  return forward_call(*args, **kwargs)


    key: [0.5234657039711191, 0.5812274368231047, 0.6462093862815884, 0.6570397111913358, 0.6534296028880866, 0.6534296028880866, 0.6570397111913358, 0.6678700361010831, 0.6606498194945848, 0.6642599277978339]

=== Evaluating decile 0 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 1 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 2 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 3 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 4 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 5 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 6 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 7 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 8 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 9 removal ===


  return forward_call(*args, **kwargs)


    value: [0.4729241877256318, 0.6570397111913358, 0.6462093862815884, 0.631768953068592, 0.6606498194945848, 0.6462093862815884, 0.6462093862815884, 0.6389891696750902, 0.6353790613718412, 0.6606498194945848]

=== Evaluating decile 0 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 1 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 2 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 3 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 4 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 5 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 6 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 7 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 8 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 9 removal ===


  return forward_call(*args, **kwargs)


    output: [0.5703971119133574, 0.6534296028880866, 0.6534296028880866, 0.6462093862815884, 0.6462093862815884, 0.6534296028880866, 0.6642599277978339, 0.6606498194945848, 0.6606498194945848, 0.6606498194945848]

=== Evaluating decile 0 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 1 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 2 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 3 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 4 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 5 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 6 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 7 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 8 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 9 removal ===


  return forward_call(*args, **kwargs)


    ffn: [0.4729241877256318, 0.6642599277978339, 0.6606498194945848, 0.6462093862815884, 0.628158844765343, 0.6895306859205776, 0.6714801444043321, 0.6570397111913358, 0.6462093862815884, 0.5776173285198556]

=== EPOCH 10 ===


  return forward_call(*args, **kwargs)


  epoch10 | Control: 0.6426

=== Evaluating decile 0 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 1 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 2 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 3 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 4 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 5 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 6 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 7 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 8 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 9 removal ===


  return forward_call(*args, **kwargs)


    all: [0.5270758122743683, 0.4729241877256318, 0.592057761732852, 0.6245487364620939, 0.628158844765343, 0.6570397111913358, 0.6209386281588448, 0.6534296028880866, 0.6353790613718412, 0.631768953068592]

=== Evaluating decile 0 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 1 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 2 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 3 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 4 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 5 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 6 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 7 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 8 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 9 removal ===


  return forward_call(*args, **kwargs)


    query: [0.555956678700361, 0.5667870036101083, 0.6389891696750902, 0.6425992779783394, 0.6498194945848376, 0.6462093862815884, 0.631768953068592, 0.6425992779783394, 0.6462093862815884, 0.6425992779783394]

=== Evaluating decile 0 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 1 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 2 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 3 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 4 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 5 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 6 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 7 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 8 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 9 removal ===


  return forward_call(*args, **kwargs)


    key: [0.5523465703971119, 0.5703971119133574, 0.6353790613718412, 0.6498194945848376, 0.6462093862815884, 0.631768953068592, 0.6425992779783394, 0.6498194945848376, 0.6462093862815884, 0.6425992779783394]

=== Evaluating decile 0 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 1 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 2 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 3 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 4 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 5 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 6 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 7 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 8 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 9 removal ===


  return forward_call(*args, **kwargs)


    value: [0.4729241877256318, 0.6425992779783394, 0.6353790613718412, 0.6389891696750902, 0.6425992779783394, 0.6209386281588448, 0.631768953068592, 0.6209386281588448, 0.628158844765343, 0.6353790613718412]

=== Evaluating decile 0 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 1 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 2 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 3 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 4 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 5 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 6 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 7 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 8 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 9 removal ===


  return forward_call(*args, **kwargs)


    output: [0.555956678700361, 0.6498194945848376, 0.6462093862815884, 0.6534296028880866, 0.6425992779783394, 0.6389891696750902, 0.6425992779783394, 0.6462093862815884, 0.6425992779783394, 0.6425992779783394]

=== Evaluating decile 0 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 1 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 2 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 3 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 4 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 5 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 6 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 7 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 8 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 9 removal ===


  return forward_call(*args, **kwargs)


    ffn: [0.4729241877256318, 0.6498194945848376, 0.6606498194945848, 0.6389891696750902, 0.6389891696750902, 0.6498194945848376, 0.6353790613718412, 0.6534296028880866, 0.6353790613718412, 0.6462093862815884]
=== SAVED RESULTS TO drive/MyDrive/Summer2025/w266/Project/results/rte_results.json ===


below took 2 hours.

In [None]:
#qqp
qqp_test = qqp_test.map(preprocess_qqp, batched=True, fn_kwargs={'tokenizer': bert_tokenizer})
qqp_res = run_experiment(dataset_name="qqp",
                          test_data=qqp_test,
                          tokenizer=bert_tokenizer)

Map:   0%|          | 0/4000 [00:00<?, ? examples/s]


=== STARTING QQP EXPERIMENTS ===

=== EPOCH 1 ===


  return forward_call(*args, **kwargs)


  epoch1 | Control: 0.7600

=== Evaluating decile 0 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 1 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 2 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 3 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 4 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 5 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 6 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 7 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 8 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 9 removal ===


  return forward_call(*args, **kwargs)


    all: [0.36375, 0.635, 0.732, 0.75, 0.74925, 0.756, 0.74625, 0.7555, 0.75825, 0.64075]

=== Evaluating decile 0 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 1 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 2 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 3 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 4 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 5 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 6 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 7 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 8 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 9 removal ===


  return forward_call(*args, **kwargs)


    query: [0.63625, 0.74175, 0.7595, 0.763, 0.76, 0.76125, 0.76125, 0.76025, 0.76125, 0.75975]

=== Evaluating decile 0 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 1 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 2 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 3 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 4 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 5 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 6 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 7 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 8 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 9 removal ===


  return forward_call(*args, **kwargs)


    key: [0.63625, 0.74975, 0.76175, 0.7605, 0.76275, 0.7635, 0.76025, 0.76025, 0.75975, 0.76075]

=== Evaluating decile 0 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 1 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 2 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 3 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 4 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 5 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 6 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 7 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 8 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 9 removal ===


  return forward_call(*args, **kwargs)


    value: [0.63625, 0.7475, 0.7525, 0.7545, 0.75525, 0.757, 0.7555, 0.75925, 0.75725, 0.757]

=== Evaluating decile 0 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 1 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 2 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 3 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 4 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 5 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 6 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 7 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 8 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 9 removal ===


  return forward_call(*args, **kwargs)


    output: [0.63625, 0.749, 0.75875, 0.75625, 0.75525, 0.757, 0.75825, 0.75625, 0.758, 0.76]

=== Evaluating decile 0 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 1 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 2 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 3 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 4 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 5 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 6 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 7 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 8 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 9 removal ===


  return forward_call(*args, **kwargs)


    ffn: [0.36375, 0.72725, 0.7435, 0.7525, 0.7545, 0.75675, 0.7535, 0.7585, 0.764, 0.6395]

=== EPOCH 2 ===


  return forward_call(*args, **kwargs)


  epoch2 | Control: 0.7800

=== Evaluating decile 0 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 1 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 2 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 3 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 4 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 5 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 6 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 7 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 8 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 9 removal ===


  return forward_call(*args, **kwargs)


    all: [0.36375, 0.66425, 0.6995, 0.7565, 0.761, 0.76475, 0.77275, 0.7765, 0.78, 0.691]

=== Evaluating decile 0 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 1 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 2 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 3 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 4 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 5 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 6 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 7 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 8 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 9 removal ===


  return forward_call(*args, **kwargs)


    query: [0.63625, 0.76625, 0.78025, 0.77775, 0.78175, 0.7825, 0.78175, 0.781, 0.7825, 0.78075]

=== Evaluating decile 0 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 1 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 2 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 3 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 4 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 5 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 6 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 7 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 8 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 9 removal ===


  return forward_call(*args, **kwargs)


    key: [0.63625, 0.7725, 0.77575, 0.777, 0.783, 0.78375, 0.78025, 0.78125, 0.78075, 0.78025]

=== Evaluating decile 0 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 1 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 2 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 3 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 4 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 5 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 6 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 7 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 8 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 9 removal ===


  return forward_call(*args, **kwargs)


    value: [0.63625, 0.76375, 0.77125, 0.77375, 0.77675, 0.77875, 0.77975, 0.78075, 0.781, 0.77825]

=== Evaluating decile 0 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 1 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 2 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 3 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 4 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 5 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 6 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 7 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 8 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 9 removal ===


  return forward_call(*args, **kwargs)


    output: [0.63625, 0.76375, 0.77525, 0.77625, 0.77475, 0.7795, 0.7785, 0.7795, 0.77975, 0.7795]

=== Evaluating decile 0 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 1 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 2 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 3 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 4 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 5 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 6 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 7 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 8 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 9 removal ===


  return forward_call(*args, **kwargs)


    ffn: [0.36375, 0.73325, 0.75125, 0.77175, 0.77175, 0.76875, 0.77825, 0.77575, 0.78125, 0.67975]

=== EPOCH 3 ===


  return forward_call(*args, **kwargs)


  epoch3 | Control: 0.7875

=== Evaluating decile 0 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 1 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 2 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 3 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 4 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 5 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 6 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 7 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 8 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 9 removal ===


  return forward_call(*args, **kwargs)


    all: [0.36375, 0.62225, 0.648, 0.74825, 0.761, 0.764, 0.77225, 0.7775, 0.7885, 0.745]

=== Evaluating decile 0 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 1 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 2 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 3 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 4 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 5 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 6 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 7 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 8 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 9 removal ===


  return forward_call(*args, **kwargs)


    query: [0.641, 0.7695, 0.781, 0.784, 0.7875, 0.7875, 0.789, 0.7875, 0.7875, 0.78825]

=== Evaluating decile 0 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 1 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 2 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 3 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 4 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 5 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 6 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 7 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 8 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 9 removal ===


  return forward_call(*args, **kwargs)


    key: [0.64375, 0.76875, 0.77875, 0.784, 0.78875, 0.787, 0.7865, 0.7875, 0.78775, 0.7885]

=== Evaluating decile 0 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 1 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 2 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 3 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 4 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 5 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 6 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 7 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 8 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 9 removal ===


  return forward_call(*args, **kwargs)


    value: [0.63625, 0.76625, 0.77625, 0.78075, 0.78425, 0.78275, 0.787, 0.78725, 0.78775, 0.78975]

=== Evaluating decile 0 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 1 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 2 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 3 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 4 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 5 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 6 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 7 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 8 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 9 removal ===


  return forward_call(*args, **kwargs)


    output: [0.63625, 0.7605, 0.7775, 0.78325, 0.7795, 0.78625, 0.783, 0.789, 0.7875, 0.78825]

=== Evaluating decile 0 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 1 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 2 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 3 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 4 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 5 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 6 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 7 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 8 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 9 removal ===


  return forward_call(*args, **kwargs)


    ffn: [0.36375, 0.71625, 0.74325, 0.77025, 0.775, 0.7715, 0.781, 0.77775, 0.789, 0.734]

=== EPOCH 10 ===


  return forward_call(*args, **kwargs)


  epoch10 | Control: 0.7863

=== Evaluating decile 0 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 1 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 2 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 3 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 4 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 5 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 6 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 7 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 8 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 9 removal ===


  return forward_call(*args, **kwargs)


    all: [0.36375, 0.54825, 0.6495, 0.75375, 0.76125, 0.7635, 0.77325, 0.77975, 0.78025, 0.78175]

=== Evaluating decile 0 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 1 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 2 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 3 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 4 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 5 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 6 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 7 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 8 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 9 removal ===


  return forward_call(*args, **kwargs)


    query: [0.638, 0.771, 0.7765, 0.78475, 0.7845, 0.7845, 0.7855, 0.78875, 0.788, 0.78625]

=== Evaluating decile 0 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 1 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 2 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 3 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 4 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 5 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 6 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 7 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 8 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 9 removal ===


  return forward_call(*args, **kwargs)


    key: [0.63775, 0.7695, 0.779, 0.78425, 0.788, 0.7855, 0.788, 0.788, 0.787, 0.787]

=== Evaluating decile 0 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 1 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 2 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 3 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 4 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 5 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 6 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 7 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 8 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 9 removal ===


  return forward_call(*args, **kwargs)


    value: [0.63625, 0.767, 0.778, 0.7825, 0.78125, 0.783, 0.7835, 0.78375, 0.7865, 0.787]

=== Evaluating decile 0 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 1 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 2 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 3 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 4 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 5 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 6 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 7 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 8 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 9 removal ===


  return forward_call(*args, **kwargs)


    output: [0.64025, 0.7635, 0.77575, 0.78175, 0.77825, 0.78675, 0.786, 0.7855, 0.78575, 0.7865]

=== Evaluating decile 0 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 1 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 2 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 3 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 4 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 5 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 6 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 7 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 8 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 9 removal ===


  return forward_call(*args, **kwargs)


    ffn: [0.36375, 0.73225, 0.7535, 0.78075, 0.77675, 0.7735, 0.78075, 0.77575, 0.7825, 0.77475]
=== SAVED RESULTS TO drive/MyDrive/Summer2025/w266/Project/results/qqp_results.json ===


below took 2 hours

In [None]:
#qnli
qnli_test = qnli_test.map(preprocess_qnli, batched=True, fn_kwargs={'tokenizer': bert_tokenizer})
qnli_res = run_experiment(dataset_name="qnli",
                          test_data=qnli_test,
                          tokenizer=bert_tokenizer)

Map:   0%|          | 0/4000 [00:00<?, ? examples/s]


=== STARTING QNLI EXPERIMENTS ===

=== EPOCH 1 ===


  return forward_call(*args, **kwargs)


  epoch1 | Control: 0.8193

=== Evaluating decile 0 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 1 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 2 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 3 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 4 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 5 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 6 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 7 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 8 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 9 removal ===


  return forward_call(*args, **kwargs)


    all: [0.494, 0.506, 0.58325, 0.643, 0.70575, 0.76225, 0.76725, 0.7985, 0.79975, 0.6415]

=== Evaluating decile 0 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 1 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 2 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 3 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 4 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 5 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 6 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 7 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 8 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 9 removal ===


  return forward_call(*args, **kwargs)


    query: [0.506, 0.809, 0.80925, 0.817, 0.81975, 0.81625, 0.817, 0.812, 0.817, 0.82025]

=== Evaluating decile 0 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 1 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 2 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 3 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 4 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 5 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 6 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 7 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 8 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 9 removal ===


  return forward_call(*args, **kwargs)


    key: [0.50675, 0.7865, 0.8125, 0.81, 0.8105, 0.81725, 0.81725, 0.819, 0.81675, 0.8185]

=== Evaluating decile 0 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 1 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 2 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 3 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 4 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 5 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 6 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 7 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 8 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 9 removal ===


  return forward_call(*args, **kwargs)


    value: [0.506, 0.812, 0.809, 0.8105, 0.8135, 0.813, 0.8115, 0.81275, 0.81, 0.81325]

=== Evaluating decile 0 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 1 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 2 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 3 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 4 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 5 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 6 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 7 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 8 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 9 removal ===


  return forward_call(*args, **kwargs)


    output: [0.506, 0.7885, 0.8165, 0.8105, 0.816, 0.81825, 0.81525, 0.81925, 0.81825, 0.81875]

=== Evaluating decile 0 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 1 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 2 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 3 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 4 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 5 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 6 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 7 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 8 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 9 removal ===


  return forward_call(*args, **kwargs)


    ffn: [0.494, 0.58075, 0.673, 0.737, 0.76175, 0.79125, 0.801, 0.80925, 0.81225, 0.625]

=== EPOCH 2 ===


  return forward_call(*args, **kwargs)


  epoch2 | Control: 0.8207

=== Evaluating decile 0 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 1 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 2 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 3 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 4 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 5 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 6 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 7 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 8 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 9 removal ===


  return forward_call(*args, **kwargs)


    all: [0.494, 0.5075, 0.60525, 0.69775, 0.76825, 0.79825, 0.8115, 0.81975, 0.82325, 0.7135]

=== Evaluating decile 0 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 1 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 2 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 3 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 4 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 5 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 6 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 7 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 8 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 9 removal ===


  return forward_call(*args, **kwargs)


    query: [0.506, 0.81125, 0.8205, 0.8195, 0.82125, 0.8225, 0.821, 0.82, 0.82125, 0.82175]

=== Evaluating decile 0 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 1 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 2 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 3 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 4 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 5 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 6 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 7 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 8 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 9 removal ===


  return forward_call(*args, **kwargs)


    key: [0.513, 0.796, 0.81675, 0.816, 0.8225, 0.81975, 0.82075, 0.82, 0.82075, 0.8215]

=== Evaluating decile 0 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 1 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 2 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 3 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 4 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 5 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 6 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 7 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 8 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 9 removal ===


  return forward_call(*args, **kwargs)


    value: [0.506, 0.81925, 0.81525, 0.8215, 0.81975, 0.81875, 0.82, 0.818, 0.8205, 0.82]

=== Evaluating decile 0 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 1 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 2 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 3 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 4 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 5 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 6 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 7 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 8 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 9 removal ===


  return forward_call(*args, **kwargs)


    output: [0.53325, 0.80375, 0.8205, 0.817, 0.819, 0.82175, 0.821, 0.8205, 0.82, 0.82075]

=== Evaluating decile 0 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 1 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 2 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 3 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 4 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 5 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 6 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 7 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 8 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 9 removal ===


  return forward_call(*args, **kwargs)


    ffn: [0.494, 0.645, 0.733, 0.78275, 0.80675, 0.8155, 0.82675, 0.8245, 0.82475, 0.696]

=== EPOCH 3 ===


  return forward_call(*args, **kwargs)


  epoch3 | Control: 0.8185

=== Evaluating decile 0 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 1 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 2 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 3 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 4 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 5 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 6 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 7 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 8 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 9 removal ===


  return forward_call(*args, **kwargs)


    all: [0.494, 0.502, 0.60725, 0.71, 0.77525, 0.7985, 0.81475, 0.82075, 0.8225, 0.74775]

=== Evaluating decile 0 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 1 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 2 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 3 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 4 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 5 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 6 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 7 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 8 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 9 removal ===


  return forward_call(*args, **kwargs)


    query: [0.50975, 0.805, 0.815, 0.82025, 0.8195, 0.818, 0.82, 0.819, 0.82025, 0.81975]

=== Evaluating decile 0 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 1 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 2 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 3 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 4 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 5 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 6 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 7 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 8 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 9 removal ===


  return forward_call(*args, **kwargs)


    key: [0.52975, 0.7835, 0.8145, 0.8175, 0.81925, 0.8205, 0.81925, 0.81925, 0.8185, 0.81825]

=== Evaluating decile 0 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 1 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 2 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 3 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 4 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 5 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 6 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 7 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 8 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 9 removal ===


  return forward_call(*args, **kwargs)


    value: [0.50625, 0.8185, 0.8155, 0.8165, 0.81675, 0.8165, 0.81725, 0.81575, 0.8195, 0.8175]

=== Evaluating decile 0 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 1 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 2 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 3 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 4 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 5 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 6 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 7 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 8 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 9 removal ===


  return forward_call(*args, **kwargs)


    output: [0.64175, 0.8, 0.8135, 0.814, 0.81775, 0.81975, 0.8195, 0.81875, 0.8175, 0.819]

=== Evaluating decile 0 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 1 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 2 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 3 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 4 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 5 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 6 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 7 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 8 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 9 removal ===


  return forward_call(*args, **kwargs)


    ffn: [0.494, 0.65575, 0.74525, 0.78375, 0.8095, 0.82, 0.82025, 0.82375, 0.82425, 0.73775]

=== EPOCH 10 ===


  return forward_call(*args, **kwargs)


  epoch10 | Control: 0.8257

=== Evaluating decile 0 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 1 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 2 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 3 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 4 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 5 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 6 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 7 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 8 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 9 removal ===


  return forward_call(*args, **kwargs)


    all: [0.494, 0.497, 0.62875, 0.7115, 0.769, 0.791, 0.8055, 0.815, 0.818, 0.8305]

=== Evaluating decile 0 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 1 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 2 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 3 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 4 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 5 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 6 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 7 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 8 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 9 removal ===


  return forward_call(*args, **kwargs)


    query: [0.52775, 0.79275, 0.81125, 0.81675, 0.824, 0.81875, 0.82425, 0.823, 0.8235, 0.824]

=== Evaluating decile 0 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 1 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 2 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 3 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 4 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 5 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 6 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 7 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 8 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 9 removal ===


  return forward_call(*args, **kwargs)


    key: [0.51, 0.77275, 0.8075, 0.81025, 0.81725, 0.8195, 0.82425, 0.82375, 0.8245, 0.825]

=== Evaluating decile 0 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 1 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 2 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 3 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 4 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 5 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 6 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 7 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 8 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 9 removal ===


  return forward_call(*args, **kwargs)


    value: [0.50725, 0.81225, 0.816, 0.82, 0.82075, 0.8225, 0.8205, 0.823, 0.824, 0.823]

=== Evaluating decile 0 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 1 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 2 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 3 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 4 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 5 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 6 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 7 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 8 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 9 removal ===


  return forward_call(*args, **kwargs)


    output: [0.5765, 0.79925, 0.8155, 0.819, 0.8205, 0.822, 0.82275, 0.82575, 0.8255, 0.82575]

=== Evaluating decile 0 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 1 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 2 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 3 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 4 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 5 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 6 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 7 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 8 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 9 removal ===


  return forward_call(*args, **kwargs)


    ffn: [0.494, 0.63375, 0.7525, 0.78425, 0.80475, 0.812, 0.81375, 0.81875, 0.82225, 0.82725]
=== SAVED RESULTS TO drive/MyDrive/Summer2025/w266/Project/results/qnli_results.json ===


In [None]:
#mnli
mnli_test = mnli_test.map(preprocess_mnli, batched=True, fn_kwargs={'tokenizer': bert_tokenizer})
mnli_res = run_experiment(dataset_name="mnli",
                          test_data=mnli_test,
                          tokenizer=bert_tokenizer)

Map:   0%|          | 0/4000 [00:00<?, ? examples/s]


=== STARTING MNLI EXPERIMENTS ===

=== EPOCH 1 ===


  return forward_call(*args, **kwargs)


  epoch1 | Control: 0.6170

=== Evaluating decile 0 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 1 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 2 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 3 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 4 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 5 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 6 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 7 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 8 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 9 removal ===


  return forward_call(*args, **kwargs)


    all: [0.33925, 0.37875, 0.5625, 0.58925, 0.597, 0.6125, 0.62575, 0.62575, 0.62725, 0.514]

=== Evaluating decile 0 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 1 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 2 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 3 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 4 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 5 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 6 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 7 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 8 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 9 removal ===


  return forward_call(*args, **kwargs)


    query: [0.38875, 0.608, 0.6145, 0.6235, 0.6175, 0.6175, 0.619, 0.61725, 0.61775, 0.61725]

=== Evaluating decile 0 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 1 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 2 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 3 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 4 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 5 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 6 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 7 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 8 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 9 removal ===


  return forward_call(*args, **kwargs)


    key: [0.36125, 0.614, 0.62625, 0.6205, 0.62475, 0.617, 0.622, 0.61825, 0.6185, 0.618]

=== Evaluating decile 0 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 1 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 2 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 3 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 4 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 5 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 6 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 7 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 8 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 9 removal ===


  return forward_call(*args, **kwargs)


    value: [0.357, 0.61, 0.61525, 0.62, 0.6145, 0.6195, 0.6245, 0.62375, 0.6255, 0.621]

=== Evaluating decile 0 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 1 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 2 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 3 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 4 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 5 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 6 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 7 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 8 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 9 removal ===


  return forward_call(*args, **kwargs)


    output: [0.39475, 0.61075, 0.61825, 0.621, 0.6165, 0.61625, 0.61675, 0.61625, 0.6185, 0.618]

=== Evaluating decile 0 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 1 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 2 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 3 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 4 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 5 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 6 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 7 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 8 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 9 removal ===


  return forward_call(*args, **kwargs)


    ffn: [0.36, 0.47425, 0.596, 0.6125, 0.61175, 0.62125, 0.6265, 0.6245, 0.61825, 0.497]

=== EPOCH 2 ===


  return forward_call(*args, **kwargs)


  epoch2 | Control: 0.6565

=== Evaluating decile 0 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 1 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 2 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 3 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 4 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 5 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 6 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 7 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 8 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 9 removal ===


  return forward_call(*args, **kwargs)


    all: [0.36025, 0.37475, 0.57625, 0.611, 0.63225, 0.65525, 0.6565, 0.66675, 0.66625, 0.59575]

=== Evaluating decile 0 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 1 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 2 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 3 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 4 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 5 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 6 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 7 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 8 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 9 removal ===


  return forward_call(*args, **kwargs)


    query: [0.3905, 0.63075, 0.6535, 0.657, 0.65525, 0.6555, 0.65775, 0.6575, 0.6545, 0.65425]

=== Evaluating decile 0 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 1 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 2 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 3 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 4 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 5 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 6 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 7 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 8 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 9 removal ===


  return forward_call(*args, **kwargs)


    key: [0.35875, 0.635, 0.6525, 0.6555, 0.65975, 0.6585, 0.65625, 0.6555, 0.65675, 0.656]

=== Evaluating decile 0 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 1 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 2 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 3 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 4 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 5 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 6 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 7 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 8 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 9 removal ===


  return forward_call(*args, **kwargs)


    value: [0.31225, 0.6485, 0.65625, 0.65775, 0.66075, 0.6595, 0.6615, 0.66075, 0.6605, 0.6555]

=== Evaluating decile 0 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 1 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 2 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 3 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 4 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 5 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 6 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 7 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 8 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 9 removal ===


  return forward_call(*args, **kwargs)


    output: [0.4195, 0.64725, 0.6475, 0.65725, 0.65925, 0.659, 0.6555, 0.654, 0.655, 0.65525]

=== Evaluating decile 0 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 1 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 2 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 3 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 4 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 5 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 6 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 7 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 8 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 9 removal ===


  return forward_call(*args, **kwargs)


    ffn: [0.36, 0.5175, 0.632, 0.64925, 0.65675, 0.6635, 0.6595, 0.662, 0.661, 0.56625]

=== EPOCH 3 ===


  return forward_call(*args, **kwargs)


  epoch3 | Control: 0.6690

=== Evaluating decile 0 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 1 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 2 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 3 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 4 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 5 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 6 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 7 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 8 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 9 removal ===


  return forward_call(*args, **kwargs)


    all: [0.3595, 0.3695, 0.55575, 0.601, 0.6315, 0.6695, 0.67025, 0.67975, 0.67625, 0.627]

=== Evaluating decile 0 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 1 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 2 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 3 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 4 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 5 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 6 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 7 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 8 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 9 removal ===


  return forward_call(*args, **kwargs)


    query: [0.37175, 0.63425, 0.65325, 0.66275, 0.66775, 0.668, 0.669, 0.67, 0.67, 0.665]

=== Evaluating decile 0 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 1 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 2 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 3 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 4 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 5 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 6 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 7 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 8 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 9 removal ===


  return forward_call(*args, **kwargs)


    key: [0.36, 0.626, 0.65675, 0.664, 0.665, 0.6695, 0.66875, 0.667, 0.6685, 0.66775]

=== Evaluating decile 0 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 1 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 2 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 3 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 4 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 5 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 6 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 7 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 8 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 9 removal ===


  return forward_call(*args, **kwargs)


    value: [0.31125, 0.65325, 0.66375, 0.66625, 0.6645, 0.6695, 0.674, 0.675, 0.67725, 0.6715]

=== Evaluating decile 0 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 1 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 2 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 3 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 4 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 5 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 6 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 7 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 8 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 9 removal ===


  return forward_call(*args, **kwargs)


    output: [0.4135, 0.662, 0.65225, 0.661, 0.665, 0.66, 0.66625, 0.6655, 0.6705, 0.668]

=== Evaluating decile 0 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 1 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 2 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 3 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 4 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 5 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 6 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 7 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 8 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 9 removal ===


  return forward_call(*args, **kwargs)


    ffn: [0.36, 0.51625, 0.6355, 0.658, 0.67075, 0.67525, 0.67225, 0.67625, 0.6715, 0.6165]

=== EPOCH 10 ===


  return forward_call(*args, **kwargs)


  epoch10 | Control: 0.6895

=== Evaluating decile 0 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 1 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 2 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 3 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 4 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 5 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 6 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 7 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 8 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 9 removal ===


  return forward_call(*args, **kwargs)


    all: [0.31175, 0.374, 0.5775, 0.6385, 0.657, 0.67975, 0.6905, 0.70075, 0.6915, 0.66875]

=== Evaluating decile 0 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 1 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 2 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 3 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 4 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 5 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 6 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 7 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 8 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 9 removal ===


  return forward_call(*args, **kwargs)


    query: [0.3715, 0.64725, 0.684, 0.6935, 0.6925, 0.694, 0.69575, 0.69475, 0.691, 0.69]

=== Evaluating decile 0 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 1 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 2 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 3 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 4 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 5 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 6 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 7 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 8 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 9 removal ===


  return forward_call(*args, **kwargs)


    key: [0.36075, 0.65225, 0.6875, 0.69125, 0.69425, 0.69325, 0.69275, 0.69225, 0.6925, 0.69075]

=== Evaluating decile 0 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 1 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 2 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 3 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 4 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 5 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 6 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 7 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 8 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 9 removal ===


  return forward_call(*args, **kwargs)


    value: [0.31075, 0.6845, 0.69575, 0.69575, 0.697, 0.68875, 0.69275, 0.69375, 0.6925, 0.692]

=== Evaluating decile 0 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 1 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 2 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 3 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 4 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 5 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 6 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 7 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 8 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 9 removal ===


  return forward_call(*args, **kwargs)


    output: [0.4645, 0.69025, 0.6875, 0.693, 0.69575, 0.69025, 0.68825, 0.691, 0.69325, 0.6915]

=== Evaluating decile 0 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 1 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 2 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 3 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 4 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 5 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 6 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 7 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 8 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 9 removal ===


  return forward_call(*args, **kwargs)


    ffn: [0.36, 0.5265, 0.651, 0.66775, 0.6835, 0.69075, 0.69675, 0.69725, 0.68925, 0.67175]
=== SAVED RESULTS TO drive/MyDrive/Summer2025/w266/Project/results/mnli_results.json ===


In [None]:
model_rte1 = AutoModelForSequenceClassification.from_pretrained(
    "drive/MyDrive/Summer2025/w266/Project/models/rte_epoch1"
)

In [None]:
test_processed = rte_dataset['validation'].map(preprocess_rte, batched=True, fn_kwargs={'tokenizer': bert_tokenizer})

accuracies = evaluate_decile_removal(
    control_model=model_rte1,
    tokenizer=bert_tokenizer,
    test_data=test_processed,
    compute_metrics=compute_metrics,
    layer_type="query", #all, query, key, value, output, ffn
    layers=[0, 1, 2, 3]
)



=== Evaluating decile 0 removal ===



=== Evaluating decile 1 removal ===



=== Evaluating decile 2 removal ===



=== Evaluating decile 3 removal ===



=== Evaluating decile 4 removal ===



=== Evaluating decile 5 removal ===



=== Evaluating decile 6 removal ===



=== Evaluating decile 7 removal ===



=== Evaluating decile 8 removal ===



=== Evaluating decile 9 removal ===


In [None]:
res = {}
for layer_type in ["query", "key", "value", "output", "ffn"]:
  accs = evaluate_decile_removal(
    control_model=trainer3.model,
    tokenizer=tokenizer,
    trainer_args=trainer3.args,
    compute_metrics=trainer3.compute_metrics,
    test_data=test_processed,
    layer_type=layer_type, #all, query, key, value, output, ffn
    layers="all"
)
  res[layer_type] = accs


=== Evaluating decile 0 removal ===



=== Evaluating decile 1 removal ===



=== Evaluating decile 2 removal ===



=== Evaluating decile 3 removal ===



=== Evaluating decile 4 removal ===



=== Evaluating decile 5 removal ===



=== Evaluating decile 6 removal ===



=== Evaluating decile 7 removal ===



=== Evaluating decile 8 removal ===



=== Evaluating decile 9 removal ===



=== Evaluating decile 0 removal ===



=== Evaluating decile 1 removal ===



=== Evaluating decile 2 removal ===



=== Evaluating decile 3 removal ===



=== Evaluating decile 4 removal ===



=== Evaluating decile 5 removal ===



=== Evaluating decile 6 removal ===



=== Evaluating decile 7 removal ===



=== Evaluating decile 8 removal ===



=== Evaluating decile 9 removal ===



=== Evaluating decile 0 removal ===



=== Evaluating decile 1 removal ===



=== Evaluating decile 2 removal ===



=== Evaluating decile 3 removal ===



=== Evaluating decile 4 removal ===



=== Evaluating decile 5 removal ===



=== Evaluating decile 6 removal ===



=== Evaluating decile 7 removal ===



=== Evaluating decile 8 removal ===



=== Evaluating decile 9 removal ===



=== Evaluating decile 0 removal ===



=== Evaluating decile 1 removal ===



=== Evaluating decile 2 removal ===



=== Evaluating decile 3 removal ===



=== Evaluating decile 4 removal ===



=== Evaluating decile 5 removal ===



=== Evaluating decile 6 removal ===



=== Evaluating decile 7 removal ===



=== Evaluating decile 8 removal ===



=== Evaluating decile 9 removal ===



=== Evaluating decile 0 removal ===



=== Evaluating decile 1 removal ===



=== Evaluating decile 2 removal ===



=== Evaluating decile 3 removal ===



=== Evaluating decile 4 removal ===



=== Evaluating decile 5 removal ===



=== Evaluating decile 6 removal ===



=== Evaluating decile 7 removal ===



=== Evaluating decile 8 removal ===



=== Evaluating decile 9 removal ===


### FFN Layer Experiments

In [18]:
bert_tokenizer = AutoTokenizer.from_pretrained('bert-base-cased')

tokenizer_config.json:   0%|          | 0.00/49.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/213k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/436k [00:00<?, ?B/s]

In [19]:
imdb_test = test_medium.map(preprocess_imdb, batched=True, fn_kwargs={'tokenizer': bert_tokenizer})
imdb_res = run_ffn_experiment(dataset_name="imdb",
                          test_data=imdb_test,
                          tokenizer=bert_tokenizer)

Map:   0%|          | 0/5000 [00:00<?, ? examples/s]


=== STARTING IMDB EXPERIMENTS ===

=== EPOCH 1 ===


  return forward_call(*args, **kwargs)


  epoch1 | Control: 0.8474

=== Evaluating decile 0 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 1 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 2 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 3 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 4 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 5 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 6 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 7 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 8 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 9 removal ===


  return forward_call(*args, **kwargs)


    Group 0-3: [0.6358, 0.8252, 0.8374, 0.8462, 0.844, 0.8478, 0.8456, 0.8476, 0.8488, 0.8456]

=== Evaluating decile 0 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 1 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 2 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 3 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 4 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 5 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 6 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 7 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 8 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 9 removal ===


  return forward_call(*args, **kwargs)


    Group 4-7: [0.612, 0.8258, 0.8386, 0.8422, 0.8438, 0.8464, 0.8458, 0.8456, 0.8462, 0.8434]

=== Evaluating decile 0 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 1 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 2 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 3 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 4 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 5 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 6 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 7 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 8 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 9 removal ===


  return forward_call(*args, **kwargs)


    Group 8-11: [0.7184, 0.845, 0.8478, 0.8468, 0.85, 0.8462, 0.848, 0.8478, 0.8486, 0.8474]

=== EPOCH 2 ===


  return forward_call(*args, **kwargs)


  epoch2 | Control: 0.8496

=== Evaluating decile 0 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 1 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 2 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 3 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 4 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 5 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 6 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 7 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 8 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 9 removal ===


  return forward_call(*args, **kwargs)


    Group 0-3: [0.6886, 0.8434, 0.8474, 0.8482, 0.8454, 0.8506, 0.8498, 0.851, 0.8496, 0.849]

=== Evaluating decile 0 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 1 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 2 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 3 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 4 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 5 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 6 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 7 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 8 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 9 removal ===


  return forward_call(*args, **kwargs)


    Group 4-7: [0.5824, 0.839, 0.8404, 0.8432, 0.8478, 0.8474, 0.8488, 0.8484, 0.8494, 0.845]

=== Evaluating decile 0 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 1 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 2 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 3 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 4 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 5 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 6 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 7 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 8 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 9 removal ===


  return forward_call(*args, **kwargs)


    Group 8-11: [0.7616, 0.8456, 0.8472, 0.85, 0.847, 0.8484, 0.848, 0.8498, 0.8516, 0.851]

=== EPOCH 3 ===


  return forward_call(*args, **kwargs)


  epoch3 | Control: 0.8530

=== Evaluating decile 0 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 1 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 2 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 3 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 4 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 5 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 6 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 7 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 8 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 9 removal ===


  return forward_call(*args, **kwargs)


    Group 0-3: [0.6674, 0.849, 0.8462, 0.8506, 0.8486, 0.8466, 0.8514, 0.8512, 0.8504, 0.8452]

=== Evaluating decile 0 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 1 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 2 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 3 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 4 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 5 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 6 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 7 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 8 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 9 removal ===


  return forward_call(*args, **kwargs)


    Group 4-7: [0.532, 0.8438, 0.8408, 0.8396, 0.8454, 0.8472, 0.8494, 0.849, 0.8492, 0.8482]

=== Evaluating decile 0 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 1 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 2 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 3 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 4 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 5 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 6 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 7 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 8 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 9 removal ===


  return forward_call(*args, **kwargs)


    Group 8-11: [0.7642, 0.845, 0.843, 0.85, 0.8466, 0.8522, 0.8502, 0.8528, 0.8532, 0.8504]

=== EPOCH 10 ===


  return forward_call(*args, **kwargs)


  epoch10 | Control: 0.8456

=== Evaluating decile 0 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 1 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 2 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 3 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 4 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 5 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 6 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 7 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 8 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 9 removal ===


  return forward_call(*args, **kwargs)


    Group 0-3: [0.5734, 0.8416, 0.8424, 0.8394, 0.845, 0.8412, 0.8436, 0.846, 0.8454, 0.8442]

=== Evaluating decile 0 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 1 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 2 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 3 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 4 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 5 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 6 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 7 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 8 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 9 removal ===


  return forward_call(*args, **kwargs)


    Group 4-7: [0.4892, 0.8424, 0.8376, 0.8426, 0.8444, 0.8446, 0.8442, 0.8456, 0.8432, 0.8452]

=== Evaluating decile 0 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 1 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 2 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 3 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 4 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 5 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 6 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 7 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 8 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 9 removal ===


  return forward_call(*args, **kwargs)


    Group 8-11: [0.8214, 0.8432, 0.8442, 0.8442, 0.8442, 0.8432, 0.8442, 0.8436, 0.8454, 0.8462]
=== SAVED RESULTS TO drive/MyDrive/Summer2025/w266/Project/results/imdb_layer_results.json ===


In [20]:
rte_test = rte_dataset["validation"].map(preprocess_rte, batched=True, fn_kwargs={'tokenizer': bert_tokenizer})
rte_res = run_ffn_experiment(dataset_name="rte",
                          test_data=rte_test,
                          tokenizer=bert_tokenizer)

Map:   0%|          | 0/277 [00:00<?, ? examples/s]


=== STARTING RTE EXPERIMENTS ===

=== EPOCH 1 ===


  return forward_call(*args, **kwargs)


  epoch1 | Control: 0.6318

=== Evaluating decile 0 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 1 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 2 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 3 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 4 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 5 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 6 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 7 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 8 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 9 removal ===


  return forward_call(*args, **kwargs)


    Group 0-3: [0.4729241877256318, 0.6245487364620939, 0.631768953068592, 0.6353790613718412, 0.628158844765343, 0.6173285198555957, 0.631768953068592, 0.6245487364620939, 0.6209386281588448, 0.5234657039711191]

=== Evaluating decile 0 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 1 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 2 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 3 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 4 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 5 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 6 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 7 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 8 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 9 removal ===


  return forward_call(*args, **kwargs)


    Group 4-7: [0.4729241877256318, 0.6173285198555957, 0.6137184115523465, 0.5956678700361011, 0.6209386281588448, 0.628158844765343, 0.6245487364620939, 0.6245487364620939, 0.6173285198555957, 0.6028880866425993]

=== Evaluating decile 0 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 1 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 2 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 3 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 4 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 5 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 6 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 7 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 8 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 9 removal ===


  return forward_call(*args, **kwargs)


    Group 8-11: [0.5812274368231047, 0.5703971119133574, 0.5631768953068592, 0.5487364620938628, 0.5812274368231047, 0.5776173285198556, 0.5956678700361011, 0.5884476534296029, 0.5848375451263538, 0.6209386281588448]

=== EPOCH 2 ===


  return forward_call(*args, **kwargs)


  epoch2 | Control: 0.6570

=== Evaluating decile 0 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 1 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 2 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 3 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 4 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 5 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 6 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 7 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 8 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 9 removal ===


  return forward_call(*args, **kwargs)


    Group 0-3: [0.4729241877256318, 0.6570397111913358, 0.6534296028880866, 0.6498194945848376, 0.6606498194945848, 0.6606498194945848, 0.6787003610108303, 0.6642599277978339, 0.6498194945848376, 0.6064981949458483]

=== Evaluating decile 0 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 1 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 2 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 3 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 4 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 5 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 6 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 7 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 8 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 9 removal ===


  return forward_call(*args, **kwargs)


    Group 4-7: [0.4729241877256318, 0.6498194945848376, 0.6498194945848376, 0.6425992779783394, 0.6534296028880866, 0.6425992779783394, 0.6678700361010831, 0.6462093862815884, 0.6425992779783394, 0.6389891696750902]

=== Evaluating decile 0 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 1 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 2 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 3 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 4 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 5 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 6 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 7 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 8 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 9 removal ===


  return forward_call(*args, **kwargs)


    Group 8-11: [0.6101083032490975, 0.628158844765343, 0.6353790613718412, 0.6245487364620939, 0.628158844765343, 0.6245487364620939, 0.631768953068592, 0.6462093862815884, 0.6425992779783394, 0.6353790613718412]

=== EPOCH 3 ===


  return forward_call(*args, **kwargs)


  epoch3 | Control: 0.6643

=== Evaluating decile 0 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 1 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 2 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 3 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 4 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 5 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 6 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 7 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 8 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 9 removal ===


  return forward_call(*args, **kwargs)


    Group 0-3: [0.47653429602888087, 0.6389891696750902, 0.631768953068592, 0.6425992779783394, 0.631768953068592, 0.6462093862815884, 0.6642599277978339, 0.6462093862815884, 0.6534296028880866, 0.6353790613718412]

=== Evaluating decile 0 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 1 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 2 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 3 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 4 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 5 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 6 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 7 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 8 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 9 removal ===


  return forward_call(*args, **kwargs)


    Group 4-7: [0.4729241877256318, 0.6714801444043321, 0.6714801444043321, 0.6498194945848376, 0.6389891696750902, 0.6534296028880866, 0.6462093862815884, 0.6462093862815884, 0.6534296028880866, 0.6353790613718412]

=== Evaluating decile 0 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 1 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 2 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 3 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 4 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 5 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 6 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 7 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 8 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 9 removal ===


  return forward_call(*args, **kwargs)


    Group 8-11: [0.49458483754512633, 0.6642599277978339, 0.6678700361010831, 0.6462093862815884, 0.6678700361010831, 0.6714801444043321, 0.6642599277978339, 0.6534296028880866, 0.6570397111913358, 0.6570397111913358]

=== EPOCH 10 ===


  return forward_call(*args, **kwargs)


  epoch10 | Control: 0.6426

=== Evaluating decile 0 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 1 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 2 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 3 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 4 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 5 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 6 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 7 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 8 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 9 removal ===


  return forward_call(*args, **kwargs)


    Group 0-3: [0.47653429602888087, 0.6137184115523465, 0.631768953068592, 0.6209386281588448, 0.631768953068592, 0.6425992779783394, 0.628158844765343, 0.6389891696750902, 0.6425992779783394, 0.6498194945848376]

=== Evaluating decile 0 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 1 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 2 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 3 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 4 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 5 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 6 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 7 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 8 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 9 removal ===


  return forward_call(*args, **kwargs)


    Group 4-7: [0.4729241877256318, 0.6570397111913358, 0.6534296028880866, 0.6534296028880866, 0.6498194945848376, 0.6425992779783394, 0.6498194945848376, 0.6389891696750902, 0.6498194945848376, 0.6462093862815884]

=== Evaluating decile 0 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 1 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 2 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 3 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 4 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 5 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 6 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 7 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 8 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 9 removal ===


  return forward_call(*args, **kwargs)


    Group 8-11: [0.4981949458483754, 0.6570397111913358, 0.6534296028880866, 0.6570397111913358, 0.6534296028880866, 0.6570397111913358, 0.6534296028880866, 0.6534296028880866, 0.6498194945848376, 0.631768953068592]
=== SAVED RESULTS TO drive/MyDrive/Summer2025/w266/Project/results/rte_layer_results.json ===


In [21]:
#qqp
qqp_test = qqp_test.map(preprocess_qqp, batched=True, fn_kwargs={'tokenizer': bert_tokenizer})
qqp_res = run_ffn_experiment(dataset_name="qqp",
                          test_data=qqp_test,
                          tokenizer=bert_tokenizer)

Map:   0%|          | 0/4000 [00:00<?, ? examples/s]


=== STARTING QQP EXPERIMENTS ===

=== EPOCH 1 ===


  return forward_call(*args, **kwargs)


  epoch1 | Control: 0.7600

=== Evaluating decile 0 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 1 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 2 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 3 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 4 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 5 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 6 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 7 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 8 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 9 removal ===


  return forward_call(*args, **kwargs)


    Group 0-3: [0.63175, 0.74925, 0.758, 0.75525, 0.75825, 0.75975, 0.75875, 0.75825, 0.758, 0.7315]

=== Evaluating decile 0 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 1 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 2 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 3 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 4 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 5 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 6 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 7 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 8 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 9 removal ===


  return forward_call(*args, **kwargs)


    Group 4-7: [0.63175, 0.755, 0.75375, 0.75875, 0.76025, 0.7565, 0.75775, 0.76, 0.762, 0.76375]

=== Evaluating decile 0 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 1 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 2 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 3 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 4 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 5 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 6 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 7 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 8 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 9 removal ===


  return forward_call(*args, **kwargs)


    Group 8-11: [0.4085, 0.75275, 0.75775, 0.759, 0.75725, 0.757, 0.756, 0.75775, 0.761, 0.75275]

=== EPOCH 2 ===


  return forward_call(*args, **kwargs)


  epoch2 | Control: 0.7760

=== Evaluating decile 0 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 1 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 2 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 3 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 4 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 5 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 6 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 7 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 8 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 9 removal ===


  return forward_call(*args, **kwargs)


    Group 0-3: [0.63175, 0.76975, 0.773, 0.77425, 0.773, 0.7755, 0.77825, 0.7725, 0.7765, 0.7615]

=== Evaluating decile 0 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 1 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 2 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 3 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 4 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 5 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 6 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 7 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 8 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 9 removal ===


  return forward_call(*args, **kwargs)


    Group 4-7: [0.63175, 0.77225, 0.767, 0.775, 0.77425, 0.7735, 0.7725, 0.77975, 0.77525, 0.779]

=== Evaluating decile 0 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 1 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 2 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 3 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 4 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 5 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 6 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 7 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 8 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 9 removal ===


  return forward_call(*args, **kwargs)


    Group 8-11: [0.3685, 0.766, 0.768, 0.76975, 0.7725, 0.7715, 0.772, 0.77125, 0.77475, 0.76375]

=== EPOCH 3 ===


  return forward_call(*args, **kwargs)


  epoch3 | Control: 0.7805

=== Evaluating decile 0 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 1 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 2 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 3 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 4 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 5 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 6 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 7 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 8 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 9 removal ===


  return forward_call(*args, **kwargs)


    Group 0-3: [0.63175, 0.773, 0.77875, 0.78, 0.77675, 0.77775, 0.77725, 0.78, 0.776, 0.76375]

=== Evaluating decile 0 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 1 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 2 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 3 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 4 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 5 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 6 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 7 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 8 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 9 removal ===


  return forward_call(*args, **kwargs)


    Group 4-7: [0.63175, 0.76825, 0.76925, 0.77825, 0.77325, 0.7785, 0.778, 0.7775, 0.781, 0.7745]

=== Evaluating decile 0 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 1 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 2 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 3 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 4 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 5 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 6 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 7 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 8 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 9 removal ===


  return forward_call(*args, **kwargs)


    Group 8-11: [0.36825, 0.76875, 0.76225, 0.774, 0.77375, 0.77525, 0.7765, 0.78, 0.78375, 0.77225]

=== EPOCH 10 ===


  return forward_call(*args, **kwargs)


  epoch10 | Control: 0.7802

=== Evaluating decile 0 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 1 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 2 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 3 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 4 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 5 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 6 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 7 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 8 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 9 removal ===


  return forward_call(*args, **kwargs)


    Group 0-3: [0.6315, 0.7775, 0.77925, 0.7775, 0.78, 0.78025, 0.7815, 0.78075, 0.7775, 0.7815]

=== Evaluating decile 0 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 1 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 2 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 3 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 4 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 5 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 6 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 7 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 8 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 9 removal ===


  return forward_call(*args, **kwargs)


    Group 4-7: [0.63425, 0.77675, 0.77475, 0.7795, 0.78, 0.77975, 0.77625, 0.77875, 0.77625, 0.7715]

=== Evaluating decile 0 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 1 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 2 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 3 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 4 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 5 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 6 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 7 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 8 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 9 removal ===


  return forward_call(*args, **kwargs)


    Group 8-11: [0.36825, 0.774, 0.766, 0.77425, 0.77725, 0.77225, 0.77325, 0.78125, 0.78325, 0.7825]
=== SAVED RESULTS TO drive/MyDrive/Summer2025/w266/Project/results/qqp_layer_results.json ===


In [22]:
#qnli
qnli_test = qnli_test.map(preprocess_qnli, batched=True, fn_kwargs={'tokenizer': bert_tokenizer})
qnli_res = run_ffn_experiment(dataset_name="qnli",
                          test_data=qnli_test,
                          tokenizer=bert_tokenizer)

Map:   0%|          | 0/4000 [00:00<?, ? examples/s]


=== STARTING QNLI EXPERIMENTS ===

=== EPOCH 1 ===


  return forward_call(*args, **kwargs)


  epoch1 | Control: 0.8233

=== Evaluating decile 0 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 1 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 2 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 3 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 4 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 5 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 6 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 7 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 8 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 9 removal ===


  return forward_call(*args, **kwargs)


    Group 0-3: [0.502, 0.72625, 0.76675, 0.771, 0.79525, 0.8095, 0.8155, 0.818, 0.8175, 0.8235]

=== Evaluating decile 0 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 1 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 2 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 3 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 4 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 5 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 6 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 7 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 8 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 9 removal ===


  return forward_call(*args, **kwargs)


    Group 4-7: [0.51675, 0.7785, 0.80475, 0.812, 0.813, 0.8095, 0.81875, 0.8165, 0.819, 0.805]

=== Evaluating decile 0 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 1 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 2 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 3 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 4 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 5 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 6 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 7 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 8 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 9 removal ===


  return forward_call(*args, **kwargs)


    Group 8-11: [0.49875, 0.81, 0.80875, 0.818, 0.818, 0.8235, 0.8225, 0.8285, 0.8255, 0.826]

=== EPOCH 2 ===


  return forward_call(*args, **kwargs)


  epoch2 | Control: 0.8283

=== Evaluating decile 0 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 1 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 2 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 3 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 4 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 5 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 6 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 7 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 8 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 9 removal ===


  return forward_call(*args, **kwargs)


    Group 0-3: [0.50125, 0.78625, 0.808, 0.81225, 0.82775, 0.8325, 0.8315, 0.8285, 0.8285, 0.82975]

=== Evaluating decile 0 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 1 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 2 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 3 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 4 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 5 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 6 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 7 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 8 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 9 removal ===


  return forward_call(*args, **kwargs)


    Group 4-7: [0.50025, 0.80425, 0.8265, 0.82925, 0.8265, 0.8265, 0.8315, 0.82575, 0.82875, 0.8245]

=== Evaluating decile 0 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 1 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 2 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 3 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 4 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 5 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 6 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 7 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 8 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 9 removal ===


  return forward_call(*args, **kwargs)


    Group 8-11: [0.49875, 0.8235, 0.823, 0.823, 0.8265, 0.82975, 0.82625, 0.8315, 0.8285, 0.83275]

=== EPOCH 3 ===


  return forward_call(*args, **kwargs)


  epoch3 | Control: 0.8235

=== Evaluating decile 0 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 1 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 2 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 3 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 4 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 5 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 6 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 7 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 8 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 9 removal ===


  return forward_call(*args, **kwargs)


    Group 0-3: [0.48875, 0.7965, 0.81325, 0.8215, 0.827, 0.8305, 0.82575, 0.82625, 0.82525, 0.82775]

=== Evaluating decile 0 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 1 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 2 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 3 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 4 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 5 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 6 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 7 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 8 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 9 removal ===


  return forward_call(*args, **kwargs)


    Group 4-7: [0.49875, 0.79975, 0.82025, 0.82325, 0.8265, 0.82675, 0.83, 0.82375, 0.826, 0.81975]

=== Evaluating decile 0 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 1 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 2 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 3 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 4 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 5 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 6 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 7 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 8 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 9 removal ===


  return forward_call(*args, **kwargs)


    Group 8-11: [0.49925, 0.82575, 0.82025, 0.827, 0.8275, 0.827, 0.82525, 0.8275, 0.8295, 0.82725]

=== EPOCH 10 ===


  return forward_call(*args, **kwargs)


  epoch10 | Control: 0.8337

=== Evaluating decile 0 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 1 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 2 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 3 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 4 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 5 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 6 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 7 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 8 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 9 removal ===


  return forward_call(*args, **kwargs)


    Group 0-3: [0.49825, 0.77925, 0.79775, 0.81275, 0.82475, 0.82875, 0.83175, 0.8315, 0.82975, 0.83325]

=== Evaluating decile 0 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 1 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 2 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 3 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 4 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 5 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 6 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 7 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 8 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 9 removal ===


  return forward_call(*args, **kwargs)


    Group 4-7: [0.49875, 0.793, 0.82325, 0.8265, 0.82225, 0.82525, 0.82875, 0.8245, 0.83225, 0.81925]

=== Evaluating decile 0 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 1 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 2 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 3 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 4 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 5 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 6 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 7 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 8 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 9 removal ===


  return forward_call(*args, **kwargs)


    Group 8-11: [0.49925, 0.8275, 0.829, 0.8255, 0.83, 0.83, 0.829, 0.838, 0.8345, 0.8395]
=== SAVED RESULTS TO drive/MyDrive/Summer2025/w266/Project/results/qnli_layer_results.json ===


In [23]:
#mnli
mnli_test = mnli_test.map(preprocess_mnli, batched=True, fn_kwargs={'tokenizer': bert_tokenizer})
mnli_res = run_ffn_experiment(dataset_name="mnli",
                          test_data=mnli_test,
                          tokenizer=bert_tokenizer)

Map:   0%|          | 0/4000 [00:00<?, ? examples/s]


=== STARTING MNLI EXPERIMENTS ===

=== EPOCH 1 ===


  return forward_call(*args, **kwargs)


  epoch1 | Control: 0.6035

=== Evaluating decile 0 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 1 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 2 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 3 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 4 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 5 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 6 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 7 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 8 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 9 removal ===


  return forward_call(*args, **kwargs)


    Group 0-3: [0.3535, 0.59775, 0.606, 0.60675, 0.61275, 0.612, 0.61225, 0.6045, 0.60925, 0.55425]

=== Evaluating decile 0 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 1 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 2 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 3 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 4 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 5 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 6 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 7 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 8 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 9 removal ===


  return forward_call(*args, **kwargs)


    Group 4-7: [0.31475, 0.60075, 0.604, 0.61225, 0.61175, 0.6095, 0.6075, 0.6085, 0.60975, 0.6165]

=== Evaluating decile 0 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 1 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 2 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 3 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 4 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 5 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 6 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 7 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 8 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 9 removal ===


  return forward_call(*args, **kwargs)


    Group 8-11: [0.37275, 0.60525, 0.6045, 0.60225, 0.607, 0.605, 0.60325, 0.6025, 0.602, 0.60475]

=== EPOCH 2 ===


  return forward_call(*args, **kwargs)


  epoch2 | Control: 0.6498

=== Evaluating decile 0 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 1 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 2 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 3 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 4 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 5 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 6 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 7 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 8 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 9 removal ===


  return forward_call(*args, **kwargs)


    Group 0-3: [0.36425, 0.6405, 0.65625, 0.65625, 0.656, 0.64825, 0.6565, 0.653, 0.64825, 0.59675]

=== Evaluating decile 0 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 1 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 2 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 3 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 4 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 5 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 6 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 7 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 8 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 9 removal ===


  return forward_call(*args, **kwargs)


    Group 4-7: [0.31825, 0.63325, 0.654, 0.654, 0.65125, 0.65375, 0.65275, 0.64975, 0.65175, 0.65275]

=== Evaluating decile 0 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 1 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 2 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 3 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 4 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 5 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 6 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 7 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 8 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 9 removal ===


  return forward_call(*args, **kwargs)


    Group 8-11: [0.3675, 0.647, 0.642, 0.648, 0.649, 0.64875, 0.649, 0.64525, 0.64825, 0.63775]

=== EPOCH 3 ===


  return forward_call(*args, **kwargs)


  epoch3 | Control: 0.6610

=== Evaluating decile 0 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 1 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 2 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 3 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 4 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 5 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 6 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 7 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 8 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 9 removal ===


  return forward_call(*args, **kwargs)


    Group 0-3: [0.36625, 0.66075, 0.665, 0.66925, 0.6725, 0.668, 0.6705, 0.66825, 0.66625, 0.62375]

=== Evaluating decile 0 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 1 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 2 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 3 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 4 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 5 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 6 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 7 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 8 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 9 removal ===


  return forward_call(*args, **kwargs)


    Group 4-7: [0.33125, 0.63375, 0.669, 0.663, 0.6705, 0.66825, 0.6645, 0.6605, 0.66275, 0.661]

=== Evaluating decile 0 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 1 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 2 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 3 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 4 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 5 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 6 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 7 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 8 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 9 removal ===


  return forward_call(*args, **kwargs)


    Group 8-11: [0.3675, 0.65375, 0.65625, 0.659, 0.661, 0.66, 0.6625, 0.66175, 0.6555, 0.6525]

=== EPOCH 10 ===


  return forward_call(*args, **kwargs)


  epoch10 | Control: 0.6787

=== Evaluating decile 0 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 1 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 2 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 3 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 4 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 5 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 6 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 7 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 8 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 9 removal ===


  return forward_call(*args, **kwargs)


    Group 0-3: [0.34925, 0.6605, 0.68, 0.6785, 0.6855, 0.68575, 0.68725, 0.68175, 0.6795, 0.6775]

=== Evaluating decile 0 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 1 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 2 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 3 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 4 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 5 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 6 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 7 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 8 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 9 removal ===


  return forward_call(*args, **kwargs)


    Group 4-7: [0.31375, 0.65375, 0.67575, 0.67125, 0.6815, 0.6775, 0.68025, 0.6815, 0.6835, 0.68325]

=== Evaluating decile 0 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 1 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 2 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 3 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 4 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 5 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 6 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 7 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 8 removal ===


  return forward_call(*args, **kwargs)



=== Evaluating decile 9 removal ===


  return forward_call(*args, **kwargs)


    Group 8-11: [0.4195, 0.66825, 0.66975, 0.679, 0.6785, 0.68375, 0.67475, 0.679, 0.6755, 0.67075]
=== SAVED RESULTS TO drive/MyDrive/Summer2025/w266/Project/results/mnli_layer_results.json ===
