In [125]:
import random
from typing import List, Tuple, Dict


class SyntheticPPIBuilder:
    def __init__(
        self,
        num_bags: int = 1000,
        seq_len_range: Tuple[int, int] = (150, 300),
        domain_len_range: Tuple[int, int] = (20, 50),
        num_domains_range: Tuple[int, int] = (2, 5),
        positive_rate: float = 0.5,
        num_motifs: int = 15,
        seed: int = 42
    ):
        self.num_bags = num_bags
        self.seq_len_range = seq_len_range
        self.domain_len_range = domain_len_range
        self.num_domains_range = num_domains_range
        self.positive_rate = positive_rate
        self.num_motifs = num_motifs
        self.rng = random.Random(seed)
        self.amino_acids = "ACDEFGHIKLMNPQRSTVWY"
        self.motifs = self._generate_motifs(self.num_motifs)
        self.bags = []

    def _generate_motifs(self, n: int) -> List[str]:
        return [
            ''.join(self.rng.choices(self.amino_acids, k=self.rng.randint(3, 5)))
            for _ in range(n)
        ]

    def _generate_sequence(self) -> str:
        length = self.rng.randint(*self.seq_len_range)
        return ''.join(self.rng.choices(self.amino_acids, k=length))

    def _sample_domains(self, seq: str) -> List[Tuple[int, int]]:
        n_domains = self.rng.randint(*self.num_domains_range)
        domains = []
        attempts = 0
        max_attempts = 20
        while len(domains) < n_domains and attempts < max_attempts:
            start = self.rng.randint(0, len(seq) - self.domain_len_range[1])
            end = start + self.rng.randint(*self.domain_len_range)
            if end <= len(seq):
                if all(end <= s or start >= e for (s, e) in domains):  # No overlap
                    domains.append((start, end))
            attempts += 1
        return domains

    def _insert_motif(self, seq: str, start: int, motif: str) -> Tuple[str, Tuple[int, int]]:
        """Insert a motif at a domain location, updating the sequence and domain span."""
        new_seq = seq[:start] + motif + seq[start + len(motif):]
        return new_seq, (start, start + len(motif))

    def _create_bag(self, pid: int) -> Dict:
        seqA = self._generate_sequence()
        seqB = self._generate_sequence()
        domA = self._sample_domains(seqA)
        domB = self._sample_domains(seqB)

        all_pairs = [(i, j) for i in range(len(domA)) for j in range(len(domB))]
        has_positive = self.rng.random() < self.positive_rate
        key_pair_count = self.rng.randint(1, min(3, len(all_pairs))) if has_positive else 0
        key_pairs = self.rng.sample(all_pairs, k=key_pair_count) if key_pair_count > 0 else []

        # Insert motifs for key pairs
        for idx, (i, j) in enumerate(key_pairs):
            motif = self.motifs[idx % len(self.motifs)]

            # Update domain A with motif
            startA, _ = domA[i]
            seqA, new_domA = self._insert_motif(seqA, startA, motif)
            domA[i] = new_domA

            # Update domain B with motif
            startB, _ = domB[j]
            seqB, new_domB = self._insert_motif(seqB, startB, motif)
            domB[j] = new_domB

        instances = []
        for i in range(len(domA)):
            for j in range(len(domB)):
                label = int((i, j) in key_pairs)
                instances.append({
                    "domainA": domA[i],
                    "domainB": domB[j],
                    "label": label
                })

        return {
            "proteinA": f"ProtA_{pid}",
            "proteinB": f"ProtB_{pid}",
            "sequenceA": seqA,
            "sequenceB": seqB,
            "domainsA": domA,
            "domainsB": domB,
            "instances": instances
        }

    def build(self) -> List[Dict]:
        self.bags = [self._create_bag(i) for i in range(self.num_bags)]
        return self.bags

In [126]:
import numpy as np
from collections import Counter
from typing import List, Dict

class DomainEncoder:
    def __init__(self, k: int = 3, vocab_size: int = 100):
        """
        Initialize the encoder.
        Args:
            k: k-mer size
            vocab_size: number of most frequent k-mers to keep
        """
        self.k = k
        self.vocab_size = vocab_size
        self.vocab: List[str] = []
        self.fitted = False

    def _get_kmers(self, sequence: str) -> List[str]:
        return [sequence[i:i+self.k] for i in range(len(sequence) - self.k + 1)]

    def fit(self, domain_sequences: List[str]):
        """Build k-mer vocabulary from a list of domain sequences."""
        all_kmers = []
        for seq in domain_sequences:
            all_kmers.extend(self._get_kmers(seq))
        kmer_counts = Counter(all_kmers)
        most_common = kmer_counts.most_common(self.vocab_size)
        self.vocab = [k for k, _ in most_common]
        self.fitted = True

    def encode(self, sequence: str) -> np.ndarray:
        """Encode a domain sequence using the fitted k-mer vocabulary."""
        if not self.fitted:
            raise RuntimeError("DomainEncoder must be fitted before encoding.")
        vec = np.zeros(len(self.vocab), dtype=np.float32)
        kmers = self._get_kmers(sequence)
        counts = Counter(kmers)
        for i, kmer in enumerate(self.vocab):
            vec[i] = counts[kmer]
        if len(kmers) > 0:
            vec /= len(kmers)  # normalize
        return vec

    def encode_many(self, sequences: List[str]) -> np.ndarray:
        """Batch encode a list of domain sequences."""
        return np.stack([self.encode(seq) for seq in sequences])

