In [1]:
import numpy as np
import tensorflow as tf

In [2]:
# Constants
AA = "ACDEFGHIKLMNPQRSTVWY-"
AA_TO_INT = {a: i for i, a in enumerate(AA)}
UNK_IDX = 20  # Index for "unknown"
MASK_TOKEN = -1.0
NORM_TOKEN = 1.0
PAD_TOKEN = -2.0
PAD_VALUE = 0.0
MASK_VALUE = 0.0

def seq_to_onehot(sequence: str, max_seq_len: int) -> np.ndarray:
    """Convert peptide sequence to one-hot encoding"""
    arr = np.full((max_seq_len, 21), PAD_VALUE, dtype=np.float32) # initialize padding with 0
    for j, aa in enumerate(sequence.upper()[:max_seq_len]):
        arr[j, AA_TO_INT.get(aa, UNK_IDX)] = 1.0
        # print number of UNKs in the sequence
    # num_unks = np.sum(arr[:, UNK_IDX])
    # zero out gaps
    arr[:, AA_TO_INT['-']] = PAD_VALUE  # Set gaps to PAD_VALUE
    # if num_unks > 0:
    #     print(f"Warning: {num_unks} unknown amino acids in sequence '{sequence}'")
    return arr


def OHE_to_seq(ohe: np.ndarray, gap: bool = False) -> list:
    """
    Convert a one-hot encoded matrix back to a peptide sequence.
    # (B, max_pep_len, 21) -> (B, max_pep_len)
    Args:
        ohe: One-hot encoded matrix of shape (B, N, 21).
    Returns:
        sequence: Peptide sequence as a string. (B,)
    """
    sequence = []
    for i in range(ohe.shape[0]):  # Iterate over batch dimension
        seq = []
        for j in range(ohe.shape[1]):  # Iterate over sequence length
            # Use a small tolerance for floating point comparison
            if gap and np.all(np.isclose(ohe[i, j], 0)):
                seq.append('-')
            else:
                aa_index = np.argmax(ohe[i, j])  # Get index of the max value in one-hot encoding
                if aa_index < len(AA):  # Check if it's a valid amino acid index
                    seq.append(AA[aa_index])
                else:
                    seq.append('X')  # Use 'X' for unknown amino acids
        sequence.append(''.join(seq))  # Join the list into a string
    return sequence  # Return list of sequences


def OHE_to_seq_single(ohe: np.ndarray, gap=False, AA_VOCAB="ACDEFGHIKLMNPQRSTVWY-", pad_idx=20) -> str:
    """
    Convert a one-hot encoded matrix back to a peptide sequence.
    Args:
        ohe: One-hot encoded matrix of shape (N, 21).
        gap: Whether to include gap characters.
        AA_VOCAB: The amino acid vocabulary used for one-hot encoding.
        pad_idx: The index used for padding.
    Returns:
        sequence: Peptide sequence as a string.
    """
    seq = []
    for j in range(ohe.shape[0]):  # Iterate over sequence length
        # Check if the sum of the vector is close to 0 (padding) or close to 1 (amino acid)
        if np.isclose(np.sum(ohe[j]), 0) and not gap:
             continue # Skip padding characters if gap is False
        elif gap and np.all(np.isclose(ohe[j], 0)):
            seq.append('-')
        elif np.isclose(np.sum(ohe[j]), 1.0):
            aa_index = np.argmax(ohe[j])  # Get index of the max value in one-hot encoding
            if aa_index < len(AA_VOCAB):  # Check if it's a valid amino acid index
                seq.append(AA_VOCAB[aa_index])
            else:
                seq.append('X')  # Use 'X' for unknown amino acids
        else:
             # Handle cases where the vector is not clearly padding or a one-hot encoding
             # This might indicate an issue with the input data or the one-hot encoding process
             seq.append('?') # Use '?' to indicate an ambiguous case
    return ''.join(seq)  # Join the list into a string


