In [None]:
# ============================================================================
# Cell 1: Install dependencies (Colab)
# ============================================================================

!pip install -q transformers accelerate datasets torchmetrics

In [None]:
# ============================================================================
# Cell 2: Imports, plotting style, and deterministic seeding
# ============================================================================

import os
import json
import random
from pathlib import Path
from typing import List, Dict

import numpy as np
import pandas as pd

import matplotlib.pyplot as plt
import seaborn as sns

import torch
from torch.utils.data import Dataset

from sklearn.model_selection import train_test_split
from sklearn.metrics import (
    f1_score,
    precision_recall_fscore_support,
    roc_auc_score,
    average_precision_score,
)

from transformers import (
    AutoTokenizer,
    AutoModelForSequenceClassification,
    TrainingArguments,
    Trainer,
    set_seed,
)

import warnings
warnings.filterwarnings("ignore")

# --------------------------------------------------------------------------
# Plot style
# --------------------------------------------------------------------------
sns.set_style("whitegrid")
plt.rcParams["figure.figsize"] = (8, 6)
plt.rcParams["font.size"] = 11
plt.rcParams["axes.titlesize"] = 13
plt.rcParams["axes.labelsize"] = 12

# --------------------------------------------------------------------------
# Reproducibility
# --------------------------------------------------------------------------
RANDOM_STATE = 42

np.random.seed(RANDOM_STATE)
random.seed(RANDOM_STATE)
set_seed(RANDOM_STATE)

torch.manual_seed(RANDOM_STATE)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(RANDOM_STATE)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(f"â®• Libraries imported. Using device: {DEVICE}")

â®• Libraries imported. Using device: cpu


In [None]:
# ============================================================================
# Cell 3: Disable Weights & Biases logging globally
# ============================================================================

import os

os.environ["WANDB_DISABLED"] = "true"
os.environ["WANDB_MODE"] = "disabled"
os.environ["WANDB_SILENT"] = "true"

print("â®• Weights & Biases logging disabled.")


â®• Weights & Biases logging disabled.


In [None]:
# ============================================================================
# Cell 4: Workspace and load train/test
# ============================================================================

import os
import subprocess
from pathlib import Path
import pandas as pd

# === CONFIGURATION: your GitHub repo ===
# If you changed the username/repo name, update this URL.
GITHUB_REPO = "https://github.com/aayushis1203/dietcheck.git"
REPO_NAME = GITHUB_REPO.split('/')[-1].replace('.git', '')

# === Helper functions (copied from 00) ===

def find_repo_root():
    """
    Find repository root by searching for .git directory.
    Prevents nested repo cloning if already inside repo.
    """
    current = os.path.abspath(os.getcwd())

    for _ in range(5):  # Search up to 5 levels
        if os.path.exists(os.path.join(current, '.git')):
            return current
        parent = os.path.dirname(current)
        if parent == current:
            break
        current = parent

    return None


def setup_workspace():
    """
    Setup workspace for both Colab and local environments.
    Returns absolute paths to repo root, data, and results directories.
    """
    try:
        import google.colab  # type: ignore
        in_colab = True
        print("â®• Running in Google Colab")

        # Check if already inside repo (prevents nested cloning)
        repo_root = find_repo_root()

        if repo_root:
            print(f"â®• Already inside repo at: {repo_root}")
            os.chdir(repo_root)
        else:
            # Clone repo if not present
            if not os.path.exists(REPO_NAME):
                print(f"â®• Cloning {GITHUB_REPO}...")
                result = subprocess.run(
                    ['git', 'clone', GITHUB_REPO],
                    capture_output=True,
                    text=True
                )
                if result.returncode != 0:
                    raise RuntimeError(f"Git clone failed: {result.stderr}")

            os.chdir(REPO_NAME)

    except ImportError:
        in_colab = False
        print("ðŸ”§ Running locally")

        # Find repo root automatically
        repo_root = find_repo_root()

        if repo_root:
            os.chdir(repo_root)
        else:
            print("â®•  Warning: Not in a git repository, using current directory")

    # Get absolute paths
    repo_root = os.path.abspath(os.getcwd())
    data_dir = os.path.join(repo_root, 'data')
    results_dir = os.path.join(repo_root, 'results')

    # Create directories
    os.makedirs(data_dir, exist_ok=True)
    os.makedirs(results_dir, exist_ok=True)

    print(f"â®• Repo root: {repo_root}")
    print(f"âž¤ Data: {data_dir}")
    print(f"âž¤ Results: {results_dir}")

    return repo_root, data_dir, results_dir


