In [1]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import json
import re

In [2]:
# loading the data and doing a sanity check to make sure all necessary columns are present
branded_df = pd.read_csv("../DATA/processed/branded_food_clean.csv")
branded_df = branded_df.sample(n=1000, random_state=42)

expected_cols = {'fdc_id', 'product_name', 'ingredients_list_json', 'num_ingredients'}

missing_cols = expected_cols - set(branded_df.columns)

if missing_cols:
  print(f"Missing columns: {missing_cols}")
else:
  print("All expected columns are present")

All expected columns are present


In [3]:
def parse_json(json_record):
  if pd.isna(json_record):
    return []

  record =  str(json_record).strip()

  if record.startswith('[') and record.endswith(']'):
    try:
            items = json.loads(record)
    except Exception:
            items = record.split(',')

  else:
    items = record.split(',')

  norm = []
  for it in items:
        tok = re.sub(r"^ingredients\s*:\s*", "", str(it), flags=re.IGNORECASE).strip().lower()
        # remove 'made from...' or 'made with...' and everything after it
        tok = re.sub(r"\bmade\s+(from|with)\b.*", "", tok).strip()
        # keep only letters, digits, %, spaces, and hyphens
        tok = re.sub(r"[^a-z0-9%\s\-]", " ", tok)
        tok = re.sub(r"\s+", " ", tok).strip()
        if tok:
            norm.append(tok)
  return norm

branded_df['ingredients_list'] = branded_df['ingredients_list_json'].apply(parse_json)


In [4]:
def rank_weight(rank, alpha = 0.6, scale = 1.0):
  return scale / (rank ** alpha)

def order_weight_tokens(ing_list, alpha=0.6, scale=6):
    """
    Create a string where earlier ingredients are repeated more often.
    alpha controls how quickly weight decays with rank.
    scale controls how many repetitions the first ingredient gets.
    """
    bag = []
    for rank, item in enumerate(ing_list, start=1):
        weight = 1 / (rank ** alpha)  # position decay
        reps = max(1, int(weight * scale))  # repeat token
        bag.extend([item] * reps)
    return " , ".join(bag)

branded_df['ordered_text'] = branded_df['ingredients_list'].apply(order_weight_tokens)


def lf_any(tokens, pattern, base_weight=1.0, alpha=0.6):
    pat = re.compile(pattern)
    score = 0.0
    for i, t in enumerate(tokens, start=1):
        if pat.search(t):
            score += base_weight * rank_weight(i, alpha=alpha)
    return score

