In [None]:
# train the model
import numpy as np
import pandas as pd
import os
from sim import *

state_d_names = 'standard_'
state_d = [[(1,0)],[(0,1)],[(1,1)]]

# state_d_names = 'advanced_'
# state_d = [[(1,0)],[(0,1)],[(1,1),(1,2),(2,1)]]

iterations = 500
seeds = np.arange(0, 3000, 1000, dtype=int)
all_seed_pred = False

trans_restrict = [(1,0)]
use_existing_est = True
temp_data_name = 'full_data_80'
raw_data = pd.read_csv(os.path.join('..','data', temp_data_name+'.csv'),
header=0)

train_data = raw_data[['receptor','peptide','binder']]
store_path = os.path.join('..','result','')

if not os.path.exists(store_path):
    os.mkdir(store_path)
est_para(train_data[train_data['binder'] == 1], store_path = store_path,
        state_d = state_d, iterations = iterations)

In [None]:
############## get the trained model
para_path = os.path.join(store_path, 'ml')
init_dis, trans_mat, px, py, pm, p_mx, p_my, dire_dis= read_para(para_path)
model1 = sim(init_dis = init_dis, trans_mat = trans_mat, 
    px = px, py = py, pm = pm, p_mx = p_mx, p_my = p_my,
    dire_dis = dire_dis)

In [None]:
############## generate peptides that bind to a specific TCR
# x: TCR, y: peptide
aa_kinds_list = list(aa_kinds)
fix_len = 9
burn_in = 50
iterations = 150
tcr_fix_seq = 'CASSIRSSYEQYF'


# generate initial Y seq
initial_seq = ''.join(np.random.choice(aa_kinds_list, size=fix_len, replace = True))
bound_pep_seq = initial_seq
generate_path = os.path.join('..', 'result')


def change_seq(temp_seq:str, change_col, changed_res):
    temp_seq = list(temp_seq)
    temp_seq[change_col] = changed_res
    return ''.join(temp_seq)


def cal_pix(temp_seq:str):
    trans_data = model1.trans_seq([tcr_fix_seq, temp_seq])
    lik = model1.cal_one_lik(trans_data)
    return lik


def cal_gibbs_p(temp_seq:str, change_col):
    p = [cal_pix(change_seq(temp_seq, change_col, changed_res)) for changed_res in aa_kinds_list]
    p = np.array(p)
    p = p/p.sum()
    return p


logliks = []
cur_lik = cal_pix(bound_pep_seq)
logliks.append(np.log(cur_lik))
pep_list = []

for i in range(iterations):
    temp_seq_list = list(bound_pep_seq)

    # Gibbs
    for change_col in range(fix_len):
        changed_p = cal_gibbs_p(bound_pep_seq, change_col)
        changed_res = np.random.choice(aa_kinds_list, p = changed_p, size=None, replace = True)
        temp_seq_list[change_col] = changed_res
        bound_pep_seq = ''.join(temp_seq_list)
        cur_lik = cal_pix(bound_pep_seq)

    logliks.append(np.log(cur_lik))
    if i >burn_in:
        pep_list.append(bound_pep_seq)


pep_df = pd.DataFrame(pep_list,columns=['pep'])
pep_df.to_csv(os.path.join(generate_path, tcr_fix_seq+'_pep.csv'), index=False)

In [None]:
############## generate peptides that bind to multiple TCR
sel_tcrs = ['CASSIRSSYEQYF',
 'CASSWGGGSHYGYTF',
 'CASSFSGNTGELFF',
 'CASSIRSAYEQYF',
 'CASSLRDGSEAFF']

# x: TCR, y: peptide
aa_kinds_list = list(aa_kinds)
fix_len = 9
burn_in = 50
iterations = 150
tcr_n = len(sel_tcrs)
# generate initial Y seq
initial_seq = ''.join(np.random.choice(aa_kinds_list, size=fix_len, replace = True))
initial_z = np.random.choice(tcr_n, size=None, replace = True)
bound_pep_seq = initial_seq
bound_z = initial_z
generate_path = os.path.join('..', 'result')


