# The Ultimate Guide to ml-assert
This notebook provides a step-by-step, deeply annotated exploration of every feature in the **ml-assert** library.
We will cover:
1. DataFrameAssertion (DFA) DSL
2. Low-level distribution tests (KS, Chi-square, Wasserstein)
3. High-level drift detection
4. Model performance assertions
5. Plugin system (file_exists, dvc_check)
6. CLI execution and reporting

## What is DataFrameAssertion (DFA)?
DataFrameAssertion provides a fluent, chainable API for asserting properties of pandas DataFrames,
such as schema compliance (`schema`), absence of nulls (`no_nulls`), uniqueness (`unique`),
value ranges (`in_range`), and membership in an allowed set (`values_in_set`).

## What are Distribution Tests?
ml-assert offers low-level statistical tests:
- **Kolmogorov–Smirnov (KS)** for comparing numeric distributions
- **Chi-squared** for comparing categorical distributions
- **Wasserstein distance** for measuring distributional shift magnitude

## What is Drift Detection?
High-level drift detection (`assert_no_drift`) combines KS tests for numeric columns
and Chi-squared tests for categorical columns. In our example, the first drift check is **expected to fail**
because we deliberately shift the numeric mean and alter category proportions between the two datasets.

## What are Model Performance Assertions?
The `assert_model` DSL lets you chain checks on classification metrics:
accuracy, precision, recall, F1, and ROC AUC—failing immediately if any metric is below its threshold.

## What is the Plugin System?
Plugins extend ml-assert via Python entry points. Built-in plugins include:
- `file_exists`: verifies a file exists at a given path
- `dvc_check`: ensures DVC-tracked files remain in sync with their metadata

## CLI Usage
The `ml_assert run config.yaml` command executes a series of assertions defined
in a YAML file, generating both a machine-readable JSON report and a
human-friendly HTML report.

## 1. DataFrameAssertion (DFA)
The **DataFrameAssertion** DSL lets you chain checks about a pandas DataFrame:
- Schema (columns + dtypes)
- No nulls
- Uniqueness
- Value ranges
- Values in a specific set

If any check fails, an `AssertionError` is raised immediately, stopping execution.

In [1]:
import os
from pathlib import Path

import numpy as np
import pandas as pd

In [2]:
from ml_assert import Assertion, schema

In [3]:
# Setup a temporary artifacts directory
artifact_dir = Path("ultimate_guide_artifacts")
artifact_dir.mkdir(exist_ok=True)

### 1.1. Full DataFrame Validation (Expect Failure)
We create a DataFrame with a column `empty_col` full of nulls.
We then run a chain of DFA checks, including `no_nulls()` on **all** columns.
Because `empty_col` contains 10 nulls, this should **fail** immediately.

In [4]:
# Create sample DataFrame
data = {
    "user_id": list(range(100, 110)),
    "age": [25, 30, 22, 45, 30, 50, 60, 22, 33, 41],
    "city": [
        "New York",
        "London",
        "Paris",
        "Tokyo",
        "London",
        "New York",
        "Sydney",
        "Paris",
        "London",
        "Tokyo",
    ],
    "plan_type": [
        "basic",
        "premium",
        "basic",
        "premium",
        "premium",
        "basic",
        "free",
        "free",
        "premium",
        "basic",
    ],
    "monthly_spend": [50, 100, 55, 110, 105, 45, 0, 0, 120, 60],
    "empty_col": [np.nan] * 10,
}
df = pd.DataFrame(data)
print("Sample DataFrame:")
print(df)

Sample DataFrame:
   user_id  age      city plan_type  monthly_spend  empty_col
0      100   25  New York     basic             50        NaN
1      101   30    London   premium            100        NaN
2      102   22     Paris     basic             55        NaN
3      103   45     Tokyo   premium            110        NaN
4      104   30    London   premium            105        NaN
5      105   50  New York     basic             45        NaN
6      106   60    Sydney      free              0        NaN
7      107   22     Paris      free              0        NaN
8      108   33    London   premium            120        NaN
9      109   41     Tokyo     basic             60        NaN


