# GLOW Model for Generating Molecular SMILES Strings

This notebook implements a GLOW (Generative Latent Optimization with Wasserstein GANs) model to generate new, valid molecular structures in SMILES format based on a training dataset.

## 1. Import Dependencies

In [6]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import pandas as pd
from rdkit import Chem
from rdkit.Chem import AllChem
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split

## 2. Define Hyperparameters

In [7]:
# Hyperparameters
LATENT_DIM = 128
HIDDEN_DIM = 512
NUM_FLOWS = 8
BATCH_SIZE = 10
LEARNING_RATE = 0.0001
NUM_EPOCHS = 50

## 3. Dataset Processing

Build vocabulary for SMILES strings and create a PyTorch dataset class.

In [8]:
# Define character vocabulary for SMILES strings
def build_vocabulary(smiles_list):
    chars = set()
    for smiles in smiles_list:
        chars.update(list(smiles))
    return {c: i for i, c in enumerate(sorted(chars))}, {i: c for i, c in enumerate(sorted(chars))}

In [9]:
# Dataset class for SMILES strings
class SMILESDataset(Dataset):
    def __init__(self, smiles_list, max_length=150, char_to_idx=None):
        self.smiles_list = smiles_list
        self.max_length = max_length
        
        if char_to_idx is None:
            self.char_to_idx, self.idx_to_char = build_vocabulary(smiles_list)
        else:
            self.char_to_idx = char_to_idx
            self.idx_to_char = {v: k for k, v in char_to_idx.items()}
        
        self.vocab_size = len(self.char_to_idx)
        
    def __len__(self):
        return len(self.smiles_list)
    
    def __getitem__(self, idx):
        smiles = self.smiles_list[idx]
        # One-hot encode the SMILES string
        encoded = torch.zeros(self.max_length, self.vocab_size)
        for i, char in enumerate(smiles[:self.max_length]):
            encoded[i, self.char_to_idx[char]] = 1.0
        
        # Pad with zeros
        if len(smiles) < self.max_length:
            encoded[len(smiles):, :] = 0.0
            
        return encoded

## 4. Implementation of the GLOW Model Components

### 4.1 Activation Normalization Layer

In [10]:
# Define ActNorm layer (used in GLOW)
class ActNorm(nn.Module):
    def __init__(self, channels):
        super(ActNorm, self).__init__()
        self.loc = nn.Parameter(torch.zeros(1, channels, 1))
        self.scale = nn.Parameter(torch.ones(1, channels, 1))
        self.initialized = False

    def initialize(self, x):
        with torch.no_grad():
            flatten = x.permute(1, 0, 2).contiguous().view(x.shape[1], -1)
            mean = flatten.mean(1).view(1, -1, 1)
            std = flatten.std(1).view(1, -1, 1)
            
            self.loc.data.copy_(-mean)
            self.scale.data.copy_(1 / (std + 1e-6))
            
    def forward(self, x, ldj=None, reverse=False):
        if not self.initialized:
            self.initialize(x)
            self.initialized = True
            
        if reverse:
            return (x - self.loc) / self.scale, ldj
        else:
            z = self.scale * x + self.loc
            if ldj is not None:
                ldj = ldj + torch.sum(torch.log(torch.abs(self.scale)))
            return z, ldj
    
    def reverse(self, z, ldj=None):
        return self.forward(z, ldj, reverse=True)

### 4.2 Invertible 1x1 Convolution

In [12]:
# Invertible 1x1 Convolution
class Invertible1x1Conv(nn.Module):
    def __init__(self, channels):
        super(Invertible1x1Conv, self).__init__()
        self.channels = channels
        
        # Initialize with random rotation matrix
        w_init = torch.qr(torch.randn(channels, channels))[0]
        self.weight = nn.Parameter(w_init)
        
    def forward(self, x, ldj=None, reverse=False):
        batch_size, _, seq_len = x.size()
        
        if reverse:
            weight = torch.inverse(self.weight)
        else:
            weight = self.weight
            
        z = F.conv1d(x, weight.unsqueeze(2))
        
        if ldj is not None:
            ldj = ldj + seq_len * torch.slogdet(weight)[1]
            
        return z, ldj
    
    def reverse(self, z, ldj=None):
        return self.forward(z, ldj, reverse=True)

### 4.3 Affine Coupling Layer

