This code uses the set fit model to preform task 2

In [1]:
!pip install -q setfit datasets scikit-learn

import os
import sys
import shutil

import numpy as np
import pandas as pd

from datasets import load_dataset, Dataset
from setfit import SetFitModel, Trainer, TrainingArguments
from sklearn.metrics import f1_score, classification_report

os.environ["WANDB_DISABLED"] = "true"


# Load QEvasion data
print("Loading QEvasion dataset...")
raw_ds = load_dataset("ailsntua/QEvasion")


def combine_text(example):
    q = example.get("question") or ""
    a = example.get("interview_answer") or ""
    example["text"] = f"Question: {q}\nAnswer: {a}"
    return example


train_all = raw_ds["train"].map(combine_text)

print(f"Rows in HF train split: {len(train_all)}")
print("Columns:", train_all.column_names)

if "evasion_label" not in train_all.column_names:
    sys.exit("FATAL: 'evasion_label' column not found in HF train split.")

splits = train_all.train_test_split(test_size=0.1, seed=42)
train_ds = splits["train"]
dev_ds = splits["test"]

print(f"Train split: {len(train_ds)} rows | Dev split: {len(dev_ds)} rows")

if "test" in raw_ds:
    test_public = raw_ds["test"].map(combine_text)
    print(f"Public test split rows: {len(test_public)}")
else:
    test_public = None
    print("No public 'test' split found; skipping CSV later.")


# Balance train split (strict equalization)
print("\nBalancing train data (strict equalization)...")

df_train = train_ds.to_pandas()
df_train = df_train.dropna(subset=["text", "evasion_label"])

counts = df_train["evasion_label"].value_counts()
print("Original train label counts:\n", counts, "\n")

MIN_REQUIRED = 8
valid_labels = counts[counts >= MIN_REQUIRED].index.tolist()

if len(valid_labels) < 2:
    sys.exit("FATAL: Need at least 2 classes with MIN_REQUIRED samples.")

target_size = min(counts[valid_labels].min(), 100)
print(f"Target samples per class: {target_size}")
print(f"Classes used for training: {valid_labels}\n")

balanced_dfs = []
for label in valid_labels:
    subset = df_train[df_train["evasion_label"] == label]
    if len(subset) < target_size:
        sys.exit(
            f"FATAL: Label '{label}' has only {len(subset)} samples, "
            f"which is < target_size={target_size}. Reduce MIN_REQUIRED or target_size."
        )
    balanced_dfs.append(subset.sample(n=target_size, random_state=42))

df_balanced = (
    pd.concat(balanced_dfs)
    .sample(frac=1, random_state=42)
    .reset_index(drop=True)
)

train_ds_final = Dataset.from_pandas(df_balanced)

print(
    f"Final balanced train set: {len(train_ds_final)} samples "
    f"({target_size} per class)\n"
)


# Set up SetFit model and trainer
model_id = "sentence-transformers/all-MiniLM-L6-v2"
valid_labels_sorted = sorted(valid_labels)

model = SetFitModel.from_pretrained(
    model_id,
    labels=valid_labels_sorted,
)

args = TrainingArguments(
    batch_size=16,
    num_epochs=1,
    num_iterations=10,
    report_to="none",
    eval_strategy="no",
    save_strategy="no",
)

trainer = Trainer(
    model=model,
    args=args,
    train_dataset=train_ds_final,
    column_mapping={"text": "text", "evasion_label": "label"},
)


# 4. Dev evaluation (macro F1)
def evaluate_on_dev(current_model, dev_ds, label_list):
    dev_df = dev_ds.to_pandas()
    known_labels = set(label_list)

    dev_df_eval = dev_df[dev_df["evasion_label"].isin(known_labels)].copy()
    if len(dev_df_eval) == 0:
        print("No dev samples match the training labels; cannot compute F1.")
        print("Training labels:", label_list)
        print("Dev label counts:\n", dev_df["evasion_label"].value_counts(dropna=False))
        return None, None

    dev_texts = dev_df["text"].tolist()
    dev_preds_all = current_model.predict(dev_texts)

    mask = dev_df["evasion_label"].isin(known_labels)
    dev_preds = [p for p, m in zip(dev_preds_all, mask) if m]

    label2id = {l: i for i, l in enumerate(label_list)}
    y_true = [label2id[l] for l in dev_df_eval["evasion_label"]]
    y_pred = [label2id[p] for p in dev_preds]

    macro_f1 = f1_score(y_true, y_pred, average="macro")
    report = classification_report(y_true, y_pred, target_names=label_list)

    return macro_f1, report


