# 04_quantum_vqc.ipynb ‚Äî Variational Quantum Classifier (shallow, re-uploading)

# Cell 0 ‚Äî perf env

In [1]:
# Normalize underlying BLAS thread counts for reproducible timing
import os
os.environ.setdefault("OMP_NUM_THREADS", "8")
os.environ.setdefault("OPENBLAS_NUM_THREADS", "8")
os.environ.setdefault("MKL_NUM_THREADS", "8")
os.environ.setdefault("NUMEXPR_NUM_THREADS", "8")

'8'

# Cell 1 ‚Äî imports & data (multi-dataset + journaling + PCA/scale)

In [2]:
# Load k-mer encodings; reduce dimension with PCA; scale; prepare splits (multi-dataset aware)
from pathlib import Path
import json, warnings, time, os
import numpy as np, pandas as pd
import matplotlib.pyplot as plt
import pennylane as qml
from pennylane import numpy as pnp
from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import (
    accuracy_score, precision_recall_fscore_support, roc_auc_score, confusion_matrix,
    balanced_accuracy_score, matthews_corrcoef, classification_report, average_precision_score
)

warnings.filterwarnings("ignore")
ROOT = Path("."); PROCESSED = ROOT/"data/processed"; RESULTS = ROOT/"results"
(RESULTS/"metrics").mkdir(parents=True, exist_ok=True)
(RESULTS/"plots").mkdir(parents=True, exist_ok=True)
(RESULTS/"logs").mkdir(parents=True, exist_ok=True)
np.random.seed(11); pnp.random.seed(11)

# ---- Run journal (for documentation) ----
class RunJournal:
    def __init__(self): self.events=[]
    def log(self, step, status, message, **extras):
        self.events.append({"ts":time.strftime("%Y-%m-%d %H:%M:%S"),"step":step,"status":status,"message":message,**extras})
        sym = "‚úÖ" if status=="ok" else ("‚ö†Ô∏è" if status=="warn" else "‚ùå")
        print(f"{sym} [{step}] {message}")
    def df(self): return pd.DataFrame(self.events)
    def save(self, base):
        df = self.df()
        md = ["| ts | step | status | message |","|---|---|---|---|"]
        for _,r in df.iterrows(): md.append(f"| {r.ts} | {r.step} | {r.status} | {r.message} |")
        (base.with_suffix(".md")).write_text("\n".join(md), encoding="utf-8")
        (base.with_suffix(".json")).write_text(df.to_json(orient="records", indent=2), encoding="utf-8")
        print(f"üìù Saved journal:\n  - {base.with_suffix('.md')}\n  - {base.with_suffix('.json')}")
J = RunJournal()

# Prefer multi-dataset artifacts; fall back to original
enc_candidates = [PROCESSED/"encodings_all.npz", PROCESSED/"encodings.npz"]
spl_candidates = [PROCESSED/"splits_pooled.json", PROCESSED/"splits.json"]
enc_path = next((p for p in enc_candidates if p.exists()), None)
spl_path = next((p for p in spl_candidates if p.exists()), None)
if enc_path is None or spl_path is None:
    if enc_path is None: J.log("load","fail","Encodings not found (tried encodings_all.npz, encodings.npz)")
    if spl_path is None: J.log("load","fail","Splits not found (tried splits_pooled.json, splits.json)")
    raise FileNotFoundError("Missing required artifacts in data/processed")
data = np.load(enc_path, allow_pickle=True)
with open(spl_path) as f: SPL = json.load(f)
J.log("load","ok",f"Loaded encodings from {enc_path.name} and splits from {spl_path.name}")

y = data["y"].astype(int)
X_kmer = data["kmer"].astype(np.float32)

# Optional dataset index for per-dataset diagnostics
ds_idx = data["ds_idx"] if "ds_idx" in data.files else None
ds_map = None
if ds_idx is not None and (PROCESSED/"dataset_index.csv").exists():
    ds_map = pd.read_csv(PROCESSED/"dataset_index.csv").set_index("ds_idx")["accession"].to_dict()
    J.log("datasets","ok",f"Detected {len(set(ds_idx))} dataset(s) with mapping.")