In [13]:
# Affine Coupling Layer
class AffineCoupling(nn.Module):
    def __init__(self, channels, hidden_dim):
        super(AffineCoupling, self).__init__()
        
        self.net = nn.Sequential(
            nn.Conv1d(channels // 2, hidden_dim, 3, padding=1),
            nn.ReLU(),
            nn.Conv1d(hidden_dim, hidden_dim, 1),
            nn.ReLU(),
            nn.Conv1d(hidden_dim, channels, 3, padding=1)
        )
        
        # Initialize last layer with zeros for stability
        self.net[-1].weight.data.zero_()
        self.net[-1].bias.data.zero_()
        
        self.channels = channels
        
    def forward(self, x, ldj=None, reverse=False):
        x_a, x_b = x.chunk(2, dim=1)
        
        out = self.net(x_a)
        log_s, t = out.chunk(2, dim=1)
        
        # Constrain scaling
        log_s = torch.tanh(log_s) * 0.5
        s = torch.exp(log_s)
        
        if reverse:
            z_b = (x_b - t) / s
            z = torch.cat([x_a, z_b], dim=1)
            
            if ldj is not None:
                ldj = ldj - torch.sum(log_s)
        else:
            z_b = s * x_b + t
            z = torch.cat([x_a, z_b], dim=1)
            
            if ldj is not None:
                ldj = ldj + torch.sum(log_s)
                
        return z, ldj
    
    def reverse(self, z, ldj=None):
        return self.forward(z, ldj, reverse=True)

### 4.4 Single Flow Step

In [14]:
# Single Flow Step
class FlowStep(nn.Module):
    def __init__(self, channels, hidden_dim):
        super(FlowStep, self).__init__()
        
        self.actnorm = ActNorm(channels)
        self.conv = Invertible1x1Conv(channels)
        self.coupling = AffineCoupling(channels, hidden_dim)
        
    def forward(self, x, ldj=None, reverse=False):
        if reverse:
            x, ldj = self.coupling.reverse(x, ldj)
            x, ldj = self.conv.reverse(x, ldj)
            x, ldj = self.actnorm.reverse(x, ldj)
        else:
            x, ldj = self.actnorm(x, ldj)
            x, ldj = self.conv(x, ldj)
            x, ldj = self.coupling(x, ldj)
            
        return x, ldj
    
    def reverse(self, z, ldj=None):
        return self.forward(z, ldj, reverse=True)

### 4.5 Complete GLOW Model

In [15]:
# Complete GLOW model
class GLOW(nn.Module):
    def __init__(self, in_channels, hidden_dim, num_flows):
        super(GLOW, self).__init__()
        
        self.flows = nn.ModuleList([
            FlowStep(in_channels, hidden_dim) for _ in range(num_flows)
        ])
        
        self.in_channels = in_channels
        
    def forward(self, x, ldj=None):
        # Flatten one-hot encoding to channels
        batch_size, seq_len, vocab_size = x.size()
        x = x.transpose(1, 2).contiguous()  # [B, V, L]
        
        if ldj is None:
            ldj = torch.zeros(batch_size, device=x.device)
            
        for flow in self.flows:
            x, ldj = flow(x, ldj)
            
        return x, ldj
    
    def reverse(self, z, ldj=None):
        batch_size, _, seq_len = z.size()
        
        if ldj is None:
            ldj = torch.zeros(batch_size, device=z.device)
            
        for flow in reversed(self.flows):
            z, ldj = flow.reverse(z, ldj)
            
        # Convert back to one-hot format
        z = z.transpose(1, 2).contiguous()  # [B, L, V]
        
        return z, ldj
    
    def sample(self, num_samples, seq_len, device='cpu'):
        z = torch.randn(num_samples, self.in_channels, seq_len, device=device)
        samples, _ = self.reverse(z)
        return samples

## 5. Training and Utility Functions

In [16]:
# Training function
def train_glow_model(model, dataloader, optimizer, device, epoch):
    model.train()
    total_loss = 0
    
    for batch_idx, batch in enumerate(dataloader):
        batch = batch.to(device)
        optimizer.zero_grad()
        
        # Forward pass
        z, ldj = model(batch)
        
        # Compute loss (negative log-likelihood)
        loss = 0.5 * torch.sum(z**2) / len(batch) - torch.mean(ldj)
        
        # Backward pass
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
        
        if batch_idx % 10 == 0:
            print(f'Epoch: {epoch}, Batch: {batch_idx}, Loss: {loss.item():.4f}')
    
    return total_loss / len(dataloader)

In [17]:
# Sampling and converting back to SMILES
def sample_and_decode(model, idx_to_char, num_samples, max_length, device='cpu'):
    model.eval()
    samples = model.sample(num_samples, max_length, device)
    
    # Convert to one-hot indices
    samples = samples.argmax(dim=2)
    
    # Convert to SMILES strings
    smiles_list = []
    for sample in samples:
        smiles = ''.join([idx_to_char[idx.item()] for idx in sample if idx.item() in idx_to_char])
        smiles_list.append(smiles)
    
    return smiles_list

In [18]:
# Check validity of generated SMILES
def check_validity(smiles_list):
    valid_smiles = []
    for smiles in smiles_list:
        mol = Chem.MolFromSmiles(smiles)
        if mol is not None:
            valid_smiles.append(smiles)
    
    return valid_smiles, len(valid_smiles) / len(smiles_list) if len(smiles_list) > 0 else 0

## 6. Main Execution Function

In [19]:
# Main function to run the training and sampling
def main(smiles_data, num_epochs=NUM_EPOCHS):
    # Prepare the dataset
    smiles_dataset = SMILESDataset(smiles_data)
    train_data, val_data = train_test_split(smiles_dataset, test_size=0.1, random_state=42)
    
    train_loader = DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True)
    val_loader = DataLoader(val_data, batch_size=BATCH_SIZE)
    
    # Initialize model
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = GLOW(smiles_dataset.vocab_size, HIDDEN_DIM, NUM_FLOWS).to(device)
    optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
    
    # Training loop
    for epoch in range(1, num_epochs + 1):
        train_loss = train_glow_model(model, train_loader, optimizer, device, epoch)
        print(f'Epoch {epoch}, Average loss: {train_loss:.4f}')
        
        # Generate samples every 10 epochs
        if epoch % 10 == 0 or epoch == num_epochs:
            with torch.no_grad():
                samples = sample_and_decode(model, smiles_dataset.idx_to_char, 10, smiles_dataset.max_length, device)
                valid_samples, validity_rate = check_validity(samples)
                
                print(f"Generated samples (validity rate: {validity_rate:.2%}):")
                for i, s in enumerate(samples[:5]):
                    print(f"{i+1}. {s}")
                    
                print(f"Valid samples:")
                for i, s in enumerate(valid_samples[:5]):
                    print(f"{i+1}. {s}")
    
    # Save the trained model
    torch.save({
        'model_state_dict': model.state_dict(),
        'char_to_idx': smiles_dataset.char_to_idx,
        'idx_to_char': smiles_dataset.idx_to_char,
        'max_length': smiles_dataset.max_length
    }, 'glow_smiles_model.pt')
    
    return model, smiles_dataset.char_to_idx, smiles_dataset.idx_to_char

