Notebook for testing the VAE, loading the dataseet, and creating the latent space dataset

In [2]:
import pandas as pd
import torch
import numpy as np
from tqdm import tqdm

import json
import lightning.pytorch as pl
import torch
from torch.nn.utils.rnn import pad_sequence
import selfies as sf
import time

from molformers.models.BaseTrainer import VAEModule
from molformers.models.BaseVAESwiGLURope import BaseVAE
from typing import List, Union, Optional

torch.set_float32_matmul_precision("high")

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using: {device}")

Using: cuda


In [3]:
from load_and_sample import *

In [4]:
# test if the VAE is working
example_usage() 

loading model from ./saved_models/epoch=447-step=139328.ckpt
Enc params: 1,994,592
Dec params: 277,346


  state_dict = torch.load(path_to_vae_statedict, map_location=device)["state_dict"]


Latent shape: torch.Size([1, 128]), Loss: 0.0002240684552816674
Reconstructed: ['[C][C][C][O]']
Batch latent shape: torch.Size([5, 128])
Reconstructed batch: ['[C][C][C][C][=Branch1][C][=O][N][N][C][=Branch1][C][=O][N][C][=C][C][=C][C][=C][Ring1][=Branch1]', '[C][C][=Branch1][C][=O][N][C][C][C][C][Branch1][C][C][C][Branch2][Ring2][#Branch2][C][C][C][Branch1][C][C][C][Ring1][Branch2][C][=Branch1][C][=O][C][=C][C][C][Branch1][C][C][C][Branch1][C][C][C][C][C][Ring1][Branch2][Branch1][C][C][C][C][C][Ring1][=N][Ring2][Ring1][Ring1][C][C][Ring2][Ring1][=N][Branch1][C][C][C][=Branch1][C][=O][O]', '[C][C][=Branch1][C][=O][N][C][Branch1][C][C][C][C][=C][C][=C][Branch2][Ring1][Branch2][C][#C][C][=C][C][=N][C][Branch1][N][N][C][C][C][C][Branch1][C][F][C][Ring1][#Branch1][=N][Ring1][=N][C][=C][Ring2][Ring1][Branch1]', '[C][C][=C][C][=C][C][Branch2][Ring1][P][C][C][N][C][=Branch1][C][=O][C][C][C][C][=Branch1][C][=O][N][Branch1][=N][C][C][=C][C][=C][Branch1][C][Cl][C][=C][Ring1][#Branch1][C][Ring1][

In [5]:
# Convert the smiles dataset into a selfies one

def smiles_to_selfies(smiles_list : List[str]):
    """Converts SMILES (guacamol dataset format) to SELFIES (vae format)"""

    selfies_list = []

    for i, smiles in enumerate(smiles_list):
        try:
            selfies_string = sf.encoder(smiles.strip())
            selfies_list.append(selfies_string)
        except Exception as e:
            print(f"Error: {e}\n")
    
    return selfies_list

def selfies_to_smiles(selfies_list : List[str]):
    smiles_list = []

    for i, selfies in enumerate(selfies_list):
        try:
            smiles_string = sf.decoder(selfies.strip())
            smiles_list.append(smiles_string)
        except Exception as e:
            print(f"Error: {e}\n")
    
    return smiles_list


def smiles_to_selfies_file(input_smiles_file : str, output_selfies_file : str) -> List[str]:
    """Takes a smiles file and outputs a list of selfies, saves results"""

    with open(input_smiles_file, 'r') as f:
        smiles_list = [line.strip() for line in f if line.strip()]

    print(f"Converting {len(smiles_list)} strings to SMILES format...")

    selfies_list = smiles_to_selfies(smiles_list)
    df = pd.DataFrame({'SELFIES': selfies_list})

    if output_selfies_file:
        df.to_csv(output_selfies_file, index=False)

    return selfies_list


In [6]:
input_file = "C:/Users/2023r/Documents/GuidedDiffusionProject/guacamol_v1_train.smiles"

In [16]:
smiles_to_selfies_file(input_smiles_file=input_file, output_selfies_file="output_selfies.csv")

Converting 1273104 strings to SMILES format...
Error: input violates the currently-set semantic constraints
	SMILES: CC(=O)O[IH2]1NC(=O)c2ccccc21
	Errors:
	[[IH2] with 3 bond(s) - a max. of -1 bond(s) was specified]


Error: input violates the currently-set semantic constraints
	SMILES: CC(=O)OI1(OC(C)=O)(OC(C)=O)OC(=O)c2ccccc21
	Errors:
	[I with 5 bond(s) - a max. of 1 bond(s) was specified]


Error: input violates the currently-set semantic constraints
	SMILES: O=C1OI(=O)(O)c2ccccc21
	Errors:
	[I with 5 bond(s) - a max. of 1 bond(s) was specified]


Error: input violates the currently-set semantic constraints
	SMILES: CCC(=O)N[IH2]1OC(=O)c2ccccc21
	Errors:
	[[IH2] with 3 bond(s) - a max. of -1 bond(s) was specified]


Error: input violates the currently-set semantic constraints
	SMILES: [N-]=[N+]=N[IH2]1OC(=O)c2ccccc21
	Errors:
	[[IH2] with 3 bond(s) - a max. of -1 bond(s) was specified]


Error: input violates the currently-set semantic constraints
	SMILES: O=C(N[IH2]1OC(=O)c2ccccc2

['[C][C][C][Branch1][C][C][Branch1][C][C][Br]',
 '[C][C][C][N][Branch2][Ring1][Branch1][C][C][C][=C][C][=C][C][Branch1][=Branch2][C][=C][C][=C][C][=C][Ring1][=Branch1][=C][Ring1][N][C][=Branch1][C][=O][C][O][C][Branch1][=Branch1][C][=Branch1][C][=O][O][=C][C][Branch1][C][N][C][Ring1][#Branch2][N][C][Branch1][C][C][=O]',
 '[O][C][=C][C][=C][Branch2][Ring1][#C][C][C][C][Branch1][=Branch2][C][=C][C][=C][C][=C][Ring1][=Branch1][=N][N][Ring1][O][C][=Branch1][C][=S][N][C][=C][C][=C][C][=C][Ring1][=Branch1][C][=C][Ring2][Ring1][#Branch2]',
 '[C][C][Branch1][C][C][O][C][C][O][C][Branch1][S][C][O][C][Branch1][C][C][Branch1][C][C][O][C][Ring1][#Branch1][C][O][O][C][Branch1][C][C][Branch1][C][C][O][C][Ring1][S][C][Ring2][Ring1][Ring1][O][Ring2][Ring1][Branch2]',
 '[C][O][C][=Branch1][C][=O][C][=C][C][Branch2][Ring2][Branch1][C][=Branch1][O][=C][C][C][C][C][=Branch1][C][=O][S][C][C][=C][C][Branch1][C][Cl][=C][Branch1][Ring1][O][C][C][Branch1][#Branch1][C][=Branch1][C][=O][O][C][=C][Ring1][=N][=C][

In [7]:
# Now we want to convert all of the output_selfies.csv to their 128 dim latent codes

vae = load_vae_selfies("./saved_models/epoch=447-step=139328.ckpt")
selfies_df = pd.read_csv("output_selfies.csv")['SELFIES']

loading model from ./saved_models/epoch=447-step=139328.ckpt
Enc params: 1,994,592
Dec params: 277,346


In [None]:
selfies_latents = []
failed_selfies = []

for i, selfie in enumerate(tqdm(selfies_df, desc="Processing SELFIES")):
    try:
        latent = selfies_to_latent([selfie], vae=vae)
        selfies_latents.append(latent[0])
    except Exception as e:
        tqdm.write(f"Failed SELFIES at index {i}: {selfie[:50]}...")  # Truncate for readability
        tqdm.write(f"Error: {e}")
        failed_selfies.append((i, selfie, str(e)))

df = pd.DataFrame(selfies_latents)
df.to_csv("selfies_latents.csv", index=False)

if failed_selfies:
    print(f"\nTotal failed: {len(failed_selfies)}")
    failed_df = pd.DataFrame(failed_selfies, columns=['index', 'selfies', 'error'])
    failed_df.to_csv("failed_selfies.csv", index=False)

Processing SELFIES:  11%|█         | 141304/1273079 [22:07<4:19:28, 72.70it/s] 

Failed SELFIES at index 141295: [F][P-1][Branch1][C][F][Branch1][C][F][Branch1][C]...
Error: '[P-1]'


Processing SELFIES:  60%|█████▉    | 758786/1273079 [2:07:06<1:10:49, 121.01it/s]