# GLOW Model for Generating Molecular SMILES Strings

This notebook implements a GLOW (Generative Latent Optimization with Wasserstein GANs) model to generate new, valid molecular structures in SMILES format based on a training dataset.

## 1. Import Dependencies

In [39]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import pandas as pd
from rdkit import Chem
from rdkit.Chem import AllChem
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split

## 2. Define Hyperparameters

In [54]:
# Hyperparameters
LATENT_DIM = 128
HIDDEN_DIM = 512
NUM_FLOWS = 8
BATCH_SIZE = 10
LEARNING_RATE = 0.0001
NUM_EPOCHS = 50

## 3. Dataset Processing

Build vocabulary for SMILES strings and create a PyTorch dataset class.

In [5]:
# Define character vocabulary for SMILES strings
def build_vocabulary(smiles_list):
    chars = set()
    for smiles in smiles_list:
        chars.update(list(smiles))
    return {c: i for i, c in enumerate(sorted(chars))}, {i: c for i, c in enumerate(sorted(chars))}

In [6]:
# Dataset class for SMILES strings
class SMILESDataset(Dataset):
    def __init__(self, smiles_list, max_length=150, char_to_idx=None):
        self.smiles_list = smiles_list
        self.max_length = max_length
        
        if char_to_idx is None:
            self.char_to_idx, self.idx_to_char = build_vocabulary(smiles_list)
        else:
            self.char_to_idx = char_to_idx
            self.idx_to_char = {v: k for k, v in char_to_idx.items()}
        
        self.vocab_size = len(self.char_to_idx)
        
    def __len__(self):
        return len(self.smiles_list)
    
    def __getitem__(self, idx):
        smiles = self.smiles_list[idx]
        # One-hot encode the SMILES string
        encoded = torch.zeros(self.max_length, self.vocab_size)
        for i, char in enumerate(smiles[:self.max_length]):
            encoded[i, self.char_to_idx[char]] = 1.0
        
        # Pad with zeros
        if len(smiles) < self.max_length:
            encoded[len(smiles):, :] = 0.0
            
        return encoded

## 4. Implementation of the GLOW Model Components

### 4.1 Activation Normalization Layer

In [7]:
# Define ActNorm layer (used in GLOW)
class ActNorm(nn.Module):
    def __init__(self, channels):
        super(ActNorm, self).__init__()
        self.loc = nn.Parameter(torch.zeros(1, channels, 1))
        self.scale = nn.Parameter(torch.ones(1, channels, 1))
        self.initialized = False

    def initialize(self, x):
        with torch.no_grad():
            flatten = x.permute(1, 0, 2).contiguous().view(x.shape[1], -1)
            mean = flatten.mean(1).view(1, -1, 1)
            std = flatten.std(1).view(1, -1, 1)
            
            self.loc.data.copy_(-mean)
            self.scale.data.copy_(1 / (std + 1e-6))
            
    def forward(self, x, ldj=None, reverse=False):
        if not self.initialized:
            self.initialize(x)
            self.initialized = True
            
        if reverse:
            return (x - self.loc) / self.scale, ldj
        else:
            z = self.scale * x + self.loc
            if ldj is not None:
                ldj = ldj + torch.sum(torch.log(torch.abs(self.scale)))
            return z, ldj
    
    def reverse(self, z, ldj=None):
        return self.forward(z, ldj, reverse=True)

### 4.2 Invertible 1x1 Convolution

In [8]:
# Invertible 1x1 Convolution
class Invertible1x1Conv(nn.Module):
    def __init__(self, channels):
        super(Invertible1x1Conv, self).__init__()
        self.channels = channels
        
        # Initialize with random rotation matrix
        w_init = torch.qr(torch.randn(channels, channels))[0]
        self.weight = nn.Parameter(w_init)
        
    def forward(self, x, ldj=None, reverse=False):
        batch_size, _, seq_len = x.size()
        
        if reverse:
            weight = torch.inverse(self.weight)
        else:
            weight = self.weight
            
        z = F.conv1d(x, weight.unsqueeze(2))
        
        if ldj is not None:
            ldj = ldj + seq_len * torch.slogdet(weight)[1]
            
        return z, ldj
    
    def reverse(self, z, ldj=None):
        return self.forward(z, ldj, reverse=True)

