# Ingredient Normalization (Data-Driven, PMI-Based)

## This notebook implements a **fully data-driven** normalization for ingredient phrases at scale (~2 M rows).

-   1. Streams your dataset to count unigrams/bigrams/trigrams  
-   2. Computes PMI-style association scores  
-   3. Builds a canonical vocabulary (no hard-coded lists)  
-   4. Segments each NER item by greedy longest-match using that vocabulary  
-   5. Writes a cleaned column (`NER_clean`) back to disk

### Designed for large CSVs: uses chunked ingestion and optional on-disk checkpoints.

In [None]:
# # Optional installs (run once if missing)
# !pip install pyarrow

# # Update requirements.txt 
# !pip freeze > ../requirements.txt

## Config

In [None]:
from pathlib import Path


DATA_PATH = Path("../data/wilmerarltstrmberg_data.csv")      # your raw CSV
OUTPUT_PATH = Path("../data/recipes_data_clean.parquet")  # output
VOCAB_JSON = Path("../data/ingredient_vocab_stats.json")
NER_COL = "NER"
CHUNK_SIZE = 200_000

MIN_UNIGRAM = 50
MIN_BIGRAM = 50
MIN_TRIGRAM = 30
PMI_BIGRAM = 3.5      # log odds threshold
PMI_TRIGRAM = 2.0

import ast, math, re, json, gc
from collections import Counter
from typing import List, Tuple, Iterable
import pandas as pd
from tqdm import tqdm
import numpy as np

_WORD_RE = re.compile(r"[a-z']+")



In [None]:
def parse_ner_entry(entry) -> List[str]:
    if entry is None or (isinstance(entry, float) and pd.isna(entry)):
        return []
    s = str(entry).strip()
    if not s:
        return []
    try:
        parsed = ast.literal_eval(s)
        if isinstance(parsed, list):
            return [str(x).strip() for x in parsed if str(x).strip()]
    except Exception:
        pass
    return [x.strip() for x in s.split(',') if x.strip()]

def tok(s: str) -> List[str]:
    return _WORD_RE.findall(str(s).lower())

def ngrams(tokens: List[str], n: int):
    for i in range(len(tokens) - n + 1):
        yield tuple(tokens[i:i+n])

In [None]:
from spellchecker import SpellChecker
from rapidfuzz import process, fuzz

class SpellCorrector:
    """Word-level spell corrector + phrase-level fuzzy matcher on top of PMI canon."""
    def __init__(self, known_phrases, fuzzy_threshold=88):
        # known_phrases = ['brown sugar', 'olive oil', ...]
        self._canon = known_phrases
        self._spell = SpellChecker(distance=2)
        # prime spellchecker vocabulary
        vocab_tokens = [t for p in known_phrases for t in p.split()]
        self._spell.word_frequency.load_words(vocab_tokens)
        self.fuzzy_threshold = fuzzy_threshold

    def _spell_correct_phrase(self, text):
        tokens = [self._spell.correction(t) or t for t in tok(text)]
        return " ".join(tokens)

    def correct_and_match(self, raw_item):
        """Return best-matched canonical phrase and similarity score."""
        if not raw_item:
            return "", None
        corrected = self._spell_correct_phrase(raw_item)
        match, score, _ = process.extractOne(
            corrected, self._canon, scorer=fuzz.WRatio
        )
        if score >= self.fuzzy_threshold:
            return match, int(score)
        return corrected, int(score)


