In [1]:
import pandas as pd
import numpy as np
from scipy.signal import find_peaks
from sklearn.preprocessing import MultiLabelBinarizer, minmax_scale
from sklearn.neighbors import KNeighborsClassifier
from sklearn.metrics import accuracy_score, f1_score, classification_report
from sklearn.model_selection import train_test_split

In [2]:
# ----- Preprocessing & Helpers -----
def baseline_AsLS(y, lam=1e4, p=0.01, niter=10):
    L = len(y)
    D = np.diff(np.eye(L), 2)
    D = lam * D.dot(D.T)
    w = np.ones(L)
    for _ in range(niter):
        b = np.linalg.solve(np.diag(w) + D, w * y)
        w = p * (y > b) + (1 - p) * (y < b)
    return b

def preprocess(arr, lam=1e4, p=0.01, niter=10):
    out = np.zeros_like(arr)
    for i, spec in enumerate(arr):
        bkg = baseline_AsLS(spec, lam=lam, p=p, niter=niter)
        corr = spec - bkg
        nrm = np.linalg.norm(corr)
        normed = corr / nrm if nrm else corr
        normed = np.abs(normed)
        out[i] = normed
    return out

def smooth(spec, K_smooth=3):
    kernel = np.ones(K_smooth) / K_smooth
    return np.convolve(spec, kernel, mode='same')

def extract_wavenumber_cols(df):
    return [col for col in df.columns if col.replace('.', '', 1).isdigit()]

# ----- Main CaPSim + kNN Function -----
def identify_multilabel_knn(query_df, ref_df,
    crop_max=1700, lam=1e4, p=0.01, niter=10,
    K_smooth=3, N_peak=12, w_max=15,
    height=0.01, prominence=0.01,
    n_neighbors=3):

    wav_cols = extract_wavenumber_cols(query_df)
    wavs = np.array(wav_cols, dtype=float)
    keep_cols = [col for col, w in zip(wav_cols, wavs) if w < crop_max]

    Q_raw = query_df[keep_cols].values.astype(float)
    R_raw = ref_df[keep_cols].values.astype(float)

    # Build multilabel targets
    ref_labels = list(zip(ref_df['Label 1'], ref_df['Label 2']))
    query_labels = list(zip(query_df['Label 1'], query_df['Label 2']))
    mlb = MultiLabelBinarizer()
    y = mlb.fit_transform(ref_labels)

    # Preprocess spectra
    Q = preprocess(Q_raw)
    R = preprocess(R_raw, lam, p, niter)

    # Characteristic Peaks
    CPs = {}
    for i, class_name in enumerate(mlb.classes_):
        specs = R[y[:, i] == 1]
        counts = np.zeros(Q.shape[1], int)
        for s in specs:
            pks, _ = find_peaks(smooth(s, K_smooth), height=height, prominence=prominence)
            counts[pks] += 1
        CPs[class_name] = sorted(np.argsort(counts)[-N_peak:])

    global_cp = sorted({i for idxs in CPs.values() for i in idxs})

    # Reference feature matrix
    X = []
    for s in R:
        vec = [np.max(s[max(0, i - w_max//2):i + w_max//2 + 1]) for i in global_cp]
        X.append(minmax_scale(vec))
    X = np.array(X)

    # Train kNN
    knn = KNeighborsClassifier(n_neighbors=n_neighbors, metric='cosine')
    knn.fit(X, y)

    # Query feature matrix
    Q_feat = np.vstack([
        minmax_scale([np.max(s[max(0, i - w_max//2):i + w_max//2 + 1]) for i in global_cp])
        for s in Q
    ])
    y_pred = knn.predict(Q_feat)
    y_true = mlb.transform(query_labels)

    return y_true, y_pred, mlb


In [None]:
ref_df_full = pd.read_csv("mixtures_dataset.csv")
query_df = pd.read_csv("query_only_mixed.csv")
ref_train, ref_test = train_test_split(ref_df_full, test_size=0.5, stratify=ref_df_full[['Label 1', 'Label 2']], random_state=42)

In [6]:
# --- Test Set ---
print("---- Test Set Evaluation ----")
y_true_test, y_pred_test, mlb_test = identify_multilabel_knn(ref_test, ref_train)
report_test = classification_report(y_true_test, y_pred_test, target_names=mlb_test.classes_, zero_division=0, output_dict=True)
print(pd.DataFrame(report_test).T.reset_index().rename(columns={'index': 'Class'}))

---- Test Set Evaluation ----
                    Class  precision  recall  f1-score  support
0         1-dodecanethiol        1.0     1.0       1.0    122.0
1    6-mercapto-1-hexanol        1.0     1.0       1.0     54.0
2                 benzene        1.0     1.0       1.0     97.0
3            benzenethiol        1.0     1.0       1.0     36.0
4                    etoh        1.0     1.0       1.0     60.0
5                    meoh        1.0     1.0       1.0    121.0
6   n,n-dimethylformamide        1.0     1.0       1.0     36.0
7                pyridine        1.0     1.0       1.0     54.0
8               micro avg        1.0     1.0       1.0    580.0
9               macro avg        1.0     1.0       1.0    580.0
10           weighted avg        1.0     1.0       1.0    580.0
11            samples avg        1.0     1.0       1.0    580.0


In [7]:
# --- Validation Set ---
print("\n---- Validation Set Evaluation ----")
y_true_val, y_pred_val, mlb_val = identify_multilabel_knn(query_df, ref_train)
report_val = classification_report(y_true_val, y_pred_val, target_names=mlb_val.classes_, zero_division=0, output_dict=True)
print(pd.DataFrame(report_val).T.reset_index().rename(columns={'index': 'Class'}))


---- Validation Set Evaluation ----
                    Class  precision  recall  f1-score  support
0         1-dodecanethiol        1.0     1.0       1.0      9.0
1    6-mercapto-1-hexanol        1.0     1.0       1.0      3.0
2                 benzene        1.0     1.0       1.0      5.0
3            benzenethiol        1.0     1.0       1.0      2.0
4                    etoh        1.0     1.0       1.0      3.0
5                    meoh        1.0     1.0       1.0      9.0
6   n,n-dimethylformamide        1.0     1.0       1.0      2.0
7                pyridine        1.0     1.0       1.0      3.0
8               micro avg        1.0     1.0       1.0     36.0
9               macro avg        1.0     1.0       1.0     36.0
10           weighted avg        1.0     1.0       1.0     36.0
11            samples avg        1.0     1.0       1.0     36.0
