In [2]:
from evoVAE.utils.datasets import MSA_Dataset
import evoVAE.utils.seq_tools as st
import evoVAE.utils.metrics as mt
from evoVAE.models.seqVAETest import SeqVAETest
import pandas as pd
import torch
import numpy as np

pd.set_option("display.max_rows", None)

In [3]:
extants_aln = pd.read_pickle("../data/gfp/GFP_AEQVI_full_04-29-2022_b08_extants_no_syn_no_dupes.pkl")

train_dataset = MSA_Dataset(extants_aln["encoding"], extants_aln["weights"], extants_aln["id"])
train_loader = torch.utils.data.DataLoader(
    train_dataset, batch_size=1, shuffle=False
)

SEQ_LEN = 0
BATCH_ZERO = 0
SEQ_ZERO = 0
seq_len = train_dataset[BATCH_ZERO][SEQ_ZERO].shape[SEQ_LEN]
input_dims = seq_len * 21


config={
            # Dataset info
            "alignment": "tets",
            "seq_theta": 0.2,  # reweighting
            "AA_count": 21,  # standard AA + gap
            "test_split": 0.2,
            "max_mutation": 4,  # how many mutations the model will test up to
            # ADAM
            "learning_rate": 1e-2,  # ADAM
            "weight_decay": 1e-4,  # ADAM
            # Hidden units
            "momentum": None,
            "dropout": None,
            # Training loop
            "epochs": 500,
            "batch_size": 128,
            "max_norm": 10,  # gradient clipping
            "patience": 3,
            # Model info - default settings
            "architecture": f"SeqVAE_0.25_ancestors_R",
            "latent_dims": 3,
            "hidden_dims": [256, 128, 64],
            # DMS data
            "dms_file": "../data/SPG1_STRSG_Wu_2016.pkl",
            "dms_metadata": "../data/DMS_substitutions.csv",
            "dms_id": "SPG1_STRSG_Wu_2016",
}


train_dataset = MSA_Dataset(extants_aln["encoding"], extants_aln["weights"], extants_aln["id"])
train_loader = torch.utils.data.DataLoader(
    train_dataset, batch_size=4, shuffle=False
)

SEQ_LEN = 0
BATCH_ZERO = 0
SEQ_ZERO = 0
seq_len = train_dataset[BATCH_ZERO][SEQ_ZERO].shape[SEQ_LEN]
input_dims = seq_len * 21

model = SeqVAETest(input_dims, 3, hidden_dims=config["hidden_dims"], config=config)
model.load_state_dict(torch.load("../data/gfp/model_weights/gfp_extants_no_duplicates_model_state.pt"))
model.eval()

