In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader, TensorDataset, SubsetRandomSampler, Subset
import pandas as pd
import numpy as np
import ijson
import h5py
import json



In [2]:
#data recovery

path = "Kinase_Substrate_Dataset"

with open(path, encoding="latin-1") as f:
    for i, line in enumerate(f):
        if i < 5:
            print(line, end='')

051724
Data extracted from PhosphoSitePlus(R), created by Cell Signaling Technology Inc. PhosphoSitePlus is licensed under a Creative Commons Attribution-NonCommercial-ShareAlike 3.0 Unported License. Attribution must be given in written, oral and digital presentations to PhosphoSitePlus, www.phosphosite.org. Written documents should additionally cite Hornbeck PV, Kornhauser JM, Tkachev S, Zhang B, Skrzypek E, Murray B, Latham V, Sullivan M (2012) PhosphoSitePlus: a comprehensive resource for investigating the structure and function of experimentally determined post-translational modifications in man and mouse. Nucleic Acids Res. 40, D261Ð70.; www.phosphosite.org.

GENE	KINASE	KIN_ACC_ID	KIN_ORGANISM	SUBSTRATE	SUB_GENE_ID	SUB_ACC_ID	SUB_GENE	SUB_ORGANISM	SUB_MOD_RSD	SITE_GRP_ID	SITE_+/-7_AA	DOMAIN	IN_VIVO_RXN	IN_VITRO_RXN	CST_CAT#
Dyrk2	DYRK2	Q5U4C9	mouse	NDEL1	83431	Q9ERR1	Ndel1	mouse	S336	1869686801	LGSsRPSsAPGMLPL		 	X	


In [3]:
#recuperation of data with complete information

count = 0

kin = []
sub = []
site = []

with open(path, encoding="latin-1") as f:
    for i, line in enumerate(f):
        if i > 3:
            mots = line.split()
            mots_de_longueur_6 = [mot for mot in mots if len(mot) == 6]
            #mot de longueur 6 qui contienne une lettre puis 5 chiffres
            mots_de_longueur_6_lettre_chiffre = [mot for mot in mots_de_longueur_6 if mot[0].isalpha() and mot[1].isdigit()]
            if len(mots_de_longueur_6_lettre_chiffre) == 2:
                kin_acc = mots_de_longueur_6_lettre_chiffre[0]
                sub_acc = mots_de_longueur_6_lettre_chiffre[1]
                id_sub_acc = mots.index(sub_acc)
                try:
                    site_ = mots[id_sub_acc + 3]
                    try:
                        site_ = int(site_[1:])
                        kin.append(kin_acc)
                        sub.append(sub_acc)
                        site.append(site_)
                    except ValueError:
                        continue
                except IndexError:
                    continue
            count += 1

print(count)
print(len(kin))

23470
21612


In [4]:
list_kinase = []

for i in range(len(kin)):
    list_kinase.append(kin[i])
        
list_kinase = list(set(list_kinase))
print("length of list_kinase: ", len(list_kinase))
print("length of data after removing incomplete line: ", len(kin))
print("example of list_kinase: ", list_kinase[0:5])

length of list_kinase:  803
length of data after removing incomplete line:  21612
example of list_kinase:  ['Q9BXM7', 'Q63644', 'P11799', 'P47197', 'Q923T9']


In [5]:
list_substrates = list(set(sub))
print("length of list_substrates: ", len(list_substrates))
print("length of data after removing incomplete line: ", len(sub))
print("example of list_substrates: ", list_substrates[0:5])

length of list_substrates:  4977
length of data after removing incomplete line:  21612
example of list_substrates:  ['Q9WTU3', 'P15336', 'B0LPN4', 'Q969T9', 'O35245']


In [6]:
#recuperation of data on each substrate (embedding here)

dic_sub = {}

for i in range(len(list_substrates)):
    dic_sub[list_substrates[i]] = []

i = 0
with h5py.File("../per-protein.h5", "r") as file:
    print(f"number of entries: {len(file.items())}")
    for sequence_id, embedding in file.items():
        i += 1
        if i % 50000 == 0:
            print(i)
        if sequence_id in dic_sub:
            dic_sub[sequence_id].append(np.array(embedding).tolist())

number of entries: 570820
50000
100000
150000
200000
250000
300000
350000
400000
450000
500000
550000


In [7]:
#recuperation of data on each substrate (GO and sequence here)

