In [1]:
import numpy as np
import tensorflow as tf

In [2]:
# Constants
AA = "ACDEFGHIKLMNPQRSTVWY-"
AA_TO_INT = {a: i for i, a in enumerate(AA)}
UNK_IDX = 20  # Index for "unknown"
MASK_TOKEN = -1.0
NORM_TOKEN = 1.0
PAD_TOKEN = -2.0
PAD_VALUE = 0.0
MASK_VALUE = 0.0

def seq_to_onehot(sequence: str, max_seq_len: int) -> np.ndarray:
    """Convert peptide sequence to one-hot encoding"""
    arr = np.full((max_seq_len, 21), PAD_VALUE, dtype=np.float32) # initialize padding with 0
    for j, aa in enumerate(sequence.upper()[:max_seq_len]):
        arr[j, AA_TO_INT.get(aa, UNK_IDX)] = 1.0
        # print number of UNKs in the sequence
    # num_unks = np.sum(arr[:, UNK_IDX])
    # zero out gaps
    arr[:, AA_TO_INT['-']] = PAD_VALUE  # Set gaps to PAD_VALUE
    # if num_unks > 0:
    #     print(f"Warning: {num_unks} unknown amino acids in sequence '{sequence}'")
    return arr


def OHE_to_seq(ohe: np.ndarray, gap: bool = False) -> list:
    """
    Convert a one-hot encoded matrix back to a peptide sequence.
    # (B, max_pep_len, 21) -> (B, max_pep_len)
    Args:
        ohe: One-hot encoded matrix of shape (B, N, 21).
    Returns:
        sequence: Peptide sequence as a string. (B,)
    """
    sequence = []
    for i in range(ohe.shape[0]):  # Iterate over batch dimension
        seq = []
        for j in range(ohe.shape[1]):  # Iterate over sequence length
            # Use a small tolerance for floating point comparison
            if gap and np.all(np.isclose(ohe[i, j], 0)):
                seq.append('-')
            else:
                aa_index = np.argmax(ohe[i, j])  # Get index of the max value in one-hot encoding
                if aa_index < len(AA):  # Check if it's a valid amino acid index
                    seq.append(AA[aa_index])
                else:
                    seq.append('X')  # Use 'X' for unknown amino acids
        sequence.append(''.join(seq))  # Join the list into a string
    return sequence  # Return list of sequences


def OHE_to_seq_single(ohe: np.ndarray, gap=False, AA_VOCAB="ACDEFGHIKLMNPQRSTVWY-", pad_idx=20) -> str:
    """
    Convert a one-hot encoded matrix back to a peptide sequence.
    Args:
        ohe: One-hot encoded matrix of shape (N, 21).
        gap: Whether to include gap characters.
        AA_VOCAB: The amino acid vocabulary used for one-hot encoding.
        pad_idx: The index used for padding.
    Returns:
        sequence: Peptide sequence as a string.
    """
    seq = []
    for j in range(ohe.shape[0]):  # Iterate over sequence length
        # Check if the sum of the vector is close to 0 (padding) or close to 1 (amino acid)
        if np.isclose(np.sum(ohe[j]), 0) and not gap:
             continue # Skip padding characters if gap is False
        elif gap and np.all(np.isclose(ohe[j], 0)):
            seq.append('-')
        elif np.isclose(np.sum(ohe[j]), 1.0):
            aa_index = np.argmax(ohe[j])  # Get index of the max value in one-hot encoding
            if aa_index < len(AA_VOCAB):  # Check if it's a valid amino acid index
                seq.append(AA_VOCAB[aa_index])
            else:
                seq.append('X')  # Use 'X' for unknown amino acids
        else:
             # Handle cases where the vector is not clearly padding or a one-hot encoding
             # This might indicate an issue with the input data or the one-hot encoding process
             seq.append('?') # Use '?' to indicate an ambiguous case
    return ''.join(seq)  # Join the list into a string


