### Preparation ###
Create folders and upload fp-dataset into the dataset folder.
Upload the Protein-Bert and predictor model into the models folder.

In [52]:
import torch

seed = 0
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if(device == "cuda"):
    torch.cuda.synchronize()
print(device)

cpu


In [53]:
!mkdir models

Ein Unterverzeichnis oder eine Datei mit dem Namen "models" existiert bereits.


In [54]:
!mkdir datasets

Ein Unterverzeichnis oder eine Datei mit dem Namen "datasets" existiert bereits.


### Architecture design ###

In [55]:
pip install protein_bert_pytorch

Note: you may need to restart the kernel to use updated packages.


You should consider upgrading via the 'C:\Users\Prophet.DESKTOP-UUFA83J\AppData\Local\Programs\Python\Python39\python.exe -m pip install --upgrade pip' command.


In [56]:
import torch
from protein_bert_pytorch import ProteinBERT, PretrainingWrapper
annotation_max = 18000  # 2350

def proteinBERT_model():
  return ProteinBERT(
      num_tokens = 25,
      num_annotation = annotation_max,
      dim = 512,
      dim_global = 256,
      depth = 6,
      narrow_conv_kernel = 9,
      wide_conv_kernel = 9,
      wide_conv_dilation = 5,
      attn_heads = 8,
      attn_dim_head = 64,
      local_to_global_attn = False,
      local_self_attn = True,
      num_global_tokens = 2,
      glu_conv = False
  )

In [57]:
import torch.nn as nn

class Identity(nn.Module):
    def __init__(self):
        super(Identity, self).__init__()
        
    def forward(self, x):
        return x
    
class predictorNet(torch.nn.Module):
    def __init__(self, dropout):
        super().__init__()
        self.flatten = torch.nn.Flatten(start_dim=1)
        self.linear1 = torch.nn.Linear(179200, 64)
        self.linear3 = torch.nn.Linear(64, 2)
        self.dropout = torch.nn.Dropout(p=dropout)
        
    def forward(self, x, dropout=True):
        x = self.flatten(x)
        if(dropout): x = self.dropout(x)
        x = self.linear1(x)
        x = torch.nn.functional.softplus(x)
        x = self.linear3(x)
        return x

model = torch.load('models/pretrained_drp2_005_m_ex_em_1.pt', map_location=torch.device(device))
model_predictor = torch.load('models/pretrained_drp2_005_b_ex_em_1.pt', map_location=torch.device(device))

### Data preparation ###

In [58]:
def clean(result):
    if(result!=None):
        return result
    else:
        return -1

In [59]:
_acid2int = {'A' : 1,'R' : 2,'N' : 3,'D' : 4,'C' : 5,'Q' : 6,'E' : 7,'G' : 8,'H' : 9,'I' : 10,
          'L' : 11,'K' : 12,'M' : 13,'F' : 14,'P' : 15,'S' : 16,'T' : 17,'W' : 18,'Y' : 19,
          'V' : 20,'B' : 21,'Z' : 22,'X' : 23,'*' : 24,'-' : 25,'?' : 0}

def acid2int(seq : str) -> list:
    return [(_acid2int[i]) for i in seq]

x = 'ABTZX'
acid2int(x) # returns [1, 21, 17, 22, 23]

[1, 21, 17, 22, 23]

In [60]:
import json
   
# Configuration
min_seq=25
max_seq=350
    
# Opening JSON file
f = open('datasets/fp_database.json', encoding='UTF-8')
normalizer = [0.002, 0.002, 1, 0.01]

fp_dataset = []
name_list = []
# returns JSON object as 
# a dictionary
data = json.load(f)
# Iterating through the json
# list
print(len(data))
for index, i in enumerate(data):
    try:  # To read incomplete files and make this method more robust
        name=i['name']
        seq=i['seq']
        em=i['states'][0]['em_max']
        ex=i['states'][0]['ex_max']
        qy=i['states'][0]['qy']
        br=i['states'][0]['brightness']
        if(index==0):
            print(name)
            print()
            print(seq)
            print(qy)
            print(br)
            print(em)
            print(ex)
        if(seq!=None) and (em!=None) and (ex!=None): #  and (qy!=None):
            if(len(seq)>min_seq) and (len(seq)<max_seq) and (em>0) and (ex>0):
                seq_ = seq + "".join(["?"]*(max_seq-len(seq)))
                label = torch.tensor([em, ex]).float()
                sequence = torch.tensor(acid2int(seq_))
                fp_dataset.append([sequence, torch.tensor([label[0]*normalizer[0], label[1]*normalizer[1]])])
                # fp_dataset.append([sequence, torch.tensor([torch.tensor([em]).float()*normalizer[0]])])
                name_list.append(name)
    except:
        True