# === Execute setup and store paths ===
REPO_ROOT, DATA_DIR, RESULTS_DIR = setup_workspace()
DATA_DIR = Path(DATA_DIR)
RESULTS_DIR = Path(RESULTS_DIR)
MODELS_DIR = RESULTS_DIR / "models"
MODELS_DIR.mkdir(parents=True, exist_ok=True)

print(f"\nâ®• Final REPO_ROOT : {REPO_ROOT}")
print(f"âž¤ DATA_DIR        : {DATA_DIR}")
print(f"âž¤ RESULTS_DIR     : {RESULTS_DIR}")
print(f"âž¤ MODELS_DIR      : {MODELS_DIR}")

# === Load dataset (train + optional test) ===

train_path = DATA_DIR / "train.csv"
test_path = DATA_DIR / "test.csv"

if not train_path.exists():
    raise FileNotFoundError(
        f"train.csv not found at {train_path}.\n"
        f"REPO_ROOT = {REPO_ROOT}\n\n"
        "To fix this, make sure that:\n"
        "  1) `train.csv` is committed under `data/train.csv` in your GitHub repo, OR\n"
        "  2) You have already run 00 in this same environment so it saved\n"
        "     `data/train.csv` and `data/test.csv`.\n"
        "Then re-run this cell."
    )

df_train = pd.read_csv(train_path)
df_test = pd.read_csv(test_path) if test_path.exists() else None

print("\nâ®• Loaded datasets:")
print(f"   Train shape : {df_train.shape}")
if df_test is not None:
    print(f"   Test shape  : {df_test.shape}")
else:
    print("   Test shape  : None (test.csv not found; using train/val only)")

df_train.head()


â®• Running in Google Colab
â®• Cloning https://github.com/aayushis1203/dietcheck.git...
â®• Repo root: /content/dietcheck
âž¤ Data: /content/dietcheck/data
âž¤ Results: /content/dietcheck/results

â®• Final REPO_ROOT : /content/dietcheck
âž¤ DATA_DIR        : /content/dietcheck/data
âž¤ RESULTS_DIR     : /content/dietcheck/results
âž¤ MODELS_DIR      : /content/dietcheck/results/models

â®• Loaded datasets:
   Train shape : (223, 28)
   Test shape  : (56, 28)


Unnamed: 0,product_id,name,brand,category,ingredients,serving_size_g,energy_100g,fat_100g,saturated_fat_100g,carbs_100g,...,carbs_per_serving,fiber_per_serving,sugars_per_serving,protein_per_serving,sodium_per_serving,net_carbs_per_serving,keto_compliant,high_protein,low_sodium,low_fat
0,20103644,Kokosmilch,Freshona,en:plant-based-foods-and-beverages,"91% coconut extract, water, guar gum stabilizer,",1400.0,194.0,20.5,16.8,1.4,...,19.6,8.4,12.6,11.2,280.0,11.2,0,1,0,0
1,6111184004129,Mayonnaise recette originale,Star,en:condiments,"Huile de soja, eau, vinaigre de table, jaune d...",100.0,592.0,65.21,9.8,0.22,...,0.22,0.0,0.05,1.12,700.0,0.22,1,0,0,0
2,8422174010029,Gazpacho Original,"Alvalle, PepsiCo",en:plant-based-foods-and-beverages,"Verdures fresques (94%) (tomÃ quet, pebrot verm...",100.0,42.0,2.4,0.4,3.5,...,3.5,1.2,3.3,0.9,248.0,2.3,1,0,0,1
3,6111128000163,Ain Saiss Eau Minerale Naturelle,Danone,en:beverages-and-beverages-preparations,"Calcium : 63,5\r\nMagnÃ©sium : 35,5\r\nNitratex...",33.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,1,0,1,1
4,50457243,Tomato Ketchup,Heinz,en:condiments,"Tomatoes, Spirit Vinegar, Sugar, Salt, Spice a...",15.0,102.0,0.1,0.0,23.2,...,3.48,0.0,3.42,0.18,108.0,3.48,1,0,1,1


