In [None]:
import torch

seed = 1
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if(device == "cuda"):
    torch.cuda.synchronize()

In [None]:
!mkdir models

In [None]:
!mkdir datasets

In [None]:
pip install protein_bert_pytorch



In [None]:
import torch
from protein_bert_pytorch import ProteinBERT, PretrainingWrapper

model = ProteinBERT(
    num_tokens = 25,
    num_annotation = 128, # 8943,
    dim = 512,
    dim_global = 256,
    depth = 6,
    narrow_conv_kernel = 9,
    wide_conv_kernel = 9,
    wide_conv_dilation = 5,
    attn_heads = 8,
    attn_dim_head = 64,
    local_to_global_attn = False,
    local_self_attn = True,
    num_global_tokens = 2,
    glu_conv = False
)

In [None]:
_aa2int = {'A' : 1,'R' : 2,'N' : 3,'D' : 4,'C' : 5,'Q' : 6,'E' : 7,'G' : 8,'H' : 9,'I' : 10,
          'L' : 11,'K' : 12,'M' : 13,'F' : 14,'P' : 15,'S' : 16,'T' : 17,'W' : 18,'Y' : 19,
          'V' : 20,'B' : 21,'Z' : 22,'X' : 23,'*' : 24,'-' : 25,'?' : 0}

def aa2int(seq : str) -> list:
    return [_aa2int[i] for i in seq]

x = 'ABTZX'
aa2int(x) # returns [1, 21, 17, 22, 23]

[1, 21, 17, 22, 23]

In [None]:
dataset = []

import pandas


# dataset_loaded = pandas.read_pickle('datasets/compressed_sps_dataset.pkl')

with open('datasets/compressed_sps_dataset2.pkl', 'rb') as filereader:
   dataset_loaded = pickle.load(filereader)

for instance in dataset_loaded:
    seq = aa2int(instance[0])
    seq = seq + [0]*(250-len(seq))
    dataset.append([torch.tensor(seq), instance[1]])
    
print(len(dataset))

191734


In [None]:
# fill up the selected sequences to equal length
# split into train and test-set


from torch.utils.data import DataLoader, random_split

test_size = int(0.0 * len(dataset))
train_size = len(dataset) - test_size
train_set, test_set = random_split(dataset, [train_size, test_size])

batch_size=16
train_dataloader = DataLoader(train_set, batch_size=batch_size, shuffle=True)

# Sizes of resulting train and test-set
print(len(train_dataloader))

11984


In [None]:
# do the following in a loop for a lot of sequences and annotations
from tqdm.notebook import tqdm
criterion = torch.nn.MSELoss()

def train_unsupervised(model, dataloader, epochs = 1, lr=0.000001, data_divisor=10):
    for epoch in range(epochs):
      optimizer = torch.optim.Adam(model.parameters(), lr=lr)
      summed_loss = 0.0
      model.to(device)
      with tqdm(total=len(dataloader)/data_divisor) as pbar:
          for index_data, (inputs,labels) in enumerate(dataloader):
              optimizer.zero_grad()
              
              annotation_list = [[0]*128]*inputs.shape[0]
              annotation = torch.tensor(annotation_list).float().to(device)
              mask       = torch.ones(inputs.shape).bool().to(device)
              
              loss = learner(inputs.to(device), annotation, mask = mask) # (2, 2048, 21), (2, 8943)
              summed_loss+=loss.detach()
              loss.backward()
              optimizer.step()
              pbar.set_description(f'loss: {"%.5f" % (summed_loss/((index_data+1)*batch_size))}')
              pbar.update(1)
              # Only go for 10% of training data each epoch
              if(index_data>=len(dataloader)/data_divisor):
                  break
          pbar.close()

In [None]:
model = torch.load('models/pretrained_sps_20.pt')
model.to(device)
learner = PretrainingWrapper(
    model,
    random_replace_token_prob = 0.05,    # what percentage of the tokens to replace with a random one, defaults to 5% as in paper
    remove_annotation_prob = 0.0,       # what percentage of annotations to remove, defaults to 25%
    add_annotation_prob = 0.00,          # probability to add an annotation randomly, defaults to 1%
    remove_all_annotations_prob = 0.0,   # what percentage of batch items to remove annotations for completely, defaults to 50%
    seq_loss_weight = 1.,                # weight on loss of sequence
    annotation_loss_weight = 0.,         # weight on loss of annotation
    exclude_token_ids = (0, 1, 2)        # for excluding padding, start, and end tokens from being masked
)

In [None]:
train_unsupervised(model, train_dataloader, epochs=5)
torch.save(model, 'models/pretrained_sps_25.pt')

  0%|          | 0/1198.4 [00:00<?, ?it/s]

In [None]:
train_unsupervised(model, train_dataloader, epochs=5)
torch.save(model, 'models/pretrained_sps_30.pt')

In [None]:
train_unsupervised(model, train_dataloader, epochs=5)
torch.save(model, 'models/pretrained_sps_35.pt')

In [None]:
train_unsupervised(model, train_dataloader, epochs=5)
torch.save(model, 'models/pretrained_sps_40.pt')

In [None]:
train_unsupervised(model, train_dataloader, epochs=5)
torch.save(model, 'models/pretrained_sps_45.pt')

In [None]:
train_unsupervised(model, train_dataloader, epochs=5)
torch.save(model, 'models/pretrained_sps_50.pt')