def peptides_to_onehot_kmer_windows(seq, max_seq_len, k=9) -> np.ndarray:
    """
    Converts a peptide sequence into a sliding window of k-mers, one-hot encoded.
    Output shape: (RF, k, 21), where RF = max_seq_len - k + 1
    """
    RF = max_seq_len - k + 1
    RFs = np.zeros((RF, k, 21), dtype=np.float32)
    for window in range(RF):
        if window + k <= len(seq):
            kmer = seq[window:window + k]
            for i, aa in enumerate(kmer):
                idx = AA_TO_INT.get(aa, UNK_IDX)
                RFs[window, i, idx] = 1.0
            # Pad remaining positions in k-mer if sequence is too short
            for i in range(len(kmer), k):
                RFs[window, i, UNK_IDX] = 1.0
        else:
            # Entire k-mer is padding if out of sequence
            RFs[window, :, UNK_IDX] = 1.0
    return np.array(RFs)

In [9]:
# generate dummy data
# generate a data that is pep one hot encoding with N=14, d=21.
# mhc_emb with N= 312 and d =1152. a contact map that is a soft weights over 312 mhc positions
# Create dummy data
num_samples = 1000
# Peptide One-Hot Encoding (N=14, d=21)
pep_len = 14
pep_alphabet_size = 21
pep_indices = np.random.randint(0, pep_alphabet_size, size=(num_samples, pep_len))
pep_OHE = tf.one_hot(pep_indices, depth=pep_alphabet_size, dtype=tf.float32)
pep_seq = [OHE_to_seq_single(pep_OHE[i].numpy()) for i in range(num_samples)]
pep_RFs = np.array([peptides_to_onehot_kmer_windows(pep_seq[i], max_seq_len=pep_len, k=9) for i in range(num_samples)], dtype=np.float32)
pep_RFs_seq = [
    [OHE_to_seq_single(pep_RFs[i][j], gap=True) for j in range(pep_RFs.shape[1])]
    for i in range(num_samples)
]

In [10]:
print(pep_RFs_seq)

