In [1]:
# return to parent directory
import sys 
if ".." not in sys.path:
    sys.path.append("..")

# prepare hla_seq_dict and model
import torch
import pandas as pd
from tape import TAPETokenizer, ProteinBertConfig
from model_ft import meanTAPE

data_path = "/data/lujd/neoag_data/"
hla_seq_dict = pd.read_csv(
    data_path+"main_task/HLA_sequence_dict_ABCEG.csv",
    index_col=0
    ).set_index(["HLA_name"])["clip"].to_dict()

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model_path = "/data/lujd/neoag_model/main_task/TAPE_ft/cat_mean_2mlp/"
model_name = "main_finetune_plm_tape_B32_LR3e-05_seq_clip_fold4_ep51_221104.pkl"

tokenizer = TAPETokenizer(vocab='iupac')
tape_config = ProteinBertConfig.from_pretrained('bert-base')
model = meanTAPE(tape_config, "2mlp").to(device)
model.load_state_dict(torch.load(model_path + model_name), strict = True)
model = model.eval()

In [2]:
import numpy as np

def seq2token(tokenizer, hla_seq, pep_seq, hla_max_len, pep_max_len):
    pep_tokens, hla_pep_tokens = [], []
    
    assert type(hla_seq)==str
    hla_seq = hla_seq.ljust(hla_max_len, 'X')
    hla_token = tokenizer.encode(hla_seq)

    
    if type(pep_seq) == str:
        pep_seq = [pep_seq]
    assert type(pep_seq) == list
    for seq in pep_seq:
        seq = seq.ljust(pep_max_len, 'X')
        pep_tokens.append(tokenizer.encode(seq))    # [array]

        phla_seq = hla_seq + seq
        hla_pep_tokens.append(tokenizer.encode(phla_seq))
    
    return np.array(hla_token), np.array(pep_tokens), np.array(hla_pep_tokens)

In [3]:
given_HLA = "HLA-A*01:01"
init_peptide = "FLCSRRGHL"

HLA_seq = hla_seq_dict[given_HLA]
hla_max_len = 182
pep_max_len = 15
hla_token, pep_tokens, hla_pep_tokens = seq2token(tokenizer, HLA_seq, init_peptide, hla_max_len, pep_max_len)

import torch.nn as nn

with torch.no_grad():
    hla_pep_inputs = torch.LongTensor(hla_pep_tokens).to(device)
    init_prob = model(hla_pep_inputs)
init_prob = nn.Softmax(dim=1)(init_prob)[:, 1].cpu().detach().numpy()   # 1-D
print("HLA: {}, peptide: {} | binding affinity: {:.4f}".format(given_HLA, init_peptide, init_prob.item()))

HLA: HLA-A*01:01, peptide: FLCSRRGHL | binding affinity: 0.0000


---
algorithm 1

In [28]:
amino_acid_list = [ "G", "A", "V", "L", "I",
                    "P", "F", "Y", "W", "S",
                    "T", "C", "M", "N", "Q",
                    "D", "E", "K", "R", "H"]

mutate_pool = []
for ind, amino in enumerate(init_peptide):
    for sub_amino in amino_acid_list:       # replace at the position
        if sub_amino != amino:
            new_peptide = init_peptide[:ind] + sub_amino + init_peptide[ind+1:]
            mutate_pool.append(new_peptide)
mutate_pool = sorted(list(set(mutate_pool)))
print("Iteration-1, mutate_pool size:",len(mutate_pool))

Iteration-1, mutate_pool size: 171


In [34]:
hla_token, pep_tokens, hla_pep_tokens = seq2token(tokenizer, HLA_seq, mutate_pool, hla_max_len, pep_max_len)
print(hla_pep_tokens.shape)
with torch.no_grad():
    hla_pep_inputs = torch.LongTensor(hla_pep_tokens).to(device)
    model_output = model(hla_pep_inputs)
prob = nn.Softmax(dim=1)(model_output)[:, 1].cpu().detach().numpy()   # 1-D

(171, 199)


In [58]:
beam_width = 5
topk_id = np.argsort(prob)[-beam_width:]
source_peptides = [init_peptide]
mutate_peptides = []
mutate_table = []               # (source peptide, mutate peptide, mutate position, source amino, substitution, probablity)
for id in topk_id:
    mutate_peptide = mutate_pool[id]
    mutate_peptides.append(mutate_peptide)
    for source_peptide in source_peptides:
        num_mutation, mutate_position = 0, 0
        source_amino, mutate_amino = "", ""
        assert len(source_peptide)==len(mutate_peptide)
        for position, amino in enumerate(source_peptide):
            if amino != mutate_peptide[position]:
                num_mutation += 1
                mutate_position = position+1
                source_amino = amino
                mutate_amino = mutate_peptide[position]
                
        if num_mutation==1:     # means that we find its father, so no need to continue
            break

    mutate_table.append((source_peptide, mutate_peptide, mutate_position, source_amino, mutate_amino, prob[id]))

for mutate_info in mutate_table:
    print("source peptide: {}, mutated peptide: {} | {} {}->{} | binding probability: {:.4f}".format(
        mutate_info[0], mutate_info[1], mutate_info[2], mutate_info[3], mutate_info[4], mutate_info[5]))
                

source peptide: TIQQCQSPT, mutated peptide: TIQQCQSNT | 8 P->N | binding probability: 0.4029
source peptide: TIQQCQSPT, mutated peptide: TIQQIQSPT | 5 C->I | binding probability: 0.6952
source peptide: TIQQCQSPT, mutated peptide: TIQNCQSPT | 4 Q->N | binding probability: 0.7288
source peptide: TIQQCQSPT, mutated peptide: TIQQCQSYT | 8 P->Y | binding probability: 0.7901
source peptide: TIQQCQSPT, mutated peptide: TIQQCQSFT | 8 P->F | binding probability: 0.8279