In [None]:
# ============================================================================
# Cell 5: Load dataset (train + optional test) and basic inspection
# ============================================================================

train_path = DATA_DIR / "train.csv"
test_path = DATA_DIR / "test.csv"

if not train_path.exists():
    raise FileNotFoundError(
        f"train.csv not found at {train_path}.\n"
        "Run 00_data_collection_and_automatic_labels.ipynb first."
    )

df_train = pd.read_csv(train_path)
df_test = pd.read_csv(test_path) if test_path.exists() else None

print("â®• Loaded datasets:")
print(f"   Train shape : {df_train.shape}")
if df_test is not None:
    print(f"   Test shape  : {df_test.shape}")
else:
    print("   Test shape  : None (test.csv not found; using train/val only)")

df_train.head()



â®• Loaded datasets:
   Train shape : (223, 28)
   Test shape  : (56, 28)


Unnamed: 0,product_id,name,brand,category,ingredients,serving_size_g,energy_100g,fat_100g,saturated_fat_100g,carbs_100g,...,carbs_per_serving,fiber_per_serving,sugars_per_serving,protein_per_serving,sodium_per_serving,net_carbs_per_serving,keto_compliant,high_protein,low_sodium,low_fat
0,20103644,Kokosmilch,Freshona,en:plant-based-foods-and-beverages,"91% coconut extract, water, guar gum stabilizer,",1400.0,194.0,20.5,16.8,1.4,...,19.6,8.4,12.6,11.2,280.0,11.2,0,1,0,0
1,6111184004129,Mayonnaise recette originale,Star,en:condiments,"Huile de soja, eau, vinaigre de table, jaune d...",100.0,592.0,65.21,9.8,0.22,...,0.22,0.0,0.05,1.12,700.0,0.22,1,0,0,0
2,8422174010029,Gazpacho Original,"Alvalle, PepsiCo",en:plant-based-foods-and-beverages,"Verdures fresques (94%) (tomÃ quet, pebrot verm...",100.0,42.0,2.4,0.4,3.5,...,3.5,1.2,3.3,0.9,248.0,2.3,1,0,0,1
3,6111128000163,Ain Saiss Eau Minerale Naturelle,Danone,en:beverages-and-beverages-preparations,"Calcium : 63,5\r\nMagnÃ©sium : 35,5\r\nNitratex...",33.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,1,0,1,1
4,50457243,Tomato Ketchup,Heinz,en:condiments,"Tomatoes, Spirit Vinegar, Sugar, Salt, Spice a...",15.0,102.0,0.1,0.0,23.2,...,3.48,0.0,3.42,0.18,108.0,3.48,1,0,1,1


In [None]:
# ============================================================================
# Cell 6 : Define labels and text feature (name + ingredients)
# ============================================================================

# Multi-label targets for Task 1
LABEL_COLS: List[str] = ["keto_compliant", "high_protein", "low_sodium", "low_fat"]

# Check that required columns exist
required_cols = ["name", "ingredients"] + LABEL_COLS
missing = [c for c in required_cols if c not in df_train.columns]
if missing:
    raise ValueError(f"Missing required columns in train.csv: {missing}")

# --------------------------------------------------------------------------
# Build combined text feature: name + " | " + ingredients
# --------------------------------------------------------------------------
def build_text_all(df: pd.DataFrame) -> pd.Series:
    """
    Combine product name and ingredients into a single text field.

    Format: "Name | ingredients..."
    """
    name = df["name"].fillna("").astype(str)
    ing = df["ingredients"].fillna("").astype(str)

    text_all = (name.str.strip() + " | " + ing.str.strip()).str.strip()
    # Collapse excessive whitespace
    text_all = text_all.str.replace(r"\s+", " ", regex=True)
    return text_all

df_train["text_all"] = build_text_all(df_train)

if df_test is not None:
    if {"name", "ingredients"}.issubset(df_test.columns):
        df_test["text_all"] = build_text_all(df_test)
    else:
        print("âž¤ Test set is missing 'name' or 'ingredients'; "
              "skipping text_all construction for test.")

# This is now the ONLY text feature we use
TEXT_COL: str = "text_all"

# --------------------------------------------------------------------------
# Prepare features and labels for quick inspection
# --------------------------------------------------------------------------
X_text = df_train[TEXT_COL].fillna("")
y_labels = df_train[LABEL_COLS].astype(int)

