In [14]:
from evoVAE.utils.datasets import MSA_Dataset
import evoVAE.utils.seq_tools as st
import evoVAE.utils.metrics as mt
from evoVAE.models.seqVAETest import SeqVAETest
from evoVAE.trainer.seqVAE_train import seq_train
from sklearn.model_selection import train_test_split
import pandas as pd
import torch
import numpy as np

pd.set_option("display.max_rows", None)

# Config

In [2]:
config = {
        # Dataset info
        "dataset": "playground",
        "seq_theta": 0.2, # reweighting 
        "AA_count": 21, # standard AA + gap
        
        # ADAM 
        "learning_rate": 1e-5, # ADAM
        "weight_decay": 0.01, # ADAM

        # Hidden units 
        "momentum": 0.1, 
        "dropout": 0.5,

        # Training loop 
        "epochs": 100,
        "batch_size": 128,
        "max_norm": 1.0, # gradient clipping
        
        # Model info
        "architecture": "SeqVAETest",
        "latent_dims": 2,
        "hidden_dims": [32, 16],
    }

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cpu')

# Read in data

In [3]:
from pathlib import Path
DATA_PATH = "/Users/sebs_mac/OneDrive - The University of Queensland/honours/data/gfp_alns/independent_runs/no_synthetic/ancestors/seqs/"
filepath = DATA_PATH + 'run_14_ancestors.fa'
aln = st.read_aln_file(filepath)
#aln


Reading the alignment: /Users/sebs_mac/OneDrive - The University of Queensland/honours/data/gfp_alns/independent_runs/no_synthetic/ancestors/seqs/run_14_ancestors.fa
Checking for bad characters: ['B', 'J', 'X', 'Z']
Performing one hot encoding
Number of seqs: 359


In [4]:

train, val = train_test_split(aln, test_size=0.2)


# TRAINING
train_dataset = MSA_Dataset(
    train["encoding"], train.index, train["id"]
)

# VALIDATION
val_dataset = MSA_Dataset(
    val["encoding"], val.index, val["id"]
)

# DATA LOADERS #
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=2, shuffle=False)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=2, shuffle=False)

print(len(train_loader), len(val_loader))
next(iter(train_loader))[0].shape

144 36


torch.Size([2, 238, 21])

# Build model

In [5]:
# get the sequence length 
SEQ_LEN = 0
BATCH_ZERO = 0
SEQ_ZERO = 0
seq_len = train_dataset[BATCH_ZERO][SEQ_ZERO].shape[SEQ_LEN]
input_dims = seq_len * config['AA_count']

seq_len, input_dims

# use preset structure for hidden dimensions 
model = SeqVAETest(input_dims=input_dims, latent_dims=config['latent_dims'], hidden_dims=config['hidden_dims'], config=config) 
model

