In [1]:
!pip uninstall -y protobuf
!pip install protobuf==3.20.3

Found existing installation: protobuf 3.20.3
Uninstalling protobuf-3.20.3:
  Successfully uninstalled protobuf-3.20.3
Collecting protobuf==3.20.3
  Using cached protobuf-3.20.3-py2.py3-none-any.whl.metadata (720 bytes)
Using cached protobuf-3.20.3-py2.py3-none-any.whl (162 kB)
Installing collected packages: protobuf
[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
bigframes 2.26.0 requires google-cloud-bigquery-storage<3.0.0,>=2.30.0, which is not installed.
grain 0.2.15 requires protobuf>=5.28.3, but you have protobuf 3.20.3 which is incompatible.
onnx 1.20.0 requires protobuf>=4.25.1, but you have protobuf 3.20.3 which is incompatible.
ray 2.52.1 requires click!=8.3.*,>=7.0, but you have click 8.3.1 which is incompatible.
opentelemetry-proto 1.37.0 requires protobuf<7.0,>=5.0, but you have protobuf 3.20.3 which is incompatible.
tensorflow-metadata 1.17.2 

In [2]:
import os
os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python"

In [3]:
# =========================
# Imports
# =========================
import json
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.optim import AdamW
from transformers import (
    DistilBertTokenizerFast,
    DistilBertForTokenClassification,
    get_linear_schedule_with_warmup
)
from pathlib import Path
import os

2025-12-30 04:15:36.082454: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1767068136.105442     303 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1767068136.112746     303 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1767068136.130296     303 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1767068136.130319     303 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1767068136.130321     303 computation_placer.cc:177] computation placer alr

In [4]:
# =========================
# Paths (Kaggle-safe)
# =========================
DATA_PATH = "/kaggle/input/mindlens-data/span_ner.jsonl"
LABELS_PATH = "/kaggle/input/mindlens-data/labels.json"
CLASS_WEIGHTS_PATH = "/kaggle/input/mindlens-data/class_weights.pt"

CHECKPOINT_DIR = "/kaggle/working/checkpoints"
Path(CHECKPOINT_DIR).mkdir(parents=True, exist_ok=True)

In [5]:
# =========================
# Device
# =========================
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

Using device: cuda


In [6]:
from transformers import DistilBertTokenizerFast

tokenizer = DistilBertTokenizerFast.from_pretrained(
    "distilbert-base-uncased"
)

In [7]:
# =========================
# Load label maps
# =========================
with open(LABELS_PATH) as f:
    label2id = json.load(f)

id2label = {v: k for k, v in label2id.items()}
num_labels = len(label2id)

In [8]:
def collate_fn(batch):
    input_ids = [item["input_ids"] for item in batch]
    attention_masks = [item["attention_mask"] for item in batch]
    labels = [item["labels"] for item in batch]

    input_ids = torch.nn.utils.rnn.pad_sequence(
        input_ids,
        batch_first=True,
        padding_value=tokenizer.pad_token_id
    )

    attention_masks = torch.nn.utils.rnn.pad_sequence(
        attention_masks,
        batch_first=True,
        padding_value=0
    )

    labels = torch.nn.utils.rnn.pad_sequence(
        labels,
        batch_first=True,
        padding_value=-100   # VERY IMPORTANT
    )

    return {
        "input_ids": input_ids,
        "attention_mask": attention_masks,
        "labels": labels
    }

In [9]:
# =========================
# Load class weights
# =========================
ckpt = torch.load(CLASS_WEIGHTS_PATH, map_location="cpu")
class_weights = ckpt["class_weights"].to(device)

# =========================
# Dataset
# =========================
class SpanNERDataset(Dataset):
    def __init__(self, data_or_path):
        if isinstance(data_or_path, str):
            # original behavior (path-based)
            self.samples = []
            with open(data_or_path) as f:
                for line in f:
                    self.samples.append(json.loads(line))
        else:
            # new behavior (list-based)
            self.samples = data_or_path

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        s = self.samples[idx]
        return {
            "input_ids": torch.tensor(s["input_ids"], dtype=torch.long),
            "attention_mask": torch.tensor(s["attention_mask"], dtype=torch.long),
            "labels": torch.tensor(s["labels"], dtype=torch.long),
        }