def peptides_to_onehot_kmer_windows(seq, max_seq_len, k=9) -> np.ndarray:
    """
    Converts a peptide sequence into a sliding window of k-mers, one-hot encoded.
    Output shape: (RF, k, 21), where RF = max_seq_len - k + 1
    """
    RF = max_seq_len - k + 1
    RFs = np.zeros((RF, k, 21), dtype=np.float32)
    for window in range(RF):
        if window + k <= len(seq):
            kmer = seq[window:window + k]
            for i, aa in enumerate(kmer):
                idx = AA_TO_INT.get(aa, UNK_IDX)
                RFs[window, i, idx] = 1.0
            # Pad remaining positions in k-mer if sequence is too short
            for i in range(len(kmer), k):
                RFs[window, i, UNK_IDX] = 1.0
        else:
            # Entire k-mer is padding if out of sequence
            RFs[window, :, UNK_IDX] = 1.0
    return np.array(RFs)

In [13]:
# generate dummy data
# generate a data that is pep one hot encoding with N=14, d=21.
# mhc_emb with N= 312 and d =1152. a contact map that is a soft weights over 312 mhc positions
# Create dummy data
num_samples = 1000
# Peptide One-Hot Encoding (N=14, d=21)
pep_len = 14
pep_alphabet_size = 20
pep_indices = np.random.randint(0, pep_alphabet_size, size=(num_samples, pep_len))
pep_OHE = tf.one_hot(pep_indices, depth=pep_alphabet_size, dtype=tf.float32)
pep_seq = [OHE_to_seq_single(pep_OHE[i].numpy()) for i in range(num_samples)]
pep_RFs = np.array([peptides_to_onehot_kmer_windows(pep_seq[i], max_seq_len=pep_len, k=9) for i in range(num_samples)], dtype=np.float32)
pep_RFs_seq = [
    [OHE_to_seq_single(pep_RFs[i][j], gap=True) for j in range(pep_RFs.shape[1])]
    for i in range(num_samples)
]

In [14]:
print(pep_RFs_seq)