In [127]:
import numpy as np
from typing import List, Tuple, Dict

def prepare_mil_data(
    bags: List[Dict],
    encoder: DomainEncoder,
) -> Tuple[List[np.ndarray], List[int], List[List[int]]]:
    """
    Convert synthetic bags to MIL input format.

    Returns:
        X_bags: List of 2D arrays, each shape (num_instances, 2 * encoding_dim)
        y_bags: List of bag labels (0/1)
        key_instance_indices: List of lists of indices of positive instances per bag
    """
    X_bags = []
    y_bags = []
    key_instance_indices = []

    for bag in bags:
        seqA = bag["sequenceA"]
        seqB = bag["sequenceB"]
        domainsA = bag["domainsA"]
        domainsB = bag["domainsB"]
        instances = bag["instances"]

        # Encode all domains once (cache)
        encA = [encoder.encode(seqA[start:end]) for (start, end) in domainsA]
        encB = [encoder.encode(seqB[start:end]) for (start, end) in domainsB]

        instance_vectors = []
        instance_labels = []
        pos_indices = []

        for idx, inst in enumerate(instances):
            i, j = domainsA.index(inst["domainA"]), domainsB.index(inst["domainB"])
            vecA = encA[i]
            vecB = encB[j]
            concat_vec = np.concatenate([vecA, vecB])
            instance_vectors.append(concat_vec)

            label = inst["label"]
            instance_labels.append(label)
            if label == 1:
                pos_indices.append(idx)

        bag_label = 1 if any(instance_labels) else 0

        X_bags.append(np.vstack(instance_vectors))
        y_bags.append(bag_label)
        key_instance_indices.append(pos_indices)

    return X_bags, y_bags, key_instance_indices


class OneHotDomainEncoder:
    def __init__(self):
        self.amino_acids = "ACDEFGHIKLMNPQRSTVWY"
        self.aa_to_idx = {aa: i for i, aa in enumerate(self.amino_acids)}
        self.dim = len(self.amino_acids)

    def one_hot(self, seq: str) -> np.ndarray:
        mat = np.zeros((len(seq), self.dim), dtype=np.float32)
        for i, aa in enumerate(seq):
            if aa in self.aa_to_idx:
                mat[i, self.aa_to_idx[aa]] = 1.0
        return mat

    def encode(self, seq: str, pooling: str = "max") -> np.ndarray:
        oh = self.one_hot(seq)
        if len(oh) == 0:
            return np.zeros(self.dim)
        if pooling == "mean":
            return oh.mean(axis=0)
        elif pooling == "max":
            return oh.max(axis=0)
        else:
            raise ValueError("Pooling must be 'mean' or 'max'")


In [128]:
builder = SyntheticPPIBuilder(num_bags=1000)
bags = builder.build()
print(bags[0].keys())  

dict_keys(['proteinA', 'proteinB', 'sequenceA', 'sequenceB', 'domainsA', 'domainsB', 'instances'])


In [129]:
all_domains = [
    bag["sequenceA"][start:end]
    for bag in bags
    for (start, end) in bag["domainsA"]
] + [
    bag["sequenceB"][start:end]
    for bag in bags
    for (start, end) in bag["domainsB"]
]

# encoder = DomainEncoder(k=3, vocab_size=1000)
# encoder.fit(all_domains)


encoder = OneHotDomainEncoder()

In [130]:
bags, labels, key_pos = prepare_mil_data(bags, encoder)

In [131]:
bags[0].shape

(12, 40)

In [132]:
import pandas as pd

from sklearn.metrics import accuracy_score
from sklearn.model_selection import train_test_split

from sklearn.ensemble import RandomForestRegressor

from qsarmil.mil.wrapper import InstanceWrapper, BagWrapper
from qsarmil.mil.network.regressor import InstanceNetworkRegressor, BagNetworkRegressor

from qsarmil.mil.network.classifier import (AttentionNetworkClassifier,
                                            GatedAttentionNetworkClassifier,
                                            SelfAttentionNetworkClassifier,
                                            TempAttentionNetworkClassifier,
                                            GaussianPoolingNetworkClassifier,
                                            DynamicPoolingNetworkClassifier)