# 5. Manual training loop with early stopping
max_epochs = 5
patience = 2

best_f1 = -1.0
best_epoch = 0
epochs_no_improve = 0

best_model_dir = "best_setfit_model_tmp"
if os.path.isdir(best_model_dir):
    shutil.rmtree(best_model_dir)

print("Starting manual training with early stopping.\n")

for epoch in range(1, max_epochs + 1):
    print(f"Epoch {epoch}/{max_epochs}")

    trainer.train()

    macro_f1, report = evaluate_on_dev(model, dev_ds, valid_labels_sorted)
    if macro_f1 is None:
        print("Stopping because dev evaluation is not possible.")
        break

    print(f"\n[Epoch {epoch}] Dev macro F1: {macro_f1:.4f}")
    print(report)

    if macro_f1 > best_f1 + 1e-4:
        best_f1 = macro_f1
        best_epoch = epoch
        epochs_no_improve = 0
        print(f"New best F1 = {best_f1:.4f} at epoch {epoch}.")
        if os.path.isdir(best_model_dir):
            shutil.rmtree(best_model_dir)
        os.makedirs(best_model_dir, exist_ok=True)
        model.save_pretrained(best_model_dir)
    else:
        epochs_no_improve += 1
        print(f"No improvement for {epochs_no_improve} epoch(s).")
        if epochs_no_improve >= patience:
            print(f"\nEarly stopping triggered (patience={patience}).")
            break

if os.path.isdir(best_model_dir) and os.listdir(best_model_dir):
    print(f"\nReloading best model from epoch {best_epoch} (F1={best_f1:.4f})")
    best_model = SetFitModel.from_pretrained(best_model_dir)
else:
    print("\nNo best model directory found; using last model as best.")
    best_model = model


# 6. Final dev evaluation
print("\nFinal eval on dev (best model)")
final_f1, final_report = evaluate_on_dev(best_model, dev_ds, valid_labels_sorted)
if final_f1 is not None:
    print(f"\nFinal dev macro F1: {final_f1:.4f}")
    print(final_report)


# 7. Public test predictions CSV
if test_public is not None:
    print("\nGenerating test CSV (no F1).")
    test_texts = test_public["text"]
    test_preds = best_model.predict(test_texts)

    test_df_public = test_public.to_pandas()
    if "index" in test_df_public.columns:
        test_indices = test_df_public["index"]
    else:
        test_indices = np.arange(len(test_df_public))

    df_res = pd.DataFrame({
        "index": test_indices,
        "evasion_label": test_preds
    })

    csv_path = "setfit_minilm_predictions.csv"
    df_res.to_csv(csv_path, index=False)
    print(f"CSV saved locally as: {csv_path}")
else:
    print("\nNo public test split; skipping CSV export.")


