In [40]:
import moses
import matplotlib as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import pandas as pd
from rdkit import Chem
from rdkit.Chem import AllChem
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split

In [41]:
df_train = pd.read_csv("dataset/train.txt", sep=',')
df_train.head()

Unnamed: 0,SMILES,SPLIT
0,CCCS(=O)c1ccc2[nH]c(=NC(=O)OC)[nH]c2c1,train
1,CC(C)(C)C(=O)C(Oc1ccc(Cl)cc1)n1ccnc1,train
2,Cc1c(Cl)cccc1Nc1ncccc1C(=O)OCC(O)CO,train
3,Cn1cnc2c1c(=O)n(CC(O)CO)c(=O)n2C,train
4,CC1Oc2ccc(Cl)cc2N(CC(O)CO)C1=O,train


In [42]:
# VALID_CHARS = list("@=#$()%1234567890abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ[]\\+-/.:")
# char_to_idx = {c: i for i, c in enumerate(VALID_CHARS)}
# idx_to_char = {i: c for c, i in char_to_idx.items()}

In [43]:
# Dataset class for SMILES strings
class SMILESDataset(Dataset):
    def __init__(self, smiles_list, max_length=150, char_to_idx=None):
        self.smiles_list = smiles_list
        self.max_length = max_length

        if char_to_idx is None:
            raise ValueError("Please provide a fixed character-to-index mapping")
            # self.char_to_idx, self.idx_to_char = build_vocabulary(smiles_list)
        else:
            self.char_to_idx = char_to_idx
            self.idx_to_char = {v: k for k, v in char_to_idx.items()}

        self.vocab_size = len(self.char_to_idx)

        original_count = len(smiles_list)
        filtered = []
        invalid_count = 0

        for s in smiles_list:
            s = s.strip()
            if all(c in self.char_to_idx for c in s):
                filtered.append(s)
            else:
                invalid_count += 1
        print(f"Total: {original_count}, Valid: {len(filtered)}, Invalid: {invalid_count}")
        self.smiles_list = filtered

    def __len__(self):
        return len(self.smiles_list)

    def __getitem__(self, idx):
        smiles = self.smiles_list[idx]
        # One-hot encode the SMILES string
        encoded = torch.zeros(self.max_length, self.vocab_size)
        for i, char in enumerate(smiles[:self.max_length]):
            encoded[i, self.char_to_idx[char]] = 1.0

        # # Pad with zeros
        # if len(smiles) < self.max_length:
        #     encoded[len(smiles):, :] = 0.0

        return encoded.view(-1) #Flatten into 1D tensor

In [44]:
class VAE(nn.Module):
    def __init__(self, input_dim, latent_dim):
        super(VAE, self).__init__()
        self.input_dim = input_dim
        self.latent_dim = latent_dim
        self.dropout = nn.Dropout(p=0.2)

        # Encoder
        self.fc1 = nn.Linear(input_dim, 1024)  # Increased input layer size
        self.fc2 = nn.Linear(1024, 512)  # Added an extra hidden layer
        self.fc3 = nn.Linear(512, 256)  # Added another extra hidden layer for more complexity
        self.fc21 = nn.Linear(256, latent_dim)  # Mean of latent distribution
        self.fc22 = nn.Linear(256, latent_dim)  # Log variance of latent distribution

        # Decoder
        self.fc4 = nn.Linear(latent_dim, 256)  # Mirroring Encoder structure
        self.fc5 = nn.Linear(256, 512) # Mirroring Encoder structure
        self.fc6 = nn.Linear(512, 1024) # Mirroring Encoder structure
        self.fc7 = nn.Linear(1024, input_dim)  # Output layer

    # Note for later, changed the architecture to add dropout
    def encode(self, x):
        h1 = self.dropout(F.relu(self.fc1(x)))
        h2 = self.dropout(F.relu(self.fc2(h1)))
        h3 = self.dropout(F.relu(self.fc3(h2)))
        return self.fc21(h3), self.fc22(h3)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return mu + eps*std

    def decode(self, z):
        h4 = self.dropout(F.relu(self.fc4(z)))
        h5 = self.dropout(F.relu(self.fc5(h4)))
        h6 = self.dropout(F.relu(self.fc6(h5)))
        return torch.sigmoid(self.fc7(h6))


    def forward(self, x):
        mu, logvar = self.encode(x.view(-1, self.input_dim))
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar


In [45]:
def vae_loss(recon_x, x, mu, logvar, beta=0.01):
    BCE = F.binary_cross_entropy_with_logits(recon_x, x.view(-1, recon_x.size(1)), reduction='mean')
    # mean seemed to do better
    KL = -0.5 * torch.mean(1 + logvar - mu.pow(2) - logvar.exp())
    return BCE + KL

    #return BCE + beta * KL