def cal_pixz(temp_seq, temp_z):
    trans_data = model1.trans_seq([sel_tcrs[temp_z], temp_seq])
    lik = model1.cal_one_lik(trans_data)
    return lik


def cal_gibbs_p(temp_seq:str, change_col, temp_z):
    p = [cal_pixz(change_seq(temp_seq, change_col, changed_res), temp_z) for changed_res in aa_kinds_list]
    p = np.array(p)
    p = p/p.sum()
    return p


logliks = []
cur_lik = cal_pixz(bound_pep_seq, bound_z)
logliks.append(np.log(cur_lik))
pep_list = []

# Hybrid Gibbs Sampling
for i in range(iterations):
    temp_seq_list = list(bound_pep_seq)

    # update x Gibbs
    for change_col in range(fix_len):
        changed_p = cal_gibbs_p(bound_pep_seq, change_col, bound_z)
        changed_res = np.random.choice(aa_kinds_list, p = changed_p, size=None, replace = True)
        temp_seq_list[change_col] = changed_res
        bound_pep_seq = ''.join(temp_seq_list)
        cur_lik = cal_pixz(bound_pep_seq, bound_z)

    logliks.append(np.log(cur_lik))
    if i >burn_in:
        pep_list.append(bound_pep_seq)
    
    
    # update z Gibbs
    pz = []
    for j in range(tcr_n):
        temp_p = cal_pixz(bound_pep_seq, j)
        pz.append(temp_p)
    pz = np.array(pz)
    pz = pz/pz.sum()
    bound_z = np.random.choice(tcr_n, p=pz,size=None, replace = True)


pep_df = pd.DataFrame(pep_list,columns=['pep'])
tcr_fix_seq = 'mix_' + ''.join(sel_tcrs)
pep_df.to_csv(os.path.join(generate_path, tcr_fix_seq+'_pep.csv'), index=False)


In [None]:
############## generate negative data 
pos_n = len(train_data)
pos_df = train_data

# gnerate negative data by px and py
neg_data = model1.generate_neg_data(seq_len = 10, data_size = pos_n)
x_seqs, y_seqs = [], []
for seqs in neg_data:
    x_seqs.append('CASS'+seqs[0]+'F')
    y_seqs.append(seqs[1])
neg_df = pd.DataFrame({'receptor':x_seqs,'peptide':y_seqs,'binder':0})

total_result = pd.concat([pos_df, neg_df],ignore_index=True)
# total_result.to_csv(os.path.join('..','data', data_name+'pos_neg.csv'), 
#                     index=False, doublequote = False)

In [None]:
##############  predict the binding affinity
test_data = total_result
seq_data = []
for _, i in test_data.iterrows():
    seq_data.append([i[0],i[1]])
trans_all_test_data = model1.trans_all_seq(seq_data)


# calculate the likelihood ratio (binding ratio) to predict the binding affinity
pred_score = model1.cal_all_lr(trans_all_test_data)
temp_result = test_data.copy()
temp_result['pred_score'] = pred_score
temp_result
# temp_result.to_csv(store_seed_path+'_'+'pred_score.csv', index=False, doublequote = False)

In [None]:
############## predict the binding residues
# use the Viterbi algorithm to predict the state sequences
# the state result contais columns named: x_state, y_state, all_state
# all_state represents the state sequence for a pair of sequences
# x_state represents the states assigned to sequence X; y_state represents the states assigned to sequence Y

# the positions of 'M' in the x_state column indicate the position of the binding residues in sequence X
# the positions of 'M' in the y_state column indicate the position of the binding residues in sequence Y
state_result = model1.pred_all_seq_state(trans_all_test_data)
state_result
# state_result.to_csv(store_seed_path+'_'+'pred_states.csv', index=False, doublequote = False)