## 7. Example Usage with Sample Data

In [20]:
# Example usage with sample data
# Use this section as a template for your own dataset
df = pd.read_csv('dataset/train.txt')
df = df.drop(columns=["SPLIT"])
print(df.shape)
df.head()

(1584663, 1)


Unnamed: 0,SMILES
0,CCCS(=O)c1ccc2[nH]c(=NC(=O)OC)[nH]c2c1
1,CC(C)(C)C(=O)C(Oc1ccc(Cl)cc1)n1ccnc1
2,Cc1c(Cl)cccc1Nc1ncccc1C(=O)OCC(O)CO
3,Cn1cnc2c1c(=O)n(CC(O)CO)c(=O)n2C
4,CC1Oc2ccc(Cl)cc2N(CC(O)CO)C1=O


## 8. Run Model Training

Uncomment and run the following cell to train the model.

In [21]:
small_df = df.head(990)
print(f"Using {len(small_df)} samples for training")


# Train the model with the smaller dataset
model, char_to_idx, idx_to_char = main(small_df['SMILES'].tolist())

Using 990 samples for training


The boolean parameter 'some' has been replaced with a string parameter 'mode'.
Q, R = torch.qr(A, some)
should be replaced with
Q, R = torch.linalg.qr(A, 'reduced' if some else 'complete') (Triggered internally at C:\actions-runner\_work\pytorch\pytorch\builder\windows\pytorch\aten\src\ATen\native\BatchLinearAlgebra.cpp:2428.)
  w_init = torch.qr(torch.randn(channels, channels))[0]