# add KL annealing factor to hlelp in slowing learning and avoid KL divergence dominating loss early?
# or binary cross entropy with logits to handle the loss
# binary_cross_entropy vs BCE with logits

In [46]:
import csv

def load_smiles_from_csv(path, split_type='train'):
    smiles = []
    with open(path, 'r') as f:
        reader = csv.DictReader(f)
        for row in reader:
            if row['SPLIT'].strip().lower() == split_type:
                smiles.append(row['SMILES'].strip())
    return smiles

In [47]:
def extract_unique_chars(smiles_list):
    unique_chars = set()
    for smiles in smiles_list:
        unique_chars.update(smiles.strip())
    return sorted(unique_chars)

In [52]:
def clean_smiles(smiles):
    # Remove unwanted metadata like ",train" or ",SPLIT"
    return smiles.split(',')[0].strip()

In [49]:
def decode_smiles(one_hot_tensor, idx_to_char):
    smiles = ''
    one_hot_tensor = one_hot_tensor.view(-1, len(idx_to_char))  # unflatten
    for row in one_hot_tensor:
        idx = row.argmax().item()
        smiles += idx_to_char[idx]
    return smiles.strip()

In [None]:
# normalize the data?

# Load SMILES strings
with open('dataset/train.txt', 'r') as f:
    smiles_train = [line.strip() for line in f]

with open('dataset/test.txt', 'r') as f:
    smiles_test = [line.strip() for line in f]

# Apply cleaning to your SMILES
smiles_train = [clean_smiles(smiles) for smiles in smiles_train]
smiles_test = [clean_smiles(smiles) for smiles in smiles_test]


# smiles_train = load_smiles_from_csv('dataset/train.txt', split_type='train')
# smiles_test = load_smiles_from_csv('dataset/test.txt', split_type='test')  # if test rows are in same file

# print(f"Raw SMILES loaded: train={len(smiles_train)}, test={len(smiles_test)}")
all_smiles = smiles_train + smiles_test
unique_chars = extract_unique_chars(all_smiles)

print(f"Total unique characters: {len(unique_chars)}")
print("Unique characters in dataset:")
print(unique_chars)

# Use extracted unique characters to rebuild vocabulary
VALID_CHARS = unique_chars
char_to_idx = {c: i for i, c in enumerate(VALID_CHARS)}
idx_to_char = {i: c for c, i in char_to_idx.items()}

# Create datasets
train_dataset = SMILESDataset(smiles_train, max_length=50, char_to_idx=char_to_idx)
test_dataset = SMILESDataset(smiles_test, max_length=50, char_to_idx=char_to_idx)
print("Training Vocabulary Size:", train_dataset.vocab_size)
print("Test Vocabulary Size:", test_dataset.vocab_size) # Should be the same


print(f"# Train SMILES after filtering: {len(train_dataset)}")
print(f"# Test SMILES after filtering: {len(test_dataset)}")
# train_dataset = SMILESDataset(smiles_train)
# test_dataset = SMILESDataset(smiles_test, char_to_idx=train_dataset.char_to_idx)  # Share vocabulary

# Create data loaders
train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=8, shuffle=False)  # No need to shuffle test data

print(f"Number of batches in train_loader: {len(train_loader)}")
print(f"Number of batches in test_loader: {len(test_loader)}")

Total unique characters: 30
Unique characters in dataset:
['#', '(', ')', '-', '1', '2', '3', '4', '5', '6', '=', 'B', 'C', 'E', 'F', 'H', 'I', 'L', 'M', 'N', 'O', 'S', '[', ']', 'c', 'l', 'n', 'o', 'r', 's']
Total: 1584664, Valid: 1584664, Invalid: 0
Total: 176075, Valid: 176075, Invalid: 0
Training Vocabulary Size: 30
Test Vocabulary Size: 30
# Train SMILES after filtering: 1584664
# Test SMILES after filtering: 176075
Number of batches in train_loader: 198083
Number of batches in test_loader: 22010


In [58]:
# Check a batch of data
for i, data in enumerate(train_loader):
    if i == 0:  # Just visualize the first batch
        print(data)
        break

tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]])


In [None]:
# import json

# # Save vocab to JSON
# vocab_path = 'char_vocab.json'
# with open(vocab_path, 'w') as f:
#     json.dump({
#         'char_to_idx': char_to_idx,
#         'idx_to_char': idx_to_char
#     }, f)

# print(f"Vocabulary saved to {vocab_path}")

# call with code below

# with open('char_vocab.json', 'r') as f:
#     vocab = json.load(f)
#     char_to_idx = vocab['char_to_idx']
#     idx_to_char = {int(k): v for k, v in vocab['idx_to_char'].items()}