[['YGLYAVTG-', 'GLYAVTG-A', 'LYAVTG-AA', 'YAVTG-AAE', 'AVTG-AAER', 'VTG-AAERY'], ['QSIFGE-F-', 'SIFGE-F-W', 'IFGE-F-WV', 'FGE-F-WVA', 'GE-F-WVAR', 'E-F-WVARA'], ['GGCKKMYKR', 'GCKKMYKRE', 'CKKMYKREK', 'KKMYKREKR', 'KMYKREKRE', 'MYKREKREP'], ['EVHDPFSKC', 'VHDPFSKCL', 'HDPFSKCLI', 'DPFSKCLI-', 'PFSKCLI-A', 'FSKCLI-AK'], ['DDK-TH-GG', 'DK-TH-GGW', 'K-TH-GGWP', '-TH-GGWPW', 'TH-GGWPWP', 'H-GGWPWPF'], ['SQLVHPVAC', 'QLVHPVACR', 'LVHPVACRA', 'VHPVACRAR', 'HPVACRARQ', 'PVACRARQW'], ['PSKMEFHIS', 'SKMEFHIST', 'KMEFHISTG', 'MEFHISTGP', 'EFHISTGPA', 'FHISTGPAI'], ['WHLMETSDI', 'HLMETSDIE', 'LMETSDIET', 'METSDIETI', 'ETSDIETIH', 'TSDIETIHV'], ['FKSKLYTNP', 'KSKLYTNPI', 'SKLYTNPIT', 'KLYTNPITQ', 'LYTNPITQY', 'YTNPITQYD'], ['AEIDDSGQG', 'EIDDSGQG-', 'IDDSGQG-H', 'DDSGQG-HC', 'DSGQG-HCC', 'SGQG-HCCT'], ['HTFVKRLMQ', 'TFVKRLMQV', 'FVKRLMQVF', 'VKRLMQVFW', 'KRLMQVFWN', 'RLMQVFWNG'], ['LWCHRDRRA', 'WCHRDRRAT', 'CHRDRRATR', 'HRDRRATRQ', 'RDRRATRQK', 'DRRATRQKV'], ['NTPEIPKVL', 'TPEIPKVLS', 'PEIPKVLSC',

In [11]:
import tensorflow as tf

class OHE_KmerWindows(tf.keras.layers.Layer):
    """
    A TensorFlow layer that converts a batch of one-hot encoded sequences
    into one-hot encoded k-mer windows.

    This layer takes a 3D tensor of one-hot encoded sequences and produces a
    4D tensor of one-hot encoded sliding windows. It is a more direct alternative
    to the string-based version when data is already pre-processed.
    """
    def __init__(self, max_seq_len: int, k: int = 9, alphabet_size: int = 21, **kwargs):
        """
        Initializes the layer.

        Args:
            max_seq_len (int): The length of the input sequences.
            k (int, optional): The size of the k-mer window. Defaults to 9.
            alphabet_size (int, optional): The depth of the one-hot encoding
                                         (e.g., 21 for amino acids + unk). Defaults to 21.
        """
        super().__init__(**kwargs)
        self.max_seq_len = max_seq_len
        self.k = k
        self.alphabet_size = alphabet_size
        self.rf = max_seq_len - k + 1  # Number of receptive fields (windows)

    def call(self, inputs: tf.Tensor) -> tf.Tensor:
        """
        Processes the input tensor of one-hot encoded sequences.

        Args:
            inputs (tf.Tensor): A 3D tensor of one-hot encoded sequences.
                                Shape: (batch_size, max_seq_len, alphabet_size)

        Returns:
            tf.Tensor: The one-hot encoded k-mer windows.
                       Shape: (batch_size, self.rf, self.k, self.alphabet_size)
        """
        # --- Input Validation (Optional but Recommended) ---
        input_shape = tf.shape(inputs)
        tf.debugging.assert_equal(input_shape[1], self.max_seq_len,
            message=f"Input sequence length must be {self.max_seq_len}")
        tf.debugging.assert_equal(input_shape[2], self.alphabet_size,
            message=f"Input alphabet size must be {self.alphabet_size}")

        # 1. Reshape the input to be compatible with `extract_patches`
        # We treat the sequence as a 1D "image" with `alphabet_size` channels.
        # Shape: (batch, max_seq_len, alphabet_size) -> (batch, max_seq_len, 1, alphabet_size)
        images = tf.expand_dims(inputs, axis=2)

        # 2. Extract sliding windows (patches) of size k
        # We slide a window of size `k` along the `max_seq_len` dimension.
        patches = tf.image.extract_patches(
            images=images,
            sizes=[1, self.k, 1, 1],      # Window size: (batch, height, width, channels)
            strides=[1, 1, 1, 1],    # Slide one step at a time
            rates=[1, 1, 1, 1],
            padding='VALID'          # 'VALID' ensures we only take full windows
        )
        # The output shape of extract_patches is (batch, num_windows_h, num_windows_w, k * 1 * alphabet_size)
        # In our case: (batch_size, self.rf, 1, k * self.alphabet_size)

        # 3. Reshape the patches into the desired final format
        # Shape: (batch_size, self.rf, 1, k * alphabet_size) -> (batch_size, self.rf, k, alphabet_size)
        kmer_windows = tf.reshape(
            patches,
            [-1, self.rf, self.k, self.alphabet_size]
        )

        return kmer_windows

    def get_config(self):
        """Enables layer serialization."""
        config = super().get_config()
        config.update({
            "max_seq_len": self.max_seq_len,
            "k": self.k,
            "alphabet_size": self.alphabet_size,
        })
        return config

In [12]:
pep_seq_tf = tf.constant(pep_seq)
# print(pep_seq_tf)
print(pep_len)
pep_ohe_tf = tf.constant(pep_OHE)
print(pep_ohe_tf.shape)
pep_RFs_tf = OHE_KmerWindows(max_seq_len=pep_len, k=9)(pep_ohe_tf)
print(pep_RFs_tf.shape)
a_rf = pep_RFs_tf


14
(1000, 14, 21)
(1000, 6, 9, 21)


In [13]:
# get seqs of a_rf
a_rf_seq = [
    [OHE_to_seq_single(a_rf[i][j], gap=True) for j in range(a_rf.shape[1])]
    for i in range(num_samples)
]
print(a_rf_seq)

[['YGLYAVTG-', 'GLYAVTG-A', 'LYAVTG-AA', 'YAVTG-AAE', 'AVTG-AAER', 'VTG-AAERY'], ['QSIFGE-F-', 'SIFGE-F-W', 'IFGE-F-WV', 'FGE-F-WVA', 'GE-F-WVAR', 'E-F-WVARA'], ['GGCKKMYKR', 'GCKKMYKRE', 'CKKMYKREK', 'KKMYKREKR', 'KMYKREKRE', 'MYKREKREP'], ['EVHDPFSKC', 'VHDPFSKCL', 'HDPFSKCLI', 'DPFSKCLI-', 'PFSKCLI-A', 'FSKCLI-AK'], ['DDK-TH-GG', 'DK-TH-GGW', 'K-TH-GGWP', '-TH-GGWPW', 'TH-GGWPWP', 'H-GGWPWPF'], ['SQLVHPVAC', 'QLVHPVACR', 'LVHPVACRA', 'VHPVACRAR', 'HPVACRARQ', 'PVACRARQW'], ['PSKMEFHIS', 'SKMEFHIST', 'KMEFHISTG', 'MEFHISTGP', 'EFHISTGPA', 'FHISTGPAI'], ['WHLMETSDI', 'HLMETSDIE', 'LMETSDIET', 'METSDIETI', 'ETSDIETIH', 'TSDIETIHV'], ['FKSKLYTNP', 'KSKLYTNPI', 'SKLYTNPIT', 'KLYTNPITQ', 'LYTNPITQY', 'YTNPITQYD'], ['AEIDDSGQG', 'EIDDSGQG-', 'IDDSGQG-H', 'DDSGQG-HC', 'DSGQG-HCC', 'SGQG-HCCT'], ['HTFVKRLMQ', 'TFVKRLMQV', 'FVKRLMQVF', 'VKRLMQVFW', 'KRLMQVFWN', 'RLMQVFWNG'], ['LWCHRDRRA', 'WCHRDRRAT', 'CHRDRRATR', 'HRDRRATRQ', 'RDRRATRQK', 'DRRATRQKV'], ['NTPEIPKVL', 'TPEIPKVLS', 'PEIPKVLSC',

In [14]:
# compare the long_mer vs windows by printing them in pairs
for i in range(num_samples):
    print(pep_seq[i])
    print(pep_RFs_seq[i])
    print(a_rf_seq[i])

YGLYAVTG-AAERY
['YGLYAVTG-', 'GLYAVTG-A', 'LYAVTG-AA', 'YAVTG-AAE', 'AVTG-AAER', 'VTG-AAERY']
['YGLYAVTG-', 'GLYAVTG-A', 'LYAVTG-AA', 'YAVTG-AAE', 'AVTG-AAER', 'VTG-AAERY']
QSIFGE-F-WVARA
['QSIFGE-F-', 'SIFGE-F-W', 'IFGE-F-WV', 'FGE-F-WVA', 'GE-F-WVAR', 'E-F-WVARA']
['QSIFGE-F-', 'SIFGE-F-W', 'IFGE-F-WV', 'FGE-F-WVA', 'GE-F-WVAR', 'E-F-WVARA']
GGCKKMYKREKREP
['GGCKKMYKR', 'GCKKMYKRE', 'CKKMYKREK', 'KKMYKREKR', 'KMYKREKRE', 'MYKREKREP']
['GGCKKMYKR', 'GCKKMYKRE', 'CKKMYKREK', 'KKMYKREKR', 'KMYKREKRE', 'MYKREKREP']
EVHDPFSKCLI-AK
['EVHDPFSKC', 'VHDPFSKCL', 'HDPFSKCLI', 'DPFSKCLI-', 'PFSKCLI-A', 'FSKCLI-AK']
['EVHDPFSKC', 'VHDPFSKCL', 'HDPFSKCLI', 'DPFSKCLI-', 'PFSKCLI-A', 'FSKCLI-AK']
DDK-TH-GGWPWPF
['DDK-TH-GG', 'DK-TH-GGW', 'K-TH-GGWP', '-TH-GGWPW', 'TH-GGWPWP', 'H-GGWPWPF']
['DDK-TH-GG', 'DK-TH-GGW', 'K-TH-GGWP', '-TH-GGWPW', 'TH-GGWPWP', 'H-GGWPWPF']
SQLVHPVACRARQW
['SQLVHPVAC', 'QLVHPVACR', 'LVHPVACRA', 'VHPVACRAR', 'HPVACRARQ', 'PVACRARQW']
['SQLVHPVAC', 'QLVHPVACR', 'LVHPVACRA', '