i = 0
with open('../uniprotkb_AND_reviewed_true_2024_03_26.json', "rb") as f:
    for record in ijson.items(f, "results.item"):
        try:
            i += 1
            refs = record.get("uniProtKBCrossReferences", [])
            if record["primaryAccession"] in dic_sub:
                GO = [ref["id"] for ref in refs if ref.get("database") == "GO"]
                sequence = record["sequence"]["value"]
                dic_sub[record["primaryAccession"]] = [dic_sub[record["primaryAccession"]], GO, sequence]
                    
            if i % 50000 == 0:
                print(i)
                
        except Exception as record_error:
            print("Error processing record:", record_error)

50000
100000
150000
200000
250000
300000
350000
400000
450000
500000
550000


In [8]:
#count the number of each output to keep the kinase for which we have enough data

count_kinase_dic = {}

for enzyme in list_kinase:
    count_kinase_dic[enzyme] = 0

for i in range(len(kin)):
    count_kinase_dic[kin[i]] += 1

In [9]:
X_acc = []    #substrate
X_seq = []    #sequence
X_site = []   #site
X_GO = []     #GO
X_emb = []    #embedding
Y = []        #kinase

for i in range(len(sub)):
    if len(dic_sub[sub[i]]) == 3 and len(dic_sub[sub[i]][0]) > 0 and count_kinase_dic[kin[i]] >= 5: #we keep only the E3 for which we have enough data
        X_acc.append(sub[i])
        X_seq.append(dic_sub[sub[i]][2])
        X_site.append(site[i])
        X_GO.append(dic_sub[sub[i]][1])
        X_emb.append(dic_sub[sub[i]][0][0])
        Y.append(kin[i])

In [10]:
print(len(X_acc))
print(X_acc[0])
print(X_seq[0])
print(X_site[0])
print(X_GO[0])
print(X_emb[0])
print(Y[0])

