### Definition: IMQ-Hamming Kernel
Based on Amin et al.(2023), the Inverse Multiquadratic Hamming(IMQ-H) kernel is defined for sequences $x,y \in \mathcal{S}$, where $\mathcal{S}$ is the space of finite strings (e.g., DNA/RNA/protein), as: 
$$
k_{\text{IMQ-h}}(X,Y)= \frac{1}{1+d_{H}^{\Phi}(X,Y))^2} = \frac{1}{(1+|X| \vee |Y|-(\Phi(x)|\Phi(Y)))^2}$$

This uses: 
1) $d_H(x,y)$: a Hamming distance computed over a feature space defined by sliding window counts (k-mers).
2) The inverse multiquadratic kernel form $k(a,b) = \frac{C}{(\alpha + ||\alpha - b||^2)^{\beta}}$, here adapted for string comparison.

This paper generalizes this with parameters scale $\alpha$ and $\beta$, yielding: 

$$
k_{\text{IMQ-h}}(x,y)= \frac{(1+\alpha)^{\beta}}{(\alpha +d_{H}(x,y))^{\beta}}$$


In [1]:
import pandas as pd 
import numpy as np 
import matplotlib.pyplot as plt
import logging 

In [2]:
df = pd.read_csv('mutation_with_sequences.csv')
df = df.dropna(subset=["mut_seq", "Type"])
# Inputs: mutated sequences
seqs = df["mut_seq"].tolist()
# Labels: 1 = driver, 0 = passenger
labels = df["Type"].apply(lambda x: 1.0 if x.lower() == "driver" else 0.0).values

In [3]:
print(df)

          Source    Gene                ENST Gene Code           ENST.1  \
0     cBioPortal   BRCA1   ENST00000357654.9    P38398  ENST00000357654   
1     cBioPortal   BRCA2   ENST00000380152.8    P51587  ENST00000380152   
2     cBioPortal    CDH1  ENST00000261769.10    P12830  ENST00000261769   
3     cBioPortal    CDH1  ENST00000261769.10    P12830  ENST00000261769   
4     cBioPortal    CDH1  ENST00000261769.10    P12830  ENST00000261769   
...          ...     ...                 ...       ...              ...   
1419      OncoKB  PIK3CA   ENST00000263967.4    P42336  ENST00000263967   
1420      OncoKB    PTEN   ENST00000371953.8    P60484  ENST00000371953   
1421      OncoKB    PTEN   ENST00000371953.8    P60484  ENST00000371953   
1422      OncoKB   SMAD4   ENST00000342988.8    Q13485  ENST00000342988   
1423      OncoKB    TP53   ENST00000269305.9    P04637  ENST00000269305   

        Gene Name Mutation       Type  \
0     BRCA1_HUMAN   G1788V     Driver   
1     BRCA2_HUMAN

In [4]:
# Count number of drivers and passengers
type_counts = df["Type"].str.lower().value_counts()

num_drivers = type_counts.get("driver", 0)

num_passengers = type_counts.get("passenger", 0)

print(f"Number of driver mutations: {num_drivers}")
print(f"Number of passenger mutations: {num_passengers}")

Number of driver mutations: 699
Number of passenger mutations: 725


### Pick 10 Seed Points for training GP

In [5]:
import pandas as pd
import numpy as np

df = pd.read_csv("mutation_with_sequences.csv")
df = df.dropna(subset=["mut_seq", "Type"])
df = df.sample(n=200, random_state=42).reset_index(drop=True)

# Sequences and labels
seqs = df["mut_seq"].tolist()
labels = df["Type"].apply(lambda x: 1.0 if x.lower() == "driver" else 0.0).values

In [6]:
init_idx = np.random.choice(len(seqs), size=10, replace=False)

### Define GP Model and Kernel 

## IMQH-GP and LSE setup 
Let $\mathcal{X} = \{x_1, x_2, \dots, x_N\}$ be a set of mutation sequences. At each iteration $t$, the labeled set is $\mathcal{I}_t \subset \mathcal{X}$ and the unlabeled set is $\mathcal{U}_t = \mathcal{X} - \mathcal{I}_t$.

