In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader, TensorDataset, SubsetRandomSampler, Subset
import pandas as pd
import numpy as np
import ijson
import h5py
import json



In [2]:
#data recovery

path = 'literature.E3.txt'

with open(path, 'r') as file:
    data = file.read()

line_0 = data.split('\n')[0]
print(line_0)

NUMBER	SwissProt ID (E3)	SwissProt ID (Substrate)	SwissProt AC (E3)	SwissProt AC (Substrate)	Gene Symbol (E3)	Gene Symbol (Substrate)	SOURCE	SOURCEID	SENTENCE	E3TYPE	COUNT	type	species


In [3]:
#suppression of incomplete lines

E3 = []
substrates = []

for line in data.split('\n')[1:]:
    line = line.split('\t')
    try:
        if line[3] != '-' and line[4] != '-':
            E3.append(line[3][0:6])
            substrates.append(line[4][0:6])
    except:
        pass

In [4]:
list_E3 = list(set(E3))
print("length of list_E3: ", len(list_E3))
print("length of data after removing incomplete line: ", len(E3))
print("example of list_E3: ", list_E3[0:5])

length of list_E3:  842
length of data after removing incomplete line:  4044
example of list_E3:  ['Q9Y6I7', 'O95835', 'Q3U487', 'Q86TM6', 'Q9UDY8']


In [5]:
list_substrates = list(set(substrates))
print("length of list_substrates: ", len(list_substrates))
print("length of data after removing incomplete line: ", len(substrates))
print("example of list_substrates: ", list_substrates[0:5])

length of list_substrates:  2486
length of data after removing incomplete line:  4044
example of list_substrates:  ['Q15555', 'Q9XYF4', 'O43541', 'P49458', 'Q9Y5K5']


In [6]:
#recuperation of data on each substrate (embedding here)

dic_sub = {}

for i in range(len(list_substrates)):
    dic_sub[list_substrates[i]] = []

i = 0
with h5py.File("../per-protein.h5", "r") as file:
    print(f"number of entries: {len(file.items())}")
    for sequence_id, embedding in file.items():
        i += 1
        if i % 50000 == 0:
            print(i)
        if sequence_id in dic_sub:
            dic_sub[sequence_id].append(np.array(embedding).tolist())

number of entries: 570820
50000
100000
150000
200000
250000
300000
350000
400000
450000
500000
550000


In [7]:
#recuperation of data on each substrate (GO and sequence here)

i = 0
with open('../uniprotkb_AND_reviewed_true_2024_03_26.json', "rb") as f:
    for record in ijson.items(f, "results.item"):
        try:
            i += 1
            refs = record.get("uniProtKBCrossReferences", [])
            if record["primaryAccession"] in dic_sub:
                GO = [ref["id"] for ref in refs if ref.get("database") == "GO"]
                sequence = record["sequence"]["value"]
                dic_sub[record["primaryAccession"]] = [dic_sub[record["primaryAccession"]], GO, sequence]
                    
            if i % 10000 == 0:
                print(i)
                
        except Exception as record_error:
            print("Error processing record:", record_error)

10000
20000
30000
40000
50000
60000
70000
80000
90000
100000
110000
120000
130000
140000
150000
160000
170000
180000
190000
200000
210000
220000
230000
240000
250000
260000
270000
280000
290000
300000
310000
320000
330000
340000
350000
360000
370000
380000
390000
400000
410000
420000
430000
440000
450000
460000
470000
480000
490000
500000
510000
520000
530000
540000
550000
560000
570000


In [8]:
#count the number of each output to keep the E3 for which we have enough data

count_E3_dic = {}

for i in range(len(list_E3)):
    count_E3_dic[list_E3[i]] = 0

for i in range(len(E3)):
    count_E3_dic[E3[i]] += 1

In [9]:
X_acc = []    #accession number
X_seq = []    #sequence
X_GO = []     #GO
X_emb = []    #embedding
Y = []        #E3

for i in range(len(substrates)):
    if len(dic_sub[substrates[i]]) == 3 and len(dic_sub[substrates[i]][0]) > 0 and count_E3_dic[E3[i]] >= 5: #we keep only the E3 for which we have enough data
        X_acc.append(substrates[i])
        Y.append(E3[i])
        X_emb.append(dic_sub[substrates[i]][0][0])
        X_GO.append(dic_sub[substrates[i]][1])
        X_seq.append(dic_sub[substrates[i]][2])