In [10]:
import random

with open(DATA_PATH) as f:
    all_samples = [json.loads(line) for line in f]

random.seed(42)
random.shuffle(all_samples)

split = int(0.9 * len(all_samples))
train_samples = all_samples[:split]
val_samples   = all_samples[split:]

# =========================
# DataLoader
# =========================
train_dataset = SpanNERDataset(train_samples)
val_dataset   = SpanNERDataset(val_samples)


train_loader = DataLoader(
    train_dataset,
    batch_size=8,
    shuffle=True,
    collate_fn=collate_fn
)

val_loader = DataLoader(
    val_dataset,
    batch_size=8,
    shuffle=False,   # IMPORTANT
    collate_fn=collate_fn
)


In [11]:
# =========================
# Model
# =========================
model = DistilBertForTokenClassification.from_pretrained(
    "distilbert-base-uncased",
    num_labels=num_labels,
    id2label=id2label,
    label2id=label2id
)
model.to(device)

# =========================
# Loss, Optimizer, Scheduler
# =========================
criterion = nn.CrossEntropyLoss(
    weight=class_weights,
    ignore_index=-100
)

optimizer = AdamW(
    model.parameters(),
    lr=5e-5,
    weight_decay=0.01
)

num_epochs = 5
num_training_steps = num_epochs * len(train_loader)
num_warmup_steps = int(0.1 * num_training_steps)

scheduler = get_linear_schedule_with_warmup(
    optimizer,
    num_warmup_steps=num_warmup_steps,
    num_training_steps=num_training_steps
)

Some weights of DistilBertForTokenClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [12]:
def decode_bio(label_ids, id2label):
    """
    Convert BIO label IDs to spans.
    Returns: [(start, end, technique), ...]
    """
    spans = []
    start = None
    curr_tech = None

    for i, lid in enumerate(label_ids):
        if lid == -100:
            continue

        label = id2label[lid]

        if label == "O":
            if start is not None:
                spans.append((start, i - 1, curr_tech))
                start = None
                curr_tech = None
            continue

        prefix, tech = label.split("-", 1)

        if prefix == "B":
            if start is not None:
                spans.append((start, i - 1, curr_tech))
            start = i
            curr_tech = tech

        elif prefix == "I":
            if start is None:
                # illegal I → treat as B
                start = i
                curr_tech = tech
            elif tech != curr_tech:
                spans.append((start, i - 1, curr_tech))
                start = i
                curr_tech = tech

    if start is not None:
        spans.append((start, len(label_ids) - 1, curr_tech))

    return spans

In [13]:
def span_level_metrics(model, dataloader, id2label):
    model.eval()

    TP = FP = FN = 0

    with torch.no_grad():
        for batch in dataloader:
            batch = {k: v.to(device) for k, v in batch.items()}

            outputs = model(
                input_ids=batch["input_ids"],
                attention_mask=batch["attention_mask"]
            )

            preds = outputs.logits.argmax(dim=-1)

            for i in range(preds.size(0)):
                gold_spans = set(
                    decode_bio(batch["labels"][i].tolist(), id2label)
                )
                pred_spans = set(
                    decode_bio(preds[i].tolist(), id2label)
                )

                TP += len(gold_spans & pred_spans)
                FP += len(pred_spans - gold_spans)
                FN += len(gold_spans - pred_spans)

    precision = TP / (TP + FP + 1e-8)
    recall = TP / (TP + FN + 1e-8)
    f1 = 2 * precision * recall / (precision + recall + 1e-8)

    return {
        "precision": precision,
        "recall": recall,
        "f1": f1,
        "TP": TP,
        "FP": FP,
        "FN": FN
    }

In [14]:
from collections import defaultdict

