# Tutorial for identifying neoantigens of clinical cancer patients with deepAntigen

## Import relevant packages

In [1]:
import pandas as pd
import pickle
import numpy as np
from rdkit import Chem
from tqdm import tqdm
import torch
import torch.backends.cudnn as cudnn
from torch_geometric.data import Batch
from torch_geometric import data as DATA
import torch.multiprocessing as mp
from multiprocessing import Process
from deepAntigen.antigenHLAI.load_dataset.featurizer import MolGraphConvFeaturizer
from deepAntigen.antigenHLAI.networks.pHLAI_seq import DeepGCN as PHLA
from deepAntigen.antigenTCR.networks.pTCR_seq import DeepGCN as PTCR

In [2]:
device = torch.device("cuda:{}".format(0) if torch.cuda.is_available() else "cpu")

## Define some utils for loading model and processing data

In [3]:
def set_model(model_path, model_name):
    state = torch.load(model_path)
    pretrain_state_dict = state['model']
    args = state['opt']
    if model_name == 'PHLA':
        model = PHLA(args)
    elif model_name == 'PTCR':
        model = PTCR(args)
    else:
        pass
    model.load_state_dict(pretrain_state_dict)
    if torch.cuda.is_available():
        model = model.to(device)
        cudnn.benchmark = True
    return model

def predict_PHLA(model,peptide_graph,pseudo_graph):
    with torch.no_grad():
        model.eval()
        peptide_graphs = Batch.from_data_list([peptide_graph])
        peptide_graphs = peptide_graphs.to(device)
        pseudo_graphs = Batch.from_data_list([pseudo_graph])
        pseudo_graphs = pseudo_graphs.to(device)
        logits = model(peptide_graphs, pseudo_graphs)
        preds = logits.argmax(dim=1)
        scores = logits[:,1]
    return scores.item()

def predict_PTCR(model,peptide_graphs,cdr3_graphs):
    with torch.no_grad():
        model.eval()
        peptide_graphs = Batch.from_data_list(peptide_graphs)
        peptide_graphs = peptide_graphs.to(device)
        cdr3_graphs = Batch.from_data_list(cdr3_graphs)
        cdr3_graphs = cdr3_graphs.to(device)
        logits = model(peptide_graphs, cdr3_graphs)
        preds = logits.argmax(dim=1)
        scores = logits[:,1]
    return list(scores.detach().cpu().numpy())

def generateGraph(seqs,threading_num):
    seq_set = set(seqs)
    seq_graph = {}
    threading_num = min(threading_num, len(seq_set))
    seq_manager = mp.Manager()
    seq_queue = seq_manager.list([])
    processes = []
    chunked = chunk_set(seq_set,threading_num)
    for i in range(threading_num):
        process = Process(target=generateGraph_subprocess, args=(chunked[i], seq_queue))
        process.start()
        processes.append(process)
    for process in processes:
        process.join()
    for graph_dict in seq_queue:
        seq_graph.update(pickle.loads(graph_dict))
    return seq_graph

def generateGraph_subprocess(seqs, queue):
    manager = mp.Manager()
    graphs={}
    featurizer = MolGraphConvFeaturizer(use_edges=True)
    for i,seq in tqdm(enumerate(seqs)):
        seq_chem = Chem.MolFromSequence(seq)
        seq_feature = featurizer._featurize(seq_chem)
        feature, edge_index, edge_feature = seq_feature.node_features, seq_feature.edge_index, seq_feature.edge_features
        GCNData = DATA.Data(x=torch.Tensor(feature), edge_index=torch.LongTensor(edge_index), edge_attr=torch.Tensor(edge_feature))
        graphs[seq]=GCNData
    graph_serialized = pickle.dumps(graphs)
    queue.append(graph_serialized)

def chunk_set(seq_set, n):
    seq_list = list(seq_set)
    avg_chunk_size = len(seq_list) // n
    remainder = len(seq_list) % n

    chunked = []
    start = 0
    for i in range(n):
        if i!=n-1:
            end = start + avg_chunk_size
            chunked.append(seq_list[start:end])
            start = end
        else:
            chunked.append(seq_list[start:])
    return chunked

def check(seq):
    AAstringList=list('ACDEFGHIKLMNPQRSTVWY')
    i = 0
    for aa in seq:
        if aa not in AAstringList:
            break
        else:
            i += 1
    if i == len(seq):
        return False
    else:
        return True

## Input variant peptides and HLA information of clinical cancer patients

In [11]:
variant_peptides = pd.read_csv('./clinical_cancer_patients/#P980589/#P980589_variant_peptides.csv', header=0)
HLA = pd.read_csv('./clinical_cancer_patients/#P980589/#P980589_hla.csv',header=0)
hla_pseudo = pd.read_csv('./deepAntigen/antigenHLAI/load_dataset/hlaI_pseudo_seq.csv', header=0)

## Identify candidate peptides that can be presented by HLA with deepAntigen

In [5]:
pHLA_model = set_model('./deepAntigen/antigenHLAI/Weights/seq-level_parameters.pt', 'PHLA')
featurizer = MolGraphConvFeaturizer(use_edges=True)
present_peptides = []
present_scores = []
peptide_list = []
pseudo_list = []
peptide_list = variant_peptides['MutationPeptide'].values
peptide_list2 = []
for peptide in peptide_list:
    if check(peptide):
        continue
    else:
        peptide_list2.append(peptide)
tmHLA = HLA[HLA['class'].isin(['A', 'B', 'C', 'G'])].copy()

for allele in tmHLA['type'].values:
    num = allele.count(':')
    if num>=2:
        allele = allele[4:-3]
    else:
        allele = allele[4:]
    pseudo = hla_pseudo.loc[hla_pseudo['allele']==allele, 'sequence'].iloc[0]
    pseudo_list.append(pseudo)