In [10]:
print(X_emb[0])
print(X_GO[0])
print(X_seq[0])
print(X_acc[0])
print(Y[0])
print(len(X_emb))

[0.0362548828125, 0.10919189453125, -0.031890869140625, 0.04132080078125, 0.0831298828125, 0.0731201171875, 0.04193115234375, -0.110595703125, 0.043426513671875, -0.06585693359375, -0.0304718017578125, 0.002597808837890625, -0.0450439453125, -0.055511474609375, 0.0287017822265625, -0.03497314453125, 0.060699462890625, 0.058258056640625, 0.0038394927978515625, -0.044219970703125, -0.01068115234375, -0.007808685302734375, -0.0202178955078125, 0.0249481201171875, -0.0081634521484375, 0.039215087890625, 0.0058135986328125, -0.03350830078125, 0.03619384765625, 1.1086463928222656e-05, -0.0185699462890625, 0.06439208984375, -0.11669921875, 0.0584716796875, -0.1175537109375, -0.0640869140625, 0.0328369140625, -0.019927978515625, -0.037872314453125, 0.04486083984375, -0.041961669921875, 0.01537322998046875, -0.0198516845703125, 0.0240478515625, 0.056365966796875, -0.033843994140625, 0.002620697021484375, 0.060089111328125, 0.0246429443359375, 0.04595947265625, -0.0887451171875, -0.0419921875, 0

In [11]:
#encoding of the go terms

count_go = {}

for i in range(len(X_GO)):
    for go in X_GO[i]:
        if go in count_go:
            count_go[go] += 1
        else:
            count_go[go] = 1

print("number of go terms: ", len(count_go))
print("example: ", list(count_go.items())[0:5])

#we keep the most frequent go terms
number_go = 2000

most_frequent_go = dict(sorted(count_go.items(), key=lambda item: item[1], reverse=True))

most_frequent_go = dict(list(most_frequent_go.items())[0:number_go])

X_GO_filtered = []

for i in range(len(X_GO)):
    go_filtered = []
    for go in X_GO[i]:
        if go in most_frequent_go:
            go_filtered.append(go)
    X_GO_filtered.append(go_filtered)

# encoder X_GO_filtered

list_go = list(most_frequent_go.keys())
dic_GO = {}
for i in range(len(list_go)):
    dic_GO[list_go[i]] = i

X_GO_filtered_int = []

for i in range(len(X_GO_filtered)):
    go = X_GO_filtered[i]
    go_int = [0]*len(list_go)
    for j in range(len(go)):
        go_int[dic_GO[go[j]]] = 1
    X_GO_filtered_int.append(go_int)

print(X_GO_filtered_int[0])
print(len(X_GO_filtered_int[0]))

#enregistrement of dic_GO

with open('dic_GO_problem_2.json', 'w') as fp:
    json.dump(dic_GO, fp)

number of go terms:  9355
example:  [('GO:0005634', 1776), ('GO:0003700', 345), ('GO:1990841', 114), ('GO:0043565', 168), ('GO:0009740', 3)]
[1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 

In [12]:
#encoding of E3 accession
list_E3 = list(set(Y))

dic_E3 = {}
for i in range(len(list_E3)):
    dic_E3[list_E3[i]] = i

dic_substrates_E3 = {}
for i in range(len(X_acc)):
    dic_substrates_E3[X_acc[i]] = [0]*len(list_E3)

for i in range(len(X_acc)):
    dic_substrates_E3[X_acc[i]][dic_E3[Y[i]]] = 1

Y_int = []
for i in range(len(Y)):
    Y_int.append(dic_substrates_E3[X_acc[i]])

#enregistrement of dic_E3
with open('dic_E3.json', 'w') as fp:
    json.dump(dic_E3, fp)

In [13]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

X_embedding_t = torch.tensor(X_emb, dtype=torch.float32).to(device)
X_go_t = torch.tensor(X_GO_filtered_int, dtype=torch.float32).to(device)
y_t = torch.tensor(Y_int, dtype=torch.float32).to(device)

In [14]:
class model_embedding(nn.Module):
    def __init__(self, embed_dim, output_dim):
        super(model_embedding, self).__init__()
        self.fc = nn.Linear(embed_dim, 2048)
        self.fc1 = nn.Linear(2048, 2048)
        self.fc2 = nn.Linear(2048, output_dim)
        
    def forward(self, embed):
        x = F.relu(self.fc(embed))
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        x = F.softmax(x, dim=1)
        return x
    
class model_embedding_go(nn.Module):
    def __init__(self, embed_dim, go_dim, output_dim):
        super(model_embedding_go, self).__init__()
        self.fc_embed = nn.Linear(embed_dim, 2048)
        self.fc_go = nn.Linear(go_dim, 2048)
        self.fc1 = nn.Linear(4096, 2048)
        self.fc2 = nn.Linear(2048, output_dim)
    
    def forward(self, embed, go):
        x_embed = F.relu(self.fc_embed(embed))
        x_go = F.relu(self.fc_go(go))
        x = torch.cat((x_embed, x_go), dim=1)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        x = F.softmax(x, dim=1)
        return x

In [15]:
def train_with_go(model, dataloader, optimizer, criterion):
    model.train()
    running_loss = 0.0
    for i, data in enumerate(dataloader):
        embed, go, labels = data
        optimizer.zero_grad()
        outputs = model(embed, go)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
    return running_loss / len(dataloader)

def test_accuracy_with_go(model, dataloader, criterion):
    model.eval()
    correct = 0
    correct3 = 0
    total = 0
    running_loss = 0.0
    with torch.no_grad():
        for data in dataloader:
            embed, go, labels = data
            outputs = model(embed, go)
            loss = criterion(outputs, labels)
            running_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            for i in range(len(predicted)):
                if labels[i][predicted[i]] == 1:
                    correct += 1
            #top 3 predictions
            _, predicted3 = torch.topk(outputs.data, 3, dim=1)
            for i in range(len(labels)):
                for j in range(3):
                    if labels[i][predicted3[i][j]] == 1:
                        correct3 += 1
                        break
    return correct / total, correct3 / total, running_loss / len(dataloader)

In [16]:
X_embedding_train, X_embedding_test, X_go_train, X_go_test, y_train, y_test = train_test_split(X_embedding_t, X_go_t, y_t, test_size=0.2, random_state=42)

train_dataset = TensorDataset(X_embedding_train, X_go_train, y_train)
test_dataset = TensorDataset(X_embedding_test, X_go_test, y_test)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

# parameters
embed_dim = X_embedding_t.shape[1]   #embedding dimension = 1024
go_dim = X_go_t.shape[1]   #go terms dimension = 2000
output_dim = y_t.shape[1]   #number of enzymes = 535

model = model_embedding_go(embed_dim, go_dim, output_dim).to(device)
optimizer = optim.Adam(model.parameters(), lr=0.0001)
criterion = nn.BCELoss()

# early stopping

n_epochs = 100
patience = 15
best_acc = 0

for epoch in range(n_epochs):
    train_loss = train_with_go(model, train_loader, optimizer, criterion)
    acc, acc3, test_loss = test_accuracy_with_go(model, test_loader, criterion)
    print(f"Epoch {epoch+1}/{n_epochs} | Train loss: {train_loss:.3f} | Test loss: {test_loss:.3f} | Test accuracy: {acc:.3f} | Test top 3 accuracy: {acc3:.3f}")
    if acc > best_acc:
        best_acc = acc
        patience_counter = 0
        torch.save(model.state_dict(), "model_embedding_go_problem_2_final.pt")
    else:
        patience_counter += 1
    if patience_counter > patience:
        break

  from .autonotebook import tqdm as notebook_tqdm


Epoch 1/100 | Train loss: 0.080 | Test loss: 0.075 | Test accuracy: 0.202 | Test top 3 accuracy: 0.338
Epoch 2/100 | Train loss: 0.071 | Test loss: 0.068 | Test accuracy: 0.323 | Test top 3 accuracy: 0.441
Epoch 3/100 | Train loss: 0.061 | Test loss: 0.062 | Test accuracy: 0.378 | Test top 3 accuracy: 0.497
Epoch 4/100 | Train loss: 0.056 | Test loss: 0.058 | Test accuracy: 0.398 | Test top 3 accuracy: 0.532
Epoch 5/100 | Train loss: 0.051 | Test loss: 0.055 | Test accuracy: 0.450 | Test top 3 accuracy: 0.557
Epoch 6/100 | Train loss: 0.048 | Test loss: 0.053 | Test accuracy: 0.477 | Test top 3 accuracy: 0.598
Epoch 7/100 | Train loss: 0.045 | Test loss: 0.051 | Test accuracy: 0.523 | Test top 3 accuracy: 0.602
Epoch 8/100 | Train loss: 0.043 | Test loss: 0.050 | Test accuracy: 0.526 | Test top 3 accuracy: 0.617
Epoch 9/100 | Train loss: 0.041 | Test loss: 0.048 | Test accuracy: 0.540 | Test top 3 accuracy: 0.638
Epoch 10/100 | Train loss: 0.043 | Test loss: 0.047 | Test accuracy: 0.57

In [17]:
def train_without_go(model, dataloader, optimizer, criterion):
    model.train()
    running_loss = 0.0
    for i, data in enumerate(dataloader):
        embed, go, labels = data
        optimizer.zero_grad()
        outputs = model(embed)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
    return running_loss / len(dataloader)

def test_accuracy_without_go(model, dataloader, criterion):
    model.eval()
    correct = 0
    correct3 = 0
    total = 0
    running_loss = 0.0
    with torch.no_grad():
        for data in dataloader:
            embed, go, labels = data
            outputs = model(embed)
            loss = criterion(outputs, labels)
            running_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            for i in range(len(predicted)):
                if labels[i][predicted[i]] == 1:
                    correct += 1
            #top 3 predictions
            _, predicted3 = torch.topk(outputs.data, 3, dim=1)
            for i in range(len(labels)):
                for j in range(3):
                    if labels[i][predicted3[i][j]] == 1:
                        correct3 += 1
                        break
    return correct / total, correct3 / total, running_loss / len(dataloader)

In [18]:
X_embedding_train, X_embedding_test, X_go_train, X_go_test, y_train, y_test = train_test_split(X_embedding_t, X_go_t, y_t, test_size=0.2, random_state=42)

train_dataset = TensorDataset(X_embedding_train, X_go_train, y_train)
test_dataset = TensorDataset(X_embedding_test, X_go_test, y_test)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

# parameters
embed_dim = X_embedding_t.shape[1]   #embedding dimension = 1024
go_dim = X_go_t.shape[1]   #go terms dimension = 2000
output_dim = y_t.shape[1]   #number of enzymes = 535

model = model_embedding(embed_dim, output_dim).to(device)
optimizer = optim.Adam(model.parameters(), lr=0.0001)
criterion = nn.BCELoss()

# early stopping

n_epochs = 150
patience = 15
best_acc = 0

for epoch in range(n_epochs):
    train_loss = train_without_go(model, train_loader, optimizer, criterion)
    acc, acc3, test_loss = test_accuracy_without_go(model, test_loader, criterion)
    print(f"Epoch {epoch+1}/{n_epochs} | Train loss: {train_loss:.3f} | Test loss: {test_loss:.3f} | Test accuracy: {acc:.3f} | Test top 3 accuracy: {acc3:.3f}")
    if acc > best_acc:
        best_acc = acc
        patience_counter = 0
        torch.save(model.state_dict(), "model_embedding_problem_2_final.pt")
    else:
        patience_counter += 1
    if patience_counter > patience:
        break

Epoch 1/150 | Train loss: 0.083 | Test loss: 0.080 | Test accuracy: 0.168 | Test top 3 accuracy: 0.262
Epoch 2/150 | Train loss: 0.078 | Test loss: 0.077 | Test accuracy: 0.198 | Test top 3 accuracy: 0.347
Epoch 3/150 | Train loss: 0.074 | Test loss: 0.073 | Test accuracy: 0.222 | Test top 3 accuracy: 0.379
Epoch 4/150 | Train loss: 0.070 | Test loss: 0.071 | Test accuracy: 0.241 | Test top 3 accuracy: 0.398
Epoch 5/150 | Train loss: 0.068 | Test loss: 0.069 | Test accuracy: 0.275 | Test top 3 accuracy: 0.405
Epoch 6/150 | Train loss: 0.067 | Test loss: 0.068 | Test accuracy: 0.274 | Test top 3 accuracy: 0.421
Epoch 7/150 | Train loss: 0.065 | Test loss: 0.067 | Test accuracy: 0.294 | Test top 3 accuracy: 0.403
Epoch 8/150 | Train loss: 0.063 | Test loss: 0.066 | Test accuracy: 0.315 | Test top 3 accuracy: 0.441
Epoch 9/150 | Train loss: 0.062 | Test loss: 0.065 | Test accuracy: 0.335 | Test top 3 accuracy: 0.470
Epoch 10/150 | Train loss: 0.060 | Test loss: 0.064 | Test accuracy: 0.33