In [4]:
# Save protein sequences as a FASTA file as required by the embedding generator
import pickle5 as pickle
from tqdm import tqdm
# Load pickle file
with open('data/datasetB_all_sequences.pkl', 'rb') as f:
    data = pickle.load(f)

# Write protein sequences in a fasta file
with open("data/datasetB_all_sequences.fasta", "w") as f:
    for key in tqdm(data):
        f.write('>'+key+'\n')
        f.write(data[key])
        f.write('\n')
f.close()

In [4]:

models = ['esm2_t33_650M_UR50D', 'esm2_t36_3B_UR50D', 'esm1_t34_670M_UR50S',
        'esm1_t6_43M_UR50S', 'esm2_t30_150M_UR50D', 'esm1v_t33_650M_UR90S_1',
        'esm1b_t33_650M_UR50S', 'esm2_t12_35M_UR50D', 'esm_msa1b_t12_100M_UR50S',
        'esm1_t34_670M_UR50D', 'esm_msa1_t12_100M_UR50S', 'esm2_t6_8M_UR50D',
        'esm_if1_gvp4_t16_142M_UR50']

import os, torch
# Process extracted emb
from Bio import SeqIO
fasta_sequences = SeqIO.parse(open('data/datasetB_all_sequences.fasta'),'fasta')

model_name='esm1_t6_43M_UR50S'

for fasta in fasta_sequences:
    #print('Filename:', datapoint)
    print(fasta.id, len(fasta.seq))
    emb = torch.load(model_name+'/'+fasta.id+'.pt')
    key = list(emb['representations'].keys())
    print(emb['representations'][key[0]].shape)
    print(emb['representations'][key[0]])
    break

1g58_A 196
torch.Size([196, 768])
tensor([[-0.7342,  0.7721, -0.2252,  ..., -0.7896,  1.2874,  0.9320],
        [-0.0884,  0.7979, -0.2805,  ...,  0.1627,  0.5936,  0.3086],
        [ 0.7058, -0.5040, -0.4714,  ..., -0.1014,  1.6905, -0.0378],
        ...,
        [ 0.0742, -0.1984,  0.2745,  ..., -0.9772,  1.1318, -0.0934],
        [-0.0326,  0.4182, -0.3380,  ..., -1.2327, -0.1898,  1.3206],
        [-0.5076, -0.3772,  0.8009,  ..., -1.1297,  0.0080, -0.9833]])


In [None]:
from tqdm import tqdm
import torch, os
concatenated_dataset = []
model_name='esm1_t6_43M_UR50S'
#for fasta in tqdm(fasta_sequences):
i = 0
for datapoint in tqdm(os.listdir(model_name)):
    #emb = torch.load(model_name+'/'+fasta.id+'.pt')
    emb = torch.load(model_name+'/'+datapoint)['representations'][6]
    concatenated_dataset.append(emb)
    
final_dataset = torch.cat(concatenated_dataset)

In [9]:
torch.cat(concatenated_dataset).shape, emb.shape

(torch.Size([451563, 768]), torch.Size([48, 768]))

In [59]:
#import pandas as pd
from torch.utils.data import Dataset, DataLoader

class ProtEmbDataset(Dataset):
    """Protein emb dataset."""
    def __init__(self, model_name, transform=None):
        """
        Args:
            root_dir (string): Directory with all the datapoints.
            transform (callable, optional): Optional transform to be applied
                on a sample.
        """
        self.model_name = model_name
        self.filenames = os.listdir(model_name)

    def __len__(self):
        return len(self.filenames)

    def __getitem__(self, idx):
        file = self.filenames[idx]
        emb = torch.load(model_name+'/'+file)
        key = list(emb['representations'].keys())
        sample = emb['representations'][key[0]]
        return sample

In [60]:
model_name = 'esm1_t6_43M_UR50S'
root_dir = 'data'
protein_dataset = ProtEmbDataset(model_name)

In [61]:
dataloader = DataLoader(protein_dataset, batch_size=1,
                        shuffle=True, num_workers=4)

In [62]:
for datapoint in dataloader:
    #datapoint = datapoint.view(-1)
    print(torch.squeeze(datapoint).shape)
    print(datapoint.shape)
    break

torch.Size([162, 768])
torch.Size([1, 162, 768])


In [69]:
from torch.utils.data import Dataset, DataLoader, random_split, SubsetRandomSampler, WeightedRandomSampler
import torch.utils.data as data
import numpy as np

# Create a validation and training set
samples_count = len(protein_dataset)
all_samples_indexes = list(range(samples_count))
np.random.shuffle(all_samples_indexes)

al_ratio = 0.2
val_end = int(samples_count * 0.2)
val_indexes = all_samples_indexes[0:val_end]
train_indexes = all_samples_indexes[val_end:]
assert len(val_indexes) + len(train_indexes) == samples_count , 'the split is not valid' 

sampler_train = data.SubsetRandomSampler(train_indexes)
sampler_val = data.SubsetRandomSampler(val_indexes)

