In [None]:
from transformers import AutoModel, AutoTokenizer
import torch
import os
import numpy as np
import random
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import classification_report, f1_score
from sklearn.preprocessing import LabelEncoder
from torch.utils.data import DataLoader
from tqdm import tqdm
import pandas as pd
LABEL = 'mistake_identification'
MODEL = "answerdotai/ModernBERT-large"
SEED = 42
train_df = pd.read_csv("mistake_identification_train.csv")
val_df = pd.read_csv("mistake_identification_val.csv")


In [None]:
from torch.utils.data import Dataset

class TokenizedTextDataset(Dataset):
    def __init__(self, df, tokenizer, max_length=160):
        self.labels = df[LABEL].astype(int).tolist()
        self.encodings = tokenizer(
            df["response"].tolist(),
            truncation=True,
            padding=True,
            max_length=max_length,
            return_tensors="pt"
        )

    def __getitem__(self, idx):
        return {
            "input_ids": self.encodings["input_ids"][idx],
            "attention_mask": self.encodings["attention_mask"][idx],
            "labels": torch.tensor(self.labels[idx], dtype=torch.long)
        }

    def __len__(self):
        return len(self.labels)

# Tokenizer
tokenizer = AutoTokenizer.from_pretrained(MODEL)

# Datasets
train_ds_nocontext = TokenizedTextDataset(train_df, tokenizer)
val_ds_nocontext = TokenizedTextDataset(val_df, tokenizer)

# DataLoaders
train_loader = DataLoader(train_ds_nocontext, batch_size=32, shuffle=False)
val_loader = DataLoader(val_ds_nocontext, batch_size=32, shuffle=False)


In [None]:
from transformers import AutoModel, AutoTokenizer
import torch
import os
import numpy as np
import random
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import classification_report, f1_score
from sklearn.preprocessing import LabelEncoder
from torch.utils.data import DataLoader
from tqdm import tqdm
import pandas as pd

os.makedirs("layerwise_embeddings", exist_ok=True)
# SETTINGS
model_name = MODEL
layers_to_test = [1, 6, 9, 11, 12, 13, 14, 15, 17, 18, 21, 23, 24] #[Merchant et al., 2020 (BERTology papers)], most useful layers are around the middle-to-late region.
batch_size = 32
pooling = "cls"  # or "mean"

# Load model + tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModel.from_pretrained(model_name, output_hidden_states=True).eval().cuda()


def extract_layer_embeddings(dataloader, layer_idx, pooling="cls"):
    print("pooling:", pooling)
    all_embeds = []
    all_labels = []
    model.eval()

    with torch.no_grad():
        for batch in tqdm(dataloader):
            input_ids = batch["input_ids"].to(model.device)
            attention_mask = batch["attention_mask"].to(model.device)
            labels = batch["labels"].clone().detach()


            outputs = model(input_ids=input_ids, attention_mask=attention_mask)
            hidden_states = outputs.hidden_states  # list of tensors (layer+1) x [batch, seq_len, hidden]

            selected = hidden_states[layer_idx]

            if pooling == "cls":
                pooled = selected[:, 0, :]
            elif pooling == "mean":
                mask = attention_mask.unsqueeze(-1).expand(selected.size()).float()
                summed = torch.sum(selected * mask, 1)
                counts = torch.clamp(mask.sum(1), min=1e-9)
                pooled = summed / counts
            else:
                raise ValueError("Invalid pooling strategy")

            all_embeds.append(pooled.cpu())
            all_labels.append(labels)

    return torch.cat(all_embeds), torch.cat(all_labels)

results = []

for layer in layers_to_test:
    print(f"\n==> Extracting features from layer {layer}")

    X_train, y_train = extract_layer_embeddings(train_loader, layer_idx=layer, pooling="mean")
    X_val, y_val = extract_layer_embeddings(val_loader, layer_idx=layer, pooling="mean")

    # can also try knn, svm or mlp later
    torch.save({
        "X_train": X_train,
        "y_train": y_train,
        "X_val": X_val,
        "y_val": y_val
    }, f"layerwise_embeddings/layer_{layer}.pt")

    clf = LogisticRegression(max_iter=10000)
    clf.fit(np.asarray(X_train), np.asarray(y_train))
    y_pred = clf.predict(np.asarray(X_val))
    f1 = f1_score(np.asarray(y_val), y_pred, average="macro")



    results.append({
        "layer": layer,
        "macro_f1": f1
    })