In [66]:
source_peptides = mutate_peptides
mutate_pool = []
for source_peptide in source_peptides:
    for ind, amino in enumerate(source_peptide):
        if amino == init_peptide[ind]:              # find non-mutated position
            for sub_amino in amino_acid_list:       # replace amino at the position
                if sub_amino != amino:
                    new_peptide = source_peptide[:ind] + sub_amino + source_peptide[ind+1:]
                    mutate_pool.append(new_peptide)
        else:
            print(ind+1)
    mutate_pool = sorted(list(set(mutate_pool)))
print("Iteration-2, mutate_pool size:",len(mutate_pool))

batch_size = 64
hla_token, pep_tokens, hla_pep_tokens = seq2token(tokenizer, HLA_seq, mutate_pool, hla_max_len, pep_max_len)
print(hla_pep_tokens.shape)
prob_all = []
with torch.no_grad():
    start_index = 0
    end_index = batch_size if len(mutate_pool) > batch_size else len(mutate_pool)

    while end_index <= len(mutate_pool) and start_index < end_index:
        hla_pep_inputs = torch.LongTensor(hla_pep_tokens[start_index:end_index]).to(device)
        print(hla_pep_inputs.shape)
        model_output = model(hla_pep_inputs)
        prob = nn.Softmax(dim=1)(model_output)[:, 1].cpu().detach().numpy() # 1-D
        prob_all.append(prob)
    
        start_index = end_index
        if end_index + batch_size < len(mutate_pool):
            end_index += batch_size
        else:
            end_index = len(mutate_pool)

    prob_all = np.concatenate(prob_all)

8
5
4
8
8
Iteration-2, mutate_pool size: 753
(753, 199)
torch.Size([64, 199])
torch.Size([64, 199])
torch.Size([64, 199])
torch.Size([64, 199])
torch.Size([64, 199])
torch.Size([64, 199])
torch.Size([64, 199])
torch.Size([64, 199])
torch.Size([64, 199])
torch.Size([64, 199])
torch.Size([64, 199])
torch.Size([49, 199])


In [69]:
beam_width = 10
topk_id = np.argsort(prob_all)[-beam_width:]
mutate_peptides = []
mutate_table = []               # (source peptide, mutate peptide, mutate position, source amino, substitution, probablity)
for id in topk_id:
    mutate_peptide = mutate_pool[id]
    mutate_peptides.append(mutate_peptide)
    for source_peptide in source_peptides:
        num_mutation, mutate_position = 0, 0
        source_amino, mutate_amino = "", ""
        assert len(source_peptide)==len(mutate_peptide)
        for position, amino in enumerate(source_peptide):
            if amino != mutate_peptide[position]:
                num_mutation += 1
                mutate_position = position+1
                source_amino = amino
                mutate_amino = mutate_peptide[position]
                
        if num_mutation==1:     # means that we find its father, so no need to continue
            break

    mutate_table.append((source_peptide, mutate_peptide, mutate_position, source_amino, mutate_amino, prob_all[id]))

for mutate_info in mutate_table:
    print("source peptide: {}, mutated peptide: {} | {} {}->{} | binding probability: {:.4f}".format(
        mutate_info[0], mutate_info[1], mutate_info[2], mutate_info[3], mutate_info[4], mutate_info[5]))

source peptide: TIQQCQSNT, mutated peptide: TIQQIQSNT | 5 C->I | binding probability: 0.9923
source peptide: TIQQCQSFT, mutated peptide: TIQQVQSFT | 5 C->V | binding probability: 0.9933
source peptide: TIQQCQSFT, mutated peptide: TIQQNQSFT | 5 C->N | binding probability: 0.9934
source peptide: TIQQCQSNT, mutated peptide: TIQQVQSNT | 5 C->V | binding probability: 0.9939
source peptide: TIQQCQSYT, mutated peptide: TIQQMQSYT | 5 C->M | binding probability: 0.9941
source peptide: TIQQCQSFT, mutated peptide: TIQQMQSFT | 5 C->M | binding probability: 0.9941
source peptide: TIQQCQSYT, mutated peptide: TIQQVQSYT | 5 C->V | binding probability: 0.9956
source peptide: TIQQIQSPT, mutated peptide: TIQQIQSVT | 8 P->V | binding probability: 0.9956
source peptide: TIQQCQSFT, mutated peptide: TIQQLQSFT | 5 C->L | binding probability: 0.9961
source peptide: TIQQCQSYT, mutated peptide: TIQQLQSYT | 5 C->L | binding probability: 0.9967


In [47]:
mutate_peptide = mutate_pool[89]
source_peptide = init_peptide
num_mutation, mutate_position = 0, 0
source_amino, mutate_amino = "", ""
assert len(source_peptide)==len(mutate_peptide)
for position, amino in enumerate(source_peptide):
    if amino != mutate_peptide[position]:
        num_mutation += 1
        source_amino = amino
        mutate_amino = mutate_peptide[position]
        mutate_position = position+1
print(num_mutation,"{} {}->{}".format(mutate_position,source_amino,mutate_amino))

1 8 P->N


In [1]:
import torch
from mutation import get_mutated_peptides

device = torch.device("cuda:6" if torch.cuda.is_available() else "cpu")

given_HLA = "HLA-A*02:02"
init_peptide = "TIQQCQSPT"

mutate_peptides = get_mutated_peptides(
                                        given_HLA, init_peptide, device,
                                        num_mutation=3, num_peptides=5, algorithm=1
                                        )