tr_idx = np.array(SPL["train"]); va_idx = np.array(SPL["val"]); te_idx = np.array(SPL["test"])
J.log("splits","ok",f"train={len(tr_idx)}, val={len(va_idx)}, test={len(te_idx)}, pos_rate={y.mean():.4f}")

# PCA ‚Üí scale
D = int(os.environ.get("VQC_D","6"))
pca = PCA(n_components=D, random_state=11)
X_tr = pca.fit_transform(X_kmer[tr_idx])
X_va = pca.transform(X_kmer[va_idx])
X_te = pca.transform(X_kmer[te_idx])
J.log("pca","ok",f"PCA D={D}, explained_var={pca.explained_variance_ratio_.sum():.3f}")

scaler = StandardScaler(with_mean=True, with_std=True)
Xtr = scaler.fit_transform(X_tr).astype(np.float32)
Xva = scaler.transform(X_va).astype(np.float32)
Xte = scaler.transform(X_te).astype(np.float32)
ytr, yva, yte = y[tr_idx], y[va_idx], y[te_idx]
print(Xtr.shape, ytr.mean())

‚úÖ [load] Loaded encodings from encodings_all.npz and splits from splits_pooled.json
‚úÖ [datasets] Detected 13 dataset(s) with mapping.
‚úÖ [splits] train=12336, val=4112, test=4112, pos_rate=0.8654
‚úÖ [pca] PCA D=6, explained_var=0.563
(12336, 6) 0.8654345006485085


# Cell 2 ‚Äî device + circuit (robust, optional noise)

In [3]:
# Define shallow variational circuit (re-uploading style) and prediction routine
def make_device(n_wires, shots=None, use_mixed=False):
    backend = "default.mixed" if use_mixed else "lightning.qubit"
    try:
        dev = qml.device(backend, wires=n_wires, shots=shots)
        J.log("device","ok",f"{backend} (wires={n_wires}, shots={shots})")
        return dev
    except Exception as e:
        J.log("device","warn",f"{backend} unavailable ({e}); fallback to default.qubit")
        return qml.device("default.qubit", wires=n_wires, shots=shots)

n_wires = D
L = int(os.environ.get("VQC_L","2"))  # layers
p_bitflip = float(os.environ.get("VQC_P_BITFLIP","0.0"))
p_depol   = float(os.environ.get("VQC_P_DEPOL","0.0"))
dev = make_device(n_wires, shots=None, use_mixed=(p_bitflip>0 or p_depol>0))

# Robust template import
try:
    BasicEntanglerLayers = qml.BasicEntanglerLayers
except AttributeError:
    from pennylane.templates.layers import BasicEntanglerLayers

weights = pnp.random.normal(scale=0.15, size=(L, n_wires), requires_grad=True)

def layer(x, w):
    qml.AngleEmbedding(x, wires=range(n_wires), rotation="Y")
    if p_bitflip>0:
        for i in range(n_wires): qml.BitFlip(p_bitflip, wires=i)
    if p_depol>0:
        for i in range(n_wires): qml.DepolarizingChannel(p_depol, wires=i)
    BasicEntanglerLayers(w[None, :], wires=range(n_wires))

@qml.qnode(dev, interface="autograd")
def vqc(x, w):
    for l in range(L):
        layer(x, w[l])
    return qml.expval(qml.PauliZ(0))

def predict_proba(X, w, as_numpy=False):
    vals=[]
    for xi in X:
        m = vqc(xi, w)              # expectation in [-1,1]
        vals.append((1+m)/2)        # map to [0,1]
    p = pnp.clip(pnp.stack(vals), 1e-6, 1-1e-6)
    return np.asarray(p) if as_numpy else p

‚úÖ [device] lightning.qubit (wires=6, shots=None)


# Cell 3 ‚Äî train (Adam + early stopping + safety + journaling)