[['MECHIADVC', 'ECHIADVCI', 'CHIADVCIP', 'HIADVCIPF', 'IADVCIPFP', 'ADVCIPFPI'], ['WRWQPQWAW', 'RWQPQWAWW', 'WQPQWAWWM', 'QPQWAWWME', 'PQWAWWMEC', 'QWAWWMECC'], ['FLWGAWQAT', 'LWGAWQATS', 'WGAWQATSM', 'GAWQATSMS', 'AWQATSMSA', 'WQATSMSAT'], ['QDDMDQAML', 'DDMDQAMLF', 'DMDQAMLFH', 'MDQAMLFHN', 'DQAMLFHNP', 'QAMLFHNPN'], ['HPGKTQDRA', 'PGKTQDRAK', 'GKTQDRAKW', 'KTQDRAKWR', 'TQDRAKWRF', 'QDRAKWRFW'], ['KQYCANIMC', 'QYCANIMCY', 'YCANIMCYF', 'CANIMCYFN', 'ANIMCYFNS', 'NIMCYFNSD'], ['HFPICCFDL', 'FPICCFDLD', 'PICCFDLDN', 'ICCFDLDNM', 'CCFDLDNMK', 'CFDLDNMKR'], ['HPRKELAHR', 'PRKELAHRF', 'RKELAHRFY', 'KELAHRFYK', 'ELAHRFYKS', 'LAHRFYKSK'], ['ISAKNFNHC', 'SAKNFNHCA', 'AKNFNHCAG', 'KNFNHCAGG', 'NFNHCAGGQ', 'FNHCAGGQS'], ['MNEEAMLHP', 'NEEAMLHPV', 'EEAMLHPVH', 'EAMLHPVHD', 'AMLHPVHDG', 'MLHPVHDGC'], ['GITLVYSIM', 'ITLVYSIMP', 'TLVYSIMPT', 'LVYSIMPTN', 'VYSIMPTNT', 'YSIMPTNTV'], ['SNFEMKGIV', 'NFEMKGIVD', 'FEMKGIVDI', 'EMKGIVDIC', 'MKGIVDICL', 'KGIVDICLG'], ['EDVGHWHYM', 'DVGHWHYMC', 'VGHWHYMCF',

In [15]:
class OneHotKmerWindows(tf.keras.layers.Layer):
    """
    A TensorFlow layer that converts a batch of peptide sequences into
    one-hot encoded k-mer windows.

    This layer takes a tensor of strings as input and produces a 4D tensor
    of one-hot encoded sliding windows.
    """
    def __init__(self, max_seq_len: int, k: int = 9, **kwargs):
        """
        Initializes the layer.

        Args:
            max_seq_len (int): The maximum length to which sequences are padded or truncated.
            k (int, optional): The size of the k-mer window. Defaults to 9.
        """
        super().__init__(**kwargs)
        self.max_seq_len = max_seq_len
        self.k = k
        self.rf = max_seq_len - k + 1  # Number of receptive fields (windows)

        # --- Constants ---
        # Define the amino acid alphabet including the "unknown" token.
        self.aa_alphabet = tf.constant(list("ACDEFGHIKLMNPQRSTVWY-"), dtype=tf.string)
        self.unk_idx = 20  # The integer index for the unknown token.

        # --- Lookup Table ---
        # Create an efficient lookup table to map amino acid characters to integers.
        # Any character not in the alphabet will be mapped to `unk_idx`.
        keys = self.aa_alphabet
        values = tf.range(len(self.aa_alphabet), dtype=tf.int64)
        self.char_to_int_table = tf.lookup.StaticHashTable(
            tf.lookup.KeyValueTensorInitializer(keys, values),
            default_value=self.unk_idx
        )

    def call(self, inputs: tf.Tensor) -> tf.Tensor:
        """
        Processes the input tensor of peptide strings.

        Args:
            inputs (tf.Tensor): A 1D tensor of strings (peptide sequences).
                                Shape: (batch_size,)

        Returns:
            tf.Tensor: The one-hot encoded k-mer windows.
                       Shape: (batch_size, self.rf, self.k, 21)
        """
        # 1. Split strings into characters
        # Input: ["PEPTIDE", "SEQ"] -> RaggedTensor([['P', 'E',...], ['S', 'E', 'Q']])
        sequences_ragged = tf.strings.bytes_split(inputs)

        # 2. Map characters to integers using the lookup table
        # RaggedTensor([['P', 'E',...], ...]) -> RaggedTensor([[11, 3,...], ...])
        int_sequences_ragged = self.char_to_int_table.lookup(sequences_ragged)

        # 3. Convert to a dense tensor and pad/truncate to `max_seq_len`
        # This ensures all sequences have a uniform length for batch processing.
        # We pad with the unknown index, which will be correctly one-hot encoded later.
        # Shape: (batch_size, max_seq_len)
        int_sequences_padded = int_sequences_ragged.to_tensor(
            default_value=self.unk_idx,
            shape=[None, self.max_seq_len]
        )

        # 4. Create sliding windows (k-mers) using `extract_patches`
        # This is an efficient, vectorized way to create sliding windows.
        # We temporarily add dimensions to treat the sequence like a 1D "image".
        # Shape before: (batch_size, max_seq_len)
        # Shape after reshape: (batch_size, 1, max_seq_len, 1)
        images = tf.reshape(int_sequences_padded, [-1, 1, self.max_seq_len, 1])

        # Extract patches of size k.
        # Shape: (batch_size, 1, self.rf, k)
        patches = tf.image.extract_patches(
            images=images,
            sizes=[1, 1, self.k, 1],
            strides=[1, 1, 1, 1],
            rates=[1, 1, 1, 1],
            padding='VALID'
        )

        # Reshape to the desired k-mer window format
        # Shape: (batch_size, self.rf, k)
        kmer_windows = tf.reshape(patches, [-1, self.rf, self.k])

        # 5. One-hot encode the integer k-mers
        # The depth is 21 to account for the 20 standard amino acids plus one unknown token.
        # Shape: (batch_size, self.rf, k, 21)
        one_hot_windows = tf.one_hot(kmer_windows, depth=21, dtype=tf.float32)

        return one_hot_windows

    def get_config(self):
        """Enables layer serialization."""
        config = super().get_config()
        config.update({
            "max_seq_len": self.max_seq_len,
            "k": self.k,
        })
        return config

In [16]:
pep_seq_tf = tf.constant(pep_seq)
print(pep_seq_tf)
print(pep_len)
pep_RFs_tf = OneHotKmerWindows(max_seq_len=pep_len, k=9)(pep_seq_tf)
print(pep_RFs_tf.shape)
a_rf = pep_RFs_tf


tf.Tensor(
[b'MECHIADVCIPFPI' b'WRWQPQWAWWMECC' b'FLWGAWQATSMSAT' b'QDDMDQAMLFHNPN'
 b'HPGKTQDRAKWRFW' b'KQYCANIMCYFNSD' b'HFPICCFDLDNMKR' b'HPRKELAHRFYKSK'
 b'ISAKNFNHCAGGQS' b'MNEEAMLHPVHDGC' b'GITLVYSIMPTNTV' b'SNFEMKGIVDICLG'
 b'EDVGHWHYMCFLGD' b'IRAWVLPKSHLYSK' b'KRGQPHDPLHAEVC' b'WNEAQDCVPIEKYP'
 b'IDRLGPHGSCRVAY' b'NFWRLQMLRNIPGY' b'MNWAIETKLNPKRD' b'TEPNEHIISHRSVN'
 b'VLEKNEATMMSIDY' b'VVWMEMWFWDSDVM' b'RCYVEIGTMSRTVP' b'NKVWFENCSVCCHQ'
 b'TFWACSTVWIDCRI' b'QMNYANFDPYLCAK' b'DKVVPEQITLYGLA' b'RFYFAVCLKGHPGG'
 b'FCPSIGKTGIPVNE' b'NFWGDGASYGQSGD' b'GCEPLITGMCNHFM' b'YNISQIVYDTNPRE'
 b'DVIGYAAHPAWTEG' b'CCIHGLWGVTEFRD' b'CSMGIHNKQMCATV' b'VSDTVKGYFDDDRF'
 b'FIAEFNNWLTVQEA' b'PKDYMITKAPHSPC' b'YGALWDIELMMPGS' b'STVRSIVQCNNDER'
 b'MSEHPPGGECPIHP' b'ICWWEYSWRYRVLC' b'NTGLCHMLFHVIHI' b'WHGHESWPVPCRMP'
 b'CDAQHDRDGHPFYA' b'FCIPMNKHFHPDPM' b'TFIQWTTTHQMFWE' b'IMGYPWNIFRYGGS'
 b'VSHEVGIISEYIAQ' b'VSMAMGTVFKDKNT' b'KFEDGAWSVPHGEK' b'HGGMSDYNWKQATD'
 b'LNIQHMSYDRPRNC' b'VPLGHFHCDHLMAY' b'K

In [17]:
# get seqs of a_rf
a_rf_seq = [
    [OHE_to_seq_single(a_rf[i][j], gap=True) for j in range(a_rf.shape[1])]
    for i in range(num_samples)
]
print(a_rf_seq)

[['MECHIADVC', 'ECHIADVCI', 'CHIADVCIP', 'HIADVCIPF', 'IADVCIPFP', 'ADVCIPFPI'], ['WRWQPQWAW', 'RWQPQWAWW', 'WQPQWAWWM', 'QPQWAWWME', 'PQWAWWMEC', 'QWAWWMECC'], ['FLWGAWQAT', 'LWGAWQATS', 'WGAWQATSM', 'GAWQATSMS', 'AWQATSMSA', 'WQATSMSAT'], ['QDDMDQAML', 'DDMDQAMLF', 'DMDQAMLFH', 'MDQAMLFHN', 'DQAMLFHNP', 'QAMLFHNPN'], ['HPGKTQDRA', 'PGKTQDRAK', 'GKTQDRAKW', 'KTQDRAKWR', 'TQDRAKWRF', 'QDRAKWRFW'], ['KQYCANIMC', 'QYCANIMCY', 'YCANIMCYF', 'CANIMCYFN', 'ANIMCYFNS', 'NIMCYFNSD'], ['HFPICCFDL', 'FPICCFDLD', 'PICCFDLDN', 'ICCFDLDNM', 'CCFDLDNMK', 'CFDLDNMKR'], ['HPRKELAHR', 'PRKELAHRF', 'RKELAHRFY', 'KELAHRFYK', 'ELAHRFYKS', 'LAHRFYKSK'], ['ISAKNFNHC', 'SAKNFNHCA', 'AKNFNHCAG', 'KNFNHCAGG', 'NFNHCAGGQ', 'FNHCAGGQS'], ['MNEEAMLHP', 'NEEAMLHPV', 'EEAMLHPVH', 'EAMLHPVHD', 'AMLHPVHDG', 'MLHPVHDGC'], ['GITLVYSIM', 'ITLVYSIMP', 'TLVYSIMPT', 'LVYSIMPTN', 'VYSIMPTNT', 'YSIMPTNTV'], ['SNFEMKGIV', 'NFEMKGIVD', 'FEMKGIVDI', 'EMKGIVDIC', 'MKGIVDICL', 'KGIVDICLG'], ['EDVGHWHYM', 'DVGHWHYMC', 'VGHWHYMCF',

In [18]:
# compare the long_mer vs windows by printing them in pairs
for i in range(num_samples):
    print(pep_seq[i])
    print(pep_RFs_seq[i])
    print(a_rf_seq[i])

MECHIADVCIPFPI
['MECHIADVC', 'ECHIADVCI', 'CHIADVCIP', 'HIADVCIPF', 'IADVCIPFP', 'ADVCIPFPI']
['MECHIADVC', 'ECHIADVCI', 'CHIADVCIP', 'HIADVCIPF', 'IADVCIPFP', 'ADVCIPFPI']
WRWQPQWAWWMECC
['WRWQPQWAW', 'RWQPQWAWW', 'WQPQWAWWM', 'QPQWAWWME', 'PQWAWWMEC', 'QWAWWMECC']
['WRWQPQWAW', 'RWQPQWAWW', 'WQPQWAWWM', 'QPQWAWWME', 'PQWAWWMEC', 'QWAWWMECC']
FLWGAWQATSMSAT
['FLWGAWQAT', 'LWGAWQATS', 'WGAWQATSM', 'GAWQATSMS', 'AWQATSMSA', 'WQATSMSAT']
['FLWGAWQAT', 'LWGAWQATS', 'WGAWQATSM', 'GAWQATSMS', 'AWQATSMSA', 'WQATSMSAT']
QDDMDQAMLFHNPN
['QDDMDQAML', 'DDMDQAMLF', 'DMDQAMLFH', 'MDQAMLFHN', 'DQAMLFHNP', 'QAMLFHNPN']
['QDDMDQAML', 'DDMDQAMLF', 'DMDQAMLFH', 'MDQAMLFHN', 'DQAMLFHNP', 'QAMLFHNPN']
HPGKTQDRAKWRFW
['HPGKTQDRA', 'PGKTQDRAK', 'GKTQDRAKW', 'KTQDRAKWR', 'TQDRAKWRF', 'QDRAKWRFW']
['HPGKTQDRA', 'PGKTQDRAK', 'GKTQDRAKW', 'KTQDRAKWR', 'TQDRAKWRF', 'QDRAKWRFW']
KQYCANIMCYFNSD
['KQYCANIMC', 'QYCANIMCY', 'YCANIMCYF', 'CANIMCYFN', 'ANIMCYFNS', 'NIMCYFNSD']
['KQYCANIMC', 'QYCANIMCY', 'YCANIMCYF', '