def technique_metrics(model, dataloader, id2label):
    model.eval()
    stats = defaultdict(lambda: {"TP": 0, "FP": 0, "FN": 0})

    with torch.no_grad():
        for batch in dataloader:
            batch = {k: v.to(device) for k, v in batch.items()}
            outputs = model(
                input_ids=batch["input_ids"],
                attention_mask=batch["attention_mask"]
            )
            preds = outputs.logits.argmax(dim=-1)

            for i in range(preds.size(0)):
                gold_spans = decode_bio(
                    batch["labels"][i].tolist(), id2label
                )
                pred_spans = decode_bio(
                    preds[i].tolist(), id2label
                )

                gold_by_tech = defaultdict(set)
                pred_by_tech = defaultdict(set)

                for s in gold_spans:
                    gold_by_tech[s[2]].add(s)
                for s in pred_spans:
                    pred_by_tech[s[2]].add(s)

                for tech in set(gold_by_tech) | set(pred_by_tech):
                    g = gold_by_tech[tech]
                    p = pred_by_tech[tech]
                    stats[tech]["TP"] += len(g & p)
                    stats[tech]["FP"] += len(p - g)
                    stats[tech]["FN"] += len(g - p)

    results = {}
    for tech, s in stats.items():
        P = s["TP"] / (s["TP"] + s["FP"] + 1e-8)
        R = s["TP"] / (s["TP"] + s["FN"] + 1e-8)
        F1 = 2 * P * R / (P + R + 1e-8)
        results[tech] = {
            "precision": P,
            "recall": R,
            "f1": F1,
            **s
        }

    return results

In [15]:
# =========================
# Checkpoint utils
# =========================
def save_checkpoint(epoch, step):
    torch.save(
        {
            "epoch": epoch,
            "step": step,
            "model": model.state_dict(),
            "optimizer": optimizer.state_dict(),
            "scheduler": scheduler.state_dict(),
        },
        f"{CHECKPOINT_DIR}/epoch_{epoch}_step_{step}.pt"
    )

import re

def load_latest_checkpoint():
    ckpts = []

    for p in Path(CHECKPOINT_DIR).glob("epoch_*_step_*.pt"):
        m = re.match(r"epoch_(\d+)_step_(\d+)\.pt", p.name)
        if m:
            epoch = int(m.group(1))
            step = int(m.group(2))
            ckpts.append((epoch, step, p))

    if not ckpts:
        return 0, 0

    # sort by epoch, then step
    ckpts.sort(key=lambda x: (x[0], x[1]))
    _, _, ckpt_path = ckpts[-1]

    print("Resuming from:", ckpt_path)

    ckpt = torch.load(ckpt_path, map_location=device)
    model.load_state_dict(ckpt["model"])
    optimizer.load_state_dict(ckpt["optimizer"])
    scheduler.load_state_dict(ckpt["scheduler"])

    return ckpt["epoch"], ckpt["step"] + 1

In [16]:
O_ID = label2id["O"]

# identify B-* labels
B_LABEL_IDS = {
    v for k, v in label2id.items()
    if k.startswith("B-")
}

In [17]:
def evaluate(model, dataloader, criterion):
    model.eval()

    total_loss = 0.0

    # token counts
    total_tokens = 0
    total_correct = 0

    non_o_tokens = 0
    non_o_correct = 0

    b_tokens = 0
    b_correct = 0

    b_pred_as_i = 0
    b_pred_as_o = 0

    with torch.no_grad():
        for batch in dataloader:
            batch = {k: v.to(device) for k, v in batch.items()}

            outputs = model(
                input_ids=batch["input_ids"],
                attention_mask=batch["attention_mask"]
            )

            logits = outputs.logits
            labels = batch["labels"]

            loss = criterion(
                logits.view(-1, logits.size(-1)),
                labels.view(-1)
            )
            total_loss += loss.item()

            preds = logits.argmax(dim=-1)

            mask = labels != -100
            valid_labels = labels[mask]
            valid_preds = preds[mask]

            # overall accuracy
            total_correct += (valid_preds == valid_labels).sum().item()
            total_tokens += valid_labels.numel()

            # non-O accuracy
            non_o_mask = valid_labels != O_ID
            non_o_correct += (
                valid_preds[non_o_mask] == valid_labels[non_o_mask]
            ).sum().item()
            non_o_tokens += non_o_mask.sum().item()

            # B-only accuracy
            b_mask = torch.isin(valid_labels, torch.tensor(
                list(B_LABEL_IDS), device=device
            ))

            b_correct += (
                valid_preds[b_mask] == valid_labels[b_mask]
            ).sum().item()
            b_tokens += b_mask.sum().item()

            # B confusion
            b_pred_as_i += (
                torch.isin(valid_labels, torch.tensor(list(B_LABEL_IDS), device=device)) &
                torch.isin(valid_preds, torch.tensor(
                    [v for k, v in label2id.items() if k.startswith("I-")],
                    device=device
                ))
            ).sum().item()

            b_pred_as_o += (
                (valid_labels != O_ID) &
                torch.isin(valid_labels, torch.tensor(list(B_LABEL_IDS), device=device)) &
                (valid_preds == O_ID)
            ).sum().item()

    return {
        "val_loss": total_loss / len(dataloader),
        "acc_all": total_correct / total_tokens,
        "acc_non_o": non_o_correct / max(1, non_o_tokens),
        "acc_b": b_correct / max(1, b_tokens),
        "b_as_i": b_pred_as_i,
        "b_as_o": b_pred_as_o,
        "b_total": b_tokens,
    }


