In [None]:
#!/usr/bin/env python3

#
# Arontier Inc.: Artificial Intelligence in Precision Medicine
# Copyright: 2018-present
#

In [1]:
import os
import torch
from torch import nn
import torch.nn.functional as F
import esm
import json
import numpy as np
import argparse
from einops import rearrange
import string

In [2]:
torch.cuda.set_device(1) 

In [3]:
PROTEIN_PROPERTY = "disorder"
DISORDER_LABEL = '+'

MAX_MSA_ROW_NUM = 256  # 256
MAX_MSA_COL_NUM = 1023  # start token +1 1024

torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x7f17506a5e80>

In [None]:
class lstm_net(nn.Module):
    def __init__(self, input_feature_size=768, hidden_node=256, dropout=0.25, need_row_attention=False, class_num=8):
        super().__init__()
        self.need_row_attention = need_row_attention
        self.linear_proj = nn.Sequential(
            nn.Linear(input_feature_size, input_feature_size // 2),
            nn.InstanceNorm1d(input_feature_size // 2),
            nn.ReLU(),
            nn.Linear(input_feature_size // 2, input_feature_size // 4),
            nn.InstanceNorm1d(input_feature_size // 4),
            nn.ReLU(),
            nn.Linear(input_feature_size // 4, input_feature_size // 4),
        )

        if self.need_row_attention:
            lstm_input_feature_size = input_feature_size // 4 + 144*2
        else:
            lstm_input_feature_size = input_feature_size // 4

        self.lstm = nn.LSTM(
            input_size=lstm_input_feature_size,
            hidden_size=hidden_node,
            num_layers=2,
            bidirectional=True,
            dropout=dropout,
        )

        self.to_property = nn.Sequential(
            nn.Linear(hidden_node * 2, hidden_node * 2),
            nn.InstanceNorm1d(hidden_node * 2),
            nn.ReLU(),
            nn.Linear(hidden_node * 2, class_num),
        )

    def forward(self, msa_query_embeddings, msa_row_attentions):
        msa_query_embeddings = self.linear_proj(msa_query_embeddings)

        if self.need_row_attention:
            msa_row_attentions = rearrange(msa_row_attentions, 'b l h i j -> b (l h) i j')
            msa_attention_features = torch.cat((torch.mean(msa_row_attentions, dim=2), torch.mean(msa_row_attentions, dim=3)), dim=1)
            # msa_attention_features = (torch.mean(msa_row_attentions, dim=2) + torch.mean(msa_row_attentions, dim=3))/2
            msa_attention_features = msa_attention_features.permute((0, 2, 1))

            lstm_input = torch.cat([msa_query_embeddings, msa_attention_features], dim=2)

        else:
            lstm_input = msa_query_embeddings

        lstm_input = lstm_input.permute((1, 0, 2))
        lstm_output, lstm_hidden = self.lstm(lstm_input)
        lstm_output = lstm_output.permute((1, 0, 2))
        label_output = self.to_property(lstm_output)

        return label_output

In [None]:
def extract_msa_transformer_features(msa_seq, msa_transformer, msa_batch_converter, device=torch.device("cpu")):
    msa_seq_label, msa_seq_str, msa_seq_token = msa_batch_converter([msa_seq])
    msa_seq_token = msa_seq_token.to(device)
    msa_row, msa_col = msa_seq_token.shape[1], msa_seq_token.shape[2]
    print(f"{msa_seq_label[0][0]}, msa_row: {msa_row}, msa_col: {msa_col}")

    if msa_col > MAX_MSA_COL_NUM:
        print(f"msa col num should less than {MAX_MSA_COL_NUM}. This program force the msa col to under {MAX_MSA_COL_NUM}")
    msa_seq_token = msa_seq_token[:, :, :MAX_MSA_COL_NUM]

    ### keys: ['logits', 'representations', 'col_attentions', 'row_attentions', 'contacts']
    msa_transformer_outputs = msa_transformer(
        msa_seq_token, repr_layers=[12],
        need_head_weights=True, return_contacts=True)
    msa_row_attentions = msa_transformer_outputs['row_attentions']
    msa_representations = msa_transformer_outputs['representations'][12]
    msa_query_representation = msa_representations[:, 0, 1:, :]  # remove start token
    msa_row_attentions = msa_row_attentions[..., 1:, 1:]  # remove start token

    return msa_query_representation, msa_row_attentions

In [None]:
def str_find_ch(s, ch):
    return [i for i, ltr in enumerate(s) if ltr == ch]



def save_property_to_json(out_property_json, output_property, query_seq):
    output_property_list = output_property.tolist()
    output_property_list = [round(x, 4) for x in output_property_list]

    json_dict = {
        "disorder_data": output_property_list,
        "query_seq": query_seq,
        "metadata": {
            "precision": 4,
            "title": "disorder",
            "data-min": 0,
            "data-max": 1,
        }
    }

    with open(out_property_json, "w") as f:
        json.dump(json_dict, f, indent=4)

In [None]:
if __name__ == '__main__':

    parser = argparse.ArgumentParser(description='S-Pred for accessible surface area prediction: input msa and output idr (.json)')
    parser.add_argument('-i', '--input_path', type=str, default='examples/s_pred_idr.a3m',
                        help='input msa path (.json or .a3m)')
    parser.add_argument('-o', '--output_path', type=str, default='s_pred_idr.out',
                        help='output predicted idr probability')
    parser.add_argument('--conv_model_path', type=str,
                        default='s_pred_idr_weights.pth',
                        help='model weight path')

    msa_args = parser.add_argument_group('MSA')

    msa_args.add_argument('--msa_method', type=str, help='input msa method')
    msa_args.add_argument('--msa_row_num', type=int, default=256,
                          help='input msa row num to msa transformer')
    parser.add_argument('--device', type=str, default='cpu', choices=['cpu', 'gpu'],
                        help='choose device: cpu or gpu')


    args = parser.parse_args()

    print("===================================")
    print("Print Arguments:")
    print("===================================")

    print(' '.join(f'{k} = {v}\n' for k, v in vars(args).items()))


    if args.device == 'cpu':
        device = torch.device("cpu")
    else:
        if torch.cuda.is_available():
            device = torch.device("cuda:0")
        else:
            print("gpu is not available, run on cpu")
            device = torch.device("cpu")

    ## if already have the msa_transformer_weight
    ## msa_transformer, msa_alphabet = esm.pretrained.load_model_and_alphabet_local(msa_transformer_weight_path)

    msa_transformer, msa_alphabet = esm.pretrained.esm_msa1_t12_100M_UR50S()
    msa_batch_converter = msa_alphabet.get_batch_converter()

    msa_transformer.to(device)
    msa_transformer.eval()

    conv_model = lstm_net(input_feature_size=768, hidden_node=256, dropout=0.25, need_row_attention=True, class_num=1)
    conv_model = conv_model.to(device)

    if device.type == 'cpu':
        ch = torch.load(args.conv_model_path, map_location=torch.device('cpu'))
    else:
        ch = torch.load(args.conv_model_path)

    conv_model.load_state_dict(ch['conv_model'])
    conv_model.to(device)
    conv_model.eval()

    for param in msa_transformer.parameters():
        param.requires_grad = False
    for param in conv_model.parameters():
        param.requires_grad = False

    print("===================================")
    print("Extract msa transformer features")
    print("===================================")

    if args.input_path.endswith('.json'):
        msa_seq, query_seq = read_msa_json(args.input_path, args.msa_method, args.msa_row_num)
    else:
        msa_seq, query_seq = read_msa_file(args.input_path, args.msa_row_num)


    msa_row_num = len(msa_seq)
    msa_col_num = len(query_seq)

    print(f"msa row number: {msa_row_num}")
    print(f"msa column number: {msa_col_num}")


    msa_query_representation, msa_row_attentions = extract_msa_transformer_features(msa_seq,
                                                                                    msa_transformer,
                                                                                    msa_batch_converter,
                                                                                    device=device)

    msa_query_representation.to(device)
    msa_row_attentions.to(device)

    output_property = conv_model(msa_query_representation, msa_row_attentions)

    output_property_sigmoid = F.sigmoid(output_property)
    output_property_sigmoid_np = output_property_sigmoid.data.cpu().numpy().squeeze()

    output_property_json_path = args.output_path + '.idr.json'
    save_property_to_json(output_property_json_path, output_property_sigmoid_np, query_seq)

    print("===================================")
    print("Done")
    print("===================================")