In [None]:
import json
import torch
from transformers import (
    DistilBertForTokenClassification,
    DistilBertTokenizerFast,
    Trainer,
    TrainingArguments
)
from datasets import Dataset
import numpy as np

# Configuration
MODEL_NAME = "distilbert-base-uncased"
TRAIN_DATA_PATH = "./tasteset_final.jsonl"  # Your cleaned dataset
OUTPUT_DIR = "./recipe_ner_model"
LABELS = ["O", "B-AMOUNT", "B-UNIT", "B-INGREDIENT"]

# Load dataset
def load_dataset(file_path):
    data = []
    with open(file_path, "r") as f:
        for line in f:
            data.append(json.loads(line))
    return Dataset.from_list(data)

dataset = load_dataset(TRAIN_DATA_PATH)

# Initialize tokenizer
tokenizer = DistilBertTokenizerFast.from_pretrained(MODEL_NAME)

# Tokenize and align labels
def tokenize_and_align_labels(examples):
    tokenized_inputs = tokenizer(
        examples["text"],
        truncation=True,
        padding="max_length",
        max_length=128,
        is_split_into_words=False
    )

    labels = []
    for i, entities in enumerate(examples["entities"]):
        word_ids = tokenized_inputs.word_ids(batch_index=i)
        label_ids = []
        
        for word_idx in word_ids:
            # Special tokens get -100
            if word_idx is None:
                label_ids.append(-100)
            else:
                # Find entity for this word
                for entity in entities:
                    if entity["start"] <= word_idx < entity["end"]:
                        label_ids.append(LABELS.index(entity["label"]))
                        break
                else:  # No entity found
                    label_ids.append(LABELS.index("O"))
        
        labels.append(label_ids)

    tokenized_inputs["labels"] = labels
    return tokenized_inputs

# Process dataset
tokenized_dataset = dataset.map(
    tokenize_and_align_labels,
    batched=True,
    remove_columns=dataset.column_names
)

# Create model
model = DistilBertForTokenClassification.from_pretrained(
    MODEL_NAME,
    num_labels=len(LABELS),
    id2label={i: label for i, label in enumerate(LABELS)},
    label2id={label: i for i, label in enumerate(LABELS)}
)

# Training arguments
training_args = TrainingArguments(
    output_dir=OUTPUT_DIR,
    num_train_epochs=3,
    per_device_train_batch_size=16,
    save_steps=500,
    save_total_limit=2,
    logging_steps=100,
    learning_rate=2e-5,
    weight_decay=0.01,
)

# Create trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_dataset,
)

# Start training
print("Starting training...")
trainer.train()

# Save final model
model.save_pretrained(OUTPUT_DIR)
tokenizer.save_pretrained(OUTPUT_DIR)
print(f"Model saved to {OUTPUT_DIR}")