Epoch: 1, Batch: 0, Loss: 1686.4832
Epoch: 1, Batch: 10, Loss: 6697155297280.0000
Epoch: 1, Batch: 20, Loss: 681.7671
Epoch: 1, Batch: 30, Loss: 2611588366336.0000
Epoch: 1, Batch: 40, Loss: 5671071252480.0000
Epoch: 1, Batch: 50, Loss: 1040258695168.0000
Epoch: 1, Batch: 60, Loss: -7356.2964
Epoch: 1, Batch: 70, Loss: 1904146513920.0000
Epoch: 1, Batch: 80, Loss: 667609858048.0000
Epoch 1, Average loss: 5394155552622.4707
Epoch: 2, Batch: 0, Loss: 1362398019584.0000
Epoch: 2, Batch: 10, Loss: 300844646400.0000
Epoch: 2, Batch: 20, Loss: -14022.0156
Epoch: 2, Batch: 30, Loss: 1402845921280.0000
Epoch: 2, Batch: 40, Loss: -14532.3203
Epoch: 2, Batch: 50, Loss: 498803605504.0000
Epoch: 2, Batch: 60, Loss: 266583048192.0000
Epoch: 2, Batch: 70, Loss: -14860.6943
Epoch: 2, Batch: 80, Loss: 280871698432.0000
Epoch 2, Average loss: 554154430897.6296
Epoch: 3, Batch: 0, Loss: 539740471296.0000
Epoch: 3, Batch: 10, Loss: 1180896067584.0000
Epoch: 3, Batch: 20, Loss: 333319405568.0000
Epoch: 3,