In [21]:
# =========================
# Resume if possible
# =========================
start_epoch, start_step = load_latest_checkpoint()

# =========================
# Metric history (Day 4)
# =========================
train_losses = []
val_losses = []
acc_all_hist = []
acc_non_o_hist = []
acc_b_hist = []

# =========================
# Day 5 metrics dir
# =========================
METRICS_DIR = "/kaggle/working/metrics"
Path(METRICS_DIR).mkdir(exist_ok=True)

# =========================
# Training Loop
# =========================
global_step = start_step

for epoch in range(start_epoch, num_epochs):
    model.train()
    epoch_loss = 0.0
    print(f"\nEpoch {epoch}")

    for step, batch in enumerate(train_loader):
        if epoch == start_epoch and step < start_step:
            continue

        batch = {k: v.to(device) for k, v in batch.items()}

        optimizer.zero_grad()

        outputs = model(
            input_ids=batch["input_ids"],
            attention_mask=batch["attention_mask"]
        )

        loss = criterion(
            outputs.logits.view(-1, outputs.logits.size(-1)),
            batch["labels"].view(-1)
        )

        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)

        optimizer.step()
        scheduler.step()

        epoch_loss += loss.item()

        if global_step % 200 == 0:
            print(f"step {global_step} | loss {loss.item():.4f}")

        if global_step % 1000 == 0:
            save_checkpoint(epoch, global_step)

        global_step += 1

    # =========================
    # End of epoch — Day 4 eval
    # =========================
    avg_train_loss = epoch_loss / len(train_loader)

    metrics = evaluate(model, val_loader, criterion)

    train_losses.append(avg_train_loss)
    val_losses.append(metrics["val_loss"])
    acc_all_hist.append(metrics["acc_all"])
    acc_non_o_hist.append(metrics["acc_non_o"])
    acc_b_hist.append(metrics["acc_b"])

    print(
        f"Epoch {epoch} | "
        f"train {avg_train_loss:.4f} | "
        f"val {metrics['val_loss']:.4f} | "
        f"acc {metrics['acc_all']:.4f} | "
        f"non-O {metrics['acc_non_o']:.4f} | "
        f"B {metrics['acc_b']:.4f}"
    )

    print(
        f"    B confusion: "
        f"B→I {metrics['b_as_i']} | "
        f"B→O {metrics['b_as_o']} / {metrics['b_total']}"
    )

    # =========================
    # Day 5 — Span-level metrics
    # =========================
    span_metrics = span_level_metrics(model, val_loader, id2label)
    tech_metrics = technique_metrics(model, val_loader, id2label)

    print(
        f"[Span] "
        f"P {span_metrics['precision']:.3f} | "
        f"R {span_metrics['recall']:.3f} | "
        f"F1 {span_metrics['f1']:.3f}"
    )

    # =========================
    # Save metrics (Day 5)
    # =========================
    with open(f"{METRICS_DIR}/epoch_{epoch}.json", "w") as f:
        json.dump(
            {
                "epoch": epoch,
                "train_loss": avg_train_loss,
                "val_loss": metrics["val_loss"],
                "token_metrics": {
                    "acc_all": metrics["acc_all"],
                    "acc_non_o": metrics["acc_non_o"],
                    "acc_b": metrics["acc_b"],
                    "b_as_i": metrics["b_as_i"],
                    "b_as_o": metrics["b_as_o"],
                    "b_total": metrics["b_total"],
                },
                "span_metrics": span_metrics,
                "technique_metrics": tech_metrics,
            },
            f,
            indent=2
        )

    save_checkpoint(epoch, global_step)

