In [1]:
import torch
import torch.nn.functional as F
import esm

from model import MSAVAE
from config import create_config
from collections import OrderedDict
from torch.nn.parallel import DistributedDataParallel as DDP
from data import read_msa, greedy_select, pad_msa_sequence

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
ckpt = torch.load("weights/model_22.pth")
new_state_dict = OrderedDict()
for k, v in ckpt["model"].items():
    name = k[7:]
    new_state_dict[name] = v

  ckpt = torch.load("weights/model_22.pth")


In [3]:
config = create_config()
model = MSAVAE(config).to("cuda:1")
model.load_state_dict(new_state_dict)

<All keys matched successfully>

In [4]:
raw_msa = read_msa("databases/openfold/scratch/alignments_2/1077246/uniclust30.a3m")
# raw_msa = read_msa("databases/data/a3m/5dik_1_A.a3m")
# raw_msa = read_msa("databases/openfold/scratch/alignments_1/105059/uniclust30.a3m")
msa = [raw_msa[2:]]
seq = [raw_msa[0]]
single_seq_embeddings, pairwise_seq_embeddings, msa_tokens, mask = pad_msa_sequence(config, seq, msa)
_, msa_alphabet = esm.pretrained.esm_msa1b_t12_100M_UR50S()

In [11]:
raw_msa[1][1] == raw_msa[0][1]

True

In [7]:
raw_msa[0]

('tr|A0A150IYD1|A0A150IYD1_9EURY_consensus',
 'MVDRLEEKGIFLTYLSNYAIRSALSCYLNIGKVNLPLKKVEGTIASTSKEELPIDPYEERIIALTELGIPYRKTKSPEEILLKRLQERVFMKNRFLLPSSKRSIVNFEDAFLGNDNLLRKRLEEFGLTKESAEYIICPQNEKCLCKKCDKRYDTSSERIREMRRRIYEIKNFNHEVCPYTH')

In [6]:
with torch.no_grad():
    pred_msa, perm, mu, logvar = model(
        single_seq_embeddings.to("cuda:1"),
        pairwise_seq_embeddings.to("cuda:1"), 
        msa_tokens.unsqueeze(-1).float().to("cuda:1"), 
        mask.to("cuda:1")
    )

mask_expanded = mask.unsqueeze(1)
mask_expanded = mask_expanded.expand(-1, config.data.msa_depth, -1).to("cuda:1")
pred_msa = pred_msa[mask_expanded]
msa_tokens = msa_tokens.to("cuda:1")
size = msa_tokens[mask_expanded].shape[-1]
msa_tokens = msa_tokens[mask_expanded].view(1, 32, size//32)

In [7]:
probs = F.softmax(pred_msa, dim=-1)
sampled_indices = torch.multinomial(probs, num_samples=1).view(1, 32, size//32)

def indices_to_tokens(indices, alphabet):
    return [''.join([alphabet.get_tok(idx.item()) for idx in sequence]) for sequence in indices]

sampled_tokens = indices_to_tokens(sampled_indices[0], msa_alphabet)
true_tokens = indices_to_tokens(msa_tokens[0], msa_alphabet)

In [8]:
for token in true_tokens:
    print(token)

-KNLLKNDGVILTYTSAAPVRSAVVNGLHVGEGPSFGRS-GGTVASLNPEDNPLSTDDERMIALSDAGIPFKDPGSSQDILKRREEERKISRGKIKFSSTVKTPIYLNEKL--EEGRVLNNLKKLGLKSPEARYVVCPQYKECICGGGCENFNNSRERIYEMSHRLRSIVTIND-------
-KRILKSKGLILTYTSSIPVKAGLIAGFHVGDGPVFGRISGGTIASPSYKDKDLSYDEERLIALSDLAIPFRDLSSCETIVENRKQERQIARGNTKISSAVKTPIYLGQDIVD--ERVLRNLTDFNTKSKIVLDLISSQNS-N---DKLNFKNSSRARILDIKQRWNSLLDI---------
-KQIIKKDGLLLTYTSAAPVRSALIKGFYVGEVPPFGRKKGGTVASLSPDSQDLPGEDELMIALSDAGIPYRDQENACLIKDRRIKERKARRGKDKFASTVKTPLYILNEP--EDHRVLRNLQSMGFNLDKSRFIVCPQFNDCICGRGCKIFKNSKERIEEMENRLQSVKN----------
-RDLLTEDGVLLTYTSSAPLRYALIDGLQVGEGPSMGRS-GGTIASPDIKPKPLNNNDERMIALSDAGIPFKDPSLPENIKQQRQKERIKARGNYKIASTVKTPVYLARDIDD--EKALKHLKDVRLDCEKSRYLVCPQFSECICTCKQERLSTSRARIKEMEKRLTNITNSK--------
-KPLLKPDGMISSFSKSHTMRYSLVKGYHIGEGPEFGRS-GGTIASTALEEKPISIDDERVVALSDAGVPLRDLDSSLEILERRDEERENVRGKFKFPSTVRCPVFLGKDL--KESRVLNGLKTIGLHSQKSMYLICPQYQDCVCGNSCKPIDNSRDRIIEMEKRLNILAVNQL-------
-STLLKEDGVILTYTSAAPVRYALLNGLEIGEGPALGRS-GGTMASPSLHTKPLGSVDERMIALSDAGIPYRDPLSANEIIENRHLERIT

In [10]:
for token in sampled_tokens:
    print(token)

FGTLIKENGVLATYTSSIAVRSAFIEGYHIGEVPEFGRKVGGTLASLNFMKKPIGSADE-HIALSNEGVPYRDPGSAREIIDRRCEERINSRGRTKFSSTVKTPIYLGISVV---D-VERNLKTFGLNGEEAMYIVKPQYQN-YCWWQFEKDASSRTR-LEM-RRLKYIATIRM-------
-ADVLQEDGIICTYTAACVMRSALIENLYIGEGNVLGRS-GGTIASCDPERE-LSHADERVIALSDVEIPYRDQ-DMMEILASRLN-RVSMR-KSKFPSTVK-PLYLGNDYV--DGRVERNLESLGPNK-RSLYILGCQKEECVC---EERNNSTRTRIYEMRWRLLEVLNFNE-------
-RRFLKPNGIILTSTSNAPIRSSLSEGLNVGDGRDFGSKSGGTVASLDP-DPNIDSFEERIIAL-DSGIPFRDPLSAQEILKRRQKERHNMRGVTLLASTVRTPVYLG-DLV--Y-SVLGRVESVGTNSLEILYMICPQENDEECKWQEKENNPTRERVLE-MKRLMFLVNL-E-------
-FDILNDDGILLTYLAASNIREAA-CDFQIGKGNKFKRKVEGTIASLDIKKKSLDIND-IKCALSMVGRPYRDPD-GMTIVLNRTKET-AARLKK-ISS--KFPVYLGLD-E-D-DRLLKG---F-LQDEKSL--IE-Q--Y---G---DI-D-S-DRILEMK-RMSILK-SR--------
FRDLLKEDGILLTYLSACAVKSGLREFFHIGKGPEFGRNSGVTIASLSISDKPLSKFDERIIALSDAGIPFKRLASD--IIKRRREERMQIRHKS--NSAVKTPIFFGKDNDG-GKKVIKNL-RVSLTKKRAFYIIEPQFSQCYCGQWEKGPDNSRERISIMTKRLWIIKTIN--------
-KKLLKDDGVIATYSSAAPVRSALSNGSEVGEGPPFGRS-GGTIASPNPSRKPL-WVDERHMALSDAGIPYRDPSDAEDIIKNRILERWV