In [5]:
print("\nRunning full DFA validation (this will fail):")
try:
    s = schema()
    s.col("user_id").is_type("int64").is_unique()
    s.col("age").is_type("int64").in_range(18, 70)
    s.col("city").is_type("object")

    Assertion(df).satisfies(s).no_nulls().validate()
    print("ERROR: DFA did not fail when it should have.")
except AssertionError as e:
    print(f"As expected, DFA failed: {e}")


Running full DFA validation (this will fail):
As expected, DFA failed: Column empty_col contains 10 null values



**Explanation:** The `no_nulls()` step checks **all** columns. Since `empty_col` has 10 null values, it raises an error listing that column and the count.

### 1.2. Partial Column Validation (Expect Success)
We can pass a list to `no_nulls()` to restrict the check to specific columns.
Here we omit `empty_col`, so the validation should **pass**.

In [6]:
print("Running partial DFA validation (should succeed):")
try:
    Assertion(df).no_nulls(
        ["user_id", "age", "city", "plan_type", "monthly_spend"]
    ).validate()
    print("Partial DFA validation passed.")
except AssertionError as e:
    print(f"ERROR: Partial DFA validation failed: {e}")

Running partial DFA validation (should succeed):
Partial DFA validation passed.


**Explanation:** By specifying only the non-null columns, we bypass the failing `empty_col` check and the chain completes successfully.

## 2. Low-level Distribution Tests
ml-assert exposes individual distribution test functions for fine-grained control:
- `assert_ks_test(sample1, sample2, alpha)`
- `assert_chi2_test(observed, expected, alpha)`
- `assert_wasserstein_distance(sample1, sample2, max_distance)`

Each raises an `AssertionError` when the test condition is not met.

In [7]:
from ml_assert.stats.distribution import (
    assert_chi2_test,
    assert_ks_test,
    assert_wasserstein_distance,
)

### 2.1. KS Test
- **Pass case:** identical samples → no error
- **Fail case:** sample shifted by +10 → p-value < alpha → error

In [8]:
arr1 = np.array([1, 2, 3, 4, 5])
arr2 = arr1.copy()
print("KS test pass (identical):")
try:
    assert_ks_test(arr1, arr2, alpha=0.05)
    print("  Passed.")
except AssertionError as e:
    print(f"  ERROR: {e}")

KS test pass (identical):
  Passed.


In [9]:
arr3 = arr1 + 10
print("KS test fail (shifted):")
try:
    assert_ks_test(arr1, arr3, alpha=0.05)
    print("  ERROR: Should have failed.")
except AssertionError as e:
    print(f"  As expected: {e}")

KS test fail (shifted):
  As expected: KS test failed (statistic=1.0000, p-value=0.0079 < 0.05)


### 2.2. Chi-square Test
- **Pass case:** observed == expected → no error
- **Fail case:** reversed counts → p-value < alpha → error

In [10]:
obs = np.array([10, 20, 30])
exp = obs.copy()
print("Chi-square pass (same counts):")
try:
    assert_chi2_test(obs, exp, alpha=0.05)
    print("  Passed.")
except AssertionError as e:
    print(f"  ERROR: {e}")

Chi-square pass (same counts):
  Passed.


In [11]:
exp2 = np.array([30, 20, 10])
print("Chi-square fail (reversed):")
try:
    assert_chi2_test(obs, exp2, alpha=0.05)
    print("  ERROR: Should have failed.")
except AssertionError as e:
    print(f"  As expected: {e}")

Chi-square fail (reversed):
  As expected: Chi-square test failed: p-value 2.6230937696693e-12 < alpha 0.05


### 2.3. Wasserstein Distance
- **Pass case:** identical arrays → distance=0 ≤ max_distance → no error
- **Fail case:** arrays differ by 10 units → exceeds max_distance → error

