### Definition: IMQ-Hamming Kernel
Based on Amin et al.(2023), the Inverse Multiquadratic Hamming(IMQ-H) kernel is defined for sequences $x,y \in \mathcal{S}$, where $\mathcal{S}$ is the space of finite strings (e.g., DNA/RNA/protein), as: 
$$
k_{\text{IMQ-h}}(X,Y)= \frac{1}{1+d_{H}^{\Phi}(X,Y))^2} = \frac{1}{(1+|X| \vee |Y|-(\Phi(x)|\Phi(Y)))^2}$$

This uses: 
1) $d_H(x,y)$: a Hamming distance computed over a feature space defined by sliding window counts (k-mers).
2) The inverse multiquadratic kernel form $k(a,b) = \frac{C}{(\alpha + ||\alpha - b||^2)^{\beta}}$, here adapted for string comparison.

This paper generalizes this with parameters scale $\alpha$ and $\beta$, yielding: 

$$
k_{\text{IMQ-h}}(x,y)= \frac{(1+\alpha)^{\beta}}{(\alpha +d_{H}(x,y))^{\beta}}$$


In [10]:
import pandas as pd 
import numpy as np 
import matplotlib.pyplot as plt
import logging 

In [11]:
df = pd.read_csv('mutation_with_sequences.csv')
df = df.dropna(subset=["mut_seq", "Type"])
# Inputs: mutated sequences
seqs = df["mut_seq"].tolist()
# Labels: 1 = driver, 0 = passenger
labels = df["Type"].apply(lambda x: 1.0 if x.lower() == "driver" else 0.0).values

### Pick 10 Seed Points for training GP

In [12]:
np.random.seed(42)
n_init = 10
init_idx = np.random.choice(len(seqs), size=n_init, replace=False)

### Define GP Model and Kernel 

In [13]:
from imq_gp import ZeroMeanIMQHGP, IMQHGP_Params
import jax.numpy as jnp

gp_model = ZeroMeanIMQHGP([seqs[i] for i in init_idx], jnp.array(labels[init_idx]))

params = IMQHGP_Params(
    raw_amplitude=jnp.array(1.0),
    raw_noise=jnp.array(1e-4),
    scale=1.0,
    beta=0.5,
    lag=1,
    alphabet_name="prot"
)

In [18]:
import importlib
import lse  
importlib.reload(lse)
from lse import LSE

lse = LSE(
    model=gp_model,
    params=params,
    X_pool=seqs,
    y_pool=labels,
    init_indices=init_idx,
    h=0.5,
    epsilon=0.05,
    delta=0.01,
    verbose=True 
)

# Run the active learning loop
lse.run(max_iter=10)


KeyboardInterrupt: 

In [None]:
print("Confident drivers:", [df.iloc[i]["Mutation"] for i in lse.ht])
print("Confident passengers:", [df.iloc[i]["Mutation"] for i in lse.lt])


In [4]:
import jax.numpy as jnp

class LSE:
    def __init__(self, model, params, X_pool, y_pool, init_indices, h=0.5, epsilon=0.05, delta=0.01):
        self.model = model
        self.params = params
        self.h = h
        self.epsilon = epsilon
        self.delta = delta
        self.t = 1
        self.X_pool = X_pool
        self.y_pool = y_pool

        self.vt = list(init_indices)                    # Visited (labeled)
        self.ut = list(set(range(len(X_pool))) - set(self.vt))  # Unlabeled pool
        self.ht = []  # High (driver)
        self.lt = []  # Low (passenger)

        self._update_gp()  # Fit GP on initial seed

    def _beta(self):
        return 2 * np.log((np.pi**2 * (self.t+1)**2) / (6 * self.delta))

    def _update_gp(self):
        X_train = [self.X_pool[i] for i in self.vt]
        y_train = self.y_pool[self.vt]
        self.model.set_training_data(X_train, y_train)

    def step(self):
        if not self.ut:
            return False  # Done

        X_test = [self.X_pool[i] for i in self.ut]
        mu, std = self.model.predict_f(self.params, X_test, full_covar=False)
        b = np.sqrt(self._beta())

        lower = mu - b * std
        upper = mu + b * std

        H_mask = (lower + self.epsilon > self.h)
        L_mask = (upper - self.epsilon <= self.h)

        idx_H = [self.ut[i] for i in np.where(H_mask)[0]]
        idx_L = [self.ut[i] for i in np.where(L_mask)[0]]

        self.ht.extend(idx_H)
        self.lt.extend(idx_L)

        newly_classified = set(idx_H + idx_L)
        self.ut = [i for i in self.ut if i not in newly_classified]

        if not self.ut:
            return False

        # Ambiguity-based selection
        ambiguity = np.minimum(upper - self.h, self.h - lower)
        next_idx_local = int(np.argmax(ambiguity))
        next_global_idx = self.ut[next_idx_local]

        self.vt.append(next_global_idx)
        self.ut.remove(next_global_idx)
        self._update_gp()
        self.t += 1
        return True

    def run(self, max_iter=100):
        for _ in range(max_iter):
            if not self.step():
                break
