In [None]:
from evoVAE.utils.datasets import MSA_Dataset
import evoVAE.utils.seq_tools as st
from evoVAE.models.seqVAE import SeqVAE
from evoVAE.trainer.seqVAE_train import seq_train
from sklearn.model_selection import train_test_split
import pandas as pd
import torch
import wandb
from pathlib import Path
import os

#### Config

In [None]:
wandb.init(
    project="SeqVAE_training",

    # hyperparameters
    config = {

        # Dataset info
        "dataset": "PhoQ",
        "seq_theta": 0.2, # reweighting 
        "AA_count": 21, # standard AA + gap
        
        # ADAM 
        "learning_rate": 1e-5, # ADAM
        "weight_decay": 0.01, # ADAM

        # Hidden units 
        "momentum": 0.9, 
        "dropout": 0.5,

        # Training loop 
        "epochs": 1,
        "batch_size": 2,
        "max_norm": 1.0, # gradient clipping
        
        # Model info
        "architecture": "SeqVAE",
        "latent_dims": 2,
        "hidden_dims": [32, 16],
    }
)


config = wandb.config
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

#### Data loading and preprocessing

In [None]:

DATA_PATH = Path("/Users/sebs_mac/OneDrive - The University of Queensland/honours/data/phoQ/uniref90_search/nr65_filtering/odseq_tree/independent_runs/ancestors")

# Gather all the ancestor sequences into a single dataframe 
trees = []
for file in os.listdir(DATA_PATH):
    if file == "ancestor_trees":
        continue 
    run = st.read_aln_file(str(DATA_PATH) + "/" + file)
    run["tree"] = file.split("_")[1]
    trees.append(run)

ancestors = pd.concat(trees)
#ancestors.to_pickle("phoQ_ancestors.pkl")


In [None]:
# Next, drop N0 and N238 as they come from outgroups 
print(ancestors.shape)
flt_ancestors = ancestors.loc[(ancestors["id"] != "N0") & (ancestors["id"] != "N238")]
print(flt_ancestors.shape)

# Then remove non-unique sequences 
flt_unique_ancestors = flt_ancestors.drop_duplicates(subset="sequence")
flt_unique_ancestors


In [None]:
train, val = train_test_split(flt_unique_ancestors, test_size=0.2)

# create one-hot encodings and calculate reweightings 

# TRAINING 
train_encodings, train_weights = st.encode_and_weight_seqs(train["sequence"],theta=config.seq_theta)
train_ids = train["id"].values # just the seq identifiers 
train_dataset = MSA_Dataset(train_encodings, train_weights, train_ids)

# VALIDATION
val_encodings, val_weights = st.encode_and_weight_seqs(val["sequence"], theta=config.seq_theta)
val_ids = val["id"].values
val_dataset = MSA_Dataset(val_encodings, val_weights, val_ids)


# DATA LOADERS #
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=config.batch_size, shuffle=False)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=config.batch_size, shuffle=False)
#next(iter(train_loader))[0].shape,next(iter(train_loader))[1].shape, next(iter(train_loader))[2]

In [None]:
#encoding, weights, id = train_dataset[0]
#print(encoding.shape, weights, id)

# translation = st.one_hot_to_seq(encoding)
# print(translation)

#### Create the model

In [None]:
# get the sequence length 
seq_len = train_dataset[0][0].shape[0]
input_dims = seq_len * config.AA_count

# use preset structure for hidden dimensions 
model = SeqVAE(input_dims=input_dims, latent_dims=config.latent_dims, hidden_dims=config.hidden_dims, config=config) 
# model

#### Training Loop

In [None]:
trained_model = seq_train(model, train_loader=train_loader, val_loader=val_loader, device=device, config=config)
wandb.finish()