dataloader_train = DataLoader(protein_dataset, batch_size=1, sampler = sampler_train, num_workers=4)
dataloader_val = DataLoader(protein_dataset, batch_size=1, sampler = sampler_val, num_workers=4)
# dataloader_test = DataLoader(protein_dataset, batch_size=1,
#                         shuffle=True, num_workers=4)

In [None]:
from torch import nn
from collections import OrderedDict
class IonNet(nn.Module):
    def __init__(self, d_in):
        super(IonNet, self).__init__()
        #self.net = models.resnet18(pretrained=True)
        self.n_features = 256
        self.linear1 = nn.Linear(d_in, 256) 
        #self.linear2 = nn.Linear(512, 128)
        self.fc = nn.Identity()
        self.fc1 = nn.Sequential(OrderedDict([('linear', nn.Linear(self.n_features,self.n_features)),('relu1', nn.ReLU()),('final', nn.Linear(self.n_features, 1))]))   #Ion1
        self.fc2 = nn.Sequential(OrderedDict([('linear', nn.Linear(self.n_features,self.n_features)),('relu1', nn.ReLU()),('final', nn.Linear(self.n_features, 1))]))   #Ion2
        self.fc3 = nn.Sequential(OrderedDict([('linear', nn.Linear(self.n_features,self.n_features)),('relu1', nn.ReLU()),('final', nn.Linear(self.n_features, 1))]))   #Ion3
        self.fc4 = nn.Sequential(OrderedDict([('linear', nn.Linear(self.n_features,self.n_features)),('relu1', nn.ReLU()),('final', nn.Linear(self.n_features, 1))]))   #Null
        
    def forward(self, x):
        ion1_head = self.fc1(self.linear1(x))
        ion2_head = self.fc2(self.linear1(x))
        ion3_head = self.fc3(self.linear1(x))
        null_head = self.fc4(self.linear1(x))
        return ion1_head, ion2_head, ion3_head, null_head

import torch.nn
class CNN2Layers(torch.nn.Module):
    def __init__(self, in_channels, n_features, kernel_size, stride, padding, dropout):
        super(CNN2Layers, self).__init__()
        self.n_features = n_features
        self.conv1 = torch.nn.Sequential(
            torch.nn.Conv1d(in_channels=in_channels, out_channels=512, kernel_size=kernel_size,
                            stride=stride, padding=padding),
            torch.nn.ELU(),
            torch.nn.Dropout(dropout),

            torch.nn.Conv1d(in_channels=512, out_channels=self.n_features, kernel_size=kernel_size,
                            stride=stride, padding=padding),
        )

        self.linear1 = nn.Linear(self.n_features, self.n_features)

        self.fc1 = nn.Sequential(OrderedDict([('linear', nn.Linear(self.n_features,self.n_features)),('relu1', nn.ReLU()),('final', nn.Linear(self.n_features, 1))]))   #Ion1
        self.fc2 = nn.Sequential(OrderedDict([('linear', nn.Linear(self.n_features,self.n_features)),('relu1', nn.ReLU()),('final', nn.Linear(self.n_features, 1))]))   #Ion2
        self.fc3 = nn.Sequential(OrderedDict([('linear', nn.Linear(self.n_features,self.n_features)),('relu1', nn.ReLU()),('final', nn.Linear(self.n_features, 1))]))   #Ion3
        self.fc4 = nn.Sequential(OrderedDict([('linear', nn.Linear(self.n_features,self.n_features)),('relu1', nn.ReLU()),('final', nn.Linear(self.n_features, 1))]))  

    def forward(self, x):
        x = self.conv1(x)
        x = torch.flatten(x, 1)
        #print(x.shape)
        ion1_head = self.fc1(self.linear1(x))
        ion2_head = self.fc2(self.linear1(x))
        ion3_head = self.fc3(self.linear1(x))
        null_head = self.fc4(self.linear1(x))
        return ion1_head, ion2_head, ion3_head, null_head

# RuntimeError: Given groups=1, weight of size [512, 768, 3], expected input[1, 674, 768] to have 768 channels, but got 674 channels instead

In [None]:
#TESTING MODEL
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = CNN2Layers(in_channels=768, n_features=256, kernel_size=1, stride=1, padding=0, dropout=0.7).to(device)
#a = model(torch.unsqueeze(dataloader.dataset[0][0],(2)))
#torch.unsqueeze(dataloader.dataset[0][0], (2)).shape

In [None]:
fields = list(dataloader.dataset[0][1].keys())
accuracies = [0.0]*len(fields)
#status = 'training' if is_training else 'validation'
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

for i, (embs, labels) in tqdm(enumerate(dataloader)):
    embs = torch.squeeze(embs)
    embs = torch.unsqueeze(embs, (2)).to(device)
    #print(fields)
    labels = [torch.tensor(labels[key], dtype=float).to(device) for key in fields]
    (lbl_ion1, lbl_ion2, lbl_ion3, lbl_null) = labels
    preds = model(embs)
    (prd_ion1, prd_ion2, prd_ion3, prd_null) = preds
    
    #accuracies = [0.0]*len(fields)
    accuracies[0] = torch.mean((torch.round(prd_ion1) == lbl_ion1).float())
    accuracies[1] = torch.mean((torch.round(prd_ion2) == lbl_ion2).float())
    accuracies[2] = torch.mean((torch.round(prd_ion3) == lbl_ion3).float())
    accuracies[3] = torch.mean((torch.round(prd_null) == lbl_null).float())
    break