In [12]:
print("Wasserstein pass (identical):")
try:
    assert_wasserstein_distance(arr1, arr2, max_distance=0.0)
    print("  Passed.")
except AssertionError as e:
    print(f"  ERROR: {e}")

Wasserstein pass (identical):
  Passed.


In [13]:
print("Wasserstein fail (distance >1):")
try:
    assert_wasserstein_distance(arr1, arr3, max_distance=1.0)
    print("  ERROR: Should have failed.")
except AssertionError as e:
    print(f"  As expected: {e}")

Wasserstein fail (distance >1):
  As expected: Wasserstein distance 10.0000 exceeds max 1.0000


## 3. High-level Drift Detection
The `assert_no_drift(df1, df2, alpha)` function runs KS tests on numeric columns and Chi-square tests on categorical columns.
It stops on the first failing column.

**Failing example:** We deliberately shift the numeric mean from ~20 to ~30 and change city distribution
(NY:50%→20%, SF:20%→50%), so drift is correctly detected in the first assertion.

In [14]:
import pandas as pd

from ml_assert.stats.drift import assert_no_drift

In [15]:
# Reference: N(20,5), balanced cities
df_ref = pd.DataFrame(
    {
        "temperature": np.random.normal(20, 5, 500),
        "city": np.random.choice(["NY", "LA", "SF"], 500, p=[0.5, 0.3, 0.2]),
    }
)
# Drift: mean shifted +10, city distribution changed
df_cur = pd.DataFrame(
    {
        "temperature": np.random.normal(30, 5, 500),
        "city": np.random.choice(["NY", "LA", "SF"], 500, p=[0.2, 0.3, 0.5]),
    }
)

In [16]:
print("Drift case (expect failure):")
try:
    assert_no_drift(df_ref, df_cur, alpha=0.05)
    print("  ERROR: Drift not detected.")
except AssertionError as e:
    print(f"  As expected: {e}")

Drift case (expect failure):
  As expected: KS test failed for series: p-value 0.0000 < alpha 0.05


In [17]:
# No drift: identical data
print("No-drift case (expect success):")
df_cur2 = df_ref.copy()
try:
    assert_no_drift(df_ref, df_cur2, alpha=0.05)
    print("  Passed.")
except AssertionError as e:
    print(f"  ERROR: False positive: {e}")

No-drift case (expect success):
  Passed.


## 4. Model Performance Assertions
The `assert_model(y_true, y_pred, y_scores)` DSL lets you chain checks on:
accuracy, precision, recall, F1, ROC AUC.
It raises on the first metric below its threshold.

In [18]:
import seaborn as sns
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import train_test_split

from ml_assert import assert_model

In [19]:
# Prepare Titanic data
titanic = sns.load_dataset("titanic")
titanic["age"] = titanic["age"].fillna(titanic["age"].median())
titanic.drop(["deck", "embark_town", "alive"], axis=1, inplace=True)
titanic = pd.get_dummies(
    titanic,
    columns=["sex", "class", "who", "adult_male", "alone", "embarked"],
    drop_first=True,
)
X = titanic.drop("survived", axis=1)
y = titanic["survived"]
X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.2, random_state=42
)
model = LogisticRegression(max_iter=1000, random_state=42).fit(X_train, y_train)
y_pred = model.predict(X_test)
y_scores = model.predict_proba(X_test)[:, 1]

In [20]:
print("Running model assertions (thresholds should be met):")
try:
    # Create model assertion
    model_assertion = assert_model(y_test, y_pred)

    # Chain metric assertions with their thresholds
    model_assertion.accuracy(min_score=0.75).precision(min_score=0.65).recall(
        min_score=0.60
    ).f1(min_score=0.70).validate()
    print("  All metrics passed.")
except AssertionError as e:
    print(f"  ERROR: {e}")

Running model assertions (thresholds should be met):
  All metrics passed.