print("✅ Training + Day 5 evaluation complete")



Epoch 0
step 0 | loss 1.3875
Epoch 0 | train 1.5619 | val 2.3032 | acc 0.8148 | non-O 0.1381 | B 0.0420
    B confusion: B→I 70 | B→O 135 / 238
[Span] P 0.016 | R 0.052 | F1 0.024

Epoch 1
Epoch 1 | train 1.5863 | val 2.3032 | acc 0.8148 | non-O 0.1381 | B 0.0420
    B confusion: B→I 70 | B→O 135 / 238
[Span] P 0.016 | R 0.052 | F1 0.024

Epoch 2
Epoch 2 | train 1.5900 | val 2.3032 | acc 0.8148 | non-O 0.1381 | B 0.0420
    B confusion: B→I 70 | B→O 135 / 238
[Span] P 0.016 | R 0.052 | F1 0.024

Epoch 3
Epoch 3 | train 1.5774 | val 2.3032 | acc 0.8148 | non-O 0.1381 | B 0.0420
    B confusion: B→I 70 | B→O 135 / 238
[Span] P 0.016 | R 0.052 | F1 0.024

Epoch 4
step 200 | loss 1.8423
Epoch 4 | train 1.5798 | val 2.3032 | acc 0.8148 | non-O 0.1381 | B 0.0420
    B confusion: B→I 70 | B→O 135 / 238
[Span] P 0.016 | R 0.052 | F1 0.024
✅ Training + Day 5 evaluation complete


In [1]:
def render_spans(text, spans, tag):
    """
    spans: List[(char_start, char_end, technique)]
    tag: 'GOLD' or 'PRED'
    """
    spans = sorted(spans, key=lambda x: x[0])
    rendered = ""
    last = 0

    for start, end, tech in spans:
        rendered += text[last:start]
        rendered += f"[{tag}|{tech}]" + text[start:end] + f"[/{tag}]"
        last = end

    rendered += text[last:]
    return rendered

In [2]:
def token_span_to_char(span, offsets):
    """
    span: (start_token, end_token, technique)
    offsets: tokenizer offset_mapping
    """
    start_tok, end_tok, tech = span

    char_start = offsets[start_tok][0]
    char_end = offsets[end_tok][1]

    return (char_start, char_end, tech)

In [3]:
def visualize_sample(text, offsets, gold_spans, pred_spans):
    gold_char_spans = [
        token_span_to_char(s, offsets) for s in gold_spans
    ]
    pred_char_spans = [
        token_span_to_char(s, offsets) for s in pred_spans
    ]

    gold_view = render_spans(text, gold_char_spans, "GOLD")
    pred_view = render_spans(text, pred_char_spans, "PRED")

    print("\n===== GOLD =====\n")
    print(gold_view)
    print("\n===== PRED =====\n")
    print(pred_view)

In [7]:
failure_cases = []

with torch.no_grad():
    for batch in val_loader:
        batch = {k: v.to(device) for k, v in batch.items()}
        outputs = model(
            input_ids=batch["input_ids"],
            attention_mask=batch["attention_mask"]
        )
        preds = outputs.logits.argmax(dim=-1)

        for i in range(preds.size(0)):
            gold_spans = decode_bio(
                batch["labels"][i].tolist(), id2label
            )
            pred_spans = decode_bio(
                preds[i].tolist(), id2label
            )

            if set(gold_spans) != set(pred_spans):
                print("FOUND FAILURE")
                print("Gold:", gold_spans)
                print("Pred:", pred_spans)

                failure_cases.append({
                    "gold": gold_spans,
                    "pred": pred_spans,
                    "input_ids": batch["input_ids"][i],
                })
                break

        if failure_cases:
            break

