In [5]:
import evoVAE.utils.seq_tools as st
import evoVAE.utils.metrics as mt
from evoVAE.models.seqVAE import SeqVAE
from typing import List, Tuple
from sklearn.model_selection import train_test_split
import pandas as pd
import torch
import numpy as np
import yaml
from evoVAE.loss.standard_loss import KL_divergence, sequence_likelihood, elbo_importance_sampling


This notebook can be used to test new features for a model without having to use the WandB service

In [2]:
with open("extant_config.yaml", "r") as stream:
    settings = yaml.safe_load(stream)

seq_len = 770 # A4 Human length 
input_dims = seq_len * settings["AA_count"]

model = SeqVAE(
    input_dims=input_dims,
    latent_dims=settings["latent_dims"],
    hidden_dims=settings["hidden_dims"],
    config=settings,
)

device = "cpu"
model.load_state_dict(torch.load("a4_extants_r1_model_state.pt", map_location=device))
model

SeqVAE(
  (encoder): Sequential(
    (0): Sequential(
      (0): Linear(in_features=16170, out_features=256, bias=True)
      (1): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): Dropout(p=0.025, inplace=False)
      (3): LeakyReLU(negative_slope=0.01)
    )
    (1): Sequential(
      (0): Linear(in_features=256, out_features=128, bias=True)
      (1): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): Dropout(p=0.025, inplace=False)
      (3): LeakyReLU(negative_slope=0.01)
    )
    (2): Sequential(
      (0): Linear(in_features=128, out_features=64, bias=True)
      (1): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): Dropout(p=0.025, inplace=False)
      (3): LeakyReLU(negative_slope=0.01)
    )
  )
  (z_mu_sampler): Linear(in_features=64, out_features=3, bias=True)
  (z_logvar_sampler): Linear(in_features=64, out_features=3, bias=True)
  (upscale_z): Sequ

In [3]:
metadata = pd.read_csv("../data/DMS_substitutions.csv")
dms_data = pd.read_csv("A4_HUMAN_Seuma_2022.csv")
one_hot = dms_data["mutated_sequence"].apply(st.seq_to_one_hot)
dms_data["encoding"] = one_hot



In [4]:
from evoVAE.utils.datasets import DMS_Dataset

dms_dataset = DMS_Dataset(dms_data["encoding"], dms_data["mutant"], dms_data["DMS_score"], dms_data["DMS_score_bin"])
dms_loader = torch.utils.data.DataLoader(
    dms_dataset, batch_size=1, shuffle=True
)

In [None]:
wild_type_hot = st.seq_to_one_hot(metadata[metadata["DMS_id"] == "A4_HUMAN_Seuma_2022"]["target_seq"].values[0])
# add dim to the front to allow model to process it
wild_type_hot = torch.Tensor(wild_type_hot)
wild_type_hot.shape

torch.Size([770, 21])

In [None]:

# encode the wild type
n_samples = 3

model.eval()
actual_fitness = []
actual_fitness_binned = []
predicted_fitness = []
ids = []
count = 0
with torch.no_grad():

    wt_elbo_mean = elbo_importance_sampling(model, wild_type_hot, n_samples)

    for variant_encoding, variant_id, score, score_bin in dms_loader:

        variant_encoding = variant_encoding.float().to(device)
        variant_elbo_mean = elbo_importance_sampling(
            model, variant_encoding, n_samples
        )

        pred_fitness = variant_elbo_mean - wt_elbo_mean

        predicted_fitness.append(pred_fitness.item())
        actual_fitness.append(score.item())
        actual_fitness_binned.append(score_bin.item())
        ids.append(variant_id[0])






torch.Size([3, 770, 21])
torch.Size([3, 770, 21])
torch.Size([3, 770, 21])
torch.Size([3, 770, 21])
torch.Size([3, 770, 21])


In [None]:
actual_fitness, predicted_fitness, actual_fitness_binned, ids[0]

([-5.13664821016175, -0.538036214384878, -5.29969629931686, -0.20128247286108],
 [-11.707142882755875,
  -3.5574731605132683,
  -16.21157337333318,
  -8.225805559365085],
 [0, 1, 0, 1],
 'G709V:I712L')

### Tanh VAE

In [8]:
from evoVAE.models.tanh_vae import tanhVAE

metadata = pd.read_csv("../data/DMS_substitutions.csv")
dms_data = pd.read_csv("A4_HUMAN_Seuma_2022.csv")
one_hot = dms_data["mutated_sequence"].apply(st.seq_to_one_hot)
dms_data["encoding"] = one_hot