**Explanation:** If any metric fell below its threshold, an `AssertionError` would be raised at that metric.

## 5. Plugin System
We ship two plugins by default:
1. **file_exists**: check that a file exists.
2. **dvc_check**: check that a DVC-tracked file is in sync.

In [21]:
# file_exists demo
file_exists_path = artifact_dir / "my_model.pkl"
file_exists_path.touch()
print(f"file_exists plugin: Created {file_exists_path}")

file_exists plugin: Created ultimate_guide_artifacts/my_model.pkl


In [22]:
# dvc_check demo
import subprocess

dvc_data = artifact_dir / "model_data.csv"
dvc_data.write_text("a,b\n1,2")
if not (artifact_dir / ".dvc").exists():
    subprocess.run(["dvc", "init", "--no-scm"], cwd=artifact_dir, check=True)
    subprocess.run(["dvc", "add", dvc_data.name], cwd=artifact_dir, check=True)
    print("dvc_check plugin: DVC setup and added file_data.csv")
else:
    print("dvc_check plugin: DVC already initialized")

Initialized DVC repository.

+---------------------------------------------------------------------+
|                                                                     |
|        DVC has enabled anonymous aggregate usage analytics.         |
|     Read the analytics documentation (and how to opt-out) here:     |
|             <https://dvc.org/doc/user-guide/analytics>              |
|                                                                     |
+---------------------------------------------------------------------+

What's next?
------------
- Check out the documentation: <https://dvc.org/doc>
- Get help and share ideas: <https://dvc.org/chat>
- Star us on GitHub: <https://github.com/iterative/dvc>
dvc_check plugin: DVC setup and added file_data.csv


[?25l⠋ Checking graph
[?25h

## 6. CLI: End-to-End Run
We can run all steps via a YAML config and the `ml_assert run` command.
This produces a JSON and HTML report.

In [23]:
# CLI: Save artifacts and run inside artifacts directory
import yaml

# Change into artifact_dir so CSVs and config go there
old_cwd = os.getcwd()
os.chdir(artifact_dir)
# Write CSV artifacts for CLI
df_ref.to_csv("ref.csv", index=False)
df_cur.to_csv("cur.csv", index=False)
pd.Series(y_test).to_csv("y_true.csv", index=False, header=False)
pd.Series(y_pred).to_csv("y_pred.csv", index=False, header=False)
pd.Series(y_scores).to_csv("y_scores.csv", index=False, header=False)
# Create config.yaml referencing local files
config = {
    "steps": [
        {"type": "drift", "train": "ref.csv", "test": "cur.csv", "alpha": 0.05},
        {
            "type": "model_performance",
            "y_true": "y_true.csv",
            "y_pred": "y_pred.csv",
            "y_scores": "y_scores.csv",
            "assertions": {"accuracy": 0.75},
        },
        {"type": "file_exists", "path": "my_model.pkl"},
        {"type": "dvc_check", "path": "model_data.csv"},
    ]
}
with open("config.yaml", "w") as f:
    yaml.dump(config, f)
print("Wrote config.yaml and artifact CSVs in:", os.getcwd())

Wrote config.yaml and artifact CSVs in: /Users/shinde/Documents/Projects/ml-assert/examples/ultimate_guide_artifacts


In [24]:
# Run the CLI
print("Running: ml_assert run config.yaml")
exit_code = os.system("poetry run ml_assert run config.yaml")
print(f"CLI exit code: {exit_code}")

Running: ml_assert run config.yaml
CLI exit code: 256


Command not found: ml_assert


In [25]:
# Return to original directory
os.chdir(old_cwd)

**Inspect the generated reports**:
- `ultimate_guide_artifacts/config.report.json`
- `ultimate_guide_artifacts/config.report.html`

## Conclusion
You now have a detailed, cell-by-cell guide illustrating exactly how ml-assert works,
with both passing and failing examples, and end-to-end automation via the CLI.