FOUND FAILURE
Gold: [(63, 84, 'Appeal_to_fear-prejudice'), (114, 126, 'Loaded_Language'), (130, 142, 'Loaded_Language'), (167, 168, 'Loaded_Language'), (225, 232, 'Loaded_Language'), (402, 412, 'Doubt'), (413, 428, 'Doubt')]
Pred: [(62, 68, 'Appeal_to_fear-prejudice'), (69, 71, 'Black-and-White_Fallacy'), (72, 72, 'Appeal_to_fear-prejudice'), (73, 74, 'Black-and-White_Fallacy'), (75, 79, 'Appeal_to_fear-prejudice'), (80, 80, 'Black-and-White_Fallacy'), (81, 81, 'Appeal_to_fear-prejudice'), (113, 113, 'Exaggeration,Minimisation'), (114, 117, 'Name_Calling,Labeling'), (118, 119, 'Exaggeration,Minimisation'), (120, 121, 'Loaded_Language'), (122, 124, 'Loaded_Language'), (125, 126, 'Exaggeration,Minimisation'), (130, 130, 'Causal_Oversimplification'), (131, 134, 'Slogans'), (135, 135, 'Flag-Waving'), (136, 136, 'Slogans'), (137, 137, 'Flag-Waving'), (138, 138, 'Causal_Oversimplification'), (139, 139, 'Slogans'), (140, 140, 'Causal_Oversimplification'), (141, 141, 'Flag-Waving'), (142, 142,

In [8]:
ex = failure_cases[0]
print("Gold spans:", ex["gold"])
print("Pred spans:", ex["pred"])

Gold spans: [(63, 84, 'Appeal_to_fear-prejudice'), (114, 126, 'Loaded_Language'), (130, 142, 'Loaded_Language'), (167, 168, 'Loaded_Language'), (225, 232, 'Loaded_Language'), (402, 412, 'Doubt'), (413, 428, 'Doubt')]
Pred spans: [(62, 68, 'Appeal_to_fear-prejudice'), (69, 71, 'Black-and-White_Fallacy'), (72, 72, 'Appeal_to_fear-prejudice'), (73, 74, 'Black-and-White_Fallacy'), (75, 79, 'Appeal_to_fear-prejudice'), (80, 80, 'Black-and-White_Fallacy'), (81, 81, 'Appeal_to_fear-prejudice'), (113, 113, 'Exaggeration,Minimisation'), (114, 117, 'Name_Calling,Labeling'), (118, 119, 'Exaggeration,Minimisation'), (120, 121, 'Loaded_Language'), (122, 124, 'Loaded_Language'), (125, 126, 'Exaggeration,Minimisation'), (130, 130, 'Causal_Oversimplification'), (131, 134, 'Slogans'), (135, 135, 'Flag-Waving'), (136, 136, 'Slogans'), (137, 137, 'Flag-Waving'), (138, 138, 'Causal_Oversimplification'), (139, 139, 'Slogans'), (140, 140, 'Causal_Oversimplification'), (141, 141, 'Flag-Waving'), (142, 142, '

In [9]:
import os

for f in os.listdir("/kaggle/working/checkpoints"):
    print(f)

epoch_2_step_123.pt
epoch_0_step_41.pt
epoch_4_step_205.pt
epoch_0_step_0.pt
epoch_3_step_164.pt
epoch_1_step_82.pt


In [10]:
import torch

ckpt_path = "/kaggle/working/checkpoints/epoch_4_step_205.pt"

ckpt = torch.load(ckpt_path, map_location="cpu")

print(ckpt.keys())


dict_keys(['epoch', 'step', 'model', 'optimizer', 'scheduler'])


In [11]:
MODEL_EXPORT_DIR = "/kaggle/working/export"
os.makedirs(MODEL_EXPORT_DIR, exist_ok=True)

torch.save(
    {
        "model_state_dict": ckpt["model"],
        "label2id": label2id,
        "id2label": id2label,
        "num_labels": num_labels,
        "base_model": "distilbert-base-uncased",
        "notes": "MindLens Span Detector v1.0 — boundary-fragmentation baseline"
    },
    f"{MODEL_EXPORT_DIR}/span_detector_v1.pt"
)