print("âž¤ Features and labels prepared (name + ingredients).")
print(f"   Text column : {TEXT_COL}")
print(f"   Label cols  : {LABEL_COLS}")

label_counts = y_labels.sum().sort_values(ascending=False)
label_props = (label_counts / len(df_train)).sort_values(ascending=False)

label_summary = pd.DataFrame(
    {"positive_count": label_counts, "positive_proportion": label_props}
)
print("\nâ®• Label distribution (train):")
display(label_summary)

print("\nâ®• Example combined text (name + ingredients):")
display(df_train[["name", "ingredients", "text_all"]].head(5))


âž¤ Features and labels prepared (name + ingredients).
   Text column : text_all
   Label cols  : ['keto_compliant', 'high_protein', 'low_sodium', 'low_fat']

â®• Label distribution (train):


Unnamed: 0,positive_count,positive_proportion
low_sodium,97,0.434978
high_protein,89,0.399103
low_fat,77,0.345291
keto_compliant,72,0.32287



â®• Example combined text (name + ingredients):


Unnamed: 0,name,ingredients,text_all
0,Kokosmilch,"91% coconut extract, water, guar gum stabilizer,","Kokosmilch | 91% coconut extract, water, guar ..."
1,Mayonnaise recette originale,"Huile de soja, eau, vinaigre de table, jaune d...","Mayonnaise recette originale | Huile de soja, ..."
2,Gazpacho Original,"Verdures fresques (94%) (tomÃ quet, pebrot verm...",Gazpacho Original | Verdures fresques (94%) (t...
3,Ain Saiss Eau Minerale Naturelle,"Calcium : 63,5\r\nMagnÃ©sium : 35,5\r\nNitratex...",Ain Saiss Eau Minerale Naturelle | Calcium : 6...
4,Tomato Ketchup,"Tomatoes, Spirit Vinegar, Sugar, Salt, Spice a...","Tomato Ketchup | Tomatoes, Spirit Vinegar, Sug..."


In [None]:
# ============================================================================
# Cell 7: Train/validation split for BERT
# ============================================================================

# For stratification, use "has_any_positive_label" as a coarse proxy.
stratify_vec = (y_labels.sum(axis=1) > 0).astype(int)

train_df, val_df = train_test_split(
    df_train,
    test_size=0.2,
    random_state=RANDOM_STATE,
    stratify=stratify_vec,
)

print("â®• Created train/validation split.")
print(f"   Train subset: {train_df.shape}")
print(f"   Val subset  : {val_df.shape}")



â®• Created train/validation split.
   Train subset: (178, 29)
   Val subset  : (45, 29)


In [None]:
# ============================================================================
# Cell 8: Dataset class for BERT multi-label training
# ============================================================================

class DietCheckTextDataset(Dataset):
    """
    Dataset for ingredient text + multi-label targets.
    """

    def __init__(
        self,
        df: pd.DataFrame,
        tokenizer,
        text_col: str,
        label_cols: List[str],
        max_length: int = 256,
    ):
        self.texts = df[text_col].fillna("").tolist()
        self.labels = df[label_cols].astype(float).values
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self) -> int:
        return len(self.texts)

    def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
        text = self.texts[idx]
        labels = self.labels[idx]

        encoding = self.tokenizer(
            text,
            truncation=True,
            padding="max_length",
            max_length=self.max_length,
            return_tensors="pt",
        )

        item = {k: v.squeeze(0) for k, v in encoding.items()}
        item["labels"] = torch.tensor(labels, dtype=torch.float)
        return item


print("â®• Dataset class defined.")



â®• Dataset class defined.


In [None]:
# ============================================================================
# Cell 9: Load tokenizer and create Dataset objects
# ============================================================================

MODEL_NAME = "bert-base-uncased"
MAX_LENGTH = 256

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

train_dataset = DietCheckTextDataset(
    df=train_df,
    tokenizer=tokenizer,
    text_col=TEXT_COL,
    label_cols=LABEL_COLS,
    max_length=MAX_LENGTH,
)

val_dataset = DietCheckTextDataset(
    df=val_df,
    tokenizer=tokenizer,
    text_col=TEXT_COL,
    label_cols=LABEL_COLS,
    max_length=MAX_LENGTH,
)