In [4]:
# Optimize VQC with mini-batch Adam + early stopping on val BCE; log outcomes/problems
opt = qml.AdamOptimizer(stepsize=float(os.environ.get("VQC_LR","0.05")))
batch_size = int(os.environ.get("VQC_BS","64"))
max_epochs = int(os.environ.get("VQC_EPOCHS","60"))
patience = int(os.environ.get("VQC_PATIENCE","6"))
param_clip = float(os.environ.get("VQC_PARAM_CLIP","1.5"))

best_va = float("inf"); best_w = pnp.array(weights, requires_grad=True); no_improve = 0
history = []

def bce_loss(y_true, p_hat):
    return -pnp.mean(y_true*pnp.log(p_hat) + (1-y_true)*pnp.log(1-p_hat))

def iterate_minibatches(X, y, bs, shuffle=True):
    idx = np.arange(len(y)); 
    if shuffle: np.random.shuffle(idx)
    for i in range(0, len(y), bs):
        sl = idx[i:i+bs]; yield X[sl], y[sl]

t0 = time.time()
for epoch in range(1, max_epochs+1):
    try:
        for Xb, yb in iterate_minibatches(Xtr, ytr, batch_size):
            def cost(w):
                y_true = pnp.array(yb, dtype=float)
                p_hat  = predict_proba(Xb, w, as_numpy=False)
                return bce_loss(y_true, p_hat)
            w_new = opt.step(cost, weights)
            # parameter clipping to reduce exploding grads / barren-plateau drift
            weights = pnp.clip(w_new, -param_clip, param_clip)
    except Exception as e:
        J.log("train","fail",f"Exception in optimizer step: {e}")
        break

    p_tr = predict_proba(Xtr, weights); p_va = predict_proba(Xva, weights)
    loss_tr = float(bce_loss(ytr, p_tr)); loss_va = float(bce_loss(yva, p_va))
    if not np.isfinite(loss_tr) or not np.isfinite(loss_va):
        J.log("train","fail","Non-finite loss encountered; stopping and reverting to best weights so far.")
        break
    history.append({"epoch":epoch, "loss_tr":loss_tr, "loss_va":loss_va})
    print(f"epoch {epoch:02d} | loss_tr={loss_tr:.4f} | loss_va={loss_va:.4f}")

    if loss_va + 1e-4 < best_va:
        best_va = float(loss_va); best_w = pnp.array(weights, requires_grad=False); no_improve = 0
    else:
        no_improve += 1
        if no_improve >= patience:
            print("early stopping."); J.log("train","ok",f"Early stopping at epoch {epoch}"); break

weights = best_w
pd.DataFrame(history).to_csv(RESULTS/"metrics/vqc_train_curve.csv", index=False)
np.save(RESULTS/"vqc_weights.npy", np.array(weights, dtype=float))
J.log("train","ok",f"Finished in {((time.time()-t0)/60):.1f} min; best_val_bce={best_va:.4f}; epochs_logged={len(history)}")

epoch 01 | loss_tr=0.6310 | loss_va=0.6297
epoch 02 | loss_tr=0.6310 | loss_va=0.6299
epoch 03 | loss_tr=0.6306 | loss_va=0.6296
epoch 04 | loss_tr=0.6309 | loss_va=0.6295
epoch 05 | loss_tr=0.6307 | loss_va=0.6292
epoch 06 | loss_tr=0.6310 | loss_va=0.6293
epoch 07 | loss_tr=0.6313 | loss_va=0.6303
epoch 08 | loss_tr=0.6319 | loss_va=0.6305
epoch 09 | loss_tr=0.6325 | loss_va=0.6310
epoch 10 | loss_tr=0.6313 | loss_va=0.6303
epoch 11 | loss_tr=0.6308 | loss_va=0.6296
early stopping.
‚úÖ [train] Early stopping at epoch 11
‚úÖ [train] Finished in 23.6 min; best_val_bce=0.6292; epochs_logged=11


# Cell 4 ‚Äî metrics (val-optimal threshold, full matrices, plots, docs)

In [5]:
# Determine threshold on val; report rich metrics for all splits; save plots & reports
from sklearn.metrics import f1_score