We model the latent function $f$ over sequences using a Gaussian Process prior:
$$ f(x) \sim \mathcal{GP}(0, k(x, x'))$$
where $k(x, x')$ is the IMQ-Hamming kernel:
$$
k(x, x') = \frac{(1 + s)^\beta}{(s + d_H(x, x'))^\beta}
$$
with scale parameter $s > 0$, shape parameter $\beta > 0$, and $d_H$ is the Hamming distance over k-mers.

## Gaussian Process Posterior Predictions

Given the current labeled data $\mathcal{I}_t$, for each $x \in \mathcal{U}_t$, the GP provides:
1) Predictive mean: $\mu_t(x)$
2) Predictive variance: $\sigma^2_t(x)$

## Confidence Invervals
We compute a high-probability confidence interval (CI) for $f(x)$:
$$
\mathrm{CI}_t(x) = \left[ \mu_t(x) - \beta_t^{1/2} \sigma_t(x), \, \mu_t(x) + \beta_t^{1/2} \sigma_t(x) \right]
$$ where
$$
\beta_t = 2 \log\left( \frac{\pi^2 (t+1)^2}{6 \delta} \right)
$$
and $\delta \in (0,1)$ is a confidence parameter (typically $\delta = 0.01$).

## Threshold Classification
A threshold $h_t$ is used to decide whether $x$ is classified as a \emph{driver} or \emph{passenger}. This can be:
1) Explicit $h_t = h$, e.g., $h = 0.5$
2) Implicit $h_t = \omega \cdot \max_{x \in \mathcal{U}_t} \mu_t(x)$, where $\omega \in (0, 1)$
Each point $x \in \mathcal{U}_t$ is classified based on its confidence interval:

1) **Driver (high)** & $\quad \text{if} \quad \mu_t(x) - \beta_t^{1/2} \sigma_t(x) > h_t - \epsilon$ 
2) **Passenger (low)** & $\quad \text{if} \quad \mu_t(x) + \beta_t^{1/2} \sigma_t(x) < h_t + \epsilon$

All other points remain unlabeled.

## Acquisition Rule
From the remaining unlabeled points $\mathcal{U}_t$, the next query $x^\ast$ is chosen according to a rule:

1) **Ambiguity:**

$$x^\ast = \arg\max_{x \in \mathcal{U}_t} \min\left( \mu_t(x) + \beta_t^{1/2} \sigma_t(x) - h_t,\; h_t - (\mu_t(x) - \beta_t^{1/2} \sigma_t(x)) \right)$$

2) **Variance**

$$x^\ast = \arg\max_{x \in \mathcal{U}_t} \sigma_t(x)$$

The selected point is queried (labeled), added to $\mathcal{I}_t$, and the GP is updated accordingly.

In [7]:
import jax.numpy as jnp
import imq_gp 

from imq_gp import ZeroMeanIMQHGP, IMQHGP_Params

gp_model = ZeroMeanIMQHGP([seqs[i] for i in init_idx], jnp.array(labels[init_idx]))

# Set hyperparameters
params = IMQHGP_Params(
    raw_amplitude=jnp.array(1.0),
    raw_noise=jnp.array(1e-4),
    scale=1.0,
    beta=0.5,
    lag=1,
    alphabet_name="prot"
)

In [8]:
# Precompute full kernel matrices
import jax.numpy as jnp
from imq_gp import imq_hamming_kernel

X_pool = seqs
X_train = [seqs[i] for i in init_idx]
X_test = [seqs[i] for i in range(len(seqs)) if i not in init_idx]

# This is the full train and test indices
train_idx = init_idx
test_idx = [i for i in range(len(seqs)) if i not in train_idx]

K_train_train = imq_hamming_kernel(X_train, X_train, alphabet_name='prot', scale=1.0, beta=0.5, lag=1)
K_test_train = imq_hamming_kernel(X_test, X_train, alphabet_name='prot', scale=1.0, beta=0.5, lag=1)
K_test_test = imq_hamming_kernel(X_test, X_test, alphabet_name='prot', scale=1.0, beta=0.5, lag=1)