### 4.3 Affine Coupling Layer

In [30]:
# Affine Coupling Layer
class AffineCoupling(nn.Module):
    def __init__(self, channels, hidden_dim):
        super(AffineCoupling, self).__init__()
        
        self.net = nn.Sequential(
            nn.Conv1d(channels // 2, hidden_dim, 3, padding=1),
            nn.ReLU(),
            nn.Conv1d(hidden_dim, hidden_dim, 1),
            nn.ReLU(),
            nn.Conv1d(hidden_dim, channels, 3, padding=1)
        )
        
        # Initialize last layer with zeros for stability
        self.net[-1].weight.data.zero_()
        self.net[-1].bias.data.zero_()
        
        self.channels = channels
        
    def forward(self, x, ldj=None, reverse=False):
        x_a, x_b = x.chunk(2, dim=1)
        
        out = self.net(x_a)
        log_s, t = out.chunk(2, dim=1)
        
        # Constrain scaling
        log_s = torch.tanh(log_s) * 0.5
        s = torch.exp(log_s)
        
        if reverse:
            z_b = (x_b - t) / s
            z = torch.cat([x_a, z_b], dim=1)
            
            if ldj is not None:
                ldj = ldj - torch.sum(log_s)
        else:
            z_b = s * x_b + t
            z = torch.cat([x_a, z_b], dim=1)
            
            if ldj is not None:
                ldj = ldj + torch.sum(log_s)
                
        return z, ldj
    
    def reverse(self, z, ldj=None):
        return self.forward(z, ldj, reverse=True)

### 4.4 Single Flow Step

In [10]:
# Single Flow Step
class FlowStep(nn.Module):
    def __init__(self, channels, hidden_dim):
        super(FlowStep, self).__init__()
        
        self.actnorm = ActNorm(channels)
        self.conv = Invertible1x1Conv(channels)
        self.coupling = AffineCoupling(channels, hidden_dim)
        
    def forward(self, x, ldj=None, reverse=False):
        if reverse:
            x, ldj = self.coupling.reverse(x, ldj)
            x, ldj = self.conv.reverse(x, ldj)
            x, ldj = self.actnorm.reverse(x, ldj)
        else:
            x, ldj = self.actnorm(x, ldj)
            x, ldj = self.conv(x, ldj)
            x, ldj = self.coupling(x, ldj)
            
        return x, ldj
    
    def reverse(self, z, ldj=None):
        return self.forward(z, ldj, reverse=True)

### 4.5 Complete GLOW Model

In [11]:
# Complete GLOW model
class GLOW(nn.Module):
    def __init__(self, in_channels, hidden_dim, num_flows):
        super(GLOW, self).__init__()
        
        self.flows = nn.ModuleList([
            FlowStep(in_channels, hidden_dim) for _ in range(num_flows)
        ])
        
        self.in_channels = in_channels
        
    def forward(self, x, ldj=None):
        # Flatten one-hot encoding to channels
        batch_size, seq_len, vocab_size = x.size()
        x = x.transpose(1, 2).contiguous()  # [B, V, L]
        
        if ldj is None:
            ldj = torch.zeros(batch_size, device=x.device)
            
        for flow in self.flows:
            x, ldj = flow(x, ldj)
            
        return x, ldj
    
    def reverse(self, z, ldj=None):
        batch_size, _, seq_len = z.size()
        
        if ldj is None:
            ldj = torch.zeros(batch_size, device=z.device)
            
        for flow in reversed(self.flows):
            z, ldj = flow.reverse(z, ldj)
            
        # Convert back to one-hot format
        z = z.transpose(1, 2).contiguous()  # [B, L, V]
        
        return z, ldj
    
    def sample(self, num_samples, seq_len, device='cpu'):
        z = torch.randn(num_samples, self.in_channels, seq_len, device=device)
        samples, _ = self.reverse(z)
        return samples

## 5. Training and Utility Functions

In [12]:
# Training function
def train_glow_model(model, dataloader, optimizer, device, epoch):
    model.train()
    total_loss = 0
    
    for batch_idx, batch in enumerate(dataloader):
        batch = batch.to(device)
        optimizer.zero_grad()
        
        # Forward pass
        z, ldj = model(batch)
        
        # Compute loss (negative log-likelihood)
        loss = 0.5 * torch.sum(z**2) / len(batch) - torch.mean(ldj)
        
        # Backward pass
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
        
        if batch_idx % 10 == 0:
            print(f'Epoch: {epoch}, Batch: {batch_idx}, Loss: {loss.item():.4f}')
    
    return total_loss / len(dataloader)

In [13]:
# Sampling and converting back to SMILES
def sample_and_decode(model, idx_to_char, num_samples, max_length, device='cpu'):
    model.eval()
    samples = model.sample(num_samples, max_length, device)
    
    # Convert to one-hot indices
    samples = samples.argmax(dim=2)
    
    # Convert to SMILES strings
    smiles_list = []
    for sample in samples:
        smiles = ''.join([idx_to_char[idx.item()] for idx in sample if idx.item() in idx_to_char])
        smiles_list.append(smiles)
    
    return smiles_list

In [14]:
# Check validity of generated SMILES
def check_validity(smiles_list):
    valid_smiles = []
    for smiles in smiles_list:
        mol = Chem.MolFromSmiles(smiles)
        if mol is not None:
            valid_smiles.append(smiles)
    
    return valid_smiles, len(valid_smiles) / len(smiles_list) if len(smiles_list) > 0 else 0

## 6. Main Execution Function

In [31]:
# Main function to run the training and sampling
def main(smiles_data, num_epochs=NUM_EPOCHS):
    # Prepare the dataset
    smiles_dataset = SMILESDataset(smiles_data)
    train_data, val_data = train_test_split(smiles_dataset, test_size=0.1, random_state=42)
    
    train_loader = DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True)
    val_loader = DataLoader(val_data, batch_size=BATCH_SIZE)
    
    # Initialize model
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = GLOW(smiles_dataset.vocab_size, HIDDEN_DIM, NUM_FLOWS).to(device)
    optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
    
    # Training loop
    for epoch in range(1, num_epochs + 1):
        train_loss = train_glow_model(model, train_loader, optimizer, device, epoch)
        print(f'Epoch {epoch}, Average loss: {train_loss:.4f}')
        
        # Generate samples every 10 epochs
        if epoch % 10 == 0 or epoch == num_epochs:
            with torch.no_grad():
                samples = sample_and_decode(model, smiles_dataset.idx_to_char, 10, smiles_dataset.max_length, device)
                valid_samples, validity_rate = check_validity(samples)
                
                print(f"Generated samples (validity rate: {validity_rate:.2%}):")
                for i, s in enumerate(samples[:5]):
                    print(f"{i+1}. {s}")
                    
                print(f"Valid samples:")
                for i, s in enumerate(valid_samples[:5]):
                    print(f"{i+1}. {s}")
    
    # Save the trained model
    torch.save({
        'model_state_dict': model.state_dict(),
        'char_to_idx': smiles_dataset.char_to_idx,
        'idx_to_char': smiles_dataset.idx_to_char,
        'max_length': smiles_dataset.max_length
    }, 'glow_smiles_model.pt')
    
    return model, smiles_dataset.char_to_idx, smiles_dataset.idx_to_char