In [None]:
class StatsNormalizer:
    def __init__(self,
                 max_ngram=3,
                 min_unigram=50, min_bigram=50, min_trigram=30,
                 pmi_bigram=3.5, pmi_trigram=2.0,
                 min_child_share=0.12,          # NEW: keep trigram only if >= 12% of head
                 max_right_entropy=1.0):        # NEW: keep trigram only if head's H_right <= 1.0
        self.max_ngram = max_ngram
        self.min_unigram = min_unigram
        self.min_bigram = min_bigram
        self.min_trigram = min_trigram
        self.pmi_bigram = pmi_bigram
        self.pmi_trigram = pmi_trigram
        self.min_child_share = min_child_share
        self.max_right_entropy = max_right_entropy
         
        from collections import defaultdict, Counter
        self.c1, self.c2, self.c3 = Counter(), Counter(), Counter()
        self.token_total = 0
        self.canon, self._canon_ready = set(), False
        self._followers = defaultdict(Counter)  

    #  counting 
    def ingest_df(self, df, ner_col="NER"):
        for entry in df[ner_col]:
            for item in parse_ner_entry(entry):
                t = tok(item)
                if not t: 
                    continue
                self.c1.update(t)
                self.token_total += len(t)
                if self.max_ngram >= 2 and len(t) >= 2:
                    self.c2.update(ngrams(t, 2))
                if self.max_ngram >= 3 and len(t) >= 3:
                    # count trigrams and follower distribution
                    for i in range(len(t) - 2):
                        a,b,c = t[i], t[i+1], t[i+2]
                        self.c3[(a,b,c)] += 1
                        self._followers[(a,b)][c] += 1

    def _right_entropy(self, ab):
        # ab is a tuple (a,b)
        foll = self._followers.get(ab)
        if not foll:
            return 0.0
        tot = sum(foll.values())
        if tot == 0:
            return 0.0
        H = 0.0
        for v in foll.values():
            p = v / tot
            H -= p * math.log(p + 1e-12)
        return H

    def _child_share(self, abc):
        ab = abc[:2]
        cabc = self.c3[abc]
        cab = self.c2[ab]
        if cab == 0:
            return 0.0
        return cabc / cab


    def ingest_csv(self, csv_path, ner_col="NER", chunksize=200_000):
        for chunk in tqdm(pd.read_csv(csv_path, chunksize=chunksize, dtype=str), desc="Counting"):
            self.ingest_df(chunk, ner_col=ner_col)
            del chunk; gc.collect()

    #  PMI 
    def _pmi_bigram(self, ab):
        a,b = ab
        cab = self.c2[ab]
        if cab==0 or self.token_total==0: return -1e9
        pa,pb = self.c1[a]/self.token_total, self.c1[b]/self.token_total
        pab = cab/self.token_total
        return math.log((pab/(pa*pb))+1e-12)

    def _pmi_trigram(self, abc):
        a,b,c = abc
        return (self._pmi_bigram((a,b))+self._pmi_bigram((b,c)))/2.0
    
    #  vocab 
    def build_vocab(self):
        self.canon.clear()

        # strong unigrams
        for w,c in self.c1.items():
            if c >= self.min_unigram:
                self.canon.add((w,))

        # strong bigrams
        for ab,c in self.c2.items():
            if c >= self.min_bigram and self._pmi_bigram(ab) >= self.pmi_bigram:
                self.canon.add(ab)

        # strong trigrams (apply PMI + child share + low branching entropy)
        for abc,c in self.c3.items():
            if c < self.min_trigram:
                continue
            if self._pmi_trigram(abc) < self.pmi_trigram:
                continue
            share = self._child_share(abc)
            if share < self.min_child_share:
                continue
            H = self._right_entropy(abc[:2])
            if H > self.max_right_entropy:
                continue
            self.canon.add(abc)

        self._canon_ready = True

    #  segmentation
    def _longest_match(self,toks,i):
        if not self._canon_ready:
            raise RuntimeError("build_vocab() first")
        if i+2<len(toks) and tuple(toks[i:i+3]) in self.canon:
            return tuple(toks[i:i+3]),3
        if i+1<len(toks) and tuple(toks[i:i+2]) in self.canon:
            return tuple(toks[i:i+2]),2
        if (toks[i],) in self.canon:
            return (toks[i],),1
        return (toks[i],),1

    def segment_item(self, text):
        t = tok(text)
        out, i = [], 0
        while i < len(t):
            phrase, k = self._longest_match(t, i)
            out.append(" ".join(phrase))
            i += k

        # dedupe
        seen, clean = set(), []
        for x in out:
            if x not in seen:
                clean.append(x); seen.add(x)

        # drop immediate repetition of the tail of previous phrase
        pruned = []
        for j, x in enumerate(clean):
            if pruned:
                prev = pruned[-1].split()
                if len(prev) >= 2 and x == prev[-1]:
                    # e.g., ["brown sugar", "sugar"] -> drop "sugar"
                    continue
            pruned.append(x)
        return pruned


    #  transform
    def transform_df(self, df, ner_col="NER", out_col="NER_clean"):
        df[out_col] = [[seg for item in parse_ner_entry(v) for seg in self.segment_item(item)]
                     for v in df[ner_col]]
        return df

    @staticmethod
    def _sanitize_for_arrow(df: pd.DataFrame, list_col: str = "NER_clean") -> pd.DataFrame:
        from pandas.api.types import (
    is_datetime64_any_dtype, is_bool_dtype,
)
        import pandas as pd

        
        df = df.copy()

        # ensure list[str]
        if list_col in df.columns:
            def _to_list_of_str(x):
                if isinstance(x, (list, tuple)):
                    return [str(y) for y in x]
                if pd.isna(x) or x is None:
                    return []
                try:
                    val = json.loads(x)
                    if isinstance(val, list):
                        return [str(y) for y in val]
                except Exception:
                    pass
                return [str(x)]
            df[list_col] = df[list_col].apply(_to_list_of_str)

        for col in df.columns:
            if col == list_col:
                continue
            s = df[col]

            try:
                if isinstance(s.dtype, pd.PeriodDtype):
                    df[col] = s.astype("string")
                    continue
            except Exception:
                pass

            if isinstance(s.dtype, pd.DatetimeTZDtype):

                df[col] = pd.to_datetime(s, errors="coerce").dt.tz_convert("UTC").dt.strftime("%Y-%m-%dT%H:%M:%S.%fZ")
                continue

            if isinstance(s.dtype, pd.PeriodDtype) and not isinstance(s.dtype, pd.DatetimeTZDtype):
                df[col] = pd.to_datetime(s, errors="coerce").dt.strftime("%Y-%m-%dT%H:%M:%S.%f")
                continue

            if isinstance(s.dtype, pd.BooleanDtype):
                if s.isna().any():
                    df[col] = s.fillna(False).astype(bool)
                continue

            if s.dtype == object:
                def _to_scalar_str(v):
                    if isinstance(v, (list, tuple, dict, set)):
                        return json.dumps(v, ensure_ascii=False)
                    return "" if v is None or (isinstance(v, float) and np.isnan(v)) else str(v)
                df[col] = s.map(_to_scalar_str)

        return df

    def transform_csv_to_parquet(self, csv_path, out_path, ner_col="NER", chunksize=200_000,
                                use_spellcheck=True, fuzzy_threshold=88):
        import pyarrow as pa
        import pyarrow.parquet as pq

        # Build a spell-corrector from the current canon (if requested)
        spell = None
        if use_spellcheck:
            canon_phrases = [" ".join(p) for p in self.canon]
            spell = SpellCorrector(canon_phrases, fuzzy_threshold=fuzzy_threshold)

        writer = None

        for chunk in pd.read_csv(csv_path, chunksize=chunksize, dtype=str):
            # 1) (optional) spell-check + fuzzy match each NER item to canon
            if use_spellcheck and spell is not None:
                spell_col = []
                for entry in chunk[ner_col]:
                    corrected = [spell.correct_and_match(item) for item in parse_ner_entry(entry)]
                    spell_col.append(corrected)
                chunk["NER_spellchecked"] = spell_col
                ner_for_seg = "NER_spellchecked"
            else:
                ner_for_seg = ner_col

            # 2) segment using PMI canon
            chunk = self.transform_df(chunk, ner_col=ner_for_seg, out_col="NER_clean")

            # 3) sanitize types for Arrow
            chunk = self._sanitize_for_arrow(chunk, list_col="NER_clean")

            # 4) to Arrow, strip pandas metadata, force list<string> for NER_clean
            table = pa.Table.from_pandas(chunk, preserve_index=False).replace_schema_metadata(None)
            fields = []
            for f in table.schema:
                if f.name == "NER_clean" and not pa.types.is_list(f.type):
                    fields.append(pa.field("NER_clean", pa.list_(pa.string())))
                else:
                    fields.append(f)
            target_schema = pa.schema(fields)
            try:
                table = table.cast(target_schema, safe=False)
            except Exception:
                arrays = [pa.array(arr, type=pa.list_(pa.string())) for arr in table.column("NER_clean").to_pylist()]
                table = table.set_column(table.schema.get_field_index("NER_clean"), "NER_clean", pa.chunked_array(arrays))

            # 5) write/append with a single ParquetWriter
            if writer is None:
                writer = pq.ParquetWriter(out_path, target_schema, compression="zstd")
            writer.write_table(table)

            del chunk, table
            gc.collect()

        if writer is not None:
            writer.close()


    # save/load 
    def save_vocab(self,path):
        data={
            "token_total":self.token_total,
            "canon":[" ".join(p) for p in sorted(self.canon)]
        }
        path.parent.mkdir(parents=True,exist_ok=True)
        with open(path,"w",encoding="utf-8") as f: json.dump(data,f,indent=2)

    @classmethod
    def load_vocab(cls,path):
        data=json.load(open(path))
        obj=cls()
        obj.canon=set(tuple(p.split()) for p in data["canon"])
        obj._canon_ready=True
        return obj