LFs_negative = {
    # Colors / dyes (incl. Red No. 3)
    "artificial_colors": (r"\b(red\s*0*3|red\s*no\.?\s*3|red\s*3|yellow\s*5|yellow\s*6|blue\s*1|artificial color|artificial colours?)\b", 1.5),
    # Titanium dioxide
    "titanium_dioxide": (r"\btitanium\s+dioxide\b", 1.5),
    # Potassium bromate
    "potassium_bromate": (r"\bpotassium\s+bromate\b", 1.6),
    # Brominated vegetable oil
    "bvo": (r"\bbrominated\s+vegetable\s+oil\b", 1.6),
    # Parabens (incl. propylparaben)
    "parabens": (r"\b(parabens?|propylparaben|methylparaben|butylparaben)\b", 1.2),
    # High fructose / corn syrup / added sugars cluster
    "hfcs_syrups_sugars": (r"\b(high\s*fructose\s*corn\s*syrup|hfcs|corn\s*syrup|glucose[-\s]*fructose|invert\s*sugar|sucrose|dextrose|fructose|glucose|molasses)\b", 1.8),
    # Nitrites/nitrates
    "nitrites": (r"\b(sodium|potassium)\s+(nitrite|nitrate)s?\b", 1.6),
    # TBHQ / BHA / BHT antioxidants
    "synthetic_antioxidants": (r"\b(tbhq|tertiary\s*butylhydroquinone|bha|butylated\s*hydroxyanisole|bht|butylated\s*hydroxytoluene)\b", 1.3),
    # Artificial sweeteners
    "non_nutritive_sweeteners": (r"\b(aspartame|acesulfame\s*k?|acesulfamek|sucralose|saccharin|neotame|advantame)\b", 1.2),
    # Hydrogenated/shortening (often avoided; reasonable to include)
    "hydrogenated_fats": (r"\b(partially\s+)?hydrogenated\b|\bshortening\b", 1.6),
     # Salt / sodium chloride (very common; keep weight modest + rely on position)
    "salt": (r"\b(salt|sodium\s+chloride)\b", 1.0),
    # Refined grains (enriched/bleached flour)
    "refined_flour": (r"\b(enriched|bleached)\s+(wheat\s+)?flour\b", 1.0),
    # Palm / tropical saturated oils
    "palm_coconut_oil": (r"\b(palm(\s+kernel)?\s+oil|coconut\s+oil)\b", 1.0),
    # Maltodextrin
    "maltodextrin": (r"\bmaltodextrin\b", 0.8),
    # MSG
    "msg": (r"\b(monosodium\s+glutamate|msg)\b", 1.0),
    # Preservatives (benzoate / sorbate / propionate)
    "preservatives": (r"\b(sodium|potassium|calcium)\s+(benzoate|sorbate|propionate)s?\b", 1.3),
    # Phosphates (texture enhancers)
    "phosphates": (r"\b(sodium|potassium|calcium)\s+\w*phosphate(s)?\b", 0.9),
    # Caramel color / color added
    "caramel_color": (r"\b(caramel\s+color|color\s+added)\b", 0.8),
    # Emulsifiers / thickeners
    "emulsifiers": (r"\b(mono-?\s* and-?\s*diglycerides|polysorbate\s*80|propylene\s+glycol|datem|pgpr|modified\s+starch)\b", 0.9),
    # Gums (treat carrageenan stronger)
    "gums": (r"\b(xanthan\s+gum|guar\s+gum|gellan\s+gum|locust\s+bean\s+gum)\b", 0.6),
    "carrageenan": (r"\bcarrageenan\b", 1.5),
    # Sugar alcohols
    "sugar_alcohols": (r"\b(erythritol|xylitol|sorbitol|mannitol|maltitol|isomalt)\b", 0.9),
    # Artificial / 'natural' flavors (mild but frequent)
    "flavors": (r"\b(artificial\s+flavors?|natural\s+flavors?)\b", 0.6),
}

def score_negative(tokens, alpha = 0.6):
  s = 0.0
  for _, (pat, w) in LFs_negative.items():
      s += lf_any(tokens, pat, base_weight=w, alpha=alpha)
  return s


