In [3]:
import torch
import esm

# Load ESM-2 model
model, alphabet = esm.pretrained.esm2_t33_650M_UR50D()
batch_converter = alphabet.get_batch_converter()
#model.eval()  # disables dropout for deterministic results

# Prepare data (first 2 sequences from ESMStructuralSplitDataset superfamily / 4)
data = [
    ("protein1", "MKTVRQERLKSIVRILERSKEPVSGAQLAEELSVSRQVIVQDIAYLRSLGYNIVATPRGYVLAGG"),
    ("protein2", "KALTARQQEVFDLIRDHISQTGMPPTRAEIAQRLGFRSPNAAEEHLKALARKGVIEIVSGASRGIRLLQEE"),
    ("protein2 with mask","KALTARQQEVFDLIRD<mask>ISQTGMPPTRAEIAQRLGFRSPNAAEEHLKALARKGVIEIVSGASRGIRLLQEE"),
    ("protein3",  "K A <mask> I S Q"),
]
batch_labels, batch_strs, batch_tokens = batch_converter(data)
batch_lens = (batch_tokens != alphabet.padding_idx).sum(1)
print('batch_lens:',batch_lens)###batch_lens代表每一个序列的长度
# Extract per-residue representations (on CPU)
with torch.no_grad():
    results = model(batch_tokens, repr_layers=[33], return_contacts=True)
token_representations = results["representations"][33]
print(token_representations.shape)
# Generate per-sequence representations via averaging
# NOTE: token 0 is always a beginning-of-sequence token, so the first residue is token 1.
sequence_representations = []
for i, tokens_len in enumerate(batch_lens):
    sequence_representations.append(token_representations[i, 1 : tokens_len - 1].mean(0))

print(sequence_representations[1].shape)

batch_lens: tensor([67, 73, 73,  8])
torch.Size([4, 73, 1280])
torch.Size([1280])


In [None]:
import torch
import esm

# Load ESM-2 model
model, alphabet = esm.pretrained.esm2_t33_650M_UR50D()
batch_converter = alphabet.get_batch_converter()
#model.eval()  # disables dropout for deterministic results

# Prepare data (first 2 sequences from ESMStructuralSplitDataset superfamily / 4)
data = [
    ("protein1", "MKTVRQERLKSIVRILERSKEPVSGAQLAEELSVSRQVIVQDIAYLRSLGYNIVATPRGYVLAGG"),
    ("protein2", "KALTARQQEVFDLIRDHISQTGMPPTRAEIAQRLGFRSPNAAEEHLKALARKGVIEIVSGASRGIRLLQEE"),
    ("protein2 with mask","KALTARQQEVFDLIRD<mask>ISQTGMPPTRAEIAQRLGFRSPNAAEEHLKALARKGVIEIVSGASRGIRLLQEE"),
    ("protein3",  "K A <mask> I S Q"),
]
batch_labels, batch_strs, batch_tokens = batch_converter(data)
batch_lens = (batch_tokens != alphabet.padding_idx).sum(1)

# Extract per-residue representations (on CPU)
with torch.no_grad():
    results = model(batch_tokens, repr_layers=[33], return_contacts=True)
token_representations = results["representations"][33]
print(token_representations.shape)
# Generate per-sequence representations via averaging
# NOTE: token 0 is always a beginning-of-sequence token, so the first residue is token 1.
sequence_representations = []
for i, tokens_len in enumerate(batch_lens):
    sequence_representations.append(token_representations[i, 1 : tokens_len - 1].mean(0))

print(sequence_representations[1].shape)

In [2]:
from loader import MoleculeDataset, SeqDataset,SeqMolDataset#########################
import torch
import torch
from torchvision.models import resnet18

import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import argparse

from loader import MoleculeDataset#################
#from torch_geometric.data import DataLoader
from torch_geometric.loader import DataLoader

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from tqdm import tqdm
import numpy as np

from model import GNN, GNN_graphpred,GNN_graphpred_1
from sklearn.metrics import roc_auc_score

from splitters import scaffold_split,scaffold_split_1
import pandas as pd
import os
import shutil
from tensorboardX import SummaryWriter
import esm

from SeqMolModel import InteractionModel,InteractionModel_1,SequenceModel

In [3]:
# Training settings
parser = argparse.ArgumentParser(description='PyTorch implementation of pre-training of graph neural networks')
parser.add_argument('--device', type=int, default=0,
                        help='which gpu to use if any (default: 0)')#0000
parser.add_argument('--batch_size', type=int, default=16,
                        help='input batch size for training (default: 32)')
parser.add_argument('--epochs', type=int, default=100,
                        help='number of epochs to train (default: 100)')
parser.add_argument('--lr', type=float, default=0.001,
                        help='learning rate (default: 0.001)')
parser.add_argument('--lr_scale', type=float, default=1,
                        help='relative learning rate for the feature extraction layer (default: 1)')
parser.add_argument('--decay', type=float, default=0,
                        help='weight decay (default: 0)')
parser.add_argument('--num_layer', type=int, default=5,
                        help='number of GNN message passing layers (default: 5).')
parser.add_argument('--emb_dim', type=int, default=300,
                        help='embedding dimensions (default: 300)')
parser.add_argument('--dropout_ratio', type=float, default=0.5,
                        help='dropout ratio (default: 0.5)')