wild_type_hot = st.seq_to_one_hot(metadata[metadata["DMS_id"] == "A4_HUMAN_Seuma_2022"]["target_seq"].values[0])
# add dim to the front to allow model to process it
wild_type_hot = torch.Tensor(wild_type_hot)
wild_type_hot.shape

torch.Size([770, 21])

In [91]:
path = "/Users/sebs_mac/uni_OneDrive/honours/data/gfp_alns/independent_runs/no_synthetic/alns/"

ancestors_extants_aln = pd.read_pickle(path + "GFP_AEQVI_full_04-29-2022_b08_ancestors_extants_no_syn_no_dupes.pkl")

numpy_aln, _, _ = st.convert_msa_numpy_array(ancestors_extants_aln)
weights = st.position_based_seq_weighting(numpy_aln, n_processes=8)
#weights = st.reweight_by_seq_similarity(numpy_aln, 0.2)
ancestors_extants_aln["weights"] = weights
# one-hot encode
one_hot = ancestors_extants_aln["sequence"].apply(st.seq_to_one_hot)
ancestors_extants_aln["encoding"] = one_hot

Sequence weight numpy array created with shape (num_seqs, columns):  (673, 238)


In [97]:
from evoVAE.utils.datasets import MSA_Dataset

ancestors_extants_aln.head()

train, val = train_test_split(ancestors_extants_aln, test_size=0.2)



train_dataset = MSA_Dataset(train["encoding"], train["weights"], train["id"])
train_loader = torch.utils.data.DataLoader(
    train_dataset, batch_size=10, shuffle=True
)

In [98]:


seq_len = 238 # A4 Human length 
input_dims = seq_len * 21 # Protein alpha + gap char 

model = tanhVAE(dim_latent_vars=2, dim_msa_vars=input_dims, num_hidden_units=[150, 150])
model

tanhVAE(
  (encoder_linears): ModuleList(
    (0): Linear(in_features=4998, out_features=150, bias=True)
    (1): Linear(in_features=150, out_features=150, bias=True)
  )
  (encoder_mu): Linear(in_features=150, out_features=2, bias=True)
  (encoder_logsigma): Linear(in_features=150, out_features=2, bias=True)
  (decoder_linears): ModuleList(
    (0): Linear(in_features=2, out_features=150, bias=True)
    (1): Linear(in_features=150, out_features=150, bias=True)
    (2): Linear(in_features=150, out_features=4998, bias=True)
  )
)

In [125]:

for x, weights, _ in train_loader:
    x = torch.flatten(x, start_dim=1)
    mu, sigma  = model.encoder(x)
    eps = torch.randn_like(sigma)
    z = mu + sigma * eps

    log_p = model.decoder(z)
    log_PxGz = torch.sum(x * log_p, -1)

    c = 1 / 2

    # compute elbo
    elbo = log_PxGz - torch.sum(
        c * (sigma**2 + mu**2 - 2 * torch.log(sigma) - 1), -1
    )

    weight = weights / torch.sum(weights)
    print(weight)
    print(elbo)
    print(elbo * weight)
    elbo = torch.sum(elbo * weight)
    print(elbo.item() * -1)

    break


tensor([0.0301, 0.0216, 0.0402, 0.0387, 0.0572, 0.1265, 0.0174, 0.0838, 0.0202,
        0.5645], dtype=torch.float64)
tensor([-728.7629, -727.6421, -726.7874, -729.8826, -731.9568, -724.7950,
        -728.0815, -728.0460, -726.5470, -728.5621], dtype=torch.float64,
       grad_fn=<SubBackward0>)
tensor([ -21.9687,  -15.7301,  -29.1814,  -28.2138,  -41.8637,  -91.6569,
         -12.6578,  -60.9740,  -14.6451, -411.2620], dtype=torch.float64,
       grad_fn=<MulBackward0>)
728.1536274991734


In [124]:
torch.sum(log_p * x, -1)

tensor([-729.8930, -726.8428, -728.5695, -726.7931, -722.4523, -723.6292,
        -731.6121, -724.1678, -724.1612, -728.6300], dtype=torch.float64,
       grad_fn=<SumBackward1>)

In [103]:
wt = metadata[metadata["DMS_id"] == "A4_HUMAN_Seuma_2022"]["target_seq"].values[0]
# add dim to the front to allow model to process it
wild_one_hot = torch.Tensor(st.seq_to_one_hot(wt)).unsqueeze(0).float()

In [38]:
wild_one_hot.shape

torch.Size([1, 770, 21])

In [43]:
wild_one_hot = wild_one_hot.expand(10, -1, -1)
wild_one_hot.shape

torch.Size([10, 770, 21])

In [45]:
torch.flatten(wild_one_hot, start_dim=1).shape

torch.Size([10, 16170])