## PASS 1: Count n-grams

In [7]:
normalizer = StatsNormalizer(
    max_ngram=3,
    min_unigram=MIN_UNIGRAM,
    min_bigram=MIN_BIGRAM,
    min_trigram=MIN_TRIGRAM,
    pmi_bigram=PMI_BIGRAM,
    pmi_trigram=PMI_TRIGRAM,
    min_child_share=0.01,   # change to a smaller value to keep more trigrams
    max_right_entropy=1.0 # change to a larger value to keep more trigrams
)
print("Streaming counts from", DATA_PATH)
normalizer.ingest_csv(DATA_PATH, ner_col=NER_COL, chunksize=CHUNK_SIZE)
print("Total tokens:", normalizer.token_total)

Counting: 1it [00:05,  5.11s/it]


ParserError: Error tokenizing data. C error: Calling read(nbytes) on source failed. Try engine='python'.

## PASS 2: Build vocabulary 

In [None]:
normalizer.build_vocab()
normalizer.save_vocab(VOCAB_JSON)
print("Canonical phrases:", len(normalizer.canon))
print("Saved vocab:", VOCAB_JSON)


## PASS 3: Segment & write

In [None]:
normalizer.transform_csv_to_parquet(
    csv_path=DATA_PATH,
    out_path=OUTPUT_PATH,
    ner_col=NER_COL,
    chunksize=CHUNK_SIZE,
    use_spellcheck=True,       
    fuzzy_threshold=88
)
print("Wrote cleaned file:", OUTPUT_PATH)

