# Inference Examples

This notebook demonstrates two inference workflows:

1. **MAE-based imputation**: use the pretrained MAE to fill missing brine-chemistry features.
2. **Label prediction**: use MAE encoder + regression head to predict experimental targets.


In [2]:
from __future__ import annotations

from pathlib import Path

import numpy as np
import pandas as pd

from src.constants import BRINE_FEATURE_COLUMNS
from src.models.inference import (
    auto_device,
    load_artifacts,
    mae_impute_brine_features,
    predict_labels,
)

PROJECT_ROOT = Path.cwd().parent if Path.cwd().name == "notebooks" else Path.cwd()
PROCESSED_DIR = PROJECT_ROOT / "data" / "processed"
MODELS_DIR = PROJECT_ROOT / "models"

MAE_CKPT = MODELS_DIR / "mae_pretrained.pth"
HEAD_CKPT = MODELS_DIR / "downstream_head.pth"
SCALER = PROCESSED_DIR / "feature_scaler.joblib"

device = auto_device()
device

device(type='mps')

In [3]:
assert PROCESSED_DIR.exists(), f"Missing: {PROCESSED_DIR}"
assert (PROCESSED_DIR / "brines.csv").exists(), "Run make_dataset first"
assert (PROCESSED_DIR / "X_lake.npy").exists(), "Run build_features first"
assert MAE_CKPT.exists(), "Train MAE first (src/models/train_mae.py)"
assert HEAD_CKPT.exists(), "Fine-tune head first (src/models/finetune_regression.py)"
assert SCALER.exists(), "Run build_features first"

artifacts = load_artifacts(mae_path=MAE_CKPT, head_path=HEAD_CKPT, scaler_path=SCALER, device=device)
print("Loaded MAE + head + scaler")

Loaded MAE + head + scaler




## 1) MAE imputation example

We take a few rows from `data/processed/brines.csv`, artificially drop some chemistry features, and impute them with the MAE.


In [4]:
brines = pd.read_csv(PROCESSED_DIR / "brines.csv")
brines = brines[list(BRINE_FEATURE_COLUMNS)].copy()

example = brines.head(5).copy()
# Simulate missing values
example.loc[0, "Cl_gL"] = np.nan
example.loc[1, "Mg_gL"] = np.nan
example.loc[2, ["Na_gL", "K_gL"]] = np.nan

example

Unnamed: 0,Li_gL,Mg_gL,Na_gL,K_gL,Ca_gL,SO4_gL,Cl_gL,MLR,TDS_gL,Light_kW_m2
0,0.018,33.31,19.225,7.01,12.33,0.61,,1850.555556,225.82,0.235208
1,0.02,,33.0,8.0,18.6,0.35,224.0,2600.0,,0.235208
2,0.012,30.9,,,12.9,0.61,161.0,2575.0,,0.235208
3,0.0467,65.194,1.696,1.356,24.614,,233.94,1396.017131,326.8467,0.251708
4,0.22,17.1,53.7,,26.3,0.12,152.0,77.727273,249.44,0.188042


In [5]:
imputed_raw, _imputed_std = mae_impute_brine_features(
    artifacts.mae,
    brine_raw=example.to_numpy(dtype=np.float32),
    scaler=artifacts.scaler,
    preserve_observed=True,
)

imputed_df = pd.DataFrame(imputed_raw, columns=BRINE_FEATURE_COLUMNS)
imputed_df

Unnamed: 0,Li_gL,Mg_gL,Na_gL,K_gL,Ca_gL,SO4_gL,Cl_gL,MLR,TDS_gL,Light_kW_m2
0,0.018,33.310001,19.224998,7.01,12.33,0.61,273.262177,1850.555664,225.820007,0.235208
1,0.02,14.202797,33.0,8.0,18.6,0.35,224.0,2600.0,5612.61084,0.235208
2,0.012,30.9,68.727341,7.489736,12.9,0.61,161.0,2575.0,2202.933838,0.235208
3,0.0467,65.194008,1.695999,1.356,24.614,16.153591,233.940002,1396.01709,326.84671,0.251708
4,0.22,17.1,53.700001,3.866351,26.299999,0.12,152.0,77.727272,249.440002,0.188042


## 2) Predict labels example

We provide a few example inputs with `TDS_gL`, `MLR`, `Light_kW_m2` (raw units) and predict:
- `Selectivity`
- `Li_Crystallization_mg_m2_h`
- `Evap_kg_m2_h`

Two modes:
- `impute_missing_chemistry=False`: encode only known chemistry fields (others treated as missing)
- `impute_missing_chemistry=True`: first MAE-impute the full chemistry vector, then encode


In [6]:
samples = [
    {"TDS_gL": 120.0, "MLR": 10.0, "Light_kW_m2": 0.2},
    {"TDS_gL": 200.0, "MLR": 25.0, "Light_kW_m2": 0.5},
]

y_pred_no_impute = predict_labels(artifacts, samples=samples, impute_missing_chemistry=False)
y_pred_impute = predict_labels(artifacts, samples=samples, impute_missing_chemistry=True)

out = pd.DataFrame(samples)
out[["pred_Selectivity", "pred_Li_Crystallization_mg_m2_h", "pred_Evap_kg_m2_h"]] = y_pred_no_impute
out[["pred2_Selectivity", "pred2_Li_Crystallization_mg_m2_h", "pred2_Evap_kg_m2_h"]] = y_pred_impute
out

Unnamed: 0,TDS_gL,MLR,Light_kW_m2,pred_Selectivity,pred_Li_Crystallization_mg_m2_h,pred_Evap_kg_m2_h,pred2_Selectivity,pred2_Li_Crystallization_mg_m2_h,pred2_Evap_kg_m2_h
0,120.0,10.0,0.2,4.577223,0.191326,1.248876,7.380245,0.479662,1.241591
1,200.0,25.0,0.5,1.482707,0.36061,1.192362,10.598277,0.404069,1.258562
