In [1]:
from evoVAE.utils.datasets import MSA_Dataset
import evoVAE.utils.seq_tools as st
from evoVAE.models.seqVAE import SeqVAE
from evoVAE.trainer.seqVAE_train import seq_train
from sklearn.model_selection import train_test_split
import pandas as pd
import torch
import wandb
from pathlib import Path
import os

#### Config

In [2]:
wandb.init(
    project="SeqVAE_training",

    # hyperparameters
    config = {

        # Dataset info
        "dataset": "PhoQ",
        "seq_theta": 0.2, # reweighting 
        "AA_count": 21, # standard AA + gap
        
        # ADAM 
        "learning_rate": 1e-5, # ADAM
        "weight_decay": 0.01, # ADAM

        # Hidden units 
        "momentum": 0.1, 
        "dropout": 0.5,

        # Training loop 
        "epochs": 100,
        "batch_size": 128,
        "max_norm": 1.0, # gradient clipping
        
        # Model info
        "architecture": "SeqVAE",
        "latent_dims": 2,
        "hidden_dims": [32, 16],
    }
)


config = wandb.config
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33msebastian-porras01[0m. Use [1m`wandb login --relogin`[0m to force relogin


#### Data loading and preprocessing

In [None]:

DATA_PATH = Path("/Users/sebs_mac/OneDrive - The University of Queensland/honours/data/phoQ/uniref90_search/nr65_filtering/odseq_tree/independent_runs/ancestors")

# Gather all the ancestor sequences into a single dataframe 
trees = []
for file in os.listdir(DATA_PATH):
    if file == "ancestor_trees":
        continue 
    run = st.read_aln_file(str(DATA_PATH) + "/" + file)
    run["tree"] = file.split("_")[1]
    trees.append(run)

ancestors = pd.concat(trees)
anc_encodings, anc_weights = st.encode_and_weight_seqs(ancestors["sequence"],theta=0.2)
ancestors["weights"] = anc_weights
#ancestors.to_pickle("phoQ_ancestors_weights.pkl")


In [None]:
# Next, drop N0 and N238 as they come from outgroups 
print(ancestors.shape)
flt_ancestors = ancestors.loc[(ancestors["id"] != "N0") & (ancestors["id"] != "N238")]
print(flt_ancestors.shape)

# Then remove non-unique sequences 
flt_unique_ancestors = flt_ancestors.drop_duplicates(subset="sequence")
flt_unique_ancestors


In [3]:
flt_unique_ancestors = st.read_aln_file("../data/alignments/tiny.aln")
anc_encodings, anc_weights = st.encode_and_weight_seqs(
    flt_unique_ancestors["sequence"], theta=config.seq_theta
)
flt_unique_ancestors["weights"] = anc_weights
flt_unique_ancestors["encodings"] = anc_encodings


train, val = train_test_split(flt_unique_ancestors, test_size=0.2)

# create one-hot encodings and calculate reweightings 

# TRAINING
train_dataset = MSA_Dataset(
    train["encodings"], train["weights"], train["id"]
)

# VALIDATION
val_dataset = MSA_Dataset(
    val["encodings"], val["weights"], val["id"]
)

# DATA LOADERS #
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=2, shuffle=False)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=2, shuffle=False)
next(iter(train_loader))[0].shape,next(iter(train_loader))[1].shape, next(iter(train_loader))[2]

Encoding the sequences and calculating weights
The sequence encoding has size: (3,)

The sequence weight array has size: (3,)



(torch.Size([2, 5, 21]),
 torch.Size([2]),
 ('H3RC00_PhoQ_UniRef90', 'A0A0J8VL97_PhoQ_UniRef90'))

#### Create the model

In [4]:
# get the sequence length 
seq_len = train_dataset[0][0].shape[0]
input_dims = seq_len * config.AA_count

# use preset structure for hidden dimensions 
model = SeqVAE(input_dims=input_dims, latent_dims=config.latent_dims, hidden_dims=config.hidden_dims, config=config) 
model

SeqVAE(
  (encoder): Sequential(
    (0): Sequential(
      (0): Linear(in_features=105, out_features=32, bias=True)
      (1): LeakyReLU(negative_slope=0.01)
      (2): Dropout(p=0.5, inplace=False)
      (3): Linear(in_features=32, out_features=32, bias=True)
      (4): LeakyReLU(negative_slope=0.01)
      (5): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): Sequential(
      (0): Linear(in_features=32, out_features=16, bias=True)
      (1): LeakyReLU(negative_slope=0.01)
      (2): Dropout(p=0.5, inplace=False)
      (3): Linear(in_features=16, out_features=16, bias=True)
      (4): LeakyReLU(negative_slope=0.01)
      (5): BatchNorm1d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (z_mu_sampler): Linear(in_features=16, out_features=2, bias=True)
  (z_logvar_sampler): Linear(in_features=16, out_features=2, bias=True)
  (upscale_z): Linear(in_features=2, out_features=16, bias=True)
  (decoder): Sequential

In [5]:

for i in train_loader:
    encoding, weight, name = i

    encoding = encoding.float()
    output = model.forward(encoding)
    print(encoding.shape, output[0].shape)
    loss, kl, likelihood = model.loss_function(output, encoding)
    print(loss, kl, likelihood)
    
 

torch.Size([2, 5, 21]) torch.Size([2, 5, 21])
tensor([[[ -4.1032,  -9.0638,  -6.2700,  -6.5977,  -5.1259,  -6.8972,  -7.2484,
           -5.8833,  -3.8101,  -4.6348,  -4.0005,  -4.3172,  -9.5180,  -8.0663,
           -6.2233,  -8.3761,  -4.0863,  -8.9076,  -5.6542,  -7.4537,  -5.0118],
         [ -5.6450,  -5.5792,  -6.6080,  -8.4601,  -5.6669,  -5.7337,  -3.2229,
           -8.0852,  -6.8210,  -5.7914,  -4.1736,  -6.7751,  -4.5791,  -5.2424,
           -5.6164,  -7.4926,  -7.0016,  -6.7585,  -6.4570,  -7.9730,  -3.9999],
         [ -7.1944,  -5.0256,  -7.3693,  -4.2791,  -7.5264,  -5.3015,  -7.8747,
           -4.1534,  -8.2535,  -4.0323,  -6.5572,  -7.7164,  -8.5688,  -4.5740,
           -4.7727,  -4.1469, -13.7790,  -7.1895,  -4.3674,  -4.0677,  -6.1627],
         [ -8.0979,  -6.7223,  -5.2084,  -5.7357,  -5.3518,  -7.1098,  -6.9843,
           -4.1962,  -5.3615,  -3.7610,  -3.6660,  -6.3922,  -7.8394,  -9.8178,
           -6.6478,  -4.9940,  -6.7737,  -5.6854,  -6.3827,  -5.3157,  

#### Training Loop

In [8]:
trained_model = seq_train(model, train_loader=train_loader, val_loader=val_loader, device=device, config=config)


tensor([[[-5.8081, -6.2357, -5.4158, -5.3432, -5.4545, -4.9457, -5.5165,
          -5.3137, -5.0300, -5.9836, -5.3193, -5.3024, -6.0173, -6.2752,
          -5.1734, -6.0566, -4.7416, -8.5482, -5.6060, -6.6750, -5.9671],
         [-4.8635, -5.1919, -5.8024, -9.7561, -5.8459, -5.1961, -5.5935,
          -6.2148, -6.0556, -4.9812, -5.0670, -6.1690, -5.3750, -5.1477,
          -5.5498, -5.9197, -5.2417, -6.1683, -5.8246, -5.4799, -5.3349],
         [-5.5036, -5.7419, -5.7478, -6.1066, -5.8866, -4.9317, -6.3637,
          -5.4781, -5.4055, -5.1837, -6.6488, -5.8844, -5.3809, -5.3406,
          -4.7881, -5.3848, -8.8061, -6.1969, -5.4482, -5.0065, -5.4447],
         [-5.3817, -5.7385, -5.2635, -5.5576, -5.8754, -5.7501, -5.5840,
          -4.8757, -5.4930, -5.7564, -5.6691, -5.4452, -5.5464, -6.2417,
          -5.9275, -5.0574, -8.0109, -5.2457, -6.1256, -5.5426, -6.1399],
         [-5.0365, -6.3796, -5.3018, -5.5279, -5.9323, -6.2842, -5.5067,
          -5.4895, -5.7609, -5.0937, -5.2449, -

In [9]:
wandb.finish()

0,1
epoch_ELBO,▆▅▆█▇▅▆▅█▄▃▅▃▄▄▃▄▃▃▄▃▄▇▃▄▅▂▂▂▂▄▂▂▁▆▁▁▂▂▂
epoch_Gauss_likelihood,▃▄▃▁▂▄▃▄▁▅▆▄▆▅▆▆▅▆▆▅▆▅▃▆▆▄▇▇▇▇▆█▇█▃██▇▇▇
epoch_KLD,▄▄▄▄▄▁▂▂▄▃▁▃▂▄▃▃▃▁▂▁▁▃█▅▅▃▂▁▂▃▇▄▃▂█▂▅▃▆▁
epoch_val_ELBO,▇█▄▄▄▇▄█▅▅▃▅▃▄▃▄▄▃▄▃▂▄▂▃▃▃▃▄▂▂▃▂▂▂▂▂▁▁▁▂
epoch_val_Gauss_likelihood,▂▁▅▅▅▂▅▁▄▄▆▄▆▅▆▅▅▆▅▆▇▅▇▆▆▆▇▆▇▇▆▇▇▇▇▇███▇
epoch_val_KLD,▂▃▂▁▁▃▁▃▂▂▃▁▂▂▁▃▂▁▃▁▁▂▁▁▂▁▅█▁▂▂▃▁▁▁▁▃▁▁▁

0,1
epoch_ELBO,601.24847
epoch_Gauss_likelihood,-601.40125
epoch_KLD,-0.15279
epoch_val_ELBO,601.9306
epoch_val_Gauss_likelihood,-602.08191
epoch_val_KLD,-0.15131