In [None]:
print ((lbl_ion1 == 1).nonzero(as_tuple=True)[0])
print ((torch.round(prd_ion1) == 1).nonzero(as_tuple=True)[0])

In [None]:
def train_val(model, dataloader, optimizer, criterion_1, is_training, device, topk, interval):
    batch_cnt = len(dataloader)
    fields = list(dataloader.dataset[0][1].keys())
    accuracies = [0.0]*len(fields)
    status = 'training' if is_training else 'validation'

    with torch.set_grad_enabled(is_training):
        model.train() if is_training else model.eval()

        for i, (embs, labels) in enumerate(dataloader):
            embs = torch.squeeze(embs)
            embs = torch.unsqueeze(embs, (2)).to(device)
            #print(fields)
            labels = [torch.tensor(labels[key], dtype=float).to(device) for key in fields]
            (lbl_ion1, lbl_ion2, lbl_ion3, lbl_null) = labels
            preds = model(embs)
            (prd_ion1, prd_ion2, prd_ion3, prd_null) = preds

            #print(torch.round(torch.squeeze(prd_ion1)).shape, lbl_ion1)
            loss_ion1 = criterion_1(torch.squeeze(prd_ion1), lbl_ion1)
            loss_ion2 = criterion_1(torch.squeeze(prd_ion2), lbl_ion2)
            loss_ion3 = criterion_1(torch.squeeze(prd_ion3), lbl_ion3)
            loss_null = criterion_1(torch.squeeze(prd_null), lbl_null)

            loss_final = loss_ion1 + loss_ion2 + loss_ion3 + loss_null

            #print(torch.squeeze(prd_ion1))
            # # accuracies 
            # _, indxs_ion1 = torch.squeeze(prd_ion1).topk(topk)
            # print(indxs_ion1)
            # _, indxs_ion2 = torch.squeeze(prd_ion2).topk(topk,dim=1)
            # _, indxs_ion3 = torch.squeeze(prd_ion3).topk(topk,dim=1)
            # _, indxs_null = torch.squeeze(prd_null).topk(topk,dim=1)

            #print(indxs_ion1)
            #accuracies = [0.0]*len(fields)
            accuracies[0] += torch.mean((torch.round(prd_ion1) == lbl_ion1).float())
            accuracies[1] += torch.mean((torch.round(prd_ion2) == lbl_ion2).float())
            accuracies[2] += torch.mean((torch.round(prd_ion3) == lbl_ion3).float())
            accuracies[3] += torch.mean((torch.round(prd_null) == lbl_null).float())

            #print(accuracies)

            if is_training:
                optimizer.zero_grad()
                loss_final.backward()
                optimizer.step()
        
            if i%interval==0:
                accs = [acc/batch_cnt for acc in accuracies]
                print(f'[{status}] iter: {i} loss: {loss_final.item():6f}')
                print (' ,'.join(list(f'{f}: {x:.4f}' for f, x in zip(fields, accs))))

def train_loop(model, epochs, dataloader_train, dataloader_val,
               optimizer, lr_scheduler, criterion_1, criterion_2, interval=10):

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = model.to(device)
    for e in range(epochs):
        lrs = [f'{lr:.6f}' for lr in lr_scheduler.get_lr()]
        print(f'epoch {e} : lrs : {" ".join(lrs)}')
        train_val(model, dataloader_train, optimizer, criterion_1, True, device, 1, interval)
        train_val(model, dataloader_val, optimizer, criterion_1, False, device, 1, 1)
        lr_scheduler.step()


In [None]:
import torch
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# #device = 'cpu'
#model = IonNet(768).to(device=device)
model = CNN2Layers(in_channels=768, n_features=512, kernel_size=3, stride=1, padding=1, dropout=0.7).to(device=device)
criterion_1 = nn.BCEWithLogitsLoss()
#criterion_1 = nn.MSELoss()
epochs = 10
lr = 0.0001

optimizer = torch.optim.SGD(model.parameters(), lr = lr)

# optimizer.add_param_group({"params": model.fc_1.parameters(), "lr": 0.1})
# optimizer.add_param_group({"params": model.fc_2.parameters(), "lr": 0.1})
# optimizer.add_param_group({"params": model.fc_3.parameters(), "lr": 0.1})
# optimizer.add_param_group({"params": model.fc_4.parameters(), "lr": 0.1})

lrsched = torch.optim.lr_scheduler.StepLR(optimizer, 10)

train_loop(model, epochs, dataloader_train, dataloader_val, optimizer, lrsched, criterion_1, 5)

# ion1_loss = nn.MSELoss() # Includes Softmax
# ion2_loss = nn.MSELoss() # Doesn't include Softmax
# ion3_loss = nn.MSELoss()
# optimizer = torch.optim.SGD(model.parameters(), lr=1e-4, momentum=0.09)
# sig = nn.Sigmoid()