## 7. Example Usage with Sample Data

In [32]:
# Example usage with sample data
# Use this section as a template for your own dataset
df = pd.read_csv('dataset/train.txt')
df = df.drop(columns=["SPLIT"])
print(df.shape)
df.head()

(1584663, 1)


Unnamed: 0,SMILES
0,CCCS(=O)c1ccc2[nH]c(=NC(=O)OC)[nH]c2c1
1,CC(C)(C)C(=O)C(Oc1ccc(Cl)cc1)n1ccnc1
2,Cc1c(Cl)cccc1Nc1ncccc1C(=O)OCC(O)CO
3,Cn1cnc2c1c(=O)n(CC(O)CO)c(=O)n2C
4,CC1Oc2ccc(Cl)cc2N(CC(O)CO)C1=O


## 8. Run Model Training

Uncomment and run the following cell to train the model.

In [None]:
small_df = df.head(990)
print(f"Using {len(small_df)} samples for training")


# Train the model with the smaller dataset
model, char_to_idx, idx_to_char = main(small_df['SMILES'].tolist())

Using 990 samples for training
Epoch: 1, Batch: 0, Loss: 1720.0635
Epoch: 1, Batch: 10, Loss: -676.7039
Epoch: 1, Batch: 20, Loss: 2116061757440.0000
Epoch: 1, Batch: 30, Loss: -2665.0054
Epoch: 1, Batch: 40, Loss: -5578.4004
Epoch: 1, Batch: 50, Loss: 375670800384.0000
Epoch: 1, Batch: 60, Loss: -9393.4727
Epoch: 1, Batch: 70, Loss: -10748.1328
Epoch: 1, Batch: 80, Loss: -12065.9932
Epoch 1, Average loss: 874248795762.2327
Epoch: 2, Batch: 0, Loss: 302141374464.0000
Epoch: 2, Batch: 10, Loss: -14882.5625
Epoch: 2, Batch: 20, Loss: -15272.7344
Epoch: 2, Batch: 30, Loss: 165434884096.0000
Epoch: 2, Batch: 40, Loss: -19446.9023
Epoch: 2, Batch: 50, Loss: -20000.6172
Epoch: 2, Batch: 60, Loss: 107590041600.0000
Epoch: 2, Batch: 70, Loss: -20778.1543
Epoch: 2, Batch: 80, Loss: -22818.6426
Epoch 2, Average loss: 66312142860.4783
Epoch: 3, Batch: 0, Loss: -24074.9648
Epoch: 3, Batch: 10, Loss: -24758.0059
Epoch: 3, Batch: 20, Loss: -24831.8535
Epoch: 3, Batch: 30, Loss: -25292.5234
Epoch: 3,