LFs_positive = {
    "whole_grain_first": (r"\b(whole\s+grain|whole[-\s]*wheat|whole\s*oats?|whole\s*rye|whole\s*barley)\b", 0.8),
     "fresh_produce": (r"\b("
                      r"apple|apples|banana|bananas|berry|berries|blueberry|strawberry|raspberry|blackberry|"
                      r"cherry|cherries|grape|grapes|melon|cantaloupe|honeydew|watermelon|mango|mangos|pineapple|pear|pears|"
                      r"peach|peaches|plum|plums|kiwi|oranges?|clementine|tangerine|lemon|lime|"
                      r"spinach|kale|lettuce|romaine|arugula|chard|collard|broccoli|cauliflower|cabbage|"
                      r"carrot|carrots|celery|beet|beets|radish|radishes|turnip|turnips|"
                      r"onion|onions|garlic|ginger|leek|shallot|scallion|"
                      r"pepper|peppers|bell\s*pepper|tomato|tomatoes|cucumber|zucchini|squash|"
                      r"pumpkin|eggplant|brussels\s*sprouts|asparagus|spinach"
                      r")\b", 1.0),
    "protein": (r"\b("
                r"protein|chicken|beef|pork|turkey|lamb|duck|salmon|tuna|cod|trout|shrimp|prawn|"
                r"crab|lobster|fish|anchovy|sardine|mackerel|egg|eggs|egg\s*white|egg\s*yolk|"
                r"tofu|tempeh|seitan|soy\s*protein|pea\s*protein|whey\s*protein|casein"
                r")\b", 0.7),
    "legumes": (r"\b("
                r"lentil|lentils|chickpea|chickpeas|pea|peas|split\s*peas|black\s*eyed\s*peas|"
                r"legume|legumes"
                r")\b", 0.8),
    "beans": (r"\b("
              r"black\s*bean|black\s*beans|kidney\s*bean|kidney\s*beans|pinto\s*bean|pinto\s*beans|"
              r"navy\s*beans|white\s*beans|cannellini|garbanzo|refried\s*beans|soybeans?|edamame"
              r")\b", 0.8),
    "nuts": (r"\b("
             r"almond|almonds|walnut|walnuts|cashew|cashews|pecan|pecans|pistachio|pistachios|"
             r"hazelnut|hazelnuts|macadamia|macadamias|brazil\s*nuts?|chestnut|chestnuts"
             r")\b", 0.8),
    # Healthy oils
    "healthy_oils": (r"\b((extra\s*virgin\s*)?olive\s+oil|avocado\s+oil|canola\s+oil)\b", 0.5),
    # Seeds
    "seeds": (r"\b(chia|flax|sesame|sunflower|pumpkin)\s+seeds?\b", 0.8),
    # Whole grains (expand coverage)
    "whole_grains_extra": (r"\b(brown\s+rice|quinoa|buckwheat|oat\s*bran|wheat\s*bran)\b", 0.9),
    # Fiber-rich isolates (small positive)
    "fiber": (r"\b(psyllium|inulin|oat\s*fiber|apple\s*fiber)\b", 0.7),
    # Fermented / live cultures (modest)
    "fermented": (r"\b(live\s+active\s+cultures|yogurt\s+cultures|kefir|sauerkraut|kimchi)\b", 0.6),
    # Herbs/spices (tiny bump; they often ride along with processed foods)
    "herbs_spices": (r"\b(turmeric|cumin|oregano|basil|thyme|rosemary|cilantro|parsley)\b", 0.3),
}

def score_positive(tokens, alpha = 0.6):
  s = 0.0
  for _, (pat, w) in LFs_positive.items():
    s += lf_any(tokens, pat, base_weight=w, alpha=alpha)
  return s

def length_penalty(tokens, cutoff = 12, per_extra = 0.03):
  extra = max(0, len(tokens) - cutoff)
  return extra * per_extra


In [5]:
def weak_score(tokens, alpha=0.6):
  base = score_negative(tokens, alpha=alpha) - 0.9 * score_positive(tokens, alpha=alpha)
  return base + length_penalty(tokens)


def score_to_label(score, t_healthy = -0.5, t_unhealthy = -0.1):
  if score < t_healthy:
    return 'healthy'
  elif score > t_unhealthy:
    return 'unhealthy'
  else:
    return 'intermediate'

def apply_snorkel_labels(df, ing_col="ingredients_list", alpha=0.6,
                         t_healthy=None, t_unhealthy=None, use_quantiles=True):
    df = df.copy()
    df["weak_score"] = df[ing_col].apply(lambda lst: weak_score(lst, alpha=alpha))
    if use_quantiles:
        q_low, q_high = df["weak_score"].quantile([0.33, 0.67])
        t_healthy = q_low if t_healthy is None else t_healthy
        t_unhealthy = q_high if t_unhealthy is None else t_unhealthy
    elif t_healthy is None or t_unhealthy is None:
        # sensible fallbacks
        t_healthy, t_unhealthy = -0.4, 0.6
    df["label"] = df["weak_score"].apply(lambda s: score_to_label(s, t_healthy, t_unhealthy))
    return df, {"t_healthy": t_healthy, "t_unhealthy": t_unhealthy}


