In [1]:
import numpy as np
import pandas as pd
import torch
from torch import nn
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from datasets import load_dataset
from sklearn import metrics
from sklearn.model_selection import train_test_split
from tqdm import tqdm

In [2]:
N_CLASSES = 7
HIDDEN = 64
INPUT = 1024

In [3]:
class Network(nn.Module):
    def __init__(self, inp, hidden, output, device):
        super(Network,self).__init__()
        
        self.device = device
        
        self.linear1=nn.Linear(inp, hidden)
        self.linear2=nn.Linear(hidden, output)
        
        self.loss = nn.CrossEntropyLoss()
 
        
    def forward(self,x):
        x=self.linear1(x)
        x=self.linear2(x)
        return x
    
    def train_model(self, dataset, epochs):  
        self.train()
        optimizer = torch.optim.Adam(self.parameters(), lr=0.0001) # TODO tune

        for epoch in range(epochs):
            with tqdm(dataset, unit="batch") as tepoch:
                for inputs, targets in tepoch:

                    inputs, targets = inputs.to(self.device), targets.to(self.device)
                    tepoch.set_description(f"Epoch {epoch + 1}")
                    
                    targets = targets[:, 0].long() # TODO what is that?

                    # clear the gradients
                    optimizer.zero_grad()
                    # compute the model output
                    yhat = self(inputs)
                    # calculate accuracy
                    correct = (yhat.argmax(1) == targets).type(torch.float).sum().item()
                    accuracy = correct / len(inputs)
                    # calculate loss
                    loss = self.loss(yhat, targets)
                    # credit assignment
                    loss.backward()
                    # update model weights
                    optimizer.step()

                    tepoch.set_postfix(loss=loss.item(), accuracy=100. * accuracy)
        
    def test(self, dataloader):
        self.eval()
        pred_label, actuals = list(), list()
        
        with torch.no_grad():
            for inputs, targets in dataloader:
                targets = targets[:, 0].long()
                
                inputs = inputs.to(self.device)
                
                # evaluate the model on the test set
                yhat = self(inputs)
                yhat = yhat.cpu().detach().numpy()
                actual = targets.numpy()
                yhat = yhat.argmax(1)
                # reshape for stacking
                actual = actual.reshape((len(actual), 1))
                yhat = yhat.reshape((len(yhat), 1))
                # store
                pred_label.append(yhat)
                actuals.append(actual)
        pred_label, actuals = np.vstack(pred_label), np.vstack(actuals)
        print("Predictions: ", pred_label[:10])
        print("Real labels: ", actuals[:10])
        # calculate accuracy
        acc = metrics.accuracy_score(actuals, pred_label)
        f1 = metrics.f1_score(actuals, pred_label, average='micro', zero_division=0)
        print(f"Test metrics: \n Accuracy: {acc}, F1 score: {float(f1):>6f}\n")
        return acc, f1



In [4]:
class EmbeddingDataset(Dataset):
    def __init__(self, df):
        self.X = np.float32(df.drop(columns=['Family']))
        self.y = np.expand_dims(df['Family'].to_numpy(), axis=1)
        self.len = len(df)
    
    def __len__(self):
        return self.len
    
    def __getitem__(self, idx):
        return self.X[idx], self.y[idx]

In [5]:
def family_mapping(family):
    if family == 'SPOUT':
        return 0
    elif family == 'AdoMet synthase':
        return 1
    elif family == 'Carbonic anhydrase':
        return 2
    elif family == 'ATCase/OTCase':
        return 3
    elif family == 'membrane':
        return 4
    elif family == 'VIT':
        return 5
    else:
        return 6

def prepare_AF_dataset(path, type):
    df = pd.read_csv(path, index_col=0)
    
    dss = load_dataset('EvaKlimentova/knots_AF')
    hf = pd.DataFrame(dss[type])
    hf = hf.drop(columns=['uniprotSequence', 'label', 'latestVersion', 'globalMetricValue', 'uniprotStart', 'uniprotEnd', 'Length', 'Domain_architecture', 'InterPro', 'Max_Topology', 'Max Freq', 'Knot Core'])
    df_family = pd.merge(hf, df, on="ID")
    
    # delete unknotted SPOUTs
    df_family = df_family.drop(df_family[(df_family['FamilyName'] == 'SPOUT') & (df_family['label'] == 0)].index)
    
    # sort proteins into a couple of family bins
    df_family['Family'] = df_family['FamilyName'].apply(family_mapping)    
    
    df_family = df_family.drop(columns=['ID', 'label', 'FamilyName'])
    return df_family

