Testing if the VAE Works:

In [1]:
# Changes working directory to root to allow imports to work

import os

def change_directory_to_root(root_name: str = "GDProject"):
    """Changes the directory to the root of the project"""
    current_folder = os.getcwd().split('/')[-1]
    if current_folder != root_name:
        os.chdir('..')

    print(f"New Current Directory is {os.getcwd()}")

change_directory_to_root()

New Current Directory is /home/alden/Research/GDProject


In [None]:
import util
import gdiffusion as gd

device = util.get_device(print_device=True)

Device: cuda


In [None]:
# VAE Test Function

def vae_test(vae: gd.MoleculeVAE | gd.PeptideVAE, string_name: str):
    # Sample random string from VAE to test if its working
    random_string = vae.sample_random(batch_size=4)

    print(f"--- Random {string_name} ---")
    for string in random_string:
        print(string)
    print("")

    print("Checking if VAE is stable. Pairs should match")
    for string in random_string:
        string_augmented = vae.decode(vae.encode(string))[0]
        print(f"Original    {string_name}: {string}")
        print(f"Transformed {string_name}: {string_augmented}", end='\n\n')

In [9]:
molecule_vae = gd.MoleculeVAE() # load from default path
vae_test(molecule_vae, string_name="SELFIES")


 Loading Molecule VAE:
------------------------------------------------
Loaded VAE Vocab from saved_models/selfies_vae/vocab.json
Getting State Dict...
Loading model from saved_models/selfies_vae/selfies-vae.ckpt
Enc params: 1,994,592
Dec params: 277,346
------------------------------------------------

--- Random SELFIES ---
[O][=C][N][C][C][Branch2][Ring1][Branch1][N][C][=C][C][=C][Branch1][Ring1][C][O][C][=Branch1][C][=O][O][N][C][Branch1][Ring1][C][=O][=N][C][Ring2][Ring1][Branch2][=O][N][C][=Branch1][=Branch2][=C][Ring2][Ring1][Branch2][C][=O][C][C][C][Ring2][Ring1][Ring1][=O]
[C][C][N][C][C][C][N][Branch1][Ring1][C][O][C][C][C][N][C][=C][O][C][Branch1][=Branch2][C][Branch1][C][C][Branch1][C][C][C][N][C][=Branch1][C][=O][C][=C][C][=C][C][=C][Ring1][=Branch1][C][C][C][C][C][Ring1][#Branch1]
[C][C][=C][C][=C][Branch1][#Branch2][C][=C][C][=C][C][=C][Ring1][=Branch1][Cl][C][=C][C][=C][Branch1][#Branch2][C][=C][C][Branch1][C][F][=C][Ring1][=Branch2][C][=N][C][Ring2][Ring1][Ring2]
[C][

In [10]:
peptide_vae = gd.PeptideVAE() # load from default path
vae_test(peptide_vae, string_name="Peptides")


 Loading Peptide VAE:
------------------------------------------------
Loaded VAE Vocab from saved_models/peptide_vae/vocab.json
Getting State Dict...
Loading model from saved_models/peptide_vae/peptide-vae.ckpt
Enc params: 2,675,904
Dec params: 360,349
------------------------------------------------

--- Random Peptides ---
PQREVHAHETELPILTEFILLKEFILLKMRLPTKGDRLPKWQRPTP
NAYLVGAYALRK
DRKKPDRKPALEALEKEAVERPPDLPFDL
VQYLDGSKMITIWLAAGGMLKLGGMLAGGHSDRPRQPETVGL

Checking if VAE is stable. Pairs should match


  state_dict = torch.load(vae_statedict_path, map_location=self.device)["state_dict"]


Original    Peptides: PQREVHAHETELPILTEFILLKEFILLKMRLPTKGDRLPKWQRPTP
Transformed Peptides: PQREVHAHETELPILTEFILLKEFILLKMRLPTKGDRLPKWQRPTP

Original    Peptides: NAYLVGAYALRK
Transformed Peptides: NAYLVGAYALRK

Original    Peptides: DRKKPDRKPALEALEKEAVERPPDLPFDL
Transformed Peptides: DRKKPDRKPALEALEKEAVERPPDLPFDL

Original    Peptides: VQYLDGSKMITIWLAAGGMLKLGGMLAGGHSDRPRQPETVGL
Transformed Peptides: VQYLDGSKMITIWLAAGGMLKLGGMLAGGHSDRPRQPETVGL