print("â®• Tokenizer and datasets ready.")
print(f"   Train examples: {len(train_dataset)}")
print(f"   Val examples  : {len(val_dataset)}")



tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

â®• Tokenizer and datasets ready.
   Train examples: 178
   Val examples  : 45


In [None]:
# ============================================================================
# Cell 10: Define BERT model for multi-label classification
# ============================================================================

model = AutoModelForSequenceClassification.from_pretrained(
    MODEL_NAME,
    num_labels=len(LABEL_COLS),
    problem_type="multi_label_classification",
)

model.to(DEVICE)
print("â®• BERT model loaded and moved to device.")



model.safetensors:   0%|          | 0.00/440M [00:00<?, ?B/s]

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


â®• BERT model loaded and moved to device.


In [None]:
# ============================================================================
# Cell 11: Helper â€“ multi-label metric computation
# ============================================================================

def sigmoid_array(x: np.ndarray) -> np.ndarray:
    x = np.asarray(x)
    return 1.0 / (1.0 + np.exp(-x))


def compute_multilabel_metrics_from_probs(
    y_true: np.ndarray,
    probs: np.ndarray,
    threshold: float = 0.5,
    label_names: List[str] = None,
) -> Dict[str, float]:
    """
    Compute micro/macro F1 and per-label metrics given true labels and
    predicted probabilities.
    """
    y_true = np.asarray(y_true).astype(int)
    probs = np.asarray(probs)

    if label_names is None:
        label_names = [f"label_{i}" for i in range(probs.shape[1])]

    if probs.shape != y_true.shape:
        raise ValueError(
            f"Shape mismatch between labels {y_true.shape} and probs {probs.shape}."
        )

    y_pred = (probs >= threshold).astype(int)

    metrics: Dict[str, float] = {}

    # Overall micro/macro F1
    metrics["micro_f1"] = f1_score(
        y_true, y_pred, average="micro", zero_division=0
    )
    metrics["macro_f1"] = f1_score(
        y_true, y_pred, average="macro", zero_division=0
    )

    # Per-label metrics
    for idx, name in enumerate(label_names):
        y_true_col = y_true[:, idx]
        y_pred_col = y_pred[:, idx]
        prob_col = probs[:, idx]

        precision, recall, f1, _ = precision_recall_fscore_support(
            y_true_col,
            y_pred_col,
            average="binary",
            zero_division=0,
        )

        # ROC AUC can fail if only one class is present in y_true_col
        try:
            roc_auc = roc_auc_score(y_true_col, prob_col)
        except ValueError:
            roc_auc = np.nan

        ap = average_precision_score(y_true_col, prob_col)

        metrics[f"{name}_precision"] = precision
        metrics[f"{name}_recall"] = recall
        metrics[f"{name}_f1"] = f1
        metrics[f"{name}_roc_auc"] = roc_auc
        metrics[f"{name}_pr_auc"] = ap

    return metrics


def hf_compute_metrics(eval_pred) -> Dict[str, float]:
    """
    Wrapper for HuggingFace Trainer (takes logits, labels).
    """
    logits, labels = eval_pred
    probs = sigmoid_array(logits)
    return compute_multilabel_metrics_from_probs(
        y_true=labels,
        probs=probs,
        threshold=0.5,
        label_names=LABEL_COLS,
    )


print("â®• Metric helpers defined.")



â®• Metric helpers defined.


In [None]:
# ============================================================================
# Cell 12: TrainingArguments and Trainer (version-compatible)
# ============================================================================

BERT_MODEL_DIR = MODELS_DIR / "task1_bert_text_only"
BERT_MODEL_DIR.mkdir(parents=True, exist_ok=True)

use_fp16 = torch.cuda.is_available()

# Try the newer transformers API first
try:
    training_args = TrainingArguments(
        output_dir=str(BERT_MODEL_DIR / "checkpoints"),
        evaluation_strategy="epoch",      # new-style API
        save_strategy="epoch",
        learning_rate=2e-5,
        per_device_train_batch_size=16,
        per_device_eval_batch_size=32,
        num_train_epochs=4,
        weight_decay=0.01,
        logging_steps=50,
        load_best_model_at_end=True,
        metric_for_best_model="micro_f1",
        greater_is_better=True,
        save_total_limit=2,
        report_to="none",                 # disable wandb / TB
        fp16=use_fp16,
    )
    print("âœ… Using new TrainingArguments API (evaluation_strategy/save_strategy).")