In [54]:
# Visualize 3 samples
print("\nSample SMILES visualizations:")
for i in range(3):
    encoded = train_dataset[i]
    original = train_dataset.smiles_list[i]
    decoded = decode_smiles(encoded, train_dataset.idx_to_char)

    print(f"\nSample {i+1}")
    print(f"Original : {original}")
    print(f"Decoded  : {decoded}")
    print(f"Shape    : {encoded.shape}")


Sample SMILES visualizations:

Sample 1
Original : SMILES
Decoded  : SMILES############################################
Shape    : torch.Size([1500])

Sample 2
Original : CCCS(=O)c1ccc2[nH]c(=NC(=O)OC)[nH]c2c1
Decoded  : CCCS(=O)c1ccc2[nH]c(=NC(=O)OC)[nH]c2c1############
Shape    : torch.Size([1500])

Sample 3
Original : CC(C)(C)C(=O)C(Oc1ccc(Cl)cc1)n1ccnc1
Decoded  : CC(C)(C)C(=O)C(Oc1ccc(Cl)cc1)n1ccnc1##############
Shape    : torch.Size([1500])


In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Instantiate the VAE model
input_dim = train_dataset.vocab_size * train_dataset.max_length  # Flatten the input (max_length x vocab_size)
latent_dim = 128

vocab_size = train_dataset.vocab_size
max_length = train_dataset.max_length

print("Vocab size:", train_dataset.vocab_size)
print("max_length:", train_dataset.max_length)
print("Input dim:", input_dim)

vae = VAE(input_dim, latent_dim)
vae.to(device)

# Optimizer
optimizer = torch.optim.Adam(vae.parameters(), lr=1e-5)

# # Learning rate scheduler
# scheduler = StepLR(optimizer, step_size=10, gamma=0.1)

# Training and Evaluation loop
epochs = 100
for epoch in range(epochs):
    vae.train()  # Set model to training mode
    train_loss = 0
    for data in train_loader:  # Iterate over training data
        optimizer.zero_grad()
        # Flatten the input here before passing to the model
        data = data.view(-1, input_dim).to(device)
        recon_batch, mu, logvar = vae(data)
        #print(f"Reconstructed output: {recon_batch[:5]}") # testing
        #break
        loss = vae_loss(recon_batch, data, mu, logvar)
        #print(f"batch loss: {loss.item()}") # for testing
        loss.backward()
        train_loss += loss.item()
        optimizer.step()

    print(f'Epoch [{epoch+1}/{epochs}], Train Loss: {train_loss/len(train_loader)}')

    # Evaluation on test set
    vae.eval()  # Set model to evaluation mode
    test_loss = 0
    with torch.no_grad():  # No need to calculate gradients during evaluation
        for data in test_loader:  # Iterate over test data
            # Flatten the input here as well
            data = data.view(-1, input_dim).to(device)
            recon_batch, mu, logvar = vae(data)
            loss = vae_loss(recon_batch, data, mu, logvar)
            test_loss += loss.item()

    print(f'Epoch [{epoch+1}/{epochs}], Test Loss: {test_loss/len(test_loader)}')

Vocab size: 30
max_length: 50
Input dim: 1500


KeyboardInterrupt: 

In [None]:
# Generate a new molecule from VAE by sampling from the latent space
def generate_smiles(model, latent_dim=64, idx_to_char=None, temperature=1.0):
    z = torch.randn(1, latent_dim).to(model.fc1.weight.device)  # Ensure z is on the same device as the model
    with torch.no_grad():
        generated = model.decode(z)  # Use decode instead of decoder
    # Add postprocessing to convert to SMILES
    # generated_tokens_indices = torch.argmax(generated, dim=-1).cpu().numpy().flatten()
    probs = F.softmax(generated / temperature, dim=-1)

    # Sample the next character from the probability distribution
    generated_tokens_indices = torch.multinomial(probs, 1).cpu().numpy().flatten()
    
    # Print generated tokens and indices for debugging
    print("Generated tokens indices:", generated_tokens_indices)
    print("Generated tokens:", [idx_to_char.get(i, "<UNK>") for i in generated_tokens_indices])
    
    # Iterate through indices to build the SMILES string
    generated_smiles = "".join([idx_to_char.get(i, "") for i in generated_tokens_indices])

    return generated_smiles

In [None]:
# Generate a new molecule from VAE by sampling from the latent space
generated_smiles = generate_smiles(vae, latent_dim, train_dataset.idx_to_char)  # pass idx_to_char

print(f"Generated SMILES: {generated_smiles}")

Generated tokens indices: [2329]
Generated tokens: ['<UNK>']
Generated SMILES: 