[19:06:35] SMILES Parse Error: syntax error while parsing: ()(2(((c1((cCcn(c1Cn1(c(1(2((((1)cccnc(1(Nc1NCc(c1(((cc1)((C2cn(ncn)11n1(c(cCCcCnccn((Nnccc1n2Cc(ncnc(2cc(cCn=)cCCN(n1CC((cn1c=)1)cO(n((c(cc12c(c=(1cc
[19:06:35] SMILES Parse Error: Failed parsing SMILES '()(2(((c1((cCcn(c1Cn1(c(1(2((((1)cccnc(1(Nc1NCc(c1(((cc1)((C2cn(ncn)11n1(c(cCCcCnccn((Nnccc1n2Cc(ncnc(2cc(cCn=)cCCN(n1CC((cn1c=)1)cO(n((c(cc12c(c=(1cc' for input: '()(2(((c1((cCcn(c1Cn1(c(1(2((((1)cccnc(1(Nc1NCc(c1(((cc1)((C2cn(ncn)11n1(c(cCCcCnccn((Nnccc1n2Cc(ncnc(2cc(cCn=)cCCN(n1CC((cn1c=)1)cO(n((c(cc12c(c=(1cc'
[19:06:35] SMILES Parse Error: syntax error while parsing: C(O(Cccncc1n2()((ccc(Cncncc((1((cnc(Ncc11cCNc1nccccn1ncc((C(((n)(n(OnC1n)Ccn(ccc2nc)cn2n1cn(cC2=(cc(2ccc((Cc(nN(((1((Cc(1(N(cn(Oc(C2n1CcCOc=(O(1=ccc(c
[19:06:35] SMILES Parse Error: Failed parsing SMILES 'C(O(Cccncc1n2()((ccc(Cncncc((1((cnc(Ncc11cCNc1nccccn1ncc((C(((n)(n(OnC1n)Ccn(ccc2nc)cn2n1cn(cC2=(cc(2ccc((Cc(nN(((1((Cc(1(N(cn(Oc(C2n1CcCOc=(O(1=ccc(c' for 

Epoch: 11, Batch: 0, Loss: -44123.0469
Epoch: 11, Batch: 10, Loss: -42796.4492
Epoch: 11, Batch: 20, Loss: -43946.2617
Epoch: 11, Batch: 30, Loss: -43768.4922
Epoch: 11, Batch: 40, Loss: 238570831872.0000
Epoch: 11, Batch: 50, Loss: -44875.3555
Epoch: 11, Batch: 60, Loss: -43483.9922
Epoch: 11, Batch: 70, Loss: -42797.2500
Epoch: 11, Batch: 80, Loss: -44322.8281
Epoch 11, Average loss: 10287345576.3877
Epoch: 12, Batch: 0, Loss: -43955.0898
Epoch: 12, Batch: 10, Loss: -44658.6836
Epoch: 12, Batch: 20, Loss: -46277.8711
Epoch: 12, Batch: 30, Loss: -44650.1133
Epoch: 12, Batch: 40, Loss: -45182.5195
Epoch: 12, Batch: 50, Loss: -45049.4062
Epoch: 12, Batch: 60, Loss: -45804.4180
Epoch: 12, Batch: 70, Loss: 2855435776.0000
Epoch: 12, Batch: 80, Loss: -45531.1602
Epoch 12, Average loss: 2010897854.1387
Epoch: 13, Batch: 0, Loss: -44901.4297
Epoch: 13, Batch: 10, Loss: -41162.4102
Epoch: 13, Batch: 20, Loss: -41364.7109
Epoch: 13, Batch: 30, Loss: 50065612800.0000
Epoch: 13, Batch: 40, Loss:

[19:09:07] SMILES Parse Error: syntax error while parsing: cCccnC1cccCccc=c(c(c((ccccCc(OcCOO1ccCcO=Ccccc)c=cC1cccOc(=(Cccc)cc(c(3cccc(C31ccCccc3Cccccccc(cCC(ccCc1CcCcC=cccccOccccc1cCCO)ccCcCc1OcCCc2cc=cc1cccc
[19:09:07] SMILES Parse Error: Failed parsing SMILES 'cCccnC1cccCccc=c(c(c((ccccCc(OcCOO1ccCcO=Ccccc)c=cC1cccOc(=(Cccc)cc(c(3cccc(C31ccCccc3Cccccccc(cCC(ccCc1CcCcC=cccccOccccc1cCCO)ccCcCc1OcCCc2cc=cc1cccc' for input: 'cCccnC1cccCccc=c(c(c((ccccCc(OcCOO1ccCcO=Ccccc)c=cC1cccOc(=(Cccc)cc(c(3cccc(C31ccCccc3Cccccccc(cCC(ccCc1CcCcC=cccccOccccc1cCCO)ccCcCc1OcCCc2cc=cc1cccc'
[19:09:07] SMILES Parse Error: syntax error while parsing: cc=C2cccOccccc(c(ccccccCnc((cc3OCccc=c2[CCCncc2ccccccccCcnOcc2n(cccccC1(cO()c=COc(==Ccc(c=(Cc1cCcccccc1c1c2cccccCc)C2ccccccccc=cOOcccCc)c)cC(=Ccccc(Cc
[19:09:07] SMILES Parse Error: Failed parsing SMILES 'cc=C2cccOccccc(c(ccccccCnc((cc3OCccc=c2[CCCncc2ccccccccCcnOcc2n(cccccC1(cO()c=COc(==Ccc(c=(Cc1cCcccccc1c1c2cccccCc)C2ccccccccc=cOOcccCc)c)cC(=Ccccc(Cc' for 

Epoch: 21, Batch: 0, Loss: -41914.1719
Epoch: 21, Batch: 10, Loss: -41375.5000
Epoch: 21, Batch: 20, Loss: 8343726080.0000
Epoch: 21, Batch: 30, Loss: 9511826432.0000
Epoch: 21, Batch: 40, Loss: -44363.7617
Epoch: 21, Batch: 50, Loss: -46230.8320
Epoch: 21, Batch: 60, Loss: 4935693824.0000
Epoch: 21, Batch: 70, Loss: -44953.6211
Epoch: 21, Batch: 80, Loss: -45063.1016
Epoch 21, Average loss: 2788253098.9565
Epoch: 22, Batch: 0, Loss: 24225054720.0000
Epoch: 22, Batch: 10, Loss: 8042054656.0000
Epoch: 22, Batch: 20, Loss: -46184.6328
Epoch: 22, Batch: 30, Loss: -42292.4023
Epoch: 22, Batch: 40, Loss: -45832.8672
Epoch: 22, Batch: 50, Loss: -45952.2734
Epoch: 22, Batch: 60, Loss: -44306.5078
Epoch: 22, Batch: 70, Loss: 4020610304.0000
Epoch: 22, Batch: 80, Loss: -43491.1367
Epoch 22, Average loss: 2854718845.6847
Epoch: 23, Batch: 0, Loss: -46071.0938
Epoch: 23, Batch: 10, Loss: 3269872128.0000
Epoch: 23, Batch: 20, Loss: -46094.1680
Epoch: 23, Batch: 30, Loss: 2329979392.0000
Epoch: 23,

[19:10:55] SMILES Parse Error: syntax error while parsing: c(C(OO((cCcCCCC((c(CCccnC(c2cc(cc((C(c(CnCCC(C)((2cncCcC((CO(C(CcCCnCc(Cc(CcCCC((C(C((cCCC2(Ccc(c1n(NCCCcc((1CnCc=(CccCc((CnC)CCcC(C(cCcc((ccnCcn(CC(c
[19:10:55] SMILES Parse Error: Failed parsing SMILES 'c(C(OO((cCcCCCC((c(CCccnC(c2cc(cc((C(c(CnCCC(C)((2cncCcC((CO(C(CcCCnCc(Cc(CcCCC((C(C((cCCC2(Ccc(c1n(NCCCcc((1CnCc=(CccCc((CnC)CCcC(C(cCcc((ccnCcn(CC(c' for input: 'c(C(OO((cCcCCCC((c(CCccnC(c2cc(cc((C(c(CnCCC(C)((2cncCcC((CO(C(CcCCnCc(Cc(CcCCC((C(C((cCCC2(Ccc(c1n(NCCCcc((1CnCc=(CccCc((CnC)CCcC(C(cCcc((ccnCcn(CC(c'
[19:10:55] SMILES Parse Error: syntax error while parsing: c1CCN(C1((CC2(cCcccCC((1((nC(CCcc(cCC(CcC(((CCCC(c(((C)CCNCncC1c=nCnCCn(cC((CC(ccCncCCCc(2c(CCCNC(((Cc(c(nnccC(CCcC(c2(C(((c(SCCc(Cc(((Cc((c(C((cc(((=
[19:10:55] SMILES Parse Error: Failed parsing SMILES 'c1CCN(C1((CC2(cCcccCC((1((nC(CCcc(cCC(CcC(((CCCC(c(((C)CCNCncC1c=nCnCCn(cC((CC(ccCncCCCc(2c(CCCNC(((Cc(c(nnccC(CCcC(c2(C(((c(SCCc(Cc(((Cc((c(C((cc(((=' for 

Epoch: 31, Batch: 10, Loss: -48182.2812
Epoch: 31, Batch: 20, Loss: -47804.9258
Epoch: 31, Batch: 30, Loss: -48654.7344
Epoch: 31, Batch: 40, Loss: -47965.2969
Epoch: 31, Batch: 50, Loss: -48804.1094
Epoch: 31, Batch: 60, Loss: 576044672.0000
Epoch: 31, Batch: 70, Loss: -46981.9883
Epoch: 31, Batch: 80, Loss: -48402.1758
Epoch 31, Average loss: 227672875.8388
Epoch: 32, Batch: 0, Loss: -49136.6094
Epoch: 32, Batch: 10, Loss: -48523.3750
Epoch: 32, Batch: 20, Loss: -49426.7031
Epoch: 32, Batch: 30, Loss: -48724.5352
Epoch: 32, Batch: 40, Loss: -50461.2305
Epoch: 32, Batch: 50, Loss: -47881.8789
Epoch: 32, Batch: 60, Loss: -49412.3750
Epoch: 32, Batch: 70, Loss: -48070.7422
Epoch: 32, Batch: 80, Loss: -49218.9766
Epoch 32, Average loss: 1945321422.2831
Epoch: 33, Batch: 0, Loss: -49929.0469
Epoch: 33, Batch: 10, Loss: -49122.9805
Epoch: 33, Batch: 20, Loss: 33070110720.0000
Epoch: 33, Batch: 30, Loss: -48400.4805
Epoch: 33, Batch: 40, Loss: 6078891008.0000
Epoch: 33, Batch: 50, Loss: -48

[19:13:03] SMILES Parse Error: syntax error while parsing: 2CCC=c2CCCOcOcC2(CcCcCC(Cn)C2cNc=C(CCOCO=2ccn=C21CC())COCccCCC(c((CcCC2n=ccCCC(CCcCC(CCOn2=CC11n)cccCC1CNCCnc)=C(=COCCC((c=2CCCO2C=CCC=C(CCnC1222C((=)
[19:13:03] SMILES Parse Error: Failed parsing SMILES '2CCC=c2CCCOcOcC2(CcCcCC(Cn)C2cNc=C(CCOCO=2ccn=C21CC())COCccCCC(c((CcCC2n=ccCCC(CCcCC(CCOn2=CC11n)cccCC1CNCCnc)=C(=COCCC((c=2CCCO2C=CCC=C(CCnC1222C((=)' for input: '2CCC=c2CCCOcOcC2(CcCcCC(Cn)C2cNc=C(CCOCO=2ccn=C21CC())COCccCCC(c((CcCC2n=ccCCC(CCcCC(CCOn2=CC11n)cccCC1CNCCnc)=C(=COCCC((c=2CCCO2C=CCC=C(CCnC1222C((=)'
[19:13:03] SMILES Parse Error: extra close parentheses while parsing: OcCCC2C)(1=2Cc=cn=C()O1CC(cCCCc2cC1(C1CO=(Cnccc(CCCCcCCC=c(=ncNC1nCCCCC=CC1)=)CCcc2(cCc(=CCOO2CccCCnC1cOC(OCCOCCO21c(c=CCC((CcccCCOcCC22C)(Cc)=(ccc(cC
[19:13:03] SMILES Parse Error: Failed parsing SMILES 'OcCCC2C)(1=2Cc=cn=C()O1CC(cCCCc2cC1(C1CO=(Cnccc(CCCCcCCC=c(=ncNC1nCCCCC=CC1)=)CCcc2(cCc(=CCOO2CccCCnC1cOC(OCCOCCO21c(c=CCC((CcccCCOcCC22C)(Cc)=(c

Epoch: 41, Batch: 0, Loss: -50804.6914
Epoch: 41, Batch: 10, Loss: 329183776.0000
Epoch: 41, Batch: 20, Loss: -50571.0312
Epoch: 41, Batch: 30, Loss: -50741.7539
Epoch: 41, Batch: 40, Loss: -50526.8633
Epoch: 41, Batch: 50, Loss: -49902.5781
Epoch: 41, Batch: 60, Loss: -50769.7930
Epoch: 41, Batch: 70, Loss: -52028.5156
Epoch: 41, Batch: 80, Loss: -49954.1406
Epoch 41, Average loss: 97702651.7593
Epoch: 42, Batch: 0, Loss: -50720.9336
Epoch: 42, Batch: 10, Loss: -51146.8867
Epoch: 42, Batch: 20, Loss: 206853184.0000
Epoch: 42, Batch: 30, Loss: -49789.8203
Epoch: 42, Batch: 40, Loss: -51241.8047
Epoch: 42, Batch: 50, Loss: -51997.1719
Epoch: 42, Batch: 60, Loss: 240928672.0000
Epoch: 42, Batch: 70, Loss: 223446880.0000
Epoch: 42, Batch: 80, Loss: -52251.3711
Epoch 42, Average loss: 93071762.7725
Epoch: 43, Batch: 0, Loss: -49905.4336
Epoch: 43, Batch: 10, Loss: 28308144128.0000
Epoch: 43, Batch: 20, Loss: -47807.1250
Epoch: 43, Batch: 30, Loss: 1701910272.0000
Epoch: 43, Batch: 40, Loss

## 9. Generate New SMILES Strings

Once the model is trained, uncomment and run this cell to generate new molecules.

In [43]:
model.eval()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
with torch.no_grad():
	generated_smiles = sample_and_decode(model, idx_to_char, 20, 150, device)
	valid_samples, validity_rate = check_validity(generated_smiles)

print(f"\nFinal generated samples (validity rate: {validity_rate:.2%}):")
for i, s in enumerate(valid_samples[:10]):
	print(f"{i+1}. {s}")


Final generated samples (validity rate: 0.00%):


[18:48:36] SMILES Parse Error: syntax error while parsing: Oc(ccc-)NOOccc2Ccc)OO)Oc1ccc2n)cccC)ccccc)c1))Cc1)ccccOncc2Cc2CccO)C2)COC)OO)Oc1c)ccCC2cCcOCccc)Oc)c2cc3)Occc)2COCccc))2CCc2(2N2cccc2C)OO)c)ccO)cCccOc
[18:48:36] SMILES Parse Error: Failed parsing SMILES 'Oc(ccc-)NOOccc2Ccc)OO)Oc1ccc2n)cccC)ccccc)c1))Cc1)ccccOncc2Cc2CccO)C2)COC)OO)Oc1c)ccCC2cCcOCccc)Oc)c2cc3)Occc)2COCccc))2CCc2(2N2cccc2C)OO)c)ccO)cCccOc' for input: 'Oc(ccc-)NOOccc2Ccc)OO)Oc1ccc2n)cccC)ccccc)c1))Cc1)ccccOncc2Cc2CccO)C2)COC)OO)Oc1c)ccCC2cCcOCccc)Oc)c2cc3)Occc)2COCccc))2CCc2(2N2cccc2C)OO)c)ccO)cCccOc'
[18:48:36] SMILES Parse Error: extra close parentheses while parsing: cc))cOCcccc)cOc)CNcCcOOc)ON)OO2cc)ccccc)cCcOcc)c=()cOcc)ncOOOccC1cC=nccCcOc2ccOcCccO()OcO)Oc2cON2O)))cc2cON)Oc22ccccO)CcccCcc)ccOCcccO2c)cN-cOO)2c2COC
[18:48:36] SMILES Parse Error: Failed parsing SMILES 'cc))cOCcccc)cOc)CNcCcOOc)ON)OO2cc)ccccc)cCcOcc)c=()cOcc)ncOOOccC1cC=nccCcOc2ccOcCccO()OcO)Oc2cON2O)))cc2cON)Oc22ccccO)CcccCcc)ccOCcccO2c)cN-cOO)2

## 10. Using Your Own Dataset

You can load your own SMILES data and train the model on it.

In [None]:
# Load your own data
# Example 1: Loading from a CSV file with a 'smiles' column
# your_df = pd.read_csv('your_smiles_data.csv')
# smiles_list = your_df['smiles'].tolist()

# Example 2: Directly define from your data string
# long_smiles = "CCCS(=O)c1ccc2[nH]c(=NC(=O)OC)[nH]c2c11CC(C)(C)C(=O)C(Oc1ccc(Cl)cc1)n1ccnc12Cc1c(Cl)cccc1Nc1ncccc1C(=O)OCC(O)CO3Cn1cnc2c1c(=O)n(CC(O)CO)c(=O)n2C4CC1Oc2ccc(Cl)cc2N(CC(O)CO)C1=O"
# # Parse into individual compounds if needed
# smiles_list = [long_smiles]  # or split it if it contains multiple compounds

# Run the model
# model, char_to_idx, idx_to_char = main(smiles_list, num_epochs=50)

## 11. Visualize Generated Molecules

This optional section shows how to visualize generated molecules.

In [46]:
# Uncomment to install RDKit visualization tools
# !pip install rdkit-pypi

from rdkit.Chem import Draw
import matplotlib.pyplot as plt
from IPython.display import display

def visualize_molecules(smiles_list, ncols=3):
    mols = [Chem.MolFromSmiles(s) for s in smiles_list if Chem.MolFromSmiles(s) is not None]
    if len(mols) == 0:
        print("No valid molecules to visualize")
        return
        
    # Calculate number of rows needed
    nrows = (len(mols) + ncols - 1) // ncols
    
    # Create subplot
    fig = plt.figure(figsize=(ncols*3, nrows*3))
    
    for i, mol in enumerate(mols):
        ax = fig.add_subplot(nrows, ncols, i+1)
        img = Draw.MolToImage(mol, size=(300, 300))
        ax.imshow(img)
        ax.set_title(f"Molecule {i+1}")
        ax.axis('off')
    
    plt.tight_layout()
    plt.show()

# Example usage:
visualize_molecules(valid_samples[:9])

No valid molecules to visualize


## 12. Save and Load Models