except TypeError as e:
    # Older transformers: fall back to a simpler config
    print("â®• Detected older transformers.TrainingArguments signature.")
    print(f"   Reason: {e}")
    print("   Falling back to a minimal, compatible configuration.\n")

    training_args = TrainingArguments(
        output_dir=str(BERT_MODEL_DIR / "checkpoints"),
        learning_rate=2e-5,
        per_device_train_batch_size=16,
        per_device_eval_batch_size=32,
        num_train_epochs=4,
        weight_decay=0.01,
        logging_steps=50,
        # No evaluation_strategy / save_strategy / load_best_model_at_end here
        # because they are not supported in older versions.
    )

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    tokenizer=tokenizer,
    compute_metrics=hf_compute_metrics,
)

print("â®• Trainer configured.")
print(f"   fp16 available: {use_fp16}")


Using the `WANDB_DISABLED` environment variable is deprecated and will be removed in v5. Use the --report_to flag to control the integrations used for logging result (for instance --report_to none).


â®• Detected older transformers.TrainingArguments signature.
   Reason: TrainingArguments.__init__() got an unexpected keyword argument 'evaluation_strategy'
   Falling back to a minimal, compatible configuration.

â®• Trainer configured.
   fp16 available: False


In [None]:
# ============================================================================
# Cell 13: Train BERT model
# ============================================================================

train_result = trainer.train()

# save the final model and tokenizer
final_model_path = BERT_MODEL_DIR / "final_model"
trainer.save_model(str(final_model_path))
tokenizer.save_pretrained(str(final_model_path))

print("\n  Training complete.")
print(f"   Final model saved to: {final_model_path}")

Step,Training Loss



  Training complete.
   Final model saved to: /content/dietcheck/results/models/task1_bert_text_only/final_model


In [None]:
# Evaluation on Validation Set

eval_results = trainer.evaluate()
print("â®• Trainer eval metrics (raw):")
for k, v in eval_results.items():
    if isinstance(v, (float, int)):
        print(f"  {k}: {v:.4f}")
    else:
        print(f"  {k}: {v}")

# Get logits on validation set for structured analysis
predictions = trainer.predict(val_dataset)
logits = predictions.predictions
labels = predictions.label_ids
probs = sigmoid_array(logits)

metrics_val = compute_multilabel_metrics_from_probs(
    y_true=labels,
    probs=probs,
    threshold=0.5,
    label_names=LABEL_COLS,
)

print("\nâ®• Detailed validation metrics:")
for k, v in metrics_val.items():
    if isinstance(v, (float, int)):
        print(f"{k:25s}: {v:.4f}")
    else:
        print(f"{k}: {v}")

# Per-label summary table
rows = []
for label in LABEL_COLS:
    rows.append(
        {
            "label": label,
            "precision": metrics_val[f"{label}_precision"],
            "recall": metrics_val[f"{label}_recall"],
            "f1": metrics_val[f"{label}_f1"],
            "roc_auc": metrics_val[f"{label}_roc_auc"],
            "pr_auc": metrics_val[f"{label}_pr_auc"],
        }
    )

label_metrics_val_df = pd.DataFrame(rows).set_index("label")

print("\nâ®• Per-label validation metrics (BERT text-only):")
display(label_metrics_val_df)

overall_val_df = pd.DataFrame(
    {
        "micro_f1": [metrics_val["micro_f1"]],
        "macro_f1": [metrics_val["macro_f1"]],
    },
    index=["overall"],
)

print("\nâ®• Overall validation F1 (BERT text-only):")
display(overall_val_df)