## Example diagnostic: see top "brown sugar ..." trigrams


In [None]:
# load canon fast path
loaded = StatsNormalizer.load_vocab(VOCAB_JSON)
normalizer = StatsNormalizer(
    max_ngram=3,
    min_unigram=MIN_UNIGRAM,
    min_bigram=MIN_BIGRAM,
    min_trigram=MIN_TRIGRAM,
    pmi_bigram=PMI_BIGRAM,
    pmi_trigram=PMI_TRIGRAM,
    min_child_share=0.01,
    max_right_entropy=1.0
)
normalizer.canon = loaded.canon
normalizer._canon_ready = True

normalizer.transform_csv_to_parquet(
    csv_path=DATA_PATH,
    out_path=OUTPUT_PATH,
    ner_col=NER_COL,
    chunksize=CHUNK_SIZE,
    use_spellcheck=True,
    fuzzy_threshold=88
)


head = ("brown","sugar")
H = normalizer._right_entropy(head)
tot = normalizer.c2[head]
rows = []
for (a,b,c), cnt in normalizer.c3.items():
    if (a,b) == head:
        share = cnt / max(1, tot)
        rows.append((f"{a} {b} {c}", cnt, share))
rows.sort(key=lambda x: -x[1])
print("H_right(brown sugar) =", H, "total bigram count:", tot)
for s, cnt, share in rows[:20]:
    print(f"{s:<30} {cnt:>6}  share={share:.3f}")


## Quick check


In [None]:
df_check=pd.read_parquet(OUTPUT_PATH).head(10)
df_check[[NER_COL,"NER_clean"]]