In [1]:
!pip install -e ./alchemy-main


Obtaining file:///C:/Users/5yx/OneDrive%20-%20Oak%20Ridge%20National%20Laboratory/alchemy-main
  Installing build dependencies: started
  Installing build dependencies: finished with status 'done'
  Checking if build backend supports build_editable: started
  Checking if build backend supports build_editable: finished with status 'done'
  Getting requirements to build editable: started
  Getting requirements to build editable: finished with status 'done'
  Preparing editable metadata (pyproject.toml): started
  Preparing editable metadata (pyproject.toml): finished with status 'done'
Building wheels for collected packages: alchemy
  Building editable for alchemy (pyproject.toml): started
  Building editable for alchemy (pyproject.toml): finished with status 'done'
  Created wheel for alchemy: filename=alchemy-0.1.0-0.editable-py3-none-any.whl size=3231 sha256=090de77990c9cb05ed751231c1c96a523a7b9fccedfb51b4e8f68227441d6a9f
  Stored in directory: C:\Users\5yx\AppData\Local\Temp\pip-ephe

In [2]:
!pip install torch



In [3]:
from skimage.color import lab2rgb

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import torch
from torch.quasirandom import SobolEngine
from color_picker_optimization.metrics import calculate_error, rgb2cmyk
from color_picker_optimization.plate_image_simulation import simulate_mixing, simulate_plate_image
from alchemy import step_surrogate

 
# === User Configurable Parameters === # change here until ENd User COnfig
lab_targets_path = "uniform_lab_from_rgb.npy"  # path to your Lab-space targets
save_filename    = "results_RF_EI_4.npz"           # output file
 
total_vol        = 275     # µL per well
max_wells        = 96  # total number of wells to sample per target
batch_size       = 4       # number of samples per batch # change
  
sobol_seed       = 42      # seed for reproducible Sobol initial batch
sobol_scramble   = True    # whether to scramble Sobol sequence
 
# Surrogate & acquisition settings
surrogate = { #change
    "name": "rf",
    "fit_kwargs": {}
}
acquisition = {
    "name": "EI", #change
    "kwargs": {"q": batch_size},
    "discrete": True
    
}
 
# Candidate grid resolution (per dimension)
num_points = 20  # how many grid points along each axis
 
# === End User Configs ===
 
def lab_to_cmyk(lab: np.ndarray) -> np.ndarray:
    """Convert a single Lab color → CMYK (4,)"""
    rgb01     = lab2rgb(lab.reshape(1,1,3))
    rgb_uint8 = (rgb01 * 255).astype(np.uint8)
    cmyk      = rgb2cmyk(rgb_uint8)
    return cmyk[0,0]
 
# 1) Load & convert targets
lab_targets  = np.load(lab_targets_path)                  # shape (N,3)
cmyk_targets = np.array([lab_to_cmyk(l) for l in lab_targets])
 
# 2) Prepare Sobol for initial batch
sobol = SobolEngine(dimension=4, scramble=sobol_scramble, seed=sobol_seed)
 
# 3) Prepare bounds & X_test grid in µL
bounds_frac = torch.tensor([[0.,0.,0.,0.],[1.,1.,1.,1.]])
bounds      = bounds_frac * total_vol  # [[0,0,0,0],[total_vol,...]]
 
vs = [
    torch.linspace(bounds[0,i], bounds[1,i], steps=num_points)
    for i in range(4)
]
C, M, Y, K = torch.meshgrid(*vs, indexing="ij")
X_all = torch.stack([C.reshape(-1), M.reshape(-1), Y.reshape(-1), K.reshape(-1)], dim=1)
mask   = (X_all.sum(dim=1).round() == float(total_vol))
X_test = X_all[mask]  # candidate set in µL
 
print(f"Total grid points before mask: {X_all.shape[0]}")
print(f"Grid points on simplex:         {X_test.shape[0]}")
 
all_results = []
 
# 4) Active learning loop
for idx, tgt in enumerate(cmyk_targets):
    print(f"\n=== Target #{idx} ===")
    X_vols       = []
    Y_errs       = []
    best_history = []
    best_dE      = np.inf
    taken        = 0
 
    # Initial Sobol batch
    sobol_pts = sobol.draw(batch_size).numpy()                          # fractions
    mixtures  = sobol_pts / sobol_pts.sum(axis=1, keepdims=True)
    vols      = (mixtures * total_vol).tolist()                        # µL
    obs       = simulate_mixing(vols)
    dEs       = calculate_error(obs, tgt)
 
    X_vols.extend(vols)
    Y_errs.extend([float(d) for d in dEs])
    best_dE = float(np.min(dEs))
    best_history.append(best_dE)
    taken = batch_size
 
    print(f"Init vols: {vols}")
    print(f"Init ΔE:   {dEs}, best: {best_dE}")
 
    # Main batches
    while taken < max_wells:
        bs = min(batch_size, max_wells - taken)
 
        X_t = torch.tensor(X_vols, dtype=torch.float32)
        Y_t = torch.tensor(Y_errs, dtype=torch.float32)
        print(f"\nTaken {taken}/{max_wells} → X_t.shape={X_t.shape}, Y_t.shape={Y_t.shape}")
 
        cand, mean, var, acq_vals = step_surrogate(
            X=X_t,
            Y=Y_t,
            X_test=X_test,
            bounds=bounds,
            surrogate=surrogate,
            acquisition=acquisition,
            return_acq_vals=False,
            maximize=False
        )
        proposals = cand.detach().cpu().numpy()  # µL
 
        print(f"Proposed vols:\n{proposals}")
 
        vols       = proposals.tolist()
        obs_colors = simulate_mixing(vols)
        dEs        = calculate_error(obs_colors, tgt)
 
        X_vols.extend(vols)
        Y_errs.extend([float(d) for d in dEs])
        print(f"Observed ΔE: {dEs}")
 
        best_dE = min(best_dE, float(np.min(dEs)))
        best_history.append(best_dE)
        print(f"Running best ΔE: {best_dE}")
 
        taken += bs
 
        # show updated plate image
        img = simulate_plate_image(X_vols)  # show latest batch
        plt.imshow(img)
        plt.axis('off')
        plt.title(f"Target #{idx} – after {taken} wells")
        plt.show()
 
    X_arr    = np.vstack(X_vols)
    Y_arr    = np.array(Y_errs)
    best_arr = np.array(best_history)
    all_results.append((X_arr, Y_arr, best_arr))
 
    print(f"Done target #{idx}: X_arr={X_arr.shape}, Y_arr={Y_arr.shape}, best_arr={best_arr.shape}")
 
# 5) Save results
Xs    = np.stack([res[0] for res in all_results], axis=0)  # (n_targets, max_wells, 4)
Ys    = np.stack([res[1] for res in all_results], axis=0)  # (n_targets, max_wells)
Bests = np.stack([res[2] for res in all_results], axis=0)  # (n_targets, n_batches)
 
np.savez(
    save_filename,
    Xs=Xs,
    Ys=Ys,
    bests=Bests,
    max_wells=max_wells,
    batch_size=batch_size
)
 
print(f"\nSaved all_results to {save_filename}")
