In [1]:
load_model = None
save_model = "0805"

In [2]:
from pdeep.mhc.mhc_binding_model import *
from pdeep.mhc.mhc_utils import *

In [3]:
fasta_list = ["uniprotkb_UP000005640_AND_reviewed_true_2024_03_01.fasta"]

In [4]:
pept_encoder = ModelSeqEncoder()
hla_encoder = ModelHlaEncoder()

In [5]:
import torchinfo

torchinfo.summary(
    pept_encoder, input_size=[(256, 14)], dtypes=[torch.long]
)

Layer (type:depth-idx)                                  Output Shape              Param #
ModelSeqEncoder                                         [256, 480]                --
├─Embedding: 1-1                                        [256, 14, 480]            61,440
├─PositionalEncoding: 1-2                               [256, 14, 480]            --
├─Hidden_HFace_Transformer: 1-3                         [256, 14, 480]            --
│    └─BertEncoder: 2-1                                 [256, 14, 480]            --
│    │    └─ModuleList: 3-1                             --                        11,084,160
├─SeqAttentionSum: 1-4                                  [256, 480]                --
│    └─Sequential: 2-2                                  [256, 14, 1]              --
│    │    └─Linear: 3-2                                 [256, 14, 1]              480
│    │    └─Softmax: 3-3                                [256, 14, 1]              --
Total params: 11,146,080
Trainable params: 11,1

In [6]:
torchinfo.summary(
    hla_encoder, input_size=[(256, 400, 480)], dtypes=[torch.float32]
)

Layer (type:depth-idx)                                  Output Shape              Param #
ModelHlaEncoder                                         [256, 480]                --
├─Hidden_HFace_Transformer: 1-1                         [256, 400, 480]           --
│    └─BertEncoder: 2-1                                 [256, 400, 480]           --
│    │    └─ModuleList: 3-1                             --                        2,771,040
├─SeqAttentionSum: 1-2                                  [256, 480]                --
│    └─Sequential: 2-2                                  [256, 400, 1]             --
│    │    └─Linear: 3-2                                 [256, 400, 1]             480
│    │    └─Softmax: 3-3                                [256, 400, 1]             --
Total params: 2,771,520
Trainable params: 2,771,520
Non-trainable params: 0
Total mult-adds (M): 709.51
Input size (MB): 196.61
Forward/backward pass size (MB): 4326.20
Params size (MB): 11.09
Estimated Total Size (MB): 45

In [7]:
import pandas as pd
train_df = pd.read_table("all_alleles/train_data/leave_one_ABC_type/train_df.tsv")
train_df

Unnamed: 0,sequence,allele
0,ALNPYQYQY,A29_02
1,QTSEKALLR,A34_02
2,TPRSTVGVAVL,B07_06
3,IYAKLFNW,A24_02
4,FSIAGTVKR,A66_01
...,...,...
528419,QQGKIAASY,B15_01
528420,TLVTSQATTL,A02_01
528421,KRHFRRDSF,B27_05
528422,VEKPQEFTI,B40_02


In [8]:
hla_df, hla_esm_list = load_esm_pkl()
hla_df

Unnamed: 0,sequence,allele,allele_detail
0,AHSMRYFYTAVSRPGRGEPHFIAVGYVDDTQFVRFDSDAASPRGEP...,C03_159,C*03:159
1,ALALTETWAGSHSMRYFYTAMSRPGRGEPRFIAVGYVDDTQFVRFD...,B15_128,B*15:128
2,ALALTETWAGSHSMRYFYTSVSRPGRGEPRFISVGYVDDTQFVRFD...,B14_12,B*14:12
3,APRTLLLLLSGALALTQTWAGSHSMRYFYTSVSRPGRGEPRFIAVG...,A03_12,A*03:12
4,APRTLVLLLSGALALTQTWAGSHSMRYFSTSVSRPGRGEPRFIAVG...,A24_79,A*24:79
...,...,...,...
16449,TLLLLLSGALALTQTWAGSHSMRYFSTSVSRPGRGEPRFIAVGYVD...,A24_26,A*24:26
16450,TLVLLLSGALALTQTWAGSHSMRYFFTSVSRPGRGEPRFIAVGYVD...,A02_01,A*02:01:03
16451,VTAPRTLLLLLWGAVALTETWAGSHSMRYFYTAMSRPGRGEPRFIT...,B44_43,B*44:43:01
16452,VTAPRTVLLLLSGALALTETWAGSHSMRYFYTAMSRPGRGEPRFIA...,B15_57,B*15:57


In [9]:
dataset = HlaDataSet(
    hla_df, hla_esm_list, train_df, fasta_list
)

In [11]:
train(
    hla_encoder, pept_encoder, 
    dataset, 
    epoch=10, warmup_epoch=5,
    batch_size=200, lr=1e-4,
    verbose=True, device="cuda"
)

2024-08-05 12:09:32> 298801 training samples
2024-08-05 12:15:06> [Epoch=0] loss=0.02818, lr=2.000e-05
2024-08-05 12:20:40> [Epoch=1] loss=0.02651, lr=4.000e-05
2024-08-05 12:26:14> [Epoch=2] loss=0.02650, lr=6.000e-05
2024-08-05 12:31:48> [Epoch=3] loss=0.02686, lr=8.000e-05
2024-08-05 12:37:22> [Epoch=4] loss=0.02710, lr=1.000e-04
2024-08-05 12:42:56> [Epoch=5] loss=0.02733, lr=9.045e-05
2024-08-05 12:48:31> [Epoch=6] loss=0.02692, lr=6.545e-05
2024-08-05 12:54:05> [Epoch=7] loss=0.02598, lr=3.455e-05
2024-08-05 12:59:39> [Epoch=8] loss=0.02530, lr=9.549e-06
2024-08-05 13:05:13> [Epoch=9] loss=0.02463, lr=0.000e+00


In [12]:
torch.save(hla_encoder.state_dict(), f"model/HLA_model_v{save_model}.pt")
torch.save(pept_encoder.state_dict(), f"model/pept_model_v{save_model}.pt")