SeqVAETest(
  (encoder): Sequential(
    (0): Sequential(
      (0): Linear(in_features=4998, out_features=256, bias=True)
      (1): LeakyReLU(negative_slope=0.01)
      (2): Linear(in_features=256, out_features=256, bias=True)
      (3): LeakyReLU(negative_slope=0.01)
    )
    (1): Sequential(
      (0): Linear(in_features=256, out_features=128, bias=True)
      (1): LeakyReLU(negative_slope=0.01)
      (2): Linear(in_features=128, out_features=128, bias=True)
      (3): LeakyReLU(negative_slope=0.01)
    )
    (2): Sequential(
      (0): Linear(in_features=128, out_features=64, bias=True)
      (1): LeakyReLU(negative_slope=0.01)
      (2): Linear(in_features=64, out_features=64, bias=True)
      (3): LeakyReLU(negative_slope=0.01)
    )
  )
  (z_mu_sampler): Linear(in_features=64, out_features=3, bias=True)
  (z_logvar_sampler): Linear(in_features=64, out_features=3, bias=True)
  (upscale_z): Linear(in_features=3, out_features=64, bias=True)
  (decoder): Sequential(
    (0): Seque

In [4]:
sub = extants_aln.head()
ens = np.stack([x.flatten() for x in sub['encoding'].values])

sub

Unnamed: 0,id,sequence,encoding,weights
0,GFP_AEQVI/1-238,MSKGEELFTGVVPILVELDGDVNGHKFSVSGEGEGDATYGKLTLKF...,"[[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,...",0.003861
1,UniRef100_UPI0011C34247/2-231,VSKGEELFTGVVPILVELDGDVNGHKFSVSGEGEGDATYGKLTLKF...,"[[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,...",0.004115
2,UniRef100_UPI0011C34247/384-556,VSKGEELFTGVVPILVELDGDVNGHKFSVRGEGEGDATNGKLTLKL...,"[[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,...",0.0625
4,UniRef100_UPI0011C3426C/384-556,VSKGEELFTGVVPILVELDGDVNGHKFSVRGEGEGDATNGKLTLKL...,"[[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,...",0.0625
5,UniRef100_UPI001C2E920B/3-240,MSKGEELFTGVVPILVELDGDVNGHKFSVSGEGEGDATYGKLTLKF...,"[[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,...",0.003861


In [5]:
from evoVAE.utils.seq_tools import IDX_TO_AA, AA_TO_IDX


# convert MSA to numpy array for inputing to VAE
def convert_msa_numpy_array(aln):
  sequence_pattern_dict = {}
  seq_msa = []
  seq_key = []
  seq_label = []

  lb = 0

  for id, seq in zip(aln['id'], aln['sequence']):
      seq_trns = [AA_TO_IDX[s] for s in seq]
      seq_trns_m = ''.join([str(x) for x in seq_trns])
      seq_msa.append(seq_trns)
      seq_key.append(id)

      if seq_trns_m not in sequence_pattern_dict:
        sequence_pattern_dict.update({seq_trns_m: lb})
        lb = lb + 1

      seq_label.append(sequence_pattern_dict[seq_trns_m])

  seq_msa = np.array(seq_msa)

  print("Sequence converted to numpy array with shape",seq_msa.shape)
  return seq_msa,seq_key,seq_label

seq_msa,seq_key,seq_label = convert_msa_numpy_array(sub)
seq_msa.shape

Sequence converted to numpy array with shape (5, 238)


(5, 238)

In [7]:

names = []
z_vals = []
from evoVAE.loss.standard_loss import KL_divergence


# SAMPLE Z VALUES FROM THE MODEL FOR EACH EXTANT
for encoding, weights, name in train_loader:


    xHat, zSample, zMu, zLogvar = model.forward(encoding.float())

    # average KL across whole batch
    KLD = KL_divergence(zMu, zLogvar, zSample, weights)

    #input_shape = tuple(xHat.shape[0:-1])
    #xHat = torch.unsqueeze(xHat, -1)
    #xHat = xHat.view(input_shape + (-1, self.AA_COUNT))

    flat_input = torch.flatten(encoding, start_dim=1)
    log_PxGz = torch.sum(flat_input * xHat, -1)
    print(log_PxGz)
    print(KLD)
    print(weights)

    elbo = log_PxGz - KLD
    norm_weight = weights / torch.sum(weights)
    print(elbo, norm_weight)
    elbo = torch.sum(elbo * norm_weight)
    print(torch.sum(log_PxGz * norm_weight))
    print(torch.sum(KLD * norm_weight))
    print(elbo)
    

    # no weighting yet on KLD or recon
    #print(log_PxGz[:, :2])
    #elbo = log_PxGz - KLD
    #print(elbo.shape)
    #print(elbo[:, :2])
    
    #names.append(name[0])
    #z_vals.append(z_sample.detach().numpy())
    break 




tensor([-696.0979, -696.5344, -699.2644, -699.3359], dtype=torch.float64,
       grad_fn=<SumBackward1>)
tensor([0.0131, 0.0131, 0.0131, 0.0131], grad_fn=<SumBackward1>)
tensor([0.0039, 0.0041, 0.0625, 0.0625], dtype=torch.float64)
tensor([-696.1110, -696.5474, -699.2774, -699.3490], dtype=torch.float64,
       grad_fn=<SubBackward0>) tensor([0.0290, 0.0309, 0.4700, 0.4700], dtype=torch.float64)
tensor(-699.1216, dtype=torch.float64, grad_fn=<SumBackward0>)
tensor(0.0131, dtype=torch.float64, grad_fn=<SumBackward0>)
tensor(-699.1346, dtype=torch.float64, grad_fn=<SumBackward0>)


In [4]:

import torch.nn.functional as F
import evoVAE.utils.seq_tools as st
import evoVAE.utils.metrics as mt

count = 0
recons = []
# EVALUATE DIFFERENCES BETWEEN THE RECONSTRUCTIONS AND INPUT 
for id, z in zip(id_to_z['id'], id_to_z['z']):
    x_hat = model.decode(torch.tensor(z))
    x_hat.shape
    orig_shape = tuple(x_hat.shape[0:-1])
    
    x_hat = torch.unsqueeze(x_hat, -1)
    x_hat = x_hat.view(orig_shape + (-1, 21))

    print(x_hat[:, :1, :])

    SEQ_POSITIONS = 0
    indices = x_hat.max(dim=-1).indices.tolist()
    
    test = extants_aln[extants_aln['id'] == id]['sequence'].values[0]
    print(test)
    recon = "".join([st.GAPPY_PROTEIN_ALPHABET[x] for x in indices[SEQ_POSITIONS]])
    print(recon)
    print()
    recons.append(recon)
    count += 1
    if count == 5:
        break


In [9]:
recons[-5:]

['-SKGAELFTGVVPILVELDGDVNGHKFSVRGEGEGDATTGKLTLKFICTTGKLPVPWPTLVTTLTYGVLCFARYPDHMK-HDFFKSAMPEGYVQERTISFKDDGNYKTRAEVKFEGGTLVNRIELKGIDFKEDGNILGHKLEYNYNSHNVYILADKQKNGIKVNFNIRHNVEDGSVQLADHYQQNTPIGDGPVLLPDYHYLSTQ-ALSKDPNEKRDHMVLLEFVTAAGIT--------',
 '-SKGAELFTGVVPILVELDGDVNGHKFSVRGEGEGDATTGKLTLKFICTTGKLPVPWPTLVTTLTYGVLCFARYPDHMK-HDFFKSAMPEGYVQERTISFKDDGNYKTRAEVKFEGGTLVNRIELKGIDFKEDGNILGHKLEYNYNSHNVYILADKQKNGIKVNFNIRHNVEDGSVQLADHYQQNTPIGDGPVLLPDYHYLSTQ-ALSKDPNEKRDHMVLLEFVTAAGIT--------',
 '-SKGAELFTGVVPILVELDGDVNGHKFSVRGEGEGDATTGKLTLKFICTTGKLPVPWPTLVTTLTYGVLCFARYPDHMK-HDFFKSAMPEGYVQERTISFKDDGNYKTRAEVKFEGGTLVNRIELKGIDFKEDGNILGHKLEYNYNSHNVYILADKQKNGIKVNFNIRHNVEDGSVQLADHYQQNTPIGDGPVLLPDYHYLSTQ-ALSKDPNEKRDHMVLLEFVTAAGIT--------',
 '-SKGAELFTGVVPILVELDGDVNGHKFSVRGEGEGDATTGKLTLKFICTTGKLPVPWPTLVTTLTYGVLCFARYPDHMK-HDFFKSAMPEGYVQERTISFKDDGNYKTRAEVKFEGGTLVNRIELKGIDFKEDGNILGHKLEYNYNSHNVYILADKQKNGIKVNFNIRHNVEDGSVQLADHYQQNTPIGDGPVLLPDYHYLSTQ-ALSKDPNEKRDHMVLLEFVTAAGIT--------',
 '-SKGAELFTGVVPILVELDGDVNGHK

In [10]:
recons[:5]

['-SKGAELFTGVVPILVELDGDVNGHKFSVRGEGEGDATTGKLTLKFICTTGKLPVPWPTLVTTLTYGVLCFARYPDHMK-HDFFKSAMPEGYVQERTISFKDDGNYKTRAEVKFEGGTLVNRIELKGIDFKEDGNILGHKLEYNYNSHNVYILADKQKNGIKVNFNIRHNVEDGSVQLADHYQQNTPIGDGPVLLPDYHYLSTQ-ALSKDPNEKRDHMVLLEFVTAAGIT--------',
 '-SKGAELFTGVVPILVELDGDVNGHKFSVRGEGEGDATTGKLTLKFICTTGKLPVPWPTLVTTLTYGVLCFARYPDHMK-HDFFKSAMPEGYVQERTISFKDDGNYKTRAEVKFEGGTLVNRIELKGIDFKEDGNILGHKLEYNYNSHNVYILADKQKNGIKVNFNIRHNVEDGSVQLADHYQQNTPIGDGPVLLPDYHYLSTQ-ALSKDPNEKRDHMVLLEFVTAAGIT--------',
 '-SKGAELFTGVVPILVELDGDVNGHKFSVRGEGEGDATTGKLTLKFICTTGKLPVPWPTLVTTLTYGVLCFARYPDHMK-HDFFKSAMPEGYVQERTISFKDDGNYKTRAEVKFEGGTLVNRIELKGIDFKEDGNILGHKLEYNYNSHNVYILADKQKNGIKVNFNIRHNVEDGSVQLADHYQQNTPIGDGPVLLPDYHYLSTQ-ALSKDPNEKRDHMVLLEFVTAAGIT--------',
 '-SKGAELFTGVVPILVELDGDVNGHKFSVRGEGEGDATTGKLTLKFICTTGKLPVPWPTLVTTLTYGVLCFARYPDHMK-HDFFKSAMPEGYVQERTISFKDDGNYKTRAEVKFEGGTLVNRIELKGIDFKEDGNILGHKLEYNYNSHNVYILADKQKNGIKVNFNIRHNVEDGSVQLADHYQQNTPIGDGPVLLPDYHYLSTQ-ALSKDPNEKRDHMVLLEFVTAAGIT--------',
 '-SKGAELFTGVVPILVELDGDVNGHK