HLA_seq_dict preparing
Model preparing
given HLA: HLA-A*02:02, given peptide: TIQQCQSPT | binding porbability: 0.0035
************** Run algorithm-1 **************
Iteration-1, mutate_pool size: 171
(171, 199)
torch.Size([64, 199])
torch.Size([64, 199])
torch.Size([43, 199])
source peptide: TIQQCQSPT, mutated peptide: TIQQCQSNT | 8 P->N | binding probability: 0.4029
source peptide: TIQQCQSPT, mutated peptide: TIQQIQSPT | 5 C->I | binding probability: 0.6952
source peptide: TIQQCQSPT, mutated peptide: TIQNCQSPT | 4 Q->N | binding probability: 0.7288
source peptide: TIQQCQSPT, mutated peptide: TIQQCQSYT | 8 P->Y | binding probability: 0.7901
source peptide: TIQQCQSPT, mutated peptide: TIQQCQSFT | 8 P->F | binding probability: 0.8279
8
5
4
8
8
Iteration-2, mutate_pool size: 753
(753, 199)
torch.Size([64, 199])
torch.Size([64, 199])
torch.Size([64, 199])
torch.Size([64, 199])
torch.Size([64, 199])
torch.Size([64, 199])
torch.Size([64, 199])
torch.Size([64, 199])
torch.Size([64, 199])
torch

---
algorithm2a

In [7]:
iteration = 1
beam_width = 5
batch_size = 16
amino_acid_list = [ "G", "A", "V", "L", "I", "P", "F", "Y", "W", "S",
                    "T", "C", "M", "N", "Q", "D", "E", "K", "R", "H"]
hla_max_len = 182
pep_max_len = 15

source_peptides = [init_peptide]

for i in range(iteration):
    # 1. Make a pool of candidate mutated peptides
    mutate_pool = []
    for source_peptide in source_peptides:
        # (1) replace amino with <unk> and calculate saliency
        saliency_all = []
        _, _, hla_spep_token = seq2token(tokenizer, HLA_seq, source_peptide, hla_max_len, pep_max_len)
        for ind, amino in enumerate(source_peptide):
            if amino == init_peptide[ind]:              # find non-mutated position against given peptide
                mask_token = hla_spep_token.copy()      # copy() is necessary
                mask_token[0][hla_max_len+1+ind] = 4    # <unk> is 4
                print(hla_spep_token, mask_token)
                print(hla_spep_token.shape, mask_token.shape)

                with torch.no_grad():
                    hla_pep_inputs = torch.LongTensor(
                        np.concatenate((hla_spep_token, mask_token), axis=0)
                        ).to(device)
                    print(hla_pep_inputs.shape)
                    model_output = model(hla_pep_inputs)
                    prob = nn.Softmax(dim=1)(model_output)[:, 1].cpu().detach().numpy() # 1-D
                    print(prob)
                    saliency = prob[1] - prob[0]
                    print(saliency)
                    saliency_all.append(saliency)
            else:
                print(ind+1)
                saliency_all.append(-1)
        
        # (2) saliency ranking
        saliency_all = np.array(saliency_all)
        print(saliency_all, len(saliency_all))
        mask_position = np.argsort(saliency_all)[-1]    # best position to be replaced
        print(mask_position)

        # (3) # replace amino at the mask_position
        for sub_amino in amino_acid_list:
            if sub_amino != source_peptide[mask_position]:
                new_peptide = source_peptide[:mask_position] + sub_amino + source_peptide[mask_position+1:]
                mutate_pool.append(new_peptide)
                
    mutate_pool = sorted(list(set(mutate_pool)))
    print("Iteration-{}, mutate_pool size: {}".format(i+1, len(mutate_pool)))

    # 2. Use our finetuned TAPE to calculate binding porbability
    # between all peptides in mutate_pool and the given HLA
    _, _, hla_pep_tokens = seq2token(tokenizer, HLA_seq, mutate_pool, hla_max_len, pep_max_len)
    # print(hla_pep_tokens.shape)
    prob_all = []
    with torch.no_grad():
        start_index = 0
        end_index = batch_size if len(mutate_pool) > batch_size else len(mutate_pool)

        while end_index <= len(mutate_pool) and start_index < end_index:
            hla_pep_inputs = torch.LongTensor(hla_pep_tokens[start_index:end_index]).to(device)
            # print(hla_pep_inputs.shape)
            model_output = model(hla_pep_inputs)
            prob = nn.Softmax(dim=1)(model_output)[:, 1].cpu().detach().numpy() # 1-D
            prob_all.append(prob)
        
            start_index = end_index
            if end_index + batch_size < len(mutate_pool):
                end_index += batch_size
            else:
                end_index = len(mutate_pool)

        prob_all = np.concatenate(prob_all)
    
    # 3. Rank and print "topk" messages
    topk_id = np.argsort(prob_all)[-beam_width:]
    mutate_peptides = []
    mutate_table = []               # (source peptide, mutate peptide, mutate position, source amino, substitution, probablity)
    for id in topk_id:
        mutate_peptide = mutate_pool[id]
        mutate_peptides.append(mutate_peptide)

        # find a father of the mutated peptide
        for source_peptide in source_peptides:
            num_mutation, mutate_position = 0, 0
            source_amino, mutate_amino = "", ""
            assert len(source_peptide)==len(mutate_peptide)
            for position, amino in enumerate(source_peptide):
                if amino != mutate_peptide[position]:
                    num_mutation += 1
                    mutate_position = position+1
                    source_amino = amino
                    mutate_amino = mutate_peptide[position]
            if num_mutation==1:     # means that we find its father, so no need to continue
                break
        # record "topk" messages
        mutate_table.append((source_peptide, mutate_peptide, mutate_position, source_amino, mutate_amino, prob_all[id]))
    
    for mutate_info in mutate_table:
        print("source peptide: {}, mutated peptide: {} | {} {}->{} | binding probability: {:.4f}".format(
            mutate_info[0], mutate_info[1], mutate_info[2], mutate_info[3], mutate_info[4], mutate_info[5]))
    
    source_peptides = mutate_peptides       # for next iteration