[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/75.6 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m75.6/75.6 kB[0m [31m6.6 MB/s[0m eta [36m0:00:00[0m
[?25h[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/84.1 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m84.1/84.1 kB[0m [31m8.4 MB/s[0m eta [36m0:00:00[0m
[?25hLoading QEvasion dataset...


  return datetime.utcnow().replace(tzinfo=utc)
The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


README.md: 0.00B [00:00, ?B/s]

data/train-00000-of-00001.parquet:   0%|          | 0.00/3.90M [00:00<?, ?B/s]

data/test-00000-of-00001.parquet:   0%|          | 0.00/259k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/3448 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/308 [00:00<?, ? examples/s]

Map:   0%|          | 0/3448 [00:00<?, ? examples/s]

Rows in HF train split: 3448
Columns: ['title', 'date', 'president', 'url', 'question_order', 'interview_question', 'interview_answer', 'gpt3.5_summary', 'gpt3.5_prediction', 'question', 'annotator_id', 'annotator1', 'annotator2', 'annotator3', 'inaudible', 'multiple_questions', 'affirmative_questions', 'index', 'clarity_label', 'evasion_label', 'text']
Train split: 3103 rows | Dev split: 345 rows


Map:   0%|          | 0/308 [00:00<?, ? examples/s]

Public test split rows: 308

Balancing train data (strict equalization)...
Original train label counts:
 evasion_label
Explicit               954
Dodging                623
Implicit               425
Deflection             355
General                348
Declining to answer    131
Claims ignorance       110
Clarification           83
Partial/half-answer     74
Name: count, dtype: int64 

Target samples per class: 74
Classes used for training: ['Explicit', 'Dodging', 'Implicit', 'Deflection', 'General', 'Declining to answer', 'Claims ignorance', 'Clarification', 'Partial/half-answer']

Final balanced train set: 666 samples (74 per class)



  return datetime.utcnow().replace(tzinfo=utc)


config.json:   0%|          | 0.00/612 [00:00<?, ?B/s]

modules.json:   0%|          | 0.00/349 [00:00<?, ?B/s]

config_sentence_transformers.json:   0%|          | 0.00/116 [00:00<?, ?B/s]

README.md: 0.00B [00:00, ?B/s]

sentence_bert_config.json:   0%|          | 0.00/53.0 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/90.9M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/350 [00:00<?, ?B/s]

vocab.txt: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/190 [00:00<?, ?B/s]

model_head.pkl not found on HuggingFace Hub, initialising classification head with random weights. You should TRAIN this model on a downstream task to use it for predictions and inference.
Applying column mapping to the training dataset
Using the `WANDB_DISABLED` environment variable is deprecated and will be removed in v5. Use the --report_to flag to control the integrations used for logging result (for instance --report_to none).
  return datetime.utcnow().replace(tzinfo=utc)
Using the `WANDB_DISABLED` environment variable is deprecated and will be removed in v5. Use the --report_to flag to control the integrations used for logging result (for instance --report_to none).


Map:   0%|          | 0/666 [00:00<?, ? examples/s]

Starting manual training with early stopping.

Epoch 1/5


***** Running training *****
  Num unique pairs = 13320
  Batch size = 16
  Num epochs = 1
  return data.pin_memory(device)
  return data.pin_memory(device)


Step,Training Loss
1,0.3845
50,0.2887
100,0.2498
150,0.2435
200,0.2467
250,0.2302
300,0.2351
350,0.2324
400,0.2319
450,0.2375


  opt_res = optimize.minimize(
  return datetime.utcnow().replace(tzinfo=utc)



[Epoch 1] Dev macro F1: 0.2453
                     precision    recall  f1-score   support

   Claims ignorance       0.14      0.56      0.23         9
      Clarification       0.54      0.78      0.64         9
Declining to answer       0.11      0.21      0.14        14
         Deflection       0.20      0.46      0.28        26
            Dodging       0.46      0.19      0.27        83
           Explicit       0.48      0.23      0.32        98
            General       0.12      0.13      0.13        38
           Implicit       0.24      0.19      0.21        63
Partial/half-answer       0.00      0.00      0.00         5

           accuracy                           0.24       345
          macro avg       0.25      0.31      0.25       345
       weighted avg       0.34      0.24      0.26       345

New best F1 = 0.2453 at epoch 1.


  return datetime.utcnow().replace(tzinfo=utc)


Epoch 2/5


***** Running training *****
  Num unique pairs = 13320
  Batch size = 16
  Num epochs = 1
  return data.pin_memory(device)
  return data.pin_memory(device)


Step,Training Loss
1,0.1687
50,0.2148
100,0.2061
150,0.1966
200,0.1983
250,0.1787
300,0.1733
350,0.1743
400,0.1811
450,0.1787


  opt_res = optimize.minimize(
  return datetime.utcnow().replace(tzinfo=utc)



[Epoch 2] Dev macro F1: 0.2991
                     precision    recall  f1-score   support

   Claims ignorance       0.26      0.67      0.38         9
      Clarification       0.50      0.78      0.61         9
Declining to answer       0.20      0.29      0.24        14
         Deflection       0.19      0.42      0.26        26
            Dodging       0.44      0.33      0.38        83
           Explicit       0.57      0.26      0.35        98
            General       0.28      0.26      0.27        38
           Implicit       0.25      0.14      0.18        63
Partial/half-answer       0.02      0.20      0.04         5

           accuracy                           0.29       345
          macro avg       0.30      0.37      0.30       345
       weighted avg       0.39      0.29      0.31       345

New best F1 = 0.2991 at epoch 2.


  return datetime.utcnow().replace(tzinfo=utc)


Epoch 3/5


***** Running training *****
  Num unique pairs = 13320
  Batch size = 16
  Num epochs = 1
  return data.pin_memory(device)
  return data.pin_memory(device)


Step,Training Loss
1,0.1334
50,0.1706
100,0.1568
150,0.1553
200,0.1586
250,0.1462
300,0.1422
350,0.15
400,0.1477
450,0.1521


  opt_res = optimize.minimize(
  return datetime.utcnow().replace(tzinfo=utc)



[Epoch 3] Dev macro F1: 0.2794
                     precision    recall  f1-score   support

   Claims ignorance       0.22      0.44      0.30         9
      Clarification       0.55      0.67      0.60         9
Declining to answer       0.18      0.29      0.22        14
         Deflection       0.22      0.38      0.28        26
            Dodging       0.40      0.37      0.39        83
           Explicit       0.53      0.19      0.28        98
            General       0.19      0.21      0.20        38
           Implicit       0.27      0.14      0.19        63
Partial/half-answer       0.03      0.40      0.06         5

           accuracy                           0.27       345
          macro avg       0.29      0.34      0.28       345
       weighted avg       0.36      0.27      0.28       345

No improvement for 1 epoch(s).
Epoch 4/5


  return datetime.utcnow().replace(tzinfo=utc)
***** Running training *****
  Num unique pairs = 13320
  Batch size = 16
  Num epochs = 1
  return data.pin_memory(device)
  return data.pin_memory(device)


Step,Training Loss
1,0.1086
50,0.151
100,0.1329
150,0.1324
200,0.1388
250,0.1275
300,0.1198
350,0.1195
400,0.1223
450,0.1274


  opt_res = optimize.minimize(
  return datetime.utcnow().replace(tzinfo=utc)



[Epoch 4] Dev macro F1: 0.2853
                     precision    recall  f1-score   support

   Claims ignorance       0.26      0.56      0.36         9
      Clarification       0.60      0.67      0.63         9
Declining to answer       0.24      0.36      0.29        14
         Deflection       0.23      0.38      0.29        26
            Dodging       0.40      0.25      0.31        83
           Explicit       0.47      0.24      0.32        98
            General       0.20      0.21      0.21        38
           Implicit       0.21      0.11      0.15        63
Partial/half-answer       0.01      0.20      0.03         5

           accuracy                           0.25       345
          macro avg       0.29      0.33      0.29       345
       weighted avg       0.34      0.25      0.27       345

No improvement for 2 epoch(s).

Early stopping triggered (patience=2).

Reloading best model from epoch 2 (F1=0.2991)


  return datetime.utcnow().replace(tzinfo=utc)



Final eval on dev (best model)

Final dev macro F1: 0.2991
                     precision    recall  f1-score   support

   Claims ignorance       0.26      0.67      0.38         9
      Clarification       0.50      0.78      0.61         9
Declining to answer       0.20      0.29      0.24        14
         Deflection       0.19      0.42      0.26        26
            Dodging       0.44      0.33      0.38        83
           Explicit       0.57      0.26      0.35        98
            General       0.28      0.26      0.27        38
           Implicit       0.25      0.14      0.18        63
Partial/half-answer       0.02      0.20      0.04         5

           accuracy                           0.29       345
          macro avg       0.30      0.37      0.30       345
       weighted avg       0.39      0.29      0.31       345


Generating test CSV (no F1).
CSV saved locally as: setfit_minilm_predictions.csv


  return datetime.utcnow().replace(tzinfo=utc)
  return datetime.utcnow().replace(tzinfo=utc)