[22:43:09] SMILES Parse Error: syntax error while parsing: ccCcCcccOcc1cC2c(ccnc2Oc1Cc=CONNcNC=CCcccccC1c(CccccOCC1c1ccCccnn(OCccc(c=cOcC(cOCCc=ccccn(=2c)cc=cncc(cccN=1CcC=CcnOOcc(cCcc1c)ccOccCccCcNCNcCccccCCc
[22:43:09] SMILES Parse Error: check for mistakes around position 93:
[22:43:09] c=cOcC(cOCCc=ccccn(=2c)cc=cncc(cccN=1CcC=
[22:43:09] ~~~~~~~~~~~~~~~~~~~~^
[22:43:09] SMILES Parse Error: Failed parsing SMILES 'ccCcCcccOcc1cC2c(ccnc2Oc1Cc=CONNcNC=CCcccccC1c(CccccOCC1c1ccCccnn(OCccc(c=cOcC(cOCCc=ccccn(=2c)cc=cncc(cccN=1CcC=CcnOOcc(cCcc1c)ccOccCccCcNCNcCccccCCc' for input: 'ccCcCcccOcc1cC2c(ccnc2Oc1Cc=CONNcNC=CCcccccC1c(CccccOCC1c1ccCccnn(OCccc(c=cOcC(cOCCc=ccccn(=2c)cc=cncc(cccN=1CcC=CcnOOcc(cCcc1c)ccOccCccCcNCNcCccccCCc'
[22:43:09] SMILES Parse Error: syntax error while parsing: )c)cn)1ccccccccOCcNc)CN)CcccNCcCccOcCcOC1ccC(Nc1CC1cOcc)Nc=c)NC1c)1N22))OCCcccccOcc=cC=cOccccC()=(NccCcNcO)1c1)CO=N)c)=ccc)COcccCccCcOcOcOC)ccC(ccOccc
[22:43:09] SMILES Parse Error: check for mistakes arou

Epoch: 11, Batch: 10, Loss: -18200.0859
Epoch: 11, Batch: 20, Loss: 21676597248.0000
Epoch: 11, Batch: 30, Loss: 6255233024.0000
Epoch: 11, Batch: 40, Loss: 5135295488.0000
Epoch: 11, Batch: 50, Loss: 282131529728.0000
Epoch: 11, Batch: 60, Loss: 38589431808.0000
Epoch: 11, Batch: 70, Loss: 97855070208.0000
Epoch: 11, Batch: 80, Loss: 21473742848.0000
Epoch 11, Average loss: 39793142267.0006
Epoch: 12, Batch: 0, Loss: 45359247360.0000
Epoch: 12, Batch: 10, Loss: 42369036288.0000
Epoch: 12, Batch: 20, Loss: 19987425280.0000
Epoch: 12, Batch: 30, Loss: 29773703168.0000
Epoch: 12, Batch: 40, Loss: 58672726016.0000
Epoch: 12, Batch: 50, Loss: 18969118720.0000
Epoch: 12, Batch: 60, Loss: 51889389568.0000
Epoch: 12, Batch: 70, Loss: 42449203200.0000
Epoch: 12, Batch: 80, Loss: 18761469952.0000
Epoch 12, Average loss: 42089252436.4655
Epoch: 13, Batch: 0, Loss: 9414069248.0000
Epoch: 13, Batch: 10, Loss: 30047248384.0000
Epoch: 13, Batch: 20, Loss: 27926882304.0000
Epoch: 13, Batch: 30, Loss:

[22:43:27] SMILES Parse Error: syntax error while parsing: =Oc)c)ccCC)OccOCOccc)OCCCc=OcOcc=CC1Oc)=cOcccC)cc)c))cccccnccccc=c)ccCCcc=c1ccCcC)()Occ(c3C)OcOcOC(1)cC)OCCccccNcO)cOc)Cc)cOc2)HOcO=ON1)Nc3OCNCccccC)O
[22:43:27] SMILES Parse Error: check for mistakes around position 1:
[22:43:27] =Oc)c)ccCC)OccOCOccc)OCCCc=OcOcc=CC1Oc)=c
[22:43:27] ^
[22:43:27] SMILES Parse Error: Failed parsing SMILES '=Oc)c)ccCC)OccOCOccc)OCCCc=OcOcc=CC1Oc)=cOcccC)cc)c))cccccnccccc=c)ccCCcc=c1ccCcC)()Occ(c3C)OcOcOC(1)cC)OCCccccNcO)cOc)Cc)cOc2)HOcO=ON1)Nc3OCNCccccC)O' for input: '=Oc)c)ccCC)OccOCOccc)OCCCc=OcOcc=CC1Oc)=cOcccC)cc)c))cccccnccccc=c)ccCCcc=c1ccCcC)()Occ(c3C)OcOcOC(1)cC)OCCccccNcO)cOc)Cc)cOc2)HOcO=ON1)Nc3OCNCccccC)O'
[22:43:27] SMILES Parse Error: syntax error while parsing: )c)Oc)c)CcO=ccC(cCccOcc1cccONOCnc1)2OccOCccO1cc1C=cc))C))Occcnc(cOccOcCnOccc3O=Ncc)=cC2CC=)11Oc)c)ccccO)c1O)Oc1n)OCO)2O=NOCOcN)cONON1cc)OcCcOc)ccN)ccc
[22:43:27] SMILES Parse Error: check for mistakes around position 1:
[22:43

Generated samples (validity rate: 0.00%):
1. =Oc)c)ccCC)OccOCOccc)OCCCc=OcOcc=CC1Oc)=cOcccC)cc)c))cccccnccccc=c)ccCCcc=c1ccCcC)()Occ(c3C)OcOcOC(1)cC)OCCccccNcO)cOc)Cc)cOc2)HOcO=ON1)Nc3OCNCccccC)O
2. )c)Oc)c)CcO=ccC(cCccOcc1cccONOCnc1)2OccOCccO1cc1C=cc))C))Occcnc(cOccOcCnOccc3O=Ncc)=cC2CC=)11Oc)c)ccccO)c1O)Oc1n)OCO)2O=NOCOcN)cONON1cc)OcCcOc)ccN)ccc
3. cccOc11))O)C1OOcc)cO)=C)c1cO)CcOcccO(O)O))ccccO1Occ1OOc)Cc)CcccOCc1NO)1Cc)cOC))2OOcOccON)1c2)cccCOc)c)Nc)cCCCOCOc1)()N(CCOSc1ccOOOcO)O))c1Ncc)2ccOO1Cc
4. Ccc=COC1c(C1cCccC)Ncc=1(Oc1O1NONOcNn)cOcccccC==OCCC112)cc)ccCNNCcc)Occccc))Oc()c1c)C(cc1CNc11COO1O)cc(1Occ(ccNccN(11CccC))cccccccHCNNcCOOc)c(O(c1ccc1C
5. c1cOc)COO=1=2(ccC=cccOcCccccc1c1cCcc=c)(cC1()cCc=cC))OO(c))1OcCccOnN)cc)N1O)CC)C=)Ccc))c(c)=Cc)cc)O)cCCCC)c1ccCCcc)c(=))ccc)()c=)O2N(NcO))cc1N(cOnOO)C
Valid samples:
Epoch: 21, Batch: 0, Loss: -18009.6016
Epoch: 21, Batch: 10, Loss: 34166409216.0000
Epoch: 21, Batch: 20, Loss: 37573099520.0000
Epoch: 21, Batch: 30, Loss: -16809.5859
Epoc

