In [1]:
import numpy as np
import tensorflow as tf

In [2]:
# Constants
AA = "ACDEFGHIKLMNPQRSTVWY-"
AA_TO_INT = {a: i for i, a in enumerate(AA)}
UNK_IDX = 20  # Index for "unknown"
MASK_TOKEN = -1.0
NORM_TOKEN = 1.0
PAD_TOKEN = -2.0
PAD_VALUE = 0.0
MASK_VALUE = 0.0

def seq_to_onehot(sequence: str, max_seq_len: int) -> np.ndarray:
    """Convert peptide sequence to one-hot encoding"""
    arr = np.full((max_seq_len, 21), PAD_VALUE, dtype=np.float32) # initialize padding with 0
    for j, aa in enumerate(sequence.upper()[:max_seq_len]):
        arr[j, AA_TO_INT.get(aa, UNK_IDX)] = 1.0
        # print number of UNKs in the sequence
    # num_unks = np.sum(arr[:, UNK_IDX])
    # zero out gaps
    arr[:, AA_TO_INT['-']] = PAD_VALUE  # Set gaps to PAD_VALUE
    # if num_unks > 0:
    #     print(f"Warning: {num_unks} unknown amino acids in sequence '{sequence}'")
    return arr


def OHE_to_seq(ohe: np.ndarray, gap: bool = False) -> list:
    """
    Convert a one-hot encoded matrix back to a peptide sequence.
    # (B, max_pep_len, 21) -> (B, max_pep_len)
    Args:
        ohe: One-hot encoded matrix of shape (B, N, 21).
    Returns:
        sequence: Peptide sequence as a string. (B,)
    """
    sequence = []
    for i in range(ohe.shape[0]):  # Iterate over batch dimension
        seq = []
        for j in range(ohe.shape[1]):  # Iterate over sequence length
            if gap and np.all(ohe[i, j] == 0):
                seq.append('-')
            else:
                aa_index = np.argmax(ohe[i, j])  # Get index of the max value in one-hot encoding
                if aa_index < len(AA):  # Check if it's a valid amino acid index
                    seq.append(AA[aa_index])
                else:
                    seq.append('X')  # Use 'X' for unknown amino acids
        sequence.append(''.join(seq))  # Join the list into a string
    return sequence  # Return list of sequences


def OHE_to_seq_single(ohe: np.ndarray, gap=False) -> str:
    """
    Convert a one-hot encoded matrix back to a peptide sequence.
    Args:
        ohe: One-hot encoded matrix of shape (N, 21).
    Returns:
        sequence: Peptide sequence as a string.
    """
    seq = []
    for j in range(ohe.shape[0]):  # Iterate over sequence length
        if gap and np.all(ohe[j] == 0):
            seq.append('-')
        else:
            aa_index = np.argmax(ohe[j])  # Get index of the max value in one-hot encoding
            seq.append(AA[aa_index])
    return ''.join(seq)  # Join the list into a string


def peptides_to_onehot_kmer_windows(seq, max_seq_len, k=9, pad_token=-1.0) -> np.ndarray:
    """
    Converts a peptide sequence into a sliding window of k-mers, one-hot encoded.
    Output shape: (RF, k, 21), where RF = max_seq_len - k + 1
    """
    RF = max_seq_len - k + 1
    RFs = np.zeros((RF, k, 21), dtype=np.float32)
    for window in range(RF):
        if window + k <= len(seq):
            kmer = seq[window:window + k]
            for i, aa in enumerate(kmer):
                idx = AA_TO_INT.get(aa, pad_token)
                RFs[window, i, idx] = 1.0
            # Pad remaining positions in k-mer if sequence is too short
            for i in range(len(kmer), k):
                RFs[window, i, pad_token] = 1.0
        else:
            # Entire k-mer is padding if out of sequence
            RFs[window, :, pad_token] = 1.0
    return np.array(RFs)

In [13]:
# generate dummy data
# generate a data that is pep one hot encoding with N=14, d=21.
# mhc_emb with N= 312 and d =1152. a contact map that is a soft weights over 312 mhc positions
# Create dummy data
num_samples = 1000
# Peptide One-Hot Encoding (N=14, d=21)
pep_len = 14
pep_alphabet_size = 20
pep_indices = np.random.randint(0, pep_alphabet_size, size=(num_samples, pep_len))
pep_OHE = tf.one_hot(pep_indices, depth=pep_alphabet_size, dtype=tf.float32)
pep_seq = [OHE_to_seq_single(pep_OHE[i].numpy()) for i in range(num_samples)]
pep_RFs = np.array([peptides_to_onehot_kmer_windows(pep_seq[i], max_seq_len=pep_len, k=9, pad_token=20) for i in range(num_samples)], dtype=np.float32)
pep_RFs_seq = [
    [OHE_to_seq_single(pep_RFs[i][j], gap=True) for j in range(pep_RFs.shape[1])]
    for i in range(num_samples)
]

In [14]:
print(pep_RFs_seq)

[['DEDWYQKEG', 'EDWYQKEGE', 'DWYQKEGET', 'WYQKEGETR', 'YQKEGETRV', 'QKEGETRVP'], ['SEAHSDFFL', 'EAHSDFFLA', 'AHSDFFLAS', 'HSDFFLASY', 'SDFFLASYD', 'DFFLASYDE'], ['ICDGQKPKS', 'CDGQKPKSY', 'DGQKPKSYM', 'GQKPKSYMS', 'QKPKSYMSV', 'KPKSYMSVE'], ['IVENWQFQL', 'VENWQFQLF', 'ENWQFQLFL', 'NWQFQLFLP', 'WQFQLFLPM', 'QFQLFLPME'], ['EVQVECEPI', 'VQVECEPIH', 'QVECEPIHE', 'VECEPIHEV', 'ECEPIHEVG', 'CEPIHEVGK'], ['TDPSIHQMW', 'DPSIHQMWT', 'PSIHQMWTL', 'SIHQMWTLG', 'IHQMWTLGI', 'HQMWTLGIM'], ['SHWMDLYSM', 'HWMDLYSMH', 'WMDLYSMHK', 'MDLYSMHKC', 'DLYSMHKCM', 'LYSMHKCMR'], ['FEWNQACGC', 'EWNQACGCH', 'WNQACGCHQ', 'NQACGCHQL', 'QACGCHQLT', 'ACGCHQLTA'], ['AQCNYFHIR', 'QCNYFHIRC', 'CNYFHIRCY', 'NYFHIRCYI', 'YFHIRCYIR', 'FHIRCYIRH'], ['PWSRWILYN', 'WSRWILYNY', 'SRWILYNYT', 'RWILYNYTS', 'WILYNYTSR', 'ILYNYTSRS'], ['VNPMQPNGK', 'NPMQPNGKC', 'PMQPNGKCA', 'MQPNGKCAA', 'QPNGKCAAC', 'PNGKCAACI'], ['KNNDEYPDE', 'NNDEYPDEM', 'NDEYPDEMM', 'DEYPDEMMM', 'EYPDEMMMA', 'YPDEMMMAD'], ['HGEEICGMA', 'GEEICGMAP', 'EEICGMAPA',