# Store in the model
gp_model._K_train_train = jnp.asarray(K_train_train)
gp_model._K_test_train = jnp.asarray(K_test_train)
gp_model._K_test_test = jnp.asarray(K_test_test)
gp_model._X_test = X_test
gp_model._test_idx = test_idx

In [9]:
import importlib
import imq_gp 
from lse import LSE

try:
    del lse
except NameError:
    pass

#Run LSE
lse = LSE(
    model=gp_model,
    params=params,
    X_pool=seqs,
    y_pool=labels,
    init_indices=init_idx,
    omega=0.40,       
    epsilon=0.05,
    rule="amb",      
    verbose=True
)

# Run seed IMQH-GP-LSE algorithm
lse.run(max_iter=1000)

[t=01] Queried idx=1, y=0, μ=0.625, σ=0.237, CI=(-0.758, 0.949)
[t=02] Queried idx=3, y=0, μ=0.454, σ=0.220, CI=(-0.754, 0.930)
[t=03] Queried idx=5, y=0, μ=0.031, σ=0.177, CI=(-0.620, 0.784)
[t=04] Queried idx=8, y=1, μ=0.320, σ=0.219, CI=(-0.812, 0.975)
[t=05] Queried idx=10, y=1, μ=0.483, σ=0.235, CI=(-0.889, 1.067)
[t=06] Queried idx=15, y=0, μ=0.476, σ=0.218, CI=(-0.831, 1.019)
[t=07] Queried idx=16, y=0, μ=0.018, σ=0.146, CI=(-0.539, 0.718)
[t=08] Queried idx=19, y=1, μ=0.388, σ=0.218, CI=(-0.859, 1.038)
[t=09] Queried idx=20, y=1, μ=0.484, σ=0.217, CI=(-0.864, 1.052)
[t=10] Queried idx=24, y=1, μ=0.554, σ=0.217, CI=(-0.869, 1.064)
[t=11] Queried idx=26, y=1, μ=0.580, σ=0.233, CI=(-0.948, 1.147)
[t=12] Queried idx=27, y=0, μ=0.016, σ=0.145, CI=(-0.551, 0.758)
[t=13] Queried idx=30, y=0, μ=0.043, σ=0.231, CI=(-0.949, 1.155)
[t=14] Queried idx=33, y=1, μ=0.607, σ=0.216, CI=(-0.888, 1.094)
[t=15] Queried idx=35, y=1, μ=0.646, σ=0.232, CI=(-0.968, 1.177)
[t=16] Queried idx=38, y=1, μ

IndexError: list index out of range

In [None]:
import matplotlib.pyplot as plt

mu_all, _ = gp_model.predict_f(params, seqs, full_covar=False)
plt.hist(np.array(mu_all), bins=20)
plt.axvline(x=lse._threshold(mu_all), color='red', linestyle='--', label='Threshold')
plt.title("Posterior mean (μ) distribution over all points")
plt.xlabel("μ")
plt.ylabel("Count")
plt.legend()
plt.show()

In [None]:
true_labels = df["Type"].apply(lambda x: 1 if x.lower() == "driver" else 0).values
mutations = df["Mutation"].values

# Queried indices and model predictions
queried_indices = lse.vt
pred_labels = [1 if i in lse.ht else 0 for i in queried_indices]
true_labels_queried = [true_labels[i] for i in queried_indices]

# Accuracy tracking
correct = [int(p == t) for p, t in zip(pred_labels, true_labels_queried)]
correct_mutations = [mutations[i] for i, is_correct in zip(queried_indices, correct) if is_correct]

# Print results
print("Queried mutation predictions (index, mutation, true, predicted, correct?):")
for idx, pred, true in zip(queried_indices, pred_labels, true_labels_queried):
    tag = "✓" if pred == true else "✗"
    print(f"  idx={idx:2d}, {mutations[idx]:>8}, true={true}, pred={pred}  {tag}")

# Summary
num_correct = sum(correct)
print(f"\nCorrect predictions: {num_correct} / {len(queried_indices)}")
print(f"Accuracy on queried points: {num_correct / len(queried_indices):.2%}")


In [None]:
## maybe use a sigmoid link or logistic link function for f(x) \in [0,1] instead of CI being negative