peptide_graphs = generateGraph(set(peptide_list2), 8)
pseudo_graphs = generateGraph(set(pseudo_list), 8)
for peptide in set(peptide_list2):
    scores = []
    peptide_graph = peptide_graphs[peptide]
    for pseudo in set(pseudo_list):
        if check(pseudo):
            print(pseudo)
            continue
        pseudo_graph = pseudo_graphs[pseudo]
        score = predict_PHLA(pHLA_model,peptide_graph,pseudo_graph)
        scores.append(score)
    # print(scores.avg)
    present_peptides.append(peptide)
    present_scores.append(np.max(np.array(scores)))
del pHLA_model
torch.cuda.empty_cache()
candidate_peptides = pd.DataFrame({"peptide":present_peptides,"scores":present_scores},index=list(range(len(present_peptides))))
candidate_peptides = candidate_peptides[candidate_peptides['scores']>0.5].copy()
candidate_peptides.sort_values('scores',ascending=False,inplace=True)

310it [00:08, 36.44it/s]
310it [00:08, 35.34it/s]
310it [00:08, 34.76it/s]
310it [00:09, 34.41it/s]
301it [00:08, 33.61it/s]
310it [00:08, 35.52it/s]
316it [00:08, 35.26it/s]
310it [00:09, 34.16it/s]
1it [00:00,  1.17it/s]
1it [00:00,  1.24it/s]
1it [00:00,  1.17it/s]
1it [00:00,  1.17it/s]
1it [00:00,  1.26it/s]
1it [00:01,  1.03s/it]


In [6]:
candidate_peptides

Unnamed: 0,peptide,scores
2193,RADFDDTVTY,0.999678
16,SLLAHIWSL,0.999384
767,VFDCIINM,0.999355
2217,FLWLLPVQL,0.999256
1095,ASYTVQAKY,0.999150
...,...,...
191,LLMKQHGSA,0.508551
99,ASPAAAIPA,0.506944
1377,KNLLVLCVI,0.504052
1682,VSPAAAPQA,0.500519


##  Input TCR repertoire of clinical cancer patients

In [7]:
cdr3df = pd.read_csv('./clinical_cancer_patients/#P980589/#P980589_tcrs.csv',header=0)
cdr3df = cdr3df.sort_values(by=['readFraction','readCount','cloneId'], ascending=False)

## Identify immunogenic peptides that can be recognized by TCR with deepAntigen

In [8]:
pTCR_model = set_model('./deepAntigen/antigenTCR/Weights/seq-level_parameters.pt', 'PTCR')
immuno_peptides = []
immuno_scores = []
peptide_list = candidate_peptides['peptide'].values
cdr3_list = cdr3df['aaSeqCDR3'].values
cdr3_list = cdr3_list[0:20000]
temp = []
for cdr3 in cdr3_list:
    if check(cdr3):
        # print(cdr3)
        continue
    else:
        if cdr3 not in temp:
            temp.append(cdr3)
num = min(len(temp), 10000)
cdr3_list = temp[:num]
peptide_graphs = generateGraph(set(peptide_list), 8)
cdr3_graphs = generateGraph(set(cdr3_list), 8)
for peptide in set(peptide_list):
    scores = []
    peptide_graph = peptide_graphs[peptide]
    batchsize = 5000
    cur = 0
    while True:
        if cur+batchsize<len(cdr3_list):
            batch_cdr3s = cdr3_list[cur:cur+batchsize]
            cur+=batchsize
            batch_cdr3_graphs = [cdr3_graphs[cdr3] for cdr3 in batch_cdr3s]
            batch_peptide_graphs = [peptide_graph]*len(batch_cdr3_graphs)
            batch_scores = predict_PTCR(pTCR_model,batch_peptide_graphs,batch_cdr3_graphs)
            scores.extend(batch_scores)
        else:
            batch_cdr3s = cdr3_list[cur:len(cdr3_list)]
            batch_cdr3_graphs = [cdr3_graphs[cdr3] for cdr3 in batch_cdr3s]
            batch_peptide_graphs = [peptide_graph]*len(batch_cdr3_graphs)
            batch_scores = predict_PTCR(pTCR_model,batch_peptide_graphs,batch_cdr3_graphs)
            scores.extend(batch_scores)
            break
    immuno_peptides.append(peptide)
    immuno_scores.append(np.mean(np.array(scores)))
rdf = pd.DataFrame({"peptide":immuno_peptides,"scores":immuno_scores},index=list(range(len(immuno_peptides))))
rdf.sort_values('scores',ascending=False,inplace=True)
del pTCR_model
torch.cuda.empty_cache()

57it [00:01, 34.91it/s]
57it [00:01, 33.79it/s]
57it [00:01, 34.93it/s]
57it [00:02, 27.16it/s]
57it [00:01, 32.58it/s]
57it [00:01, 33.81it/s]
57it [00:01, 32.64it/s]
60it [00:01, 32.94it/s]
1250it [01:35, 13.07it/s]
1250it [01:36, 13.01it/s]
1250it [01:36, 12.99it/s]
1250it [01:37, 12.79it/s]
1250it [01:38, 12.74it/s]
1250it [01:38, 12.73it/s]
1241it [01:38, 13.17it/s]
1250it [01:38, 12.66it/s]


In [9]:
rdf

Unnamed: 0,peptide,scores
13,LLVTQEILRT,0.629452
180,VADVDISR,0.627754
217,VADVDISRR,0.624487
329,NRADFDDTV,0.623177
131,TVQAKLTL,0.622699
...,...,...
309,LQGGLWFLWL,0.402385
30,AVKWGPNTW,0.388742
24,HVCWGPPA,0.381513
52,FLWGLQGGLW,0.311504