â®• Trainer eval metrics (raw):
  eval_loss: 0.6191
  eval_micro_f1: 0.3011
  eval_macro_f1: 0.2616
  eval_keto_compliant_precision: 0.0000
  eval_keto_compliant_recall: 0.0000
  eval_keto_compliant_f1: 0.0000
  eval_keto_compliant_roc_auc: 0.6822
  eval_keto_compliant_pr_auc: 0.4676
  eval_high_protein_precision: 0.4000
  eval_high_protein_recall: 0.2353
  eval_high_protein_f1: 0.2963
  eval_high_protein_roc_auc: 0.5735
  eval_high_protein_pr_auc: 0.4484
  eval_low_sodium_precision: 0.3333
  eval_low_sodium_recall: 0.0476
  eval_low_sodium_f1: 0.0833
  eval_low_sodium_roc_auc: 0.7321
  eval_low_sodium_pr_auc: 0.6516
  eval_low_fat_precision: 0.7500
  eval_low_fat_recall: 0.6000
  eval_low_fat_f1: 0.6667
  eval_low_fat_roc_auc: 0.8200
  eval_low_fat_pr_auc: 0.6050
  eval_runtime: 45.0145
  eval_samples_per_second: 1.0000
  eval_steps_per_second: 0.0440
  epoch: 4.0000

â®• Detailed validation metrics:
micro_f1                 : 0.3011
macro_f1                 : 0.2616
keto_compliant_pr

Unnamed: 0_level_0,precision,recall,f1,roc_auc,pr_auc
label,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
keto_compliant,0.0,0.0,0.0,0.682222,0.467592
high_protein,0.4,0.235294,0.296296,0.573529,0.448419
low_sodium,0.333333,0.047619,0.083333,0.732143,0.651637
low_fat,0.75,0.6,0.666667,0.82,0.605001



â®• Overall validation F1 (BERT text-only):


Unnamed: 0,micro_f1,macro_f1
overall,0.301075,0.261574


In [None]:
# ============================================================================
# Cell 15: Optional evaluation on held-out test.csv (if labels present)
# ============================================================================

bert_test_metrics_df = None
metrics_test = None

if df_test is not None and all(col in df_test.columns for col in LABEL_COLS):
    print("â–¶ Evaluating BERT model on held-out test.csv")

    test_dataset = DietCheckTextDataset(
        df=df_test,
        tokenizer=tokenizer,
        text_col=TEXT_COL,
        label_cols=LABEL_COLS,
        max_length=MAX_LENGTH,
    )

    test_pred = trainer.predict(test_dataset)
    test_logits = test_pred.predictions
    test_labels = test_pred.label_ids
    test_probs = sigmoid_array(test_logits)

    metrics_test = compute_multilabel_metrics_from_probs(
        y_true=test_labels,
        probs=test_probs,
        threshold=0.5,
        label_names=LABEL_COLS,
    )

    rows = []
    for label in LABEL_COLS:
        rows.append(
            {
                "label": label,
                "precision": metrics_test[f"{label}_precision"],
                "recall": metrics_test[f"{label}_recall"],
                "f1": metrics_test[f"{label}_f1"],
                "roc_auc": metrics_test[f"{label}_roc_auc"],
                "pr_auc": metrics_test[f"{label}_pr_auc"],
            }
        )

    bert_test_metrics_df = pd.DataFrame(rows).set_index("label")

    print("\nâ®• BERT test-set per-label metrics:")
    display(bert_test_metrics_df)

    overall_test_df = pd.DataFrame(
        {
            "micro_f1": [metrics_test["micro_f1"]],
            "macro_f1": [metrics_test["macro_f1"]],
        },
        index=["overall"],
    )

    print("\nâ®• BERT test-set overall F1:")
    display(overall_test_df)
else:
    print("â®• No fully-labeled test.csv detected â€“ skipping held-out evaluation.")



â–¶ Evaluating BERT model on held-out test.csv



â®• BERT test-set per-label metrics:


Unnamed: 0_level_0,precision,recall,f1,roc_auc,pr_auc
label,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
keto_compliant,0.0,0.0,0.0,0.694444,0.570116
high_protein,0.333333,0.1875,0.24,0.421875,0.293245
low_sodium,0.5,0.111111,0.181818,0.59387,0.543964
low_fat,0.615385,0.307692,0.410256,0.619231,0.631878



â®• BERT test-set overall F1:


Unnamed: 0,micro_f1,macro_f1
overall,0.243478,0.208019


In [None]:
# ============================================================================
# Cell 16: Compare BERT vs TFâ€“IDF and Logistic Regression baselines (from 01)
# ============================================================================

text_logreg_metrics_path = RESULTS_DIR / "task1_text_logreg_metrics.json"