SeqVAETest(
  (encoder): Sequential(
    (0): Sequential(
      (0): Linear(in_features=4998, out_features=32, bias=True)
      (1): LeakyReLU(negative_slope=0.01)
      (2): Dropout(p=0.5, inplace=False)
      (3): Linear(in_features=32, out_features=32, bias=True)
      (4): LeakyReLU(negative_slope=0.01)
      (5): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): Sequential(
      (0): Linear(in_features=32, out_features=16, bias=True)
      (1): LeakyReLU(negative_slope=0.01)
      (2): Dropout(p=0.5, inplace=False)
      (3): Linear(in_features=16, out_features=16, bias=True)
      (4): LeakyReLU(negative_slope=0.01)
      (5): BatchNorm1d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (z_mu_sampler): Linear(in_features=16, out_features=2, bias=True)
  (z_logvar_sampler): Linear(in_features=16, out_features=2, bias=True)
  (upscale_z): Linear(in_features=2, out_features=16, bias=True)
  (decoder): Seque

In [6]:
# very small nummy data 
dummy = next(iter(train_loader))[0].float()
print(dummy.shape)
# reconstruct input, note it has been flattened 
log_p, z_sample, z_mu, z_logvar = model(dummy)

# grab the shape of the input for reshaping
orig_shape = log_p.shape[0:-1]

# add on extra dim, then make it one-hot encoding shape (obs, seq_len, AA_count)
log_p = torch.unsqueeze(log_p, -1)
log_p = log_p.view(orig_shape + (-1, config['AA_count']))

torch.Size([2, 238, 21])


In [7]:
mut_data = pd.read_csv('../data/dms_data/GFP_AEQVI_Sarkisyan_2016.csv')
subset = mut_data.copy()[0:10]

In [8]:
encoding, weights = st.encode_and_weight_seqs(subset['mutated_sequence'], 0.2)
subset['encoding'] = encoding
subset['weights'] = weights
subset

Encoding the sequences and calculating weights
The sequence encoding has size: (10,)

The sequence weight array has size: (10,)



Unnamed: 0,mutant,mutated_sequence,DMS_score,DMS_score_bin,encoding,weights
0,K3R:V55A:Q94R:A110T:D117G:M153K:D216A,MSRGEELFTGVVPILVELDGDVNGHKFSVSGEGEGDATYGKLTLKF...,1.30103,0,"[[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,...",0.1
1,K3Q:V16A:I167T:L195Q,MSQGEELFTGVVPILAELDGDVNGHKFSVSGEGEGDATYGKLTLKF...,3.13735,1,"[[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,...",0.1
2,K3Q:Y143C:N164D:S205P:A227T,MSQGEELFTGVVPILVELDGDVNGHKFSVSGEGEGDATYGKLTLKF...,1.553913,0,"[[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,...",0.1
3,K3Q:Y143N:V193A,MSQGEELFTGVVPILVELDGDVNGHKFSVSGEGEGDATYGKLTLKF...,3.404237,1,"[[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,...",0.1
4,K3R,MSRGEELFTGVVPILVELDGDVNGHKFSVSGEGEGDATYGKLTLKF...,3.738586,1,"[[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,...",0.1
5,K3R:A87T:D173G,MSRGEELFTGVVPILVELDGDVNGHKFSVSGEGEGDATYGKLTLKF...,3.851893,1,"[[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,...",0.1
6,K3R:A87T:N144S:T225S,MSRGEELFTGVVPILVELDGDVNGHKFSVSGEGEGDATYGKLTLKF...,3.551648,1,"[[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,...",0.1
7,K3R:C48R:D76G:M218K,MSRGEELFTGVVPILVELDGDVNGHKFSVSGEGEGDATYGKLTLKF...,1.480047,0,"[[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,...",0.1
8,K3R:D102G:N185D:L195P:H231P:E235K,MSRGEELFTGVVPILVELDGDVNGHKFSVSGEGEGDATYGKLTLKF...,3.388889,1,"[[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,...",0.1
9,K3R:D102G:Y151C:N170D:I229T,MSRGEELFTGVVPILVELDGDVNGHKFSVSGEGEGDATYGKLTLKF...,3.499501,1,"[[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,...",0.1


In [9]:
metadata = pd.read_csv("../data/dms_data/DMS_substitutions.csv")
metadata = metadata[metadata["DMS_id"].str.contains("GFP")]
metadata

Unnamed: 0,DMS_id,DMS_filename,UniProt_ID,taxon,source_organism,target_seq,seq_len,includes_multiple_mutants,DMS_total_number_mutants,DMS_number_single_mutants,...,MSA_num_significant_L,raw_DMS_filename,raw_DMS_phenotype_name,raw_DMS_directionality,raw_DMS_mutant_column,weight_file_name,pdb_file,ProteinGym_version,raw_mut_offset,coarse_selection_type
67,GFP_AEQVI_Sarkisyan_2016,GFP_AEQVI_Sarkisyan_2016.csv,GFP_AEQVI,Eukaryote,Aequorea victoria,MSKGEELFTGVVPILVELDGDVNGHKFSVSGEGEGDATYGKLTLKF...,238,True,51714,1084,...,0.0,GFP_AEQVI_Sarkisyan_2016.csv,mean_medianBrightness_per_aaseq,1,mutant,GFP_AEQVI_theta_0.2.npy,GFP_AEQVI.pdb,0.1,,Activity


In [47]:

import sklearn
import sklearn.preprocessing


wild_type = metadata['target_seq'].to_numpy()[0]
wild_one_hot = torch.Tensor(st.seq_to_one_hot(wild_type)).unsqueeze(0)

model.eval()
wild_model_encoding, _, _, _ = model(wild_one_hot)
orig_shape = wild_model_encoding.shape[0:-1]

wild_model_encoding = torch.unsqueeze(wild_model_encoding, -1)
wild_model_encoding = wild_model_encoding.view(orig_shape + (-1, model.AA_COUNT))

# get the wild type encoding 
wild_model_encoding = wild_model_encoding.squeeze(0)
one_hot = wild_one_hot.squeeze(0)


variant_encodings = torch.Tensor(np.stack(subset['encoding'].values))
variant_model_outputs, _, _, _ = model(variant_encodings)

model_scores =[]
for variant, var_one_hot in zip(variant_model_outputs, variant_encodings):

    var_model_encoding = torch.unsqueeze(variant, -1)
    var_model_encoding = var_model_encoding.view(orig_shape + (-1, model.AA_COUNT))
    var_model_encoding = var_model_encoding.squeeze(0)
    log_prob = mt.seq_log_probability(var_one_hot, var_model_encoding)
    
    model_scores.append(log_prob)
    
model_scores = pd.Series(model_scores)
actual_scores = subset['DMS_score']
model_scores, actual_scores



(0   -723.424438
 1   -723.565063
 2   -722.048584
 3   -722.012512
 4   -722.300354
 5   -722.447144
 6   -722.689453
 7   -723.892822
 8   -722.536987
 9   -722.414734
 dtype: float64,
 0    1.301030
 1    3.137350
 2    1.553913
 3    3.404237
 4    3.738586
 5    3.851893
 6    3.551648
 7    1.480047
 8    3.388889
 9    3.499501
 Name: DMS_score, dtype: float64)

In [None]:
spear_rho, k_recall, ndcg, roc_auc = mt.summary_stats(predictions=model_scores, actual=actual_scores, actual_binned=subset['DMS_score_bin'])
spear_rho, k_recall, ndcg, roc_auc