20638
P34901
MAPVCLFAPLLLLLLGGFPVAPGESIRETEVIDPQDLLEGRYFSGALPDDEDAGGLEQDSDFELSGSGDLDDTEEPRTFPEVISPLVPLDNHIPENAQPGIRVPSEPKELEENEVIPKRVPSDVGDDDVSNKVSMSSTSQGSNIFERTEVLAALIVGGVVGILFAVFLILLLVYRMKKKDEGSYDLGKKPIYKKAPTNEFYA
183
['GO:0009986', 'GO:0043034', 'GO:0005576', 'GO:0005925', 'GO:0005886', 'GO:0001968', 'GO:0042802', 'GO:0005080', 'GO:0070053', 'GO:0007155', 'GO:0016477', 'GO:0007267', 'GO:0060122', 'GO:0042130', 'GO:0001843', 'GO:1903543', 'GO:1903553', 'GO:0051894', 'GO:0051496', 'GO:0010762', 'GO:0001657', 'GO:0042060']
[0.050628662109375, -0.043853759765625, 0.06488037109375, 0.019989013671875, 0.00875091552734375, 0.073486328125, -0.03387451171875, -0.039215087890625, 0.0170745849609375, 0.04632568359375, 0.00888824462890625, -0.031402587890625, 0.01019287109375, -0.0276947021484375, 0.0203399658203125, 0.0518798828125, 0.0135040283203125, 0.01422119140625, -0.0292816162109375, -0.0277099609375, -0.0294189453125, -0.0189361572265625, -0.045654296875, -0.0225372314453125, -0.015060

In [11]:
#encoding of the go terms

count_go = {}
for i in range(len(X_GO)):
    for go in X_GO[i]:
        if go in count_go:
            count_go[go] += 1
        else:
            count_go[go] = 1

print("number of go terms: ", len(count_go))
print("example: ", list(count_go.items())[0:5])

#we keep the most frequent go terms
number_go = 2000

most_frequent_go = dict(sorted(count_go.items(), key=lambda item: item[1], reverse=True))

most_frequent_go = dict(list(most_frequent_go.items())[0:number_go])

X_GO_filtered = []

for i in range(len(X_GO)):
    go_filtered = []
    for go in X_GO[i]:
        if go in most_frequent_go:
            go_filtered.append(go)
    X_GO_filtered.append(go_filtered)

# encoder X_GO_filtered

list_go = list(most_frequent_go.keys())
dic_GO = {}
for i in range(len(list_go)):
    dic_GO[list_go[i]] = i

X_GO_filtered_int = []

for i in range(len(X_GO_filtered)):
    go = X_GO_filtered[i]
    go_int = [0]*len(list_go)
    for j in range(len(go)):
        go_int[dic_GO[go[j]]] = 1
    X_GO_filtered_int.append(go_int)

print(X_GO_filtered_int[0])
print(len(X_GO_filtered_int[0]))

#enregistrement of dic_GO

with open('dic_GO_problem_3.json', 'w') as fp:
    json.dump(dic_GO, fp)

number of go terms:  12744
example:  [('GO:0009986', 1217), ('GO:0043034', 56), ('GO:0005576', 943), ('GO:0005925', 1644), ('GO:0005886', 7479)]
[0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,

In [12]:
#encoding of the cleavage environement

window_size = 10

X_pep = []
for i in range(len(X_seq)):
    seq = X_seq[i]
    p1 = X_site[i]
    pep = []
    min = p1-window_size
    max = p1+window_size
    for j in range(min,max):
        if j < 0:
            pep.append('X')
        elif j >= len(seq):
            pep.append('X')
        else:
            pep.append(seq[j])
    X_pep.append(pep)

#tranformation of peptides into vocab
vocab = ['A','B','C','D','E','F','G','H','I','K','L','M','N','0','P','Q','R','S','T','U','V','W','X','Y','Z']
vocab_dict = {}
for i in range(len(vocab)):
    vocab_dict[vocab[i]] = i

X_pep_int = []
for i in range(len(X_pep)):
    pep = X_pep[i]
    pep_int = []
    for j in range(len(pep)):
        pep_int.append(vocab_dict[pep[j]])
    X_pep_int.append(pep_int)

In [13]:
#encoding of kin names
list_kinase = list(set(Y))
dic_kinase = {}
for i in range(len(list_kinase)):
    dic_kinase[list_kinase[i]] = i

y_int = []
for i in range(len(Y)):
    y_encoded = [0]*len(list_kinase)
    y_encoded[dic_kinase[Y[i]]] = 1
    y_int.append(y_encoded)

#enregistrement of dic_kinase

with open('dic_kinase.json', 'w') as fp:
    json.dump(dic_kinase, fp)

In [14]:
print(len(X_pep_int))
print(len(y_int))
print(y_int[0])
print(len(y_int[0]))

20638
20638
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 

In [19]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
X_embedding_t = torch.tensor(X_emb, dtype=torch.float32).to(device)
X_pep_t = torch.tensor(X_pep_int, dtype=torch.int64).to(device)
X_site_t = torch.tensor(X_site, dtype=torch.float32).to(device)
X_site_t = X_site_t.unsqueeze(1).to(device)
X_go_t = torch.tensor(X_GO_filtered_int, dtype=torch.float32).to(device)
y_t = torch.tensor(y_int, dtype=torch.float32).to(device)

In [20]:
class model_embedding_pep(nn.Module):
    def __init__(self, embed_dim, output_dim):
        super(model_embedding_pep, self).__init__()
        self.fc_embed = nn.Linear(embed_dim, 2048)
        self.embed = nn.Embedding(26, 128)
        self.gru = nn.GRU(128, 256, 2, batch_first=True)
        self.fc_pep = nn.Linear(256, 2048)
        self.fc1 = nn.Linear(2048 * 2, 2048)
        self.fc2 = nn.Linear(2048, output_dim)
    
    def forward(self, embed, pep, site):
        x_embed = F.relu(self.fc_embed(embed))
        x_pep = self.embed(pep)
        x_pep, _ = self.gru(x_pep)
        x_pep = x_pep[:,-1,:]
        x_pep = F.relu(self.fc_pep(x_pep))
        x = torch.cat((x_embed, x_pep), dim=1)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        x = F.softmax(x, dim=1)
        return x
    

class model_embedding_pep_go(nn.Module):
    def __init__(self, embed_dim, go_dim, output_dim):
        super(model_embedding_pep_go, self).__init__()
        self.fc_embed = nn.Linear(embed_dim, 2048)
        self.embed = nn.Embedding(26, 128)
        self.gru = nn.GRU(128, 256, 2, batch_first=True)
        self.fc_pep = nn.Linear(256, 2048)
        self.fc_go = nn.Linear(go_dim, 2048)
        self.fc1 = nn.Linear(2048 * 3, 2048)
        self.fc2 = nn.Linear(2048, output_dim)
    
    def forward(self, embed, pep, site, go):
        x_embed = F.relu(self.fc_embed(embed))
        x_pep = self.embed(pep)
        x_pep, _ = self.gru(x_pep)
        x_pep = x_pep[:,-1,:]
        x_pep = F.relu(self.fc_pep(x_pep))
        x_go = F.relu(self.fc_go(go))
        x = torch.cat((x_embed, x_pep, x_go), dim=1)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        x = F.softmax(x, dim=1)
        return x

In [21]:
def train_with_go(model, dataloader, optimizer, criterion):
    model.train()
    running_loss = 0.0
    for i, data in enumerate(dataloader):
        embed, pep, site, go, labels = data
        optimizer.zero_grad()
        outputs = model(embed, pep, site, go)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
    return running_loss / len(dataloader)

def test_accuracy_with_go(model, dataloader, criterion):
    model.eval()
    correct = 0
    correct3 = 0
    total = 0
    running_loss = 0.0
    with torch.no_grad():
        for data in dataloader:
            embed, pep, site, go, labels = data
            outputs = model(embed, pep, site, go)
            loss = criterion(outputs, labels)
            running_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            _, labels = torch.max(labels.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            #top 3 predictions
            _, predicted3 = torch.topk(outputs.data, 3, dim=1)
            for i in range(len(labels)):
                if labels[i] in predicted3[i]:
                    correct3 += 1
    return correct / total, correct3 / total, running_loss / len(dataloader)

In [22]:
X_embedding_train, X_embedding_test, X_pep_train, X_pep_test, X_site_train, X_site_test, X_go_train, X_go_test, y_train, y_test = train_test_split(X_embedding_t, X_pep_t, X_site_t, X_go_t, y_t, test_size=0.2, random_state=42)

train_dataset = TensorDataset(X_embedding_train, X_pep_train, X_site_train, X_go_train, y_train)
test_dataset = TensorDataset(X_embedding_test, X_pep_test, X_site_test, X_go_test, y_test)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

# parameters
embed_dim = X_embedding_t.shape[1]   #embedding dimension = 1024
go_dim = X_go_t.shape[1]   #go terms dimension = 2000
output_dim = y_t.shape[1]   #number of kinases = 455

model = model_embedding_pep_go(embed_dim, go_dim, output_dim).to(device)
optimizer = optim.Adam(model.parameters(), lr=0.0001)
criterion = nn.BCELoss()

# early stopping

n_epochs = 100
patience = 5
best_acc = 0

for epoch in range(n_epochs):
    train_loss = train_with_go(model, train_loader, optimizer, criterion)
    acc, acc3, test_loss = test_accuracy_with_go(model, test_loader, criterion)
    print(f"Epoch {epoch+1}/{n_epochs} | Train loss: {train_loss:.3f} | Test loss: {test_loss:.3f} | Test accuracy: {acc:.3f} | Test top 3 accuracy: {acc3:.3f}")
    if acc > best_acc:
        best_acc = acc
        patience_counter = 0
        torch.save(model.state_dict(), "model_embedding_pep_go_site_problem_3_final.pt")
    else:
        patience_counter += 1
    if patience_counter > patience:
        break

Epoch 1/100 | Train loss: 0.013 | Test loss: 0.012 | Test accuracy: 0.141 | Test top 3 accuracy: 0.287
Epoch 2/100 | Train loss: 0.011 | Test loss: 0.010 | Test accuracy: 0.201 | Test top 3 accuracy: 0.374
Epoch 3/100 | Train loss: 0.009 | Test loss: 0.010 | Test accuracy: 0.223 | Test top 3 accuracy: 0.411
Epoch 4/100 | Train loss: 0.008 | Test loss: 0.009 | Test accuracy: 0.254 | Test top 3 accuracy: 0.452
Epoch 5/100 | Train loss: 0.008 | Test loss: 0.009 | Test accuracy: 0.264 | Test top 3 accuracy: 0.482
Epoch 6/100 | Train loss: 0.007 | Test loss: 0.009 | Test accuracy: 0.284 | Test top 3 accuracy: 0.507
Epoch 7/100 | Train loss: 0.006 | Test loss: 0.009 | Test accuracy: 0.300 | Test top 3 accuracy: 0.517
Epoch 8/100 | Train loss: 0.006 | Test loss: 0.009 | Test accuracy: 0.300 | Test top 3 accuracy: 0.514
Epoch 9/100 | Train loss: 0.006 | Test loss: 0.009 | Test accuracy: 0.302 | Test top 3 accuracy: 0.522
Epoch 10/100 | Train loss: 0.005 | Test loss: 0.009 | Test accuracy: 0.30

In [23]:
def train_without_go(model, dataloader, optimizer, criterion):
    model.train()
    running_loss = 0.0
    for i, data in enumerate(dataloader):
        embed, pep, site, go, labels = data
        optimizer.zero_grad()
        outputs = model(embed, pep, site)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
    return running_loss / len(dataloader)

def test_accuracy_without_go(model, dataloader, criterion):
    model.eval()
    correct = 0
    correct3 = 0
    total = 0
    running_loss = 0.0
    with torch.no_grad():
        for data in dataloader:
            embed, pep, site, go, labels = data
            outputs = model(embed, pep, site)
            loss = criterion(outputs, labels)
            running_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            _, labels = torch.max(labels.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            #top 3 predictions
            _, predicted3 = torch.topk(outputs.data, 3, dim=1)
            for i in range(len(labels)):
                if labels[i] in predicted3[i]:
                    correct3 += 1
    return correct / total, correct3 / total, running_loss / len(dataloader)

In [24]:
X_embedding_train, X_embedding_test, X_pep_train, X_pep_test, X_site_train, X_site_test, X_go_train, X_go_test, y_train, y_test = train_test_split(X_embedding_t, X_pep_t, X_site_t, X_go_t, y_t, test_size=0.2, random_state=42)

train_dataset = TensorDataset(X_embedding_train, X_pep_train, X_site_train, X_go_train, y_train)
test_dataset = TensorDataset(X_embedding_test, X_pep_test, X_site_test, X_go_test, y_test)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

# parameters
embed_dim = X_embedding_t.shape[1]   #embedding dimension = 1024
go_dim = X_go_t.shape[1]   #go terms dimension = 2000
output_dim = y_t.shape[1]   #number of kinase = 455

model = model_embedding_pep(embed_dim, output_dim).to(device)
optimizer = optim.Adam(model.parameters(), lr=0.0001)
criterion = nn.BCELoss()

# early stopping

n_epochs = 100
patience = 5
best_acc = 0

for epoch in range(n_epochs):
    train_loss = train_without_go(model, train_loader, optimizer, criterion)
    acc, acc3, test_loss = test_accuracy_without_go(model, test_loader, criterion)
    print(f"Epoch {epoch+1}/{n_epochs} | Train loss: {train_loss:.3f} | Test loss: {test_loss:.3f} | Test accuracy: {acc:.3f} | Test top 3 accuracy: {acc3:.3f}")
    if acc > best_acc:
        best_acc = acc
        patience_counter = 0
        torch.save(model.state_dict(), "model_embedding_pep_site_problem_3_final.pt")
    else:
        patience_counter += 1
    if patience_counter > patience:
        break

Epoch 1/100 | Train loss: 0.013 | Test loss: 0.012 | Test accuracy: 0.111 | Test top 3 accuracy: 0.242
Epoch 2/100 | Train loss: 0.012 | Test loss: 0.011 | Test accuracy: 0.160 | Test top 3 accuracy: 0.308
Epoch 3/100 | Train loss: 0.011 | Test loss: 0.011 | Test accuracy: 0.166 | Test top 3 accuracy: 0.328
Epoch 4/100 | Train loss: 0.010 | Test loss: 0.011 | Test accuracy: 0.180 | Test top 3 accuracy: 0.344
Epoch 5/100 | Train loss: 0.010 | Test loss: 0.010 | Test accuracy: 0.197 | Test top 3 accuracy: 0.371
Epoch 6/100 | Train loss: 0.009 | Test loss: 0.010 | Test accuracy: 0.206 | Test top 3 accuracy: 0.381
Epoch 7/100 | Train loss: 0.009 | Test loss: 0.010 | Test accuracy: 0.211 | Test top 3 accuracy: 0.393
Epoch 8/100 | Train loss: 0.009 | Test loss: 0.010 | Test accuracy: 0.210 | Test top 3 accuracy: 0.399
Epoch 9/100 | Train loss: 0.008 | Test loss: 0.010 | Test accuracy: 0.222 | Test top 3 accuracy: 0.412
Epoch 10/100 | Train loss: 0.008 | Test loss: 0.010 | Test accuracy: 0.23