from qsarmil.mil.preprocessing import BagMinMaxScaler

In [133]:
from typing import List

def kid_top1(
    key_indices: List[List[int]],
    predictions: List[List[float]]
) -> float:
    """
    Compute top-1 key instance detection accuracy for multiple key instances per bag.

    Args:
        key_indices: List of lists of true key instance indices per bag.
        predictions: List of predicted scores per instance for each bag.

    Returns:
        Accuracy: Fraction of bags where top predicted instance is among key instances.
    """
    assert len(key_indices) == len(predictions)

    correct = 0
    total = 0

    for keys, scores in zip(key_indices, predictions):
        if not keys:
            continue  # skip bags without key instances (e.g., negative bags)
        total += 1
        top_idx = max(range(len(scores)), key=lambda i: scores[i])
        if top_idx in keys:
            correct += 1

    return correct / total if total > 0 else 0.0


In [134]:
network_hparams = {'hidden_layer_sizes':(256, 128, 64),
                   'num_epoch':300,
                   'batch_size':128,
                   'learning_rate':0.001,
                   'weight_decay':0.001,
                   'instance_weight_dropout':0.01,
                   'init_cuda':False,
                   'verbose':False}

In [135]:
x_train, x_test, y_train, y_test, key_train, key_test = train_test_split(bags, labels, key_pos)

In [159]:
model = GatedAttentionNetworkClassifier(**network_hparams)
model.fit(x_train, y_train)

GatedAttentionNetworkClassifier(
  (main_net): Sequential(
    (0): Linear(in_features=40, out_features=256, bias=True)
    (1): ReLU()
    (2): Linear(in_features=256, out_features=128, bias=True)
    (3): ReLU()
    (4): Linear(in_features=128, out_features=64, bias=True)
    (5): ReLU()
  )
  (attention_V): Sequential(
    (0): Linear(in_features=64, out_features=128, bias=True)
    (1): Tanh()
  )
  (attention_U): Sequential(
    (0): Linear(in_features=64, out_features=128, bias=True)
    (1): Sigmoid()
  )
  (detector): Linear(in_features=128, out_features=1, bias=True)
  (estimator): Linear(in_features=64, out_features=1, bias=True)
)

In [160]:
y_prob = model.predict(x_test)
y_pred = np.where(y_prob > 0.5, 1, 0)
w_pred = model.get_instance_weights(x_test)

In [161]:
accuracy_score(y_train, np.where(model.predict(x_train) > 0.5, 1, 0))

1.0

In [162]:
accuracy_score(y_test, y_pred)

1.0

In [163]:
kid_top1(key_test, w_pred)

0.9366197183098591

In [184]:
N = 6

In [185]:
print(y_test[N])
print(y_pred[N].item())
print([i for i in key_test[N]])

1
1
[5, 6]


In [186]:
w_pred[N].round(2)

array([0.01, 0.04, 0.02, 0.01, 0.06, 0.47, 0.27, 0.02, 0.01, 0.05, 0.02,
       0.01], dtype=float32)

In [144]:
network_list = [
                ("AttentionNetworkClassifier", AttentionNetworkClassifier(**network_hparams)), 
                ("GatedAttentionNetworkClassifier", GatedAttentionNetworkClassifier(**network_hparams)), 
                ("SelfAttentionNetworkClassifier", SelfAttentionNetworkClassifier(**network_hparams)), 
                ("TemperatureAttentionNetworkClassifier", TempAttentionNetworkClassifier(**network_hparams)),
                ("GaussianPoolingNetworkClassifier", GaussianPoolingNetworkClassifier(**network_hparams)), 
                ("DynamicPoolingNetworkClassifier", DynamicPoolingNetworkClassifier(**network_hparams))
               ]

In [145]:
res = pd.DataFrame()
for name, model in network_list:
    # train model
    model.fit(x_train, y_train)
    # predict
    y_prob = model.predict(x_test)
    y_pred = np.where(y_prob > 0.5, 1, 0)
    w_pred = model.get_instance_weights(x_test)
    #
    res.loc[name, "PRED_ACC"] = accuracy_score(y_test, y_pred)
    res.loc[name, "KID_ACC"] = kid_top1(key_test, w_pred)

In [146]:
res.round(2)

Unnamed: 0,PRED_ACC,KID_ACC
AttentionNetworkClassifier,1.0,0.7
GatedAttentionNetworkClassifier,1.0,0.94
SelfAttentionNetworkClassifier,1.0,0.04
TemperatureAttentionNetworkClassifier,1.0,0.87
GaussianPoolingNetworkClassifier,0.57,0.62
DynamicPoolingNetworkClassifier,1.0,0.94