[22:43:45] SMILES Parse Error: extra close parentheses while parsing: c)CCc2O((1cCcc((c)ccO1)cccc=O))Oc=[OCc)1c((CCcCc2ONc1)C(c)C)C(2c=cccc2OCN))NcOcc2OcC)cCcCcc2ccc=O2cc=c2CC2cCOc(2cO)1O12c)1)C)=c1))ncc1)c)O)ccCcCCcN)NN
[22:43:45] SMILES Parse Error: check for mistakes around position 2:
[22:43:45] c)CCc2O((1cCcc((c)ccO1)cccc=O))Oc=[OCc)1c
[22:43:45] ~^
[22:43:45] SMILES Parse Error: Failed parsing SMILES 'c)CCc2O((1cCcc((c)ccO1)cccc=O))Oc=[OCc)1c((CCcCc2ONc1)C(c)C)C(2c=cccc2OCN))NcOcc2OcC)cCcCcc2ccc=O2cc=c2CC2cCOc(2cO)1O12c)1)C)=c1))ncc1)c)O)ccCcCCcN)NN' for input: 'c)CCc2O((1cCcc((c)ccO1)cccc=O))Oc=[OCc)1c((CCcCc2ONc1)C(c)C)C(2c=cccc2OCN))NcOcc2OcC)cCcCcc2ccc=O2cc=c2CC2cCOc(2cO)1O12c)1)C)=c1))ncc1)c)O)ccCcCCcN)NN'
[22:43:45] SMILES Parse Error: syntax error while parsing: cNOcCN1CccOc=-Cc(NccccCncCcC)ccccCOC2CccOc2OCc11cccc2)cncOc=Cc=c(CScOCOC))O))c))Cc1N(cCcOcN[cH=((cCNN)O)c=2(C)cc=c)c11cOCOc21OcccFOcCCNN)O1cc(c2OcCcO=
[22:43:45] SMILES Parse Error: check for mistakes around positi

Generated samples (validity rate: 0.00%):
1. c)CCc2O((1cCcc((c)ccO1)cccc=O))Oc=[OCc)1c((CCcCc2ONc1)C(c)C)C(2c=cccc2OCN))NcOcc2OcC)cCcCcc2ccc=O2cc=c2CC2cCOc(2cO)1O12c)1)C)=c1))ncc1)c)O)ccCcCCcN)NN
2. cNOcCN1CccOc=-Cc(NccccCncCcC)ccccCOC2CccOc2OCc11cccc2)cncOc=Cc=c(CScOCOC))O))c))Cc1N(cCcOcN[cH=((cCNN)O)c=2(C)cc=c)c11cOCOc21OcccFOcCCNN)O1cc(c2OcCcO=
3. c)CCc2c)1ccc1N)cccCcOCCCc12c(c2c=S2O)c2O1)c2c=c)1(cccOcOcCCN)ccCc1COc1c)2c1=1Occ))11CcCc2=NS(c)nOc)Occc2c)cc(Nc21OC1cN1O2CON2ccONNc2)2c)c2c2cCNCcc(11c
4. OCC(2cCOCCcCc)1OC=c(2CCH))Nc=C(NcC=1(1C=)Ocn2c(c2c3cc2(1CCc)cc()cO1Cc(ccO-Cc))O(ccccO==C2c1c1c12cc1CC(c)1c2cnc)NcccCOONc)cS1cCc(ccOcO2O2C)CO=O)c)OCccc
5. Nc1cCcc2c=2O2c2c2C=C)2)(cc1c2c1C)cOcccccCON1ccc()CO(cc)Cc2=O)cONcC=(]11(c1CC)=NOO1)C2c32((CCcc((NOcc)1]NC)c2ccccCc)C=Ocn1=CO)cc-cNNcCcc(c2c2c=COc2)ONc
Valid samples:
Epoch: 31, Batch: 0, Loss: 10309518336.0000
Epoch: 31, Batch: 10, Loss: 27131799552.0000
Epoch: 31, Batch: 20, Loss: 8700445696.0000
Epoch: 31, Batch: 30, Loss: 67510304768.