if text_logreg_metrics_path.exists():
    try:
        with open(text_logreg_metrics_path, "r") as f:
            baseline_obj = json.load(f)

        # Expected structure (from 01):
        # {
        #   "label_metrics": {
        #       "text": {
        #           "<label>": {"f1": ..., "roc_auc": ..., "pr_auc": ...},
        #           ...
        #       },
        #       ...
        #   },
        #   ...
        # }
        label_metrics_block = baseline_obj.get("label_metrics", {})
        text_family = label_metrics_block.get("text", None)

        if text_family is None:
            raise KeyError("Could not find 'label_metrics' -> 'text' in baseline JSON.")

        rows = []
        for label in LABEL_COLS:
            baseline_metrics = text_family.get(label, None)
            if baseline_metrics is None:
                continue

            rows.append(
                {
                    "label": label,
                    "logreg_text_roc_auc": baseline_metrics.get("roc_auc", np.nan),
                    "logreg_text_pr_auc": baseline_metrics.get("pr_auc", np.nan),
                    "logreg_text_f1": baseline_metrics.get("f1", np.nan),
                    "bert_text_roc_auc": metrics_val.get(f"{label}_roc_auc", np.nan),
                    "bert_text_pr_auc": metrics_val.get(f"{label}_pr_auc", np.nan),
                    "bert_text_f1": metrics_val.get(f"{label}_f1", np.nan),
                }
            )

        if rows:
            comparison_df = pd.DataFrame(rows).set_index("label")
            print("â®• BERT vs TFâ€“IDF + Logistic Regression (validation):")
            display(comparison_df)
        else:
            print(
                "â®• Baseline JSON found but did not contain per-label metrics "
                f"for {LABEL_COLS} under 'text'."
            )

    except Exception as e:
        print(
            "â®• Found task1_text_logreg_metrics.json but could not parse it "
            "with the expected structure."
        )
        print(f"   Reason: {e}")
else:
    print(
        f"â®• Did not find {text_logreg_metrics_path}.\n"
        "   Run 01_task1_baselines.ipynb first if you want baseline comparison."
    )



â®• Did not find /content/dietcheck/results/task1_text_logreg_metrics.json.
   Run 01_task1_baselines.ipynb first if you want baseline comparison.


In [None]:
# ============================================================================
# Cell 17: Save BERT metrics + configuration as JSON checkpoint
# ============================================================================

# Safely extract training args (older transformers may lack some attributes)
training_args_config = {
    "learning_rate": getattr(training_args, "learning_rate", None),
    "num_train_epochs": getattr(training_args, "num_train_epochs", None),
    "per_device_train_batch_size": getattr(training_args, "per_device_train_batch_size", None),
    "per_device_eval_batch_size": getattr(training_args, "per_device_eval_batch_size", None),
    "weight_decay": getattr(training_args, "weight_decay", None),
    "evaluation_strategy": getattr(training_args, "evaluation_strategy", None),
    "metric_for_best_model": getattr(training_args, "metric_for_best_model", None),
    "seed": RANDOM_STATE,
    "fp16": getattr(training_args, "fp16", use_fp16),
}

bert_metrics_checkpoint = {
    "label_metrics": {
        "val": label_metrics_val_df.to_dict(orient="index"),
        "test": bert_test_metrics_df.to_dict(orient="index")
        if bert_test_metrics_df is not None
        else None,
    },
    "overall": {
        "val": {
            "micro_f1": metrics_val["micro_f1"],
            "macro_f1": metrics_val["macro_f1"],
        },
        "test": {
            "micro_f1": metrics_test["micro_f1"],
            "macro_f1": metrics_test["macro_f1"],
        }
        if metrics_test is not None
        else None,
    },
    "model_config": {
        "model_name": MODEL_NAME,
        "num_labels": len(LABEL_COLS),
        "label_cols": LABEL_COLS,
        "text_col": TEXT_COL,
        "max_length": MAX_LENGTH,
        "training_args": training_args_config,
    },
}

bert_metrics_path = RESULTS_DIR / "task1_bert_text_only_metrics.json"
with open(bert_metrics_path, "w") as f:
    json.dump(bert_metrics_checkpoint, f, indent=2)

print(f"âž¤ Saved BERT metrics and configuration to: {bert_metrics_path}")


âž¤ Saved BERT metrics and configuration to: /content/dietcheck/results/task1_bert_text_only_metrics.json