def choose_threshold(y_val, p_val, name="VQC"):
    grid = np.linspace(0.05, 0.95, 37)
    best_thr, best_f1 = 0.5, -1
    for t in grid:
        f1 = f1_score(y_val, (p_val>=t).astype(int), zero_division=0)
        if f1 > best_f1: best_f1, best_thr = float(f1), float(t)
    if np.isnan(best_f1):
        J.log("threshold","warn",f"{name}: F1 undefined on val; using thr=0.5")
        return 0.5
    J.log("threshold","ok",f"{name}: thr={best_thr:.2f} (val F1={best_f1:.3f})")
    return best_thr

def extended_metrics(y_true, p, thr):
    yhat = (p>=thr).astype(int)
    acc = accuracy_score(y_true, yhat)
    prec, rec, f1, _ = precision_recall_fscore_support(y_true, yhat, average="binary", zero_division=0)
    try: auc = roc_auc_score(y_true, p)
    except Exception: auc = float("nan")
    try: ap = average_precision_score(y_true, p)
    except Exception: ap = float("nan")
    cm = confusion_matrix(y_true, yhat, labels=[0,1])
    tn, fp, fn, tp = cm.ravel()
    tnr = tn/(tn+fp) if (tn+fp) else float("nan")
    bal = balanced_accuracy_score(y_true, yhat)
    mcc = matthews_corrcoef(y_true, yhat) if len(np.unique(y_true))==2 else float("nan")
    rep = classification_report(y_true, yhat, output_dict=True, zero_division=0)
    return {
        "acc":acc, "prec":prec, "rec":rec, "f1":f1, "roc_auc":auc, "pr_auc":ap,
        "specificity":tnr, "balanced_acc":bal, "mcc":mcc, "thr":thr,
        "tp":int(tp), "tn":int(tn), "fp":int(fp), "fn":int(fn), "support":int(len(y_true)),
    }, cm, rep

def save_cm_csv(cm, out_csv, normalized=False):
    if normalized:
        cm = cm.astype(np.float64); rs = cm.sum(axis=1, keepdims=True); cm = np.divide(cm, np.where(rs==0,1,rs))
    pd.DataFrame(cm, index=["true_0","true_1"], columns=["pred_0","pred_1"]).to_csv(out_csv, index=True)

def plot_roc(y_true, p, title, out_png):
    from sklearn.metrics import roc_curve, auc
    try:
        fpr, tpr, _ = roc_curve(y_true, p); roc_auc = auc(fpr, tpr)
        plt.figure(); plt.plot(fpr, tpr, label=f"AUC={roc_auc:.3f}")
        plt.plot([0,1],[0,1],'--'); plt.xlabel("FPR"); plt.ylabel("TPR"); plt.title(title); plt.legend(); plt.tight_layout()
        plt.savefig(out_png, dpi=150); plt.close()
    except Exception as e:
        J.log("plot","warn",f"ROC plot skipped: {e}")

def plot_pr(y_true, p, title, out_png):
    from sklearn.metrics import precision_recall_curve, average_precision_score
    try:
        pr, rc, _ = precision_recall_curve(y_true, p); ap = average_precision_score(y_true, p)
        plt.figure(); plt.plot(rc, pr, label=f"AP={ap:.3f}")
        plt.xlabel("Recall"); plt.ylabel("Precision"); plt.title(title); plt.legend(); plt.tight_layout()
        plt.savefig(out_png, dpi=150); plt.close()
    except Exception as e:
        J.log("plot","warn",f"PR plot skipped: {e}")

# Probabilities
p_tr = predict_proba(Xtr, weights, as_numpy=True)
p_va = predict_proba(Xva, weights, as_numpy=True)
p_te = predict_proba(Xte, weights, as_numpy=True)

thr = choose_threshold(yva, p_va, name="VQC")

splits = {"train":(ytr,p_tr), "val":(yva,p_va), "test":(yte,p_te)}
rows, reports, cms = [], {}, {}
for split,(yt,pt) in splits.items():
    m, cm, rep = extended_metrics(yt, pt, thr)
    m.update({"model":"VQC","split":split})
    rows.append(m); cms[split]=cm; reports[split]=rep