In [6]:
train_df = prepare_AF_dataset("../Alphafold_dataset/ProtBertBFD_train_embedding_af_v3.csv", 'train')
test_df = prepare_AF_dataset("../Alphafold_dataset/ProtBertBFD_test_embedding_af_v3.csv", 'test')

Using custom data configuration EvaKlimentova--knots_AF-293560de9ceccb3f
Found cached dataset parquet (/home/jovyan/.cache/huggingface/datasets/EvaKlimentova___parquet/EvaKlimentova--knots_AF-293560de9ceccb3f/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec)


  0%|          | 0/2 [00:00<?, ?it/s]

Using custom data configuration EvaKlimentova--knots_AF-293560de9ceccb3f
Found cached dataset parquet (/home/jovyan/.cache/huggingface/datasets/EvaKlimentova___parquet/EvaKlimentova--knots_AF-293560de9ceccb3f/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec)


  0%|          | 0/2 [00:00<?, ?it/s]

In [7]:
print(f"Train dataset size: {len(train_df)}")
print(train_df.head(5))
print(f"Test dataset size: {len(test_df)}\n")

Train dataset size: 155624
         f0        f1        f2        f3        f4        f5        f6  \
0  0.026727 -0.032709 -0.027460  0.023039  0.018919  0.020549  0.009272   
1 -0.006818 -0.011500 -0.000443 -0.002459 -0.000295  0.000784 -0.004652   
2  0.014304 -0.017095 -0.003410  0.000148  0.022026  0.002258  0.015533   
3  0.005448 -0.001238  0.002784  0.006303  0.003013  0.010680 -0.014094   
4  0.007396  0.012624  0.000298  0.007900  0.010098  0.012833 -0.004366   

         f7        f8        f9  ...     f1015     f1016     f1017     f1018  \
0 -0.005804 -0.021650  0.027531  ... -0.043956 -0.042798  0.006104 -0.021828   
1  0.005460 -0.007820  0.005914  ... -0.016875  0.004226  0.005704 -0.005163   
2  0.006351 -0.022573  0.021902  ... -0.011653 -0.021954 -0.009635  0.029997   
3 -0.006538 -0.005705 -0.002041  ... -0.008526 -0.009459  0.004119 -0.001539   
4  0.011860 -0.006266 -0.001095  ... -0.023012 -0.002771  0.008065  0.001363   

      f1019     f1020     f1021     f1022

In [8]:
train_dset = EmbeddingDataset(train_df)
train_loader = DataLoader(train_dset, batch_size=32, shuffle=True)
test_dset = EmbeddingDataset(test_df)
test_loader = DataLoader(test_dset, batch_size=1, shuffle=False)

In [9]:
# Run on GPU or CPU
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print('Using {} device'.format(device))

Using cuda device


## Model training

In [10]:
model = Network(INPUT, HIDDEN, N_CLASSES, device=device).to(device)

In [11]:
print(model)

Network(
  (linear1): Linear(in_features=1024, out_features=64, bias=True)
  (linear2): Linear(in_features=64, out_features=7, bias=True)
  (loss): CrossEntropyLoss()
)


In [12]:
type(train_dset.X[0][0])

numpy.float32

In [13]:
model.train_model(train_loader, 10)

Epoch 1: 100%|██████████| 4864/4864 [00:15<00:00, 304.94batch/s, accuracy=100, loss=0.261] 
Epoch 2: 100%|██████████| 4864/4864 [00:18<00:00, 263.95batch/s, accuracy=100, loss=0.0177] 
Epoch 3: 100%|██████████| 4864/4864 [00:18<00:00, 269.90batch/s, accuracy=87.5, loss=0.19]  
Epoch 8:  53%|█████▎    | 2577/4864 [00:11<00:10, 222.91batch/s, accuracy=100, loss=0.00884] IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)

Epoch 8: 100%|██████████| 4864/4864 [00:19<00:00, 247.56batch/s, accuracy=100, loss=0.00329] 
Epoch 9:   7%|▋         | 338/4864 [00:01<00:18, 241.77batch/s, accuracy=100, loss=0.0122]  IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing

In [15]:
model.test(test_loader)

Predictions:  [[5]
 [0]
 [0]
 [5]
 [0]
 [4]
 [0]
 [5]
 [5]
 [0]]
Real labels:  [[5]
 [0]
 [0]
 [5]
 [0]
 [4]
 [0]
 [5]
 [5]
 [0]]
Test metrics: 
 Accuracy: 0.9945447995471154, F1 score: 0.994545



(0.9945447995471154, 0.9945447995471154)

In [14]:
torch.save(model, "ProtBertBFD_embedding_CNN_family.pth")