labeled_df, thresholds = apply_snorkel_labels(branded_df, ing_col="ingredients_list", alpha=0.6)
print(thresholds)
binary_df = labeled_df[labeled_df['label'] != 2].copy()

{'t_healthy': -0.2748080089387155, 't_unhealthy': 0.6597539553864472}


In [7]:
# Threshold Tuning
from sklearn.model_selection import train_test_split
from sklearn.pipeline import Pipeline
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import f1_score
import numpy as np

import warnings
from sklearn.exceptions import ConvergenceWarning

def relabel(df, th, tu, alpha=0.6):
    new_df, _ = apply_snorkel_labels(df, ing_col="ingredients_list",
                                     alpha=alpha, t_healthy=th, t_unhealthy=tu,
                                     use_quantiles=False)
    return new_df

def quick_f1(df):
    X = df["ordered_text"].fillna("").astype(str)
    y = df["label"].astype(str)
    m = X.str.strip().str.len() > 0
    X, y = X[m], y[m]
    X_tr, X_va, y_tr, y_va = train_test_split(X, y, test_size=0.2, stratify=y, random_state=42)
    pipe = Pipeline([
        ("tfidf", TfidfVectorizer(ngram_range=(1,2), min_df=5, token_pattern=r"(?u)\b\w+\b")),
        ("lr", LogisticRegression(max_iter=2000, solver="saga", class_weight="balanced", n_jobs=-1)),
    ])

    # Ignore noisy warnings
    with warnings.catch_warnings():
        warnings.simplefilter("ignore", ConvergenceWarning)
        pipe.fit(X_tr, y_tr)
        y_hat = pipe.predict(X_va)
        
    return f1_score(y_va, pipe.predict(X_va), average="macro", zero_division=0)

# small search (fast). tweak ranges if needed.
best = (-1, None, None)
for th in np.linspace(-1.2, -0.1, 8):   # t_healthy candidates
    for tu in np.linspace(0.0, 0.8, 8): # t_unhealthy candidates
        f1 = quick_f1(relabel(branded_df, th, tu, alpha=0.6))
        if f1 > best[0]:
            best = (f1, th, tu)

print(f"[tuning] best macro-F1={best[0]:.3f} at t_healthy={best[1]:.3f}, t_unhealthy={best[2]:.3f}")

# use the tuned thresholds for the rest of the notebook
labeled_df, _ = apply_snorkel_labels(branded_df, ing_col="ingredients_list",
                                     alpha=0.6, t_healthy=best[1], t_unhealthy=best[2],
                                     use_quantiles=False)


[tuning] best macro-F1=0.817 at t_healthy=-0.257, t_unhealthy=0.686


In [8]:
multi_df = labeled_df.copy()

# features and labels
X = multi_df["ordered_text"].fillna("")
y = multi_df["label"]   # {'healthy','intermediate','unhealthy'}

mask = X.str.strip().str.len() > 0
X, y = X[mask], y[mask]

print("Class counts:", y.value_counts().to_dict())

# Split of three classes
from sklearn.model_selection import train_test_split
X_temp, X_test, y_temp, y_test = train_test_split(
    X, y, test_size=0.15, stratify=y, random_state=42
)
rel_val = 0.15 / 0.85
X_train, X_val, y_train, y_val = train_test_split(
    X_temp, y_temp, test_size=rel_val, stratify=y_temp, random_state=42
)

Class counts: {'intermediate': 349, 'healthy': 336, 'unhealthy': 315}


In [9]:
labeled_df[labeled_df['label'] == 'healthy'].sample(20)[['product_name', 'ingredients_list_json']]