[[ 2 11 22 12 22 16 21 28 10 10 23 22 25 22 21 19 11 21 11  9 19 21 10 13
   5 25 11 28 25  8  8 23 20 10 25 21 10  8 22  8  5  5 22 20 14 16  9 19
  21  5 19 26 13  9 20  9 11 19  9 28 26  8 20  9 23 21 17 16 14  5 12 22
  20 23  8 21  5 17 15 11 23 15 21 11 28 28 17 20 22  9  8 11 22 12 23 13
  20 13 16 28 11  7  8 25 11 19  8 11 21 10 15 21 11 28 21 20  8  5 28  8
  11 14  8 28 13  5 15 17  9  8 15 21 22 26 23  5  5  8 16  5  5 20 13 23
  14 21 14 26  9  5 25 12  5  5  9 20 21 21 25 28 15  9 11 21  7 25  8 11
  15 21 21 28 15  9 17 11 14  9 23 15 20 21 23 10 15  7 22 21 21 11 12 15
  27 27 27 27 27 27  3]] [[ 2 11 22 12 22 16 21 28 10 10 23 22 25 22 21 19 11 21 11  9 19 21 10 13
   5 25 11 28 25  8  8 23 20 10 25 21 10  8 22  8  5  5 22 20 14 16  9 19
  21  5 19 26 13  9 20  9 11 19  9 28 26  8 20  9 23 21 17 16 14  5 12 22
  20 23  8 21  5 17 15 11 23 15 21 11 28 28 17 20 22  9  8 11 22 12 23 13
  20 13 16 28 11  7  8 25 11 19  8 11 21 10 15 21 11 28 21 20  8  5 28  8
  11 14  8 28

In [9]:
iteration = 3
beam_width = 5
batch_size = 16
amino_acid_list = [ "G", "A", "V", "L", "I", "P", "F", "Y", "W", "S",
                    "T", "C", "M", "N", "Q", "D", "E", "K", "R", "H"]
hla_max_len = 182
pep_max_len = 15

source_peptides = [init_peptide]

for i in range(iteration):
    # 1. Make a pool of candidate mutated peptides
    mutate_pool = []
    for source_peptide in source_peptides:
        # (1) replace amino with <unk> and calculate saliency
        saliency_all = []
        _, _, hla_spep_token = seq2token(tokenizer, HLA_seq, source_peptide, hla_max_len, pep_max_len)
        for ind, amino in enumerate(source_peptide):
            if amino == init_peptide[ind]:              # find non-mutated position against given peptide
                mask_token = hla_spep_token.copy()      # copy() is necessary
                mask_token[0][hla_max_len+1+ind] = 4    # <unk> is 4
                # print(hla_spep_token, mask_token)
                # print(hla_spep_token.shape, mask_token.shape)

                with torch.no_grad():
                    hla_pep_inputs = torch.LongTensor(
                        np.concatenate((hla_spep_token, mask_token), axis=0)
                        ).to(device)
                    print(hla_pep_inputs.shape)
                    model_output = model(hla_pep_inputs)
                    prob = nn.Softmax(dim=1)(model_output)[:, 1].cpu().detach().numpy() # 1-D
                    print(prob)
                    saliency = prob[1] - prob[0]        # masked - origin
                    print(saliency)
                    saliency_all.append(saliency)
            else:
                print(ind+1)
                saliency_all.append(-1)
        
        # (2) saliency ranking
        saliency_all = np.array(saliency_all)
        print(saliency_all, len(saliency_all))
        mask_position = np.argsort(saliency_all)[-1]    # best position to be replaced
        print(mask_position+1)

        # (3) # replace amino at the mask_position
        for sub_amino in amino_acid_list:
            if sub_amino != source_peptide[mask_position]:
                new_peptide = source_peptide[:mask_position] + sub_amino + source_peptide[mask_position+1:]
                mutate_pool.append(new_peptide)
                
    mutate_pool = sorted(list(set(mutate_pool)))
    print("Iteration-{}, mutate_pool size: {}".format(i+1, len(mutate_pool)))

    # 2. Use our finetuned TAPE to calculate binding porbability
    # between all peptides in mutate_pool and the given HLA
    _, _, hla_pep_tokens = seq2token(tokenizer, HLA_seq, mutate_pool, hla_max_len, pep_max_len)
    # print(hla_pep_tokens.shape)
    prob_all = []
    with torch.no_grad():
        start_index = 0
        end_index = batch_size if len(mutate_pool) > batch_size else len(mutate_pool)

        while end_index <= len(mutate_pool) and start_index < end_index:
            hla_pep_inputs = torch.LongTensor(hla_pep_tokens[start_index:end_index]).to(device)
            # print(hla_pep_inputs.shape)
            model_output = model(hla_pep_inputs)
            prob = nn.Softmax(dim=1)(model_output)[:, 1].cpu().detach().numpy() # 1-D
            prob_all.append(prob)
        
            start_index = end_index
            if end_index + batch_size < len(mutate_pool):
                end_index += batch_size
            else:
                end_index = len(mutate_pool)

        prob_all = np.concatenate(prob_all)
    
    # 3. Rank and print "topk" messages
    topk_id = np.argsort(prob_all)[-beam_width:]
    mutate_peptides = []
    mutate_table = []               # (source peptide, mutate peptide, mutate position, source amino, substitution, probablity)
    for id in topk_id:
        mutate_peptide = mutate_pool[id]
        mutate_peptides.append(mutate_peptide)

        # find a father of the mutated peptide
        for source_peptide in source_peptides:
            num_mutation, mutate_position = 0, 0
            source_amino, mutate_amino = "", ""
            assert len(source_peptide)==len(mutate_peptide)
            for position, amino in enumerate(source_peptide):
                if amino != mutate_peptide[position]:
                    num_mutation += 1
                    mutate_position = position+1
                    source_amino = amino
                    mutate_amino = mutate_peptide[position]
            if num_mutation==1:     # means that we find its father, so no need to continue
                break
        # record "topk" messages
        mutate_table.append((source_peptide, mutate_peptide, mutate_position, source_amino, mutate_amino, prob_all[id]))
    
    for mutate_info in mutate_table:
        print("source peptide: {}, mutated peptide: {} | {} {}->{} | binding probability: {:.4f}".format(
            mutate_info[0], mutate_info[1], mutate_info[2], mutate_info[3], mutate_info[4], mutate_info[5]))
    
    source_peptides = mutate_peptides       # for next iteration

torch.Size([2, 199])
[1.0547332e-05 2.3249698e-04]
0.00022194965
torch.Size([2, 199])
[1.0547332e-05 4.4355297e-06]
-6.111802e-06
torch.Size([2, 199])
[1.0547332e-05 2.2422773e-04]
0.0002136804
torch.Size([2, 199])
[1.0547332e-05 8.5823649e-06]
-1.964967e-06
torch.Size([2, 199])
[1.0547332e-05 5.7079604e-05]
4.6532274e-05
torch.Size([2, 199])
[1.05473318e-05 1.43992065e-05]
3.8518747e-06
torch.Size([2, 199])
[1.0547332e-05 1.0267263e-05]
-2.800689e-07
torch.Size([2, 199])
[1.0547332e-05 8.0203975e-04]
0.0007914924
torch.Size([2, 199])
[1.0547332e-05 6.7708906e-06]
-3.7764412e-06
[ 2.2194965e-04 -6.1118021e-06  2.1368040e-04 -1.9649669e-06
  4.6532274e-05  3.8518747e-06 -2.8006889e-07  7.9149241e-04
 -3.7764412e-06] 9
8
Iteration-1, mutate_pool size: 19
source peptide: FLCSRRGHL, mutated peptide: FLCSRRGVL | 8 H->V | binding probability: 0.0003
source peptide: FLCSRRGHL, mutated peptide: FLCSRRGIL | 8 H->I | binding probability: 0.0005
source peptide: FLCSRRGHL, mutated peptide: FLCSRRG

---
algorithm-2b

In [9]:
iteration = 5
beam_width = 5
batch_size = 16
amino_acid_list = [ "G", "A", "V", "L", "I", "P", "F", "Y", "W", "S",
                    "T", "C", "M", "N", "Q", "D", "E", "K", "R", "H"]
hla_max_len = 182
pep_max_len = 15

source_peptides = [init_peptide]
output_peptides = []

for i in range(iteration):
    # 1. Make a pool of candidate mutated peptides
    # (1) replace amino with <unk> and calculate saliency
    mutant_pool = []
    saliency_all = []
    for source_peptide in source_peptides:
        saliency_single = []
        _, _, hla_spep_token = seq2token(tokenizer, HLA_seq, source_peptide, hla_max_len, pep_max_len)
        for ind, amino in enumerate(source_peptide):
            if amino == init_peptide[ind]:              # find non-mutated position against given peptide
                mask_token = hla_spep_token.copy()      # copy() is necessary
                mask_token[0][hla_max_len+1+ind] = 4    # <unk> is 4
                # print(hla_spep_token, mask_token)
                # print(hla_spep_token.shape, mask_token.shape)

                with torch.no_grad():
                    hla_pep_inputs = torch.LongTensor(
                        np.concatenate((hla_spep_token, mask_token), axis=0)
                        ).to(device)
                    # print(hla_pep_inputs.shape)
                    model_output = model(hla_pep_inputs)
                    prob = nn.Softmax(dim=1)(model_output)[:, 1].cpu().detach().numpy() # 1-D
                    # print(prob)
                    saliency = prob[1] - prob[0]        # masked - origin
                    # print(saliency)
                    saliency_single.append(saliency)
            else:
                print(ind+1)
                saliency_single.append(-1)
        # print(saliency_single)
        saliency_all.append(saliency_single)
        
    # (2) calculate average saliency and rank
    saliency_all = np.array(saliency_all)
    print(saliency_all, saliency_all.shape)
    saliency_all = np.mean(saliency_all, axis=0)
    print(saliency_all, saliency_all.shape)
    mask_position = np.argsort(saliency_all)[-1]    # best position to be replaced
    print(mask_position+1)

    # (3) # replace amino at the mask_position
    for source_peptide in source_peptides:
        for sub_amino in amino_acid_list:
            if sub_amino != source_peptide[mask_position]:
                new_peptide = source_peptide[:mask_position] + sub_amino + source_peptide[mask_position+1:]
                mutant_pool.append(new_peptide)
                
    mutant_pool = sorted(list(set(mutant_pool)))
    print("Iteration-{}, mutant_pool size: {}".format(i+1, len(mutant_pool)))

    # 2. Use our finetuned TAPE to calculate binding porbability
    # between all peptides in mutant_pool and the given HLA
    _, _, hla_pep_tokens = seq2token(tokenizer, HLA_seq, mutant_pool, hla_max_len, pep_max_len)
    # print(hla_pep_tokens.shape)
    prob_all = []
    with torch.no_grad():
        start_index = 0
        end_index = batch_size if len(mutant_pool) > batch_size else len(mutant_pool)

        while end_index <= len(mutant_pool) and start_index < end_index:
            hla_pep_inputs = torch.LongTensor(hla_pep_tokens[start_index:end_index]).to(device)
            # print(hla_pep_inputs.shape)
            model_output = model(hla_pep_inputs)
            prob = nn.Softmax(dim=1)(model_output)[:, 1].cpu().detach().numpy() # 1-D
            prob_all.append(prob)
        
            start_index = end_index
            if end_index + batch_size < len(mutant_pool):
                end_index += batch_size
            else:
                end_index = len(mutant_pool)

        prob_all = np.concatenate(prob_all)
    
    # 3. Rank and print "topk" messages
    topk_id = np.argsort(prob_all)[-beam_width:]
    mutant_peptides = []
    mutate_table = []               # (source peptide, mutate peptide, mutate position, source amino, substitution, probablity)
    for id in topk_id:
        mutant_peptide = mutant_pool[id]
        mutant_peptides.append(mutant_peptide)

        # find a father of the mutated peptide
        for source_peptide in source_peptides:
            num_mutation, mutate_position = 0, 0
            source_amino, mutate_amino = "", ""
            assert len(source_peptide)==len(mutant_peptide)
            for position, amino in enumerate(source_peptide):
                if amino != mutant_peptide[position]:
                    num_mutation += 1
                    mutate_position = position+1
                    source_amino = amino
                    mutate_amino = mutant_peptide[position]
            if num_mutation==1:     # means that we find its father, so no need to continue
                break
        # record "topk" messages
        mutate_table.append((source_peptide, mutant_peptide, mutate_position, source_amino, mutate_amino, prob_all[id]))
    
    for order, mutate_info in enumerate(mutate_table):
        print("source peptide: {}, mutated peptide: {} | {} {}->{} | binding probability: {:.4f}".format(
            mutate_info[0], mutate_info[1], mutate_info[2], mutate_info[3], mutate_info[4], mutate_info[5]))

    source_peptides = mutant_peptides       # for next iteration
    output_peptides = output_peptides+source_peptides

[[ 2.2194965e-04 -6.1118021e-06  2.1368040e-04 -1.9649669e-06
   4.6532274e-05  3.8518747e-06 -2.8006889e-07  7.9149241e-04
  -3.7764412e-06]] (1, 9)
[ 2.2194965e-04 -6.1118021e-06  2.1368040e-04 -1.9649669e-06
  4.6532274e-05  3.8518747e-06 -2.8006889e-07  7.9149241e-04
 -3.7764412e-06] (9,)
8
Iteration-1, mutant_pool size: 19
source peptide: FLCSRRGHL, mutated peptide: FLCSRRGVL | 8 H->V | binding probability: 0.0003
source peptide: FLCSRRGHL, mutated peptide: FLCSRRGIL | 8 H->I | binding probability: 0.0005
source peptide: FLCSRRGHL, mutated peptide: FLCSRRGEL | 8 H->E | binding probability: 0.0005
source peptide: FLCSRRGHL, mutated peptide: FLCSRRGLL | 8 H->L | binding probability: 0.0010
source peptide: FLCSRRGHL, mutated peptide: FLCSRRGML | 8 H->M | binding probability: 0.0015
8
8
8
8
8
[[ 2.48277723e-03 -5.12296101e-05  6.24459470e-04 -7.24025886e-05
   7.70053011e-04 -3.24750727e-05 -9.63359344e-05 -1.00000000e+00
  -1.24187805e-04]
 [ 1.57261733e-03 -2.85052665e-04  1.8273319

---
algorithm1b

In [23]:
iteration = 3
beam_width = 5
batch_size = 64
amino_acid_list = [ "G", "A", "V", "L", "I", "P", "F", "Y", "W", "S",
                    "T", "C", "M", "N", "Q", "D", "E", "K", "R", "H"]
hla_max_len = 182
pep_max_len = 15

source_peptides = [init_peptide]
output_peptides = []

for i in range(iteration):
    # 1. Make a pool of candidate mutated peptides
    mutant_pool = []
    flag = 0
    frozen_position = []
    for source_peptide in source_peptides:
        for ind, amino in enumerate(source_peptide):
            if amino == init_peptide[ind]:              # find non-mutated position against given peptide
                for sub_amino in amino_acid_list:       # replace amino at the position
                    if sub_amino != amino:
                        new_peptide = source_peptide[:ind] + sub_amino + source_peptide[ind+1:]
                        mutant_pool.append(new_peptide)
            else:
                if flag == 0:
                    frozen_position.append(ind)
        flag = 1
    print("Iteration-{}, mutant_pool size: {}".format(i+1, len(mutant_pool)))
    print(frozen_position)

    # 2. Use our finetuned TAPE to calculate binding porbability
    # between all peptides in mutant_pool and the given HLA
    _, _, hla_pep_tokens = seq2token(tokenizer, HLA_seq, mutant_pool, hla_max_len, pep_max_len)
    # print(hla_pep_tokens.shape)
    prob_all, score_all = [], []
    with torch.no_grad():
        start_index = 0
        end_index = batch_size if len(mutant_pool) > batch_size else len(mutant_pool)

        while end_index <= len(mutant_pool) and start_index < end_index:
            hla_pep_inputs = torch.LongTensor(hla_pep_tokens[start_index:end_index]).to(device)
            # print(hla_pep_inputs.shape)
            model_output = model(hla_pep_inputs)
            prob = nn.Softmax(dim=1)(model_output)[:, 1].cpu().detach().numpy() # 1-D
            score = (model_output[:, 1] - model_output[:, 0]).cpu().detach().numpy()
            prob_all.append(prob)
            score_all.append(score)

            start_index = end_index
            if end_index + batch_size < len(mutant_pool):
                end_index += batch_size
            else:
                end_index = len(mutant_pool)

        prob_all = np.concatenate(prob_all)
        score_all = np.concatenate(score_all)
    
    # 3. Rank and choose topk at each position, then average and choose the best position
    sorted_id = np.argsort(-score_all)
    id_table = np.zeros((len(init_peptide), beam_width), dtype=int)
    prob_table = np.zeros((len(init_peptide), beam_width))
    num_record = np.zeros(len(init_peptide), dtype=int)
    for pos in frozen_position:
        num_record[pos] = beam_width        # when sum(num_record==beam_width)==len(init_peptide), stop searching
    print(num_record, id_table.shape)
    for id in sorted_id:
        mutant_peptide = mutant_pool[id]
        
        # find a father of the mutated peptide
        num_mutation, mutate_position = 0, 0
        for source_peptide in source_peptides:
            for position, amino in enumerate(source_peptide):
                if amino != mutant_peptide[position]:
                    num_mutation += 1
                    mutate_position = position
            if num_mutation==1:     # means that we find its father, so no need to continue
                break
        
        # record
        if num_record[mutate_position] < beam_width:
            id_table[mutate_position, num_record[mutate_position]] = id
            prob_table[mutate_position, num_record[mutate_position]] = prob_all[id]
            num_record[mutate_position] += 1
        
        # when to stop
        if np.sum(num_record==beam_width)==len(init_peptide):
            break
    print(id_table)
    print(prob_table)
    prob_table = np.mean(prob_table, axis=1)
    print(prob_table)
    best_position = np.argsort(prob_table)[-1]
    print(best_position, id_table[best_position])

    # 4. print "topk" messages
    mutant_peptides = []
    mutate_table = []               # (source peptide, mutate peptide, mutate position, source amino, substitution, probablity)
    for id in id_table[best_position]:
        mutant_peptide = mutant_pool[id]
        mutant_peptides.append(mutant_peptide)

        # find a father of the mutated peptide
        for source_peptide in source_peptides:
            num_mutation, mutate_position = 0, 0
            source_amino, mutate_amino = "", ""
            assert len(source_peptide)==len(mutant_peptide)
            for position, amino in enumerate(source_peptide):
                if amino != mutant_peptide[position]:
                    num_mutation += 1
                    mutate_position = position+1
                    source_amino = amino
                    mutate_amino = mutant_peptide[position]
            if num_mutation==1:     # means that we find its father, so no need to continue
                break
        # record "topk" messages
        mutate_table.append((source_peptide, mutant_peptide, mutate_position, source_amino, mutate_amino, prob_all[id]))
    
    for order, mutate_info in enumerate(mutate_table):
        print("source peptide: {}, mutated peptide: {} | {} {}->{} | binding probability: {:.6f}".format(
            mutate_info[0], mutate_info[1], mutate_info[2], mutate_info[3], mutate_info[4], mutate_info[5]))
        # record_file.write(f"\n>Mutate{i+1}_{order+1}\n{mutate_info[1]}")

    source_peptides = mutant_peptides       # for next iteration
    output_peptides = output_peptides+source_peptides

Iteration-1, mutant_pool size: 171
[]
[0 0 0 0 0 0 0 0 0] (9, 5)
[[  8  12  15   1  13]
 [ 30  26  22  21  24]
 [ 52  53  46  42  47]
 [ 63  71  58  68  65]
 [ 80  78  79  82  84]
 [ 99  98 110 109 100]
 [114 119 122 115 120]
 [145 136 149 137 135]
 [158 157 159 163 154]]
[[1.34379568e-03 1.23728963e-03 5.92322845e-04 4.06541309e-04
  1.81514290e-04]
 [2.76961837e-05 2.38546181e-05 1.85516037e-05 1.27769790e-05
  1.21431840e-05]
 [8.20583463e-01 5.51627390e-02 1.18687516e-02 4.45287535e-03
  2.36564665e-03]
 [2.50874054e-05 2.28368808e-05 1.72949894e-05 1.50619626e-05
  1.15439543e-05]
 [5.00832126e-03 4.63556685e-03 3.44622019e-03 4.12000401e-04
  3.09500843e-04]
 [1.79030690e-02 6.97583251e-04 1.85329613e-04 1.03554572e-04
  4.06395047e-05]
 [9.92968329e-04 3.64434556e-04 8.10366109e-05 5.77706342e-05
  5.06014658e-05]
 [1.47930451e-03 1.03686389e-03 5.38847758e-04 5.31861093e-04
  2.73193989e-04]
 [9.72742662e-02 3.12433491e-04 2.40415466e-05 9.83095742e-06
  4.24423979e-06]]
[7.522

-----
# algorithm2+

In [19]:
iteration = 5
beam_width = 5
amino_acid_list = [ "G", "A", "V", "L", "I", "P", "F", "Y", "W", "S",
                    "T", "C", "M", "N", "Q", "D", "E", "K", "R", "H"]
hla_max_len = 182
pep_max_len = 15

batch_size = 16
min_len = 8
max_len = 14
source_peptides = [init_peptide]
output_peptides = []

for i in range(iteration):
    # 1. Make a pool of candidate mutated peptides
    mutant_pool = []
    for source_peptide in source_peptides:
        # 1.1 replace and delete
        # (1) replace amino with <unk> and calculate saliency
        saliency1_all = []
        _, _, hla_spep_token = seq2token(tokenizer, HLA_seq, source_peptide, hla_max_len, pep_max_len)
        for ind, amino in enumerate(source_peptide):
            mask_token = hla_spep_token.copy()      # copy() is necessary
            mask_token[0][hla_max_len+1+ind] = 4    # <unk> is 4
            # print(hla_spep_token, mask_token)
            # print(hla_spep_token.shape, mask_token.shape)

            with torch.no_grad():
                hla_pep_inputs = torch.LongTensor(
                    np.concatenate((hla_spep_token, mask_token), axis=0)
                    ).to(device)
                # print(hla_pep_inputs.shape)
                model_output = model(hla_pep_inputs)
                prob = nn.Softmax(dim=1)(model_output)[:, 1].cpu().detach().numpy() # 1-D
                # print(prob)
                saliency = prob[1] - prob[0]        # masked - origin
                # print(saliency)
                saliency1_all.append(saliency)
        
        # (2) saliency ranking
        saliency1_all = np.array(saliency1_all)
        # print(saliency1_all, len(saliency1_all))
        mask_position = np.argsort(saliency1_all)[-1]    # best position to be replaced
        # print(mask_position+1)

        # (3) replace amino at the mask_position
        for sub_amino in amino_acid_list:
            if sub_amino != source_peptide[mask_position]:
                new_peptide = source_peptide[:mask_position] + sub_amino + source_peptide[mask_position+1:]
                mutant_pool.append(new_peptide)

        # (4) delete the amino at the mask_position
        if len(source_peptide) > min_len:
            new_peptide = source_peptide[:mask_position] + source_peptide[mask_position+1:]
            # print(new_peptide)
            mutant_pool.append(new_peptide)

        # 1.2 insert
        if len(source_peptide) < max_len:
            # (1) insert <unk> and calculate saliency
            saliency2_all = []
            # print(hla_spep_token, hla_spep_token.shape)
            for ind in range(len(source_peptide)+1):
                mask_token = hla_spep_token.copy()      # copy() is necessary
                mask_token = np.insert(mask_token, hla_max_len+1+ind, 4, axis=1)    # <unk> is 4
                # print(mask_token, mask_token.shape)
                mask_token = np.delete(mask_token, hla_max_len+pep_max_len, axis=1)
                # print(mask_token, mask_token.shape)

                with torch.no_grad():
                    hla_pep_inputs = torch.LongTensor(
                        np.concatenate((hla_spep_token, mask_token), axis=0)
                        ).to(device)
                    # print(hla_pep_inputs.shape)
                    model_output = model(hla_pep_inputs)
                    prob = nn.Softmax(dim=1)(model_output)[:, 1].cpu().detach().numpy() # 1-D
                    # print(prob)
                    saliency = prob[1] - prob[0]        # masked - origin
                    # print(saliency)
                    saliency2_all.append(saliency)
            
            # (2) saliency ranking
            saliency2_all = np.array(saliency2_all)
            # print(saliency2_all, len(saliency2_all))
            mask_position = np.argsort(saliency2_all)[-1]    # best position to be replaced
            # print(mask_position+1)

            # (3) insert amino at the mask_position
            for sub_amino in amino_acid_list:
                # if sub_amino != source_peptide[mask_position]:
                new_peptide = source_peptide[:mask_position] + sub_amino + source_peptide[mask_position:]
                # print(new_peptide)
                mutant_pool.append(new_peptide)
                
    mutant_pool = sorted(list(set(mutant_pool)))
    print("Iteration-{}, mutant_pool size: {}".format(i+1, len(mutant_pool)))

    # 2. Use our finetuned TAPE to calculate binding porbability
    # between all peptides in mutant_pool and the given HLA
    _, _, hla_pep_tokens = seq2token(tokenizer, HLA_seq, mutant_pool, hla_max_len, pep_max_len)
    # print(hla_pep_tokens.shape)
    prob_all, score_all = [], []
    with torch.no_grad():
        start_index = 0
        end_index = batch_size if len(mutant_pool) > batch_size else len(mutant_pool)

        while end_index <= len(mutant_pool) and start_index < end_index:
            hla_pep_inputs = torch.LongTensor(hla_pep_tokens[start_index:end_index]).to(device)
            # print(hla_pep_inputs.shape)
            model_output = model(hla_pep_inputs)
            prob = nn.Softmax(dim=1)(model_output)[:, 1].cpu().detach().numpy() # 1-D
            score = (model_output[:, 1] - model_output[:, 0]).cpu().detach().numpy()
            prob_all.append(prob)
            score_all.append(score)
        
            start_index = end_index
            if end_index + batch_size < len(mutant_pool):
                end_index += batch_size
            else:
                end_index = len(mutant_pool)

        prob_all = np.concatenate(prob_all)
        score_all = np.concatenate(score_all)
    
    # 3. Rank(score, not prob) and print "topk" messages
    topk_id = np.argsort(score_all)[-beam_width:]
    mutant_peptides = []
    for order, id in enumerate(topk_id):
        mutant_peptides.append(mutant_pool[id])
        print("mutated peptide: {} | length: {} | binding probability: {:.4f}".format(
            mutant_pool[id], len(mutant_pool[id]), prob_all[id]))
        # record_file.write(f"\n>Mutate{i+1}_{order+1}\n{mutate_info[1]}")

    source_peptides = mutant_peptides       # for next iteration
    output_peptides = output_peptides+source_peptides


Iteration-1, mutant_pool size: 40
mutated peptide: FLCSRRGLHL | length: 10 | binding probability: 0.0007
mutated peptide: FLCSRRGLL | length: 9 | binding probability: 0.0010
mutated peptide: FLCSRRGML | length: 9 | binding probability: 0.0015
mutated peptide: FLCSRRGAHL | length: 10 | binding probability: 0.0023
mutated peptide: FLCSRRGIHL | length: 10 | binding probability: 0.0025
Iteration-2, mutant_pool size: 198
mutated peptide: FLCSRRGAHLL | length: 11 | binding probability: 0.6326
mutated peptide: FLCSRRPGLHL | length: 11 | binding probability: 0.7582
mutated peptide: FLCSRRPGIHL | length: 11 | binding probability: 0.8923
mutated peptide: FLCSRRGAHLF | length: 11 | binding probability: 0.9602
mutated peptide: FLCSRRGAHLY | length: 11 | binding probability: 0.9984
Iteration-3, mutant_pool size: 198
mutated peptide: FLCSRIGAHLY | length: 11 | binding probability: 1.0000
mutated peptide: FLDSRRPGLHL | length: 11 | binding probability: 1.0000
mutated peptide: FLDSRRGAHLL | length: 11

In [10]:
a = np.ones((5,5))
b = np.insert(a,1,4,axis=1)
b

array([[1., 4., 1., 1., 1., 1.],
       [1., 4., 1., 1., 1., 1.],
       [1., 4., 1., 1., 1., 1.],
       [1., 4., 1., 1., 1., 1.],
       [1., 4., 1., 1., 1., 1.]])

In [8]:
c = np.delete(b,1,axis=1)
c

array([[1., 1., 1., 1., 1.],
       [1., 1., 1., 1., 1.],
       [1., 1., 1., 1., 1.],
       [1., 1., 1., 1., 1.],
       [1., 1., 1., 1., 1.]])

In [14]:
a = [1,1,1,1,1,1,1,1,1]
a+a[9:]

[1, 1, 1, 1, 1, 1, 1, 1, 1]

In [20]:
for i in range(8,14):
    print(i)

8
9
10
11
12
13