[22:44:03] SMILES Parse Error: syntax error while parsing: =cO2)2cc2OccN)cO2ccc(CccOcccCNc1ccc)2N)2cc(c(cC1NnccccCcNccO=C1CcCcCcc(c2cCNc2cc))cnCccccOO)1O)(2c1cCOcCcccccccC2)2Cc(c(cccc2cc(CccCc)ccCccCc)cc2cCCcC
[22:44:03] SMILES Parse Error: check for mistakes around position 1:
[22:44:03] =cO2)2cc2OccN)cO2ccc(CccOcccCNc1ccc)2N)2c
[22:44:03] ^
[22:44:03] SMILES Parse Error: Failed parsing SMILES '=cO2)2cc2OccN)cO2ccc(CccOcccCNc1ccc)2N)2cc(c(cC1NnccccCcNccO=C1CcCcCcc(c2cCNc2cc))cnCccccOO)1O)(2c1cCOcCcccccccC2)2Cc(c(cccc2cc(CccCc)ccCccCc)cc2cCCcC' for input: '=cO2)2cc2OccN)cO2ccc(CccOcccCNc1ccc)2N)2cc(c(cC1NnccccCcNccO=C1CcCcCcc(c2cCNc2cc))cnCccccOO)1O)(2c1cCOcCcccccccC2)2Cc(c(cccc2cc(CccCc)ccCccCc)cc2cCCcC'
[22:44:03] SMILES Parse Error: syntax error while parsing: 12cc(c2cc))ccc)cC)ccCCc=CcOO2(cCc=2Cc22c11cc1NN2NCCCcC2C2ccccCc(ccCc2O2N=2))2cCONccN)ccCCcC1C2ccc(c2c12cc)cC2C2Nc)112cNcccccc)2=cC2cCcCc2(OOcOc)cCccCc
[22:44:03] SMILES Parse Error: check for mistakes around position 1:
[22:44

Generated samples (validity rate: 0.00%):
1. =cO2)2cc2OccN)cO2ccc(CccOcccCNc1ccc)2N)2cc(c(cC1NnccccCcNccO=C1CcCcCcc(c2cCNc2cc))cnCccccOO)1O)(2c1cCOcCcccccccC2)2Cc(c(cccc2cc(CccCc)ccCccCc)cc2cCCcC
2. 12cc(c2cc))ccc)cC)ccCCc=CcOO2(cCc=2Cc22c11cc1NN2NCCCcC2C2ccccCc(ccCc2O2N=2))2cCONccN)ccCCcC1C2ccc(c2c12cc)cC2C2Nc)112cNcccccc)2=cC2cCcCc2(OOcOc)cCccCc
3. )2cC(c)12ccc2cCCOccOcCc)NcO1(c(cccc1(2)cc1ccOONccc2cCONcc(cccONcCcccccCC()(=c2CcOC(ccCcccCC2c2=ccCccCcccC2C=c2cc2cCcccccN)N)1)(ccccNccccNO)Ccc(c(2cccO
4. c2ccCc=Cc2)1N2)1ccC=2NccC2cC)cO12Nc1(NcN2c2CCncCcSCC)Ccc(1)OcC(CCcCC))OC2(ccc2c1cccc2Occcc2cc(c1ccCcCcccc12c2Cc21CO)2cNccCcC1cCOOO2c((ccccC2c2c22O1N(=
5. (1c(C1ccc3c)(ccc)c(cC)12(1(cCcccC2CcN2))(cc1N(CCcc(cCCcCcCClNC2O2=(=CCCccccC1c=2cccC)c22c2C2cCcCccccCcCOC2NCcCcCC=cc2c)CCCcccCCC2Cccc2(CcCNC111Cc(cccc
Valid samples:
Epoch: 41, Batch: 0, Loss: 2592327680.0000
Epoch: 41, Batch: 10, Loss: 2252408320.0000
Epoch: 41, Batch: 20, Loss: 1924909952.0000
Epoch: 41, Batch: 30, Loss: 1241822080.000

[22:44:22] SMILES Parse Error: syntax error while parsing: =1)=)CNc22C)cC(cccN2(1)OCcC(cO)Occ)c()N)(cNO=)c(cncc2C1(cc(cCccNONCcc=2)Ccnc)cccOc)==Ccc(ccCOcc)Cc1CcCOCCNcCc1cc(cO=Cccccc=cc=)cCccc)cnccScc(c2OCccC)1
[22:44:22] SMILES Parse Error: check for mistakes around position 1:
[22:44:22] =1)=)CNc22C)cC(cccN2(1)OCcC(cO)Occ)c()N)(
[22:44:22] ^
[22:44:22] SMILES Parse Error: Failed parsing SMILES '=1)=)CNc22C)cC(cccN2(1)OCcC(cO)Occ)c()N)(cNO=)c(cncc2C1(cc(cCccNONCcc=2)Ccnc)cccOc)==Ccc(ccCOcc)Cc1CcCOCCNcCc1cc(cO=Cccccc=cc=)cCccc)cnccScc(c2OCccC)1' for input: '=1)=)CNc22C)cC(cccN2(1)OCcC(cO)Occ)c()N)(cNO=)c(cncc2C1(cc(cCccNONCcc=2)Ccnc)cccOc)==Ccc(ccCOcc)Cc1CcCOCCNcCc1cc(cO=Cccccc=cc=)cCccc)cnccScc(c2OCccC)1'
[22:44:22] SMILES Parse Error: syntax error while parsing: OcOcCc=()cCc)NccCOCccc)cC2ccc(CcccccNcCc(C)C)c)O)((C)c)C()=c2O1NO=cC=Ccc2cCCc2(cCOc)C)==c)c=((cC)N)ccccC)OcOcc())NCccc(cn)cccCc(nc(2ccc)Cc2C=O)n)OcC)O
[22:44:22] SMILES Parse Error: check for mistakes around position 8:
[22:44

## 9. Generate New SMILES Strings

Once the model is trained, uncomment and run this cell to generate new molecules.

In [22]:
model.eval()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
with torch.no_grad():
	generated_smiles = sample_and_decode(model, idx_to_char, 20, 150, device)
	valid_samples, validity_rate = check_validity(generated_smiles)

print(f"\nFinal generated samples (validity rate: {validity_rate:.2%}):")
for i, s in enumerate(valid_samples[:10]):
	print(f"{i+1}. {s}")


Final generated samples (validity rate: 0.00%):


[22:44:22] SMILES Parse Error: syntax error while parsing: (c)c)2cc(N=C2NccO(C)=CcCccccc=cCc))O)ccCS1c)Nc=2(22NO1=cNO)cCc)C)CcCCO2(2(cc2cC(ccc)C==cO2)c(Oc2)ccccc2cCNc2cCcCCCCc)cnccc)CccccOcccCc2ccccNOCcc)=cCCN
[22:44:22] SMILES Parse Error: check for mistakes around position 1:
[22:44:22] (c)c)2cc(N=C2NccO(C)=CcCccccc=cCc))O)ccCS
[22:44:22] ^
[22:44:22] SMILES Parse Error: Failed parsing SMILES '(c)c)2cc(N=C2NccO(C)=CcCccccc=cCc))O)ccCS1c)Nc=2(22NO1=cNO)cCc)C)CcCCO2(2(cc2cC(ccc)C==cO2)c(Oc2)ccccc2cCNc2cCcCCCCc)cnccc)CccccOcccCc2ccccNOCcc)=cCCN' for input: '(c)c)2cc(N=C2NccO(C)=CcCccccc=cCc))O)ccCS1c)Nc=2(22NO1=cNO)cCc)C)CcCCO2(2(cc2cC(ccc)C==cO2)c(Oc2)ccccc2cCNc2cCcCCCCc)cnccc)CccccOcccCc2ccccNOCcc)=cCCN'
[22:44:22] SMILES Parse Error: syntax error while parsing: )cNNSccOCCOnNccC=2O2OcCCNOOccO2=1CNc)c(cO=nCO21cOcc=)c2c2cOC)c()cc))COccCc(c)C)1)c)C(Ccccc(cn)=cNN=cc)Oc1N1c)=)ccOcccC=CCccCcc(CcON))cccc))CCO2=C=c(cN
[22:44:22] SMILES Parse Error: check for mistakes around position 1:
[22:44

## 10. Using Your Own Dataset

You can load your own SMILES data and train the model on it.

In [23]:
# Load your own data
# Example 1: Loading from a CSV file with a 'smiles' column
# your_df = pd.read_csv('your_smiles_data.csv')
# smiles_list = your_df['smiles'].tolist()

# Example 2: Directly define from your data string
# long_smiles = "CCCS(=O)c1ccc2[nH]c(=NC(=O)OC)[nH]c2c11CC(C)(C)C(=O)C(Oc1ccc(Cl)cc1)n1ccnc12Cc1c(Cl)cccc1Nc1ncccc1C(=O)OCC(O)CO3Cn1cnc2c1c(=O)n(CC(O)CO)c(=O)n2C4CC1Oc2ccc(Cl)cc2N(CC(O)CO)C1=O"
# # Parse into individual compounds if needed
# smiles_list = [long_smiles]  # or split it if it contains multiple compounds

# Run the model
# model, char_to_idx, idx_to_char = main(smiles_list, num_epochs=50)

## 11. Visualize Generated Molecules

This optional section shows how to visualize generated molecules.

In [24]:
# Uncomment to install RDKit visualization tools
# !pip install rdkit-pypi

from rdkit.Chem import Draw
import matplotlib.pyplot as plt
from IPython.display import display

def visualize_molecules(smiles_list, ncols=3):
    mols = [Chem.MolFromSmiles(s) for s in smiles_list if Chem.MolFromSmiles(s) is not None]
    if len(mols) == 0:
        print("No valid molecules to visualize")
        return
        
    # Calculate number of rows needed
    nrows = (len(mols) + ncols - 1) // ncols
    
    # Create subplot
    fig = plt.figure(figsize=(ncols*3, nrows*3))
    
    for i, mol in enumerate(mols):
        ax = fig.add_subplot(nrows, ncols, i+1)
        img = Draw.MolToImage(mol, size=(300, 300))
        ax.imshow(img)
        ax.set_title(f"Molecule {i+1}")
        ax.axis('off')
    
    plt.tight_layout()
    plt.show()

# Example usage:
visualize_molecules(valid_samples[:9])

No valid molecules to visualize


## 12. Save and Load Models