Unnamed: 0,product_name,ingredients_list_json
1442000,PETITE DICED TOMATOES IN JUICE,"[""tomatoes"", ""tomato juice"", ""sea salt"", ""calc..."
1335801,"CINNAMON CRUNCH WHOLE GRAIN CRISPY SNACK BARS,...","[""tapioca syrup*"", ""brown rice flour*"", ""whey ..."
610607,"DOCTOR KRACKER, CRISPBREADS, KLASSIC 3 SEED, A...","[""organic wheat flour"", ""organic whole wheat f..."
1642740,"CHERRY ALMOND A GRANOLA SNACK, CHERRY ALMOND","[""gluten free oats"", ""roasted almonds"", ""agave..."
1844192,"ANGUS BEEF PATTIES, ANGUS BEEF","[""beef""]"
1861082,"FLYING FRUIT PUNCH MIGHTY JUICE BEVERAGE, FLYI...","[""water"", ""apple and cherry juice concentrates..."
1933099,THAI FIRED RICE,"[""water"", ""organic long grain rice"", ""sunflowe..."
1618704,"French Toast Sticks, Whole Grain (approx. 140-...","[""whole wheat bread"", ""water"", ""whole wheat ba..."
719393,VEGETARIAN MEAT SUBSTITUTE BITS,"[""textured vegetable protein"", ""soybean oil"", ..."
489715,"DELICIAS GLORIA SPICY SOUR BELTS, DELICIAS GLORIA","[""sour power strawberry belts"", ""spicy seasoni..."


In [10]:
labeled_df[labeled_df['label'] == 'unhealthy'].sample(20)[['product_name', 'ingredients_list_json']]

Unnamed: 0,product_name,ingredients_list_json
1583369,"ORIGINAL NO-STICK COOKING SPRAY, ORIGINAL","[""canola oil*"", ""coconut oil*"", ""palm oil*"", ""..."
1731025,SOUTHWESTERN STYLE SEASONED FULLY COOKED CHICK...,"[""boneless"", ""skinless chicken breasts with ri..."
1182657,"POO-TIN* POWER ENERGY FUEL, COCONUT MANGO","[""carbonated water"", ""high fructose corn syrup..."
459205,"RASPBERRY TEA, RASPBERRY","[""water"", ""corn syrup"", ""sugar"", ""citric acid""..."
731884,ITALIAN SAUSAGE,"[""pork"", ""water"", ""corn syrup"", ""the following..."
480573,"SOUR CREAM & ONION POTATO CHIPS, SOUR CREAM & ...","[""potatoes"", ""vegetable oil"", ""salt"", ""nonfat ..."
1923851,"BLACK CHERRY SPARKLING WATER BEVERAGE, BLACK C...","[""carbonated water"", ""citric acid"", ""natural f..."
397486,ORGANIC STEVIA GRANULAR SWEETENER PACKETS,"[""organic erythritol"", ""organic stevia leaf ex..."
1716444,"BUTTERY HOMESTYLE MASHED POTATOES, BUTTERY HOM...","[""potatoes"", ""maltodextrin"", ""coconut oil"", ""s..."
418522,"FRESH!, FUDGE BROWNIE CHEESECAKE, FUDGE BROWNIE","[""sugar"", ""water"", ""soybean oil"", ""wheat flour..."


In [11]:
labeled_df[labeled_df['label'] == 'intermediate'].sample(20)[['product_name', 'ingredients_list_json']]

Unnamed: 0,product_name,ingredients_list_json
21062,BABY SHIITAKE MUSHROOMS,"[""organic baby shiitake mushrooms""]"
84473,"RICHIN, DRIED ANCHOVY","[""anchovy"", ""salt""]"
883276,CANNELLINI BEANS,"[""prepared cannellini beans"", ""water"", ""salt"",..."
1785724,"SHREDDED PARMESAN CHEESE, SHREDDED PARMESAN","[""parmesan cheese"", ""powdered cellulose""]"
1620160,Pork Loin Boneless Strap Off,"[""basted with up to 16% added solution of wate..."
1811413,Golden Grahams Cereal,"[""whole grain wheat"", ""corn meal"", ""sugar"", ""b..."
827033,"LEMON LIME SPARKLING NATURAL SPRING WATER, LEM...","[""spring water"", ""co2"", ""natural flavors""]"
1231351,"MILKSHAKE CINNABON, CINNAMON ROLL","[""ice cream"", ""milk""]"
1868635,ORIGINAL GOAT CHEESE,"[""cultured pasteurized goat milk"", ""salt"", ""en..."
1616108,"CRISPY ROLLED AND FAN WAFERS, VANILLA","[""wheat flour"", ""sugar"", ""anhydrous milk fat"",..."


In [None]:
from sklearn.pipeline import Pipeline
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import StratifiedKFold, GridSearchCV

pipe = Pipeline([
    ("tfidf", TfidfVectorizer(ngram_range=(1,2), min_df=3,
                              token_pattern=r"(?u)\b\w+\b")),
    ("rf", RandomForestClassifier(random_state=42, n_jobs=-1))
])

param_grid = {
    "rf__n_estimators": [200, 400, 600],
    "rf__max_depth": [None, 20, 40],
    "rf__min_samples_leaf": [1, 2, 4],
    "rf__max_features": ["sqrt", "log2", None],
}

from sklearn.metrics import classification_report
from sklearn.model_selection import StratifiedKFold

cv = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)

grid = GridSearchCV(
    pipe, param_grid, cv=cv,
    scoring="f1_macro",   # or "f1_weighted"
    n_jobs=-1, verbose=1, error_score="raise"
)

grid.fit(pd.concat([X_train, X_val]), pd.concat([y_train, y_val]))

print("Best params:", grid.best_params_)
print("Best CV f1_macro:", grid.best_score_)

y_pred = grid.predict(X_test)
print(classification_report(y_test, y_pred, zero_division=0))

Fitting 5 folds for each of 81 candidates, totalling 405 fits


In [None]:
from sklearn.metrics import (
    confusion_matrix, ConfusionMatrixDisplay,
    classification_report, f1_score, accuracy_score
)
import matplotlib.pyplot as plt
from pathlib import Path

best_model = grid.best_estimator_
y_pred = best_model.predict(X_test)

# scores
print("Test Accuracy:", accuracy_score(y_test, y_pred))
print("Test F1 (macro):", f1_score(y_test, y_pred, average="macro", zero_division=0))
print("\nClassification Report:\n", classification_report(y_test, y_pred, zero_division=0))

# confusion matrix  (counts)
labels = ["healthy", "intermediate", "unhealthy"]  # ensure consistent order
cm = confusion_matrix(y_test, y_pred, labels=labels)

fig, ax = plt.subplots(figsize=(5.5, 4.5))
disp = ConfusionMatrixDisplay(cm, display_labels=labels)
disp.plot(ax=ax, cmap="Blues", values_format="d", colorbar=False)
ax.set_title("Confusion Matrix (Test)")
plt.tight_layout()

# save file
outdir = Path("../OUTPUT"); outdir.mkdir(parents=True, exist_ok=True)
fig.savefig(outdir / "confusion_matrix_counts.png", dpi=300, bbox_inches="tight")
plt.show()

# Normalized confusion matrix 
cm_norm = confusion_matrix(y_test, y_pred, labels=labels, normalize="true")  # rows sum to 1
fig2, ax2 = plt.subplots(figsize=(5.5, 4.5))
disp2 = ConfusionMatrixDisplay(cm_norm, display_labels=labels)
disp2.plot(ax=ax2, cmap="Blues", values_format=".2f", colorbar=True)
ax2.set_title("Confusion Matrix (Normalized by True Class)")
plt.tight_layout()

# save file
fig2.savefig(outdir / "confusion_matrix_normalized.png", dpi=300, bbox_inches="tight")
plt.show()

print("Saved figures to:", outdir.resolve())