df = pd.DataFrame(rows)
df.to_csv(RESULTS/"metrics/vqc_metrics.csv", index=False)

# Confusion matrices
for split, cm in cms.items():
    save_cm_csv(cm, RESULTS/f"metrics/vqc_cm_{split}.csv", normalized=False)
    save_cm_csv(cm, RESULTS/f"metrics/vqc_cm_{split}_norm.csv", normalized=True)

# Classification report JSON
with open(RESULTS/"metrics/vqc_classification_reports.json","w",encoding="utf-8") as f:
    json.dump(reports, f, indent=2)

# Plots (test)
plot_roc(yte, p_te, "VQC ‚Äî ROC (test)", RESULTS/"plots/vqc_roc_test.png")
plot_pr (yte, p_te, "VQC ‚Äî PR (test)",  RESULTS/"plots/vqc_pr_test.png")

# Console + journal summary (no backslashes in f-strings)
row_test = df.loc[df["split"]=="test"].iloc[0]
print(row_test.to_string())
J.log("eval","ok",
      f"VQC: test F1={row_test['f1']:.3f}, AUC={row_test['roc_auc']:.3f}, PR-AUC={row_test['pr_auc']:.3f}, thr={thr:.2f}")

‚úÖ [threshold] VQC: thr=0.05 (val F1=0.922)
acc             0.865516
prec            0.865516
rec                  1.0
f1               0.92791
roc_auc         0.540335
pr_auc          0.880499
specificity          0.0
balanced_acc         0.5
mcc                  0.0
thr                 0.05
tp                  3559
tn                     0
fp                   553
fn                     0
support             4112
model                VQC
split               test
‚úÖ [eval] VQC: test F1=0.928, AUC=0.540, PR-AUC=0.880, thr=0.05


# (Optional) Cell 5 ‚Äî Per-dataset diagnostics (if ds_idx available)

In [6]:
# Evaluate generalization per dataset on TEST split only (reuse global thr)
if ds_idx is None:
    J.log("per-dataset","warn","ds_idx not found ‚Äî skipping per-dataset diagnostics.")
else:
    rows = []
    uniq = sorted(np.unique(ds_idx[te_idx]))
    for d in uniq:
        name = ds_map.get(d, f"ds_{d}") if ds_map else f"ds_{d}"
        mask = (ds_idx[te_idx] == d)
        if mask.sum() == 0: continue
        y_true = yte[mask]; p_sub = p_te[mask]
        m, cm, _ = extended_metrics(y_true, p_sub, thr)
        m.update({"model":"VQC","dataset":name,"n":int(mask.sum())})
        rows.append(m)
    df_per = pd.DataFrame(rows)
    df_per.to_csv(RESULTS/"metrics/vqc_per_dataset_test.csv", index=False)
    J.log("per-dataset","ok",f"Saved per-dataset test metrics for {len(df_per)} dataset(s).")

‚úÖ [per-dataset] Saved per-dataset test metrics for 13 dataset(s).


# Cell 6 ‚Äî Save run journal (what worked, what didn‚Äôt, and why)

In [7]:
ts = time.strftime("%Y%m%d_%H%M%S")
base = RESULTS/"logs"/f"vqc_{ts}"

issues = []
for e in J.events:
    if e["status"] in ("warn","fail"):
        issues.append(f"- [{e['step']}] {e['message']}")
rollup = "No warnings or failures." if not issues else "Issues observed:\n" + "\n".join(issues)
print("\n=== RUN SUMMARY ===\n" + rollup)

J.save(base)
(RESULTS/"logs"/f"vqc_{ts}_summary.txt").write_text(rollup, encoding="utf-8")
print(f"üì¶ Metrics in: {RESULTS/'metrics'}  |  Plots in: {RESULTS/'plots'}")


=== RUN SUMMARY ===
üìù Saved journal:
  - results\logs\vqc_20250917_184000.md
  - results\logs\vqc_20250917_184000.json
üì¶ Metrics in: results\metrics  |  Plots in: results\plots