# Closing file
f.close()

print(len(fp_dataset))

797
617


In [61]:
import pickle
normalizer = [0.002, 0.002, 1, 0.01]
with (open("datasets/test_dataset_split_fp.pkl", "rb")) as openfile:
    [train_set, test_set] = pickle.load(openfile)  
# print(test_set[0])

In [62]:
# Get all corresponding protein names for the test-set used as the unseen reference during training process
protein_name_list = []
for protein_index, test_protein in enumerate(test_set):
    for index, element in enumerate(fp_dataset):
        if torch.all(test_protein[0].eq(element[0])):
            protein_name_list.append([protein_index, name_list[index]])
            continue
print(protein_name_list)

[[0, 'mcavRFP'], [1, 'cgfTagRFP'], [2, 'AQ14'], [3, 'rsFusionRed1'], [4, 'mRhubarb713'], [5, 'super-TagRFP'], [6, 'G1'], [7, 'Superfolder CFP'], [8, 'LanFP2'], [9, 'mEosFP-F173S'], [10, 'GFPxm18uv'], [11, 'mStable'], [12, 'mNeptune2.5'], [13, 'iRFP682'], [14, 'ShyRFP'], [15, 'deGFP1'], [16, 'mMiCy'], [17, 'sarcGFP'], [18, 'mEos4b'], [19, 'mEGFP'], [20, 'Folding Reporter GFP'], [21, 'mStrawberry'], [22, 'mCherry-XL'], [23, 'mEos2-A69T'], [24, 'mEYFP'], [25, 'Topaz'], [26, 'mNeptune2'], [27, 'd-RFP618'], [28, 'GFP(E222G)'], [29, 'h2-3'], [30, 'mGrape1'], [31, 'Skylan-NS'], [32, 'mGeos-S'], [33, 'mCherry'], [34, 'BDFP1.6'], [35, 'D10'], [36, 'RFP630'], [37, 'ptilGFP'], [38, 'eforCP'], [39, 'EBFP1.2'], [40, 'roGFP1-R8'], [41, 'scubRFP'], [42, 'mNeonGreen'], [43, 'rsFolder'], [44, 'Enhanced Cyan-Emitting GFP'], [45, 'avGFP'], [46, 'Sapphire'], [47, 'W1C'], [48, 'mRFP1-Q66S'], [49, 'mCardinal']]


In [63]:
from torch.utils.data import DataLoader
batch_size=16
train_dataloader = DataLoader(train_set, batch_size=batch_size, shuffle=True)
test_dataloader = DataLoader(test_set, batch_size=batch_size, shuffle=True)

# Sizes of resulting train and test-set
print(len(train_dataloader))
print(len(test_dataloader))

36
4


In [64]:
for index, element in enumerate(test_dataloader):
    print(len(element[0]))

16
16
16
2


### Training process ###

In [65]:
global Epoch, lc_train, lc_test
Epoch = 0
lc_train, lc_test = [], []

In [66]:
lr_factor = 1.0

In [67]:
# do the following in a loop for a lot of sequences and annotations
from tqdm.notebook import tqdm
criterion = torch.nn.MSELoss()

# only spectra
def train_supervised(_model, _predictor, dataloader_train, dataloader_test, lr=1, epochs=10, dropout=True, dropout_flipping=False):
    dropout = dropout
    global Epoch
    optimizer = torch.optim.Adam(_model.parameters(),    lr=lr*0.0000001*lr_factor)
    optimizer2 = torch.optim.Adam(_predictor.parameters(), lr=lr*0.0000002*lr_factor)  
    summed_loss = 0.0
    for epoch in range(epochs):
        Epoch+=1
        if(Epoch%10==1):
            print("epoch: " + str(Epoch-1))
        summed_loss = 0.0
        with tqdm(total=len(dataloader_train)) as pbar:
            train_index=0
            # training process
            for index, (inputs,labels) in enumerate(dataloader_train):
                optimizer.zero_grad()
                optimizer2.zero_grad()

                annotation_list = [[0]*512]*inputs.shape[0]
                annotation = torch.tensor(annotation_list).float()
                mask       = torch.ones(inputs.shape).bool()
                
                # print(inputs.shape, annotation.shape, mask.shape)
                representation, annotation_repr = _model(inputs.to(device), annotation.to(device), mask = mask.to(device))
                if(dropout_flipping==True):
                  dropout = not dropout
                x = _predictor(representation, dropout=dropout)
                # print(labels.shape, x.shape)
                # print(inputs)
                loss = criterion(labels.to(device), x.to(device))
                summed_loss+=loss.detach()
                train_index += inputs.shape[0]
                
                loss.backward()
                optimizer.step()
                optimizer2.step()
                pbar.set_description(f'train_loss: {"%.5f" % (summed_loss/train_index*100)}')
                pbar.update(1)
            
            # evaluation process
            test_index = 0
            summed_test_loss = 0
            optimizer.zero_grad()
            optimizer2.zero_grad()   
            for _, (inputs,labels) in enumerate(dataloader_test):
                
                annotation_list = [[0]*512]*inputs.shape[0]
                annotation = torch.tensor(annotation_list).float()
                mask       = torch.ones(inputs.shape).bool()
                representation, annotation_repr = _model(inputs.to(device), annotation.to(device), mask = mask.to(device))
                x = _predictor(representation, dropout=False)
                loss = criterion(labels.to(device), x.to(device))
                test_index += inputs.shape[0]
                summed_test_loss+=loss.detach()
            
            lc_train.append(summed_loss/train_index) 
            lc_test.append(summed_test_loss/test_index)    
            pbar.set_description(f'train_loss: {"%.5f" % (summed_loss/train_index*100)}, test_loss: {"%.5f" % (summed_test_loss/test_index*100)}')
            pbar.close()