df_layers_mean = pd.DataFrame(results)




==> Extracting features from layer 1
pooling: mean


100%|██████████| 62/62 [00:13<00:00,  4.47it/s]


pooling: mean


100%|██████████| 16/16 [00:03<00:00,  4.68it/s]



==> Extracting features from layer 6
pooling: mean


100%|██████████| 62/62 [00:13<00:00,  4.51it/s]


pooling: mean


100%|██████████| 16/16 [00:03<00:00,  4.69it/s]



==> Extracting features from layer 9
pooling: mean


100%|██████████| 62/62 [00:13<00:00,  4.51it/s]


pooling: mean


100%|██████████| 16/16 [00:03<00:00,  4.69it/s]



==> Extracting features from layer 11
pooling: mean


100%|██████████| 62/62 [00:13<00:00,  4.52it/s]


pooling: mean


100%|██████████| 16/16 [00:03<00:00,  4.69it/s]



==> Extracting features from layer 12
pooling: mean


100%|██████████| 62/62 [00:13<00:00,  4.51it/s]


pooling: mean


100%|██████████| 16/16 [00:03<00:00,  4.70it/s]



==> Extracting features from layer 13
pooling: mean


100%|██████████| 62/62 [00:13<00:00,  4.51it/s]


pooling: mean


100%|██████████| 16/16 [00:03<00:00,  4.69it/s]



==> Extracting features from layer 14
pooling: mean


100%|██████████| 62/62 [00:13<00:00,  4.52it/s]


pooling: mean


100%|██████████| 16/16 [00:03<00:00,  4.69it/s]



==> Extracting features from layer 15
pooling: mean


100%|██████████| 62/62 [00:13<00:00,  4.51it/s]


pooling: mean


100%|██████████| 16/16 [00:03<00:00,  4.69it/s]



==> Extracting features from layer 17
pooling: mean


100%|██████████| 62/62 [00:13<00:00,  4.51it/s]


pooling: mean


100%|██████████| 16/16 [00:03<00:00,  4.69it/s]



==> Extracting features from layer 18
pooling: mean


100%|██████████| 62/62 [00:13<00:00,  4.51it/s]


pooling: mean


100%|██████████| 16/16 [00:03<00:00,  4.70it/s]



==> Extracting features from layer 21
pooling: mean


100%|██████████| 62/62 [00:13<00:00,  4.52it/s]


pooling: mean


100%|██████████| 16/16 [00:03<00:00,  4.70it/s]
STOP: TOTAL NO. OF ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
  n_iter_i = _check_optimize_result(



==> Extracting features from layer 23
pooling: mean


100%|██████████| 62/62 [00:13<00:00,  4.53it/s]


pooling: mean


100%|██████████| 16/16 [00:03<00:00,  4.70it/s]
STOP: TOTAL NO. OF ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
  n_iter_i = _check_optimize_result(



==> Extracting features from layer 24
pooling: mean


100%|██████████| 62/62 [00:13<00:00,  4.51it/s]


pooling: mean


100%|██████████| 16/16 [00:03<00:00,  4.69it/s]
STOP: TOTAL NO. OF ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
  n_iter_i = _check_optimize_result(


In [None]:
df_layers_mean

Unnamed: 0,layer,macro_f1
0,1,0.624561
1,6,0.64783
2,9,0.649626
3,11,0.622133
4,12,0.610932
5,13,0.623624
6,14,0.609708
7,15,0.591431
8,17,0.576206
9,18,0.610633


In [None]:
df_layers_cls

Unnamed: 0,layer,macro_f1
0,0,0.294029
1,1,0.632581
2,6,0.602586
3,9,0.62688
4,11,0.620278
5,12,0.63936
6,15,0.645377
7,17,0.609596
8,18,0.630895
9,21,0.596523