parser.add_argument('--graph_pooling', type=str, default="mean",
                        help='graph level pooling (sum, mean, max, set2set, attention)')
parser.add_argument('--JK', type=str, default="last",
                        help='how the node features across layers are combined. last, sum, max or concat')
parser.add_argument('--gnn_type', type=str, default="gin")
parser.add_argument('--dataset', type=str, default = 'sider', help='root directory of dataset. For now, only classification.')
#parser.add_argument('--input_model_file', type=str, default = 'None', help='filename to read the model (if there is any)')
parser.add_argument('--input_model_file', type=str, default = 'Mole-BERT', help='filename to read the model (if there is any)')
parser.add_argument('--filename', type=str, default = '', help='output filename')
parser.add_argument('--seed', type=int, default=42, help = "Seed for splitting the dataset.")
parser.add_argument('--runseed', type=int, default=0, help = "Seed for minibatch selection, random initialization.")
parser.add_argument('--split', type = str, default="scaffold", help = "random or scaffold or random_scaffold")
parser.add_argument('--eval_train', type=int, default = 1, help='evaluating training or not')
parser.add_argument('--num_workers', type=int, default = 4, help='number of workers for dataset loading')
args = parser.parse_args(args=[])###############33


In [4]:
seq_dataset=SeqDataset('dataset/affinity/processed/sequence.csv')
#seq_dataset[2]

In [None]:
seq_dataloader=DataLoader(seq_dataset,batch_size=args.batch_size,shuffle=True,num_workers=args.num_workers)

#protein_model, protein_alphabet = esm.pretrained.esm2_t33_650M_UR50D()

#print(protein_alphabet)
#alphabet = esm.Alphabet.from_architecture(model_data["args"].arch)
#batch_converter = alphabet.get_batch_converter()
#protein_model.eval()  # disables dropout for deterministic results

#batch_converter=protein_alphabet.get_batch_converter()
protein_model, protein_alphabet = esm.pretrained.esm2_t33_650M_UR50D()
batch_converter=protein_alphabet.get_batch_converter()
for i ,seq in enumerate(seq_dataloader):
    #seq_list=seq_model(seq)
    seq_data_list=[]
    print(i,seq)
    lenth=len(seq)
    for m , s in enumerate(seq):
        seq_data_list.append((str(m),s))
    protein_sequence_representations = []
    

    batch_labels, batch_strs, batch_tokens = batch_converter(seq_data_list)
    batch_lens = (batch_tokens != alphabet.padding_idx).sum(1)
    print('batch_lens:',batch_lens)
    with torch.no_grad():
        results = protein_model(batch_tokens, repr_layers=[33], return_contacts=True)
        
    protein_token_representations = results["representations"][33]
    #print('protein_token_representations:',protein_token_representations)
    for i, tokens_len in enumerate(batch_lens):
        print(i)
        protein_sequence_representations.append(protein_token_representations[i, 1 : tokens_len - 1].mean(0))

0 ['SLQPPPQQLIVQNKTIDLPAVYQLNGGEEANPHAVKVLKELLSGKQSSKKGMLISIGEKGDKSVRKYSRQIPDHKEGYYLSVNEKEIVLAGNDERGTYYALQTFAQLLKDGKLPEVEIKDYPSVRYRGVVEGFYGTPWSHQARLSQLKFYGKNKMNTYIYGPKDDPYHSAPNWRLPYPDKEAAQLQELVAVANENEVDFVWAIHPGQDIKWNKEDRDLLLAKFEKMYQLGVRSFAVFFDDISGEGTNPQKQAELLNYIDEKFAQVKPDINQLVMCPTEYNKSWSNPNGNYLTTLGDKLNPSIQIMWTGDRVISDITRDGISWINERIKRPAYIWWNFPVSDYVRDHLLLGPVYGNDTTIAKEMSGFVTNPMEHAESSKIAIYSVASYAWNPAKYDTWQTWKDAIRTILPSAAEELECFAMHNSDLGPNGHGYRREESMDIQPAAERFLKAFKEG', 'GDIVSEKKPATEVDPTHFEKRFLKRIRDLGEGHFGKVELCRYDPEGDNTGEQVAVKSLKP', 'IIGGEFTTIENQPWFAAIYRRHRGGSVTYVCGGSLMSPCWVISATHCFIDYPKKEDYIVYLGRSRLNSNTQGEMKFEVENLILHKDYSADTLAHHNDIALLKIRSKEGRCAQPSRTIQTICLPSMYNDPQFGTSCEITGFGKEQSTDYLYPEQLKMTVVKLISHRECQQPHYYGSEVTTKMLCAADPQWKTDSCQGDSGGPLVCSLQGRMTLTGIVSWGRGCALKDKPGVYTRVSHFLPWIRSHT', 'AVPFVEDWDLVQTLGEGAYGEVQLAVNRVTEEAVAVKIVDMKR', 'IVGGYTCGANTVPYQVSLNSGYHFCGGSLINSQWV', 'MDVFLMIRRHKTTIFTDAKESSTVFELKRIVEGILKRPPDEQRLYKDDQLLDDGKTLGE', 'VDPTHFEKRFLKRIRDLGEGHFGKVELCRYDPEGDNTGEQVAVKSLKP', 'HMVIDPSELTFVQEIGSGQFGLVH