In [68]:
def print_test_instances(dataloader, model, model_predictor):
  for index, (inputs,labels) in enumerate(dataloader): 
      annotation_list = [[0]*512]*inputs.shape[0]
      annotation = torch.tensor(annotation_list).float()
      mask       = torch.ones(inputs.shape).bool()
      token, annotation_logits = model(inputs.to(device), annotation.to(device), mask = mask.to(device))
      x = model_predictor(token, dropout=False)
      loss = criterion(labels.to(device), x.to(device))
      value_list = []
      for index in range(inputs.shape[0]):
          value_list.append([round((labels[index].cpu().detach().numpy()/normalizer[0])[0],3), 
                             round((x[index].cpu().detach().numpy()/normalizer[0])[0],3)])  # :2]) # , loss)
          if(index>5):
            break
      print(value_list)
      break

print("train_data")
print_test_instances(train_dataloader, model, model_predictor)
print("test_data")
print_test_instances(test_dataloader, model, model_predictor)

train_data
[[637.0, 637.501], [518.0, 515.618], [625.0, 607.567], [513.0, 510.135], [670.0, 662.628], [490.0, 490.154], [610.0, 604.756]]
test_data
[[479.0, 501.236], [596.0, 565.441], [682.0, 688.02], [516.0, 528.403], [511.0, 512.878], [485.0, 512.042], [530.0, 512.09]]


In [90]:
# Calculate mean square error over the full test-set

def calculate_mse(dataloader, model, model_predictor, root=False):
    instance_number=0
    summed_squared_error_em=0
    summed_squared_error_ex=0
    with tqdm(total=len(dataloader)) as pbar:
        for index, (inputs,labels) in enumerate(dataloader): 
            annotation = torch.tensor([[0]*512]*inputs.shape[0]).float()
            mask       = torch.ones(inputs.shape).bool()
            token, annotation_logits = model(inputs.to(device), annotation.to(device), mask = mask.to(device))
            x = model_predictor(token, dropout=False)
            # print(x.shape)
            for index_batch in range(x.shape[0]-1):
                instance_number+=1
                em_label=(labels[index_batch].cpu().detach().numpy()/normalizer[0])[0]
                em_output=(x[index_batch].cpu().detach().numpy()/normalizer[0])[0]
                ex_label=(labels[index_batch].cpu().detach().numpy()/normalizer[0])[1]
                ex_output=(x[index_batch].cpu().detach().numpy()/normalizer[0])[1]
                summed_squared_error_em+=(em_label-em_output)**2
                summed_squared_error_ex+=(ex_label-ex_output)**2
            mse_em=(summed_squared_error_em/instance_number)
            mse_ex=(summed_squared_error_ex/instance_number)
            pbar.set_description(f'MSE emission: {"%.3f" % mse_em}, MSE excitation: {"%.3f" % mse_ex}')
            pbar.update(1)
        return mse_em, mse_ex
    
print("train_data")
mse_em, mse_ex= calculate_mse(train_dataloader, model, model_predictor)
print("mean error: " + str((round(mse_em**0.5,3), round(mse_ex**0.5,3))))
print("")
print("")
print("test_data")
mse_em, mse_ex= calculate_mse(test_dataloader, model, model_predictor)
print("mean error: " + str((round(mse_em**0.5,3), round(mse_ex**0.5,3))))

train_data


  0%|          | 0/36 [00:00<?, ?it/s]

mean error: (12.681, 20.737)


test_data


  0%|          | 0/4 [00:00<?, ?it/s]

mean error: (28.89, 37.341)
