In [5]:

import torch  
import pandas as pd
import torch.nn.functional as F
import matplotlib.pyplot as plt # for making figures


In [6]:
df = pd.read_csv('../data/processed/qm9_clean.csv')
words = df['smiles'].astype('str').tolist()

In [7]:
chars = sorted(list(set(''.join(words))))
stoi = {s:i+1 for i,s in enumerate(chars)}
stoi['.'] = 0
itos = {i:s for s,i in stoi.items()}
vocab_size = len(itos)
print(itos)
print(vocab_size)

{1: '#', 2: '(', 3: ')', 4: '+', 5: '-', 6: '1', 7: '2', 8: '3', 9: '4', 10: '5', 11: '=', 12: 'C', 13: 'F', 14: 'H', 15: 'N', 16: 'O', 17: '[', 18: ']', 0: '.'}
19


In [8]:
import random
random.seed(42)
random.shuffle(words)

In [21]:
# build the dataset
block_size = 24 # context length: how many characters do we take to predict the next one?

def build_dataset(words):  
  X, Y = [], []
  
  for w in words:
    context = [0] * block_size
    for ch in w + '.':
      ix = stoi[ch]
      X.append(context)
      Y.append(ix)
      context = context[1:] + [ix] # crop and append

  X = torch.tensor(X)
  Y = torch.tensor(Y)
  print(X.shape, Y.shape)
  return X, Y

n1 = int(0.8*len(words))
n2 = int(0.9*len(words))
Xtr,  Ytr  = build_dataset(words[:n1])     
Xdev, Ydev = build_dataset(words[n1:n2])   
Xte,  Yte  = build_dataset(words[n2:])

torch.Size([1737037, 24]) torch.Size([1737037])
torch.Size([217037, 24]) torch.Size([217037])
torch.Size([217214, 24]) torch.Size([217214])


In [22]:
for x,y in zip(Xtr[:20], Ytr[:20]):
  print(''.join(itos[ix.item()] for ix in x), '-->', itos[y.item()])

........................ --> C
.......................C --> N
......................CN --> =
.....................CN= --> C
....................CN=C --> (
...................CN=C( --> O
..................CN=C(O --> C
.................CN=C(OC --> C
................CN=C(OCC --> =
...............CN=C(OCC= --> O
..............CN=C(OCC=O --> )
.............CN=C(OCC=O) --> C
............CN=C(OCC=O)C --> #
...........CN=C(OCC=O)C# --> N
..........CN=C(OCC=O)C#N --> .
........................ --> C
.......................C --> C
......................CC --> #
.....................CC# --> C
....................CC#C --> C


In [23]:
class Linear:
  
  def __init__(self, fan_in, fan_out, bias=True):
    self.weight = torch.randn((fan_in, fan_out)) / fan_in**0.5 # note: kaiming init
    self.bias = torch.zeros(fan_out) if bias else None
  
  def __call__(self, x):
    self.out = x @ self.weight
    if self.bias is not None:
      self.out += self.bias
    return self.out
  
  def parameters(self):
    return [self.weight] + ([] if self.bias is None else [self.bias])
  
class BatchNorm1d:
  
    def __init__(self, dim, eps=1e-5, momentum=0.1):
      self.eps = eps
      self.momentum = momentum
      self.training = True
      # parameters (trained with backprop)
      self.gamma = torch.ones(dim)
      self.beta = torch.zeros(dim)
      # buffers (trained with a running 'momentum update')
      self.running_mean = torch.zeros(dim)
      self.running_var = torch.ones(dim)
    
    def __call__(self, x):
      # calculate the forward pass
      if self.training:
        if x.ndim == 2:
          dim = 0
        elif x.ndim == 3:
          dim = (0,1)
        xmean = x.mean(dim, keepdim=True) # batch mean
        xvar = x.var(dim, keepdim=True) # batch variance
      else:
        xmean = self.running_mean
        xvar = self.running_var
      xhat = (x - xmean) / torch.sqrt(xvar + self.eps) # normalize to unit variance
      self.out = self.gamma * xhat + self.beta
      # update the buffers
      if self.training:
        with torch.no_grad():
          self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * xmean
          self.running_var = (1 - self.momentum) * self.running_var + self.momentum * xvar
      return self.out
    
    def parameters(self):
      return [self.gamma, self.beta]

class Tanh:
  def __call__(self, x):
    self.out = torch.tanh(x)
    return self.out
  def parameters(self):
    return []


class Embedding:
  
  def __init__(self, num_embeddings, embedding_dim):
    self.weight = torch.randn((num_embeddings, embedding_dim))
    
  def __call__(self, IX):
    self.out = self.weight[IX]
    return self.out
  
  def parameters(self):
    return [self.weight]


class FlattenConsecutive:
  
  def __init__(self, n):
    self.n = n
    
  def __call__(self, x):
    B, T, C = x.shape
    x = x.view(B, T//self.n, C*self.n)
    if x.shape[1] == 1:
      x = x.squeeze(1)
    self.out = x
    return self.out
  
  def parameters(self):
    return []
  
class Sequential:
  
  def __init__(self, layers):
    self.layers = layers
  
  def __call__(self, x):
    for layer in self.layers:
      x = layer(x)
    self.out = x
    return self.out
  
  def parameters(self):
    # get parameters of all layers and stretch them out into one list
    return [p for layer in self.layers for p in layer.parameters()]
  
class Flatten:
  def __call__(self, x):
    self.out = x.view(x.shape[0], -1)
    return self.out
  
  def parameters(self):
    return []
  
  
    

In [24]:
n_embd = 24 # the dimensionality of the character embedding vectors
n_hidden = 512 # the number of neurons in the hidden layer of the MLP
model = Sequential([
  Embedding(vocab_size, n_embd),
  FlattenConsecutive(2), Linear(n_embd * 2, n_hidden, bias=False), BatchNorm1d(n_hidden), Tanh(),
  FlattenConsecutive(2), Linear(n_hidden*2, n_hidden, bias=False), BatchNorm1d(n_hidden), Tanh(),
  FlattenConsecutive(2), Linear(n_hidden*2, n_hidden, bias=False), BatchNorm1d(n_hidden), Tanh(),
  Linear(n_hidden, vocab_size),
])

# parameter init
with torch.no_grad():
  model.layers[-1].weight *= 0.1 # last layer make less confident

parameters = model.parameters()
print(sum(p.nelement() for p in parameters)) # number of parameters in total
for p in parameters:
  p.requires_grad = True

1086427


In [26]:
max_steps = 200000
batch_size = 32
lossi = []

for i in range(max_steps):
  
  # minibatch construct
  ix = torch.randint(0, Xtr.shape[0], (batch_size,))
  Xb, Yb = Xtr[ix], Ytr[ix] # batch X,Y
  
  # forward pass
  logits = model(Xb)
  if logits.ndim == 3:
    logits = logits[:, -1, :]
  loss = F.cross_entropy(logits, Yb) # loss function
  
  # backward pass
  for p in parameters:
    p.grad = None
  loss.backward()
  
  # update: simple SGD
  lr = 0.1 if i < 150000 else 0.01 # step learning rate decay
  for p in parameters:
    p.data += -lr * p.grad

  # track stats
  if i % 10000 == 0: # print every once in a while
    print(f'{i:7d}/{max_steps:7d}: {loss.item():.4f}')
  lossi.append(loss.log10().item())
  

      0/ 200000: 2.9517
  10000/ 200000: 0.9083
  20000/ 200000: 0.9893
  30000/ 200000: 1.4097
  40000/ 200000: 0.6834
  50000/ 200000: 1.1716
  60000/ 200000: 1.2522
  70000/ 200000: 1.2257
  80000/ 200000: 1.6226
  90000/ 200000: 1.3331
 100000/ 200000: 1.1530
 110000/ 200000: 1.2361
 120000/ 200000: 1.1351
 130000/ 200000: 0.9021
 140000/ 200000: 1.3176
 150000/ 200000: 0.9717
 160000/ 200000: 0.8621
 170000/ 200000: 1.0302
 180000/ 200000: 0.9525
 190000/ 200000: 0.8405


In [28]:

for layer in model.layers:
  layer.training = False

In [33]:
@torch.no_grad()
def split_loss(split):
    # Set to evaluation mode
    for layer in model.layers:
        if hasattr(layer, 'training'):
            layer.training = False
            
    x, y = {
        'train': (Xtr, Ytr),
        'val': (Xdev, Ydev),
        'test': (Xte, Yte),
    }[split]
    
    batch_size = 1024 # Small enough to not crash RAM
    losses = []
    
    for i in range(0, x.shape[0], batch_size):
        Xb, Yb = x[i:i+batch_size], y[i:i+batch_size]
        logits = model(Xb)
        if logits.ndim == 3:
            logits = logits[:, -1, :]
        loss = F.cross_entropy(logits, Yb.long())
        losses.append(loss.item())
    
    print(f'{split:5s} loss: {sum(losses)/len(losses):.4f}')

# Reset layers to training mode after eval if needed
def set_train_mode():
    for layer in model.layers:
        if hasattr(layer, 'training'):
            layer.training = True
            
split_loss('train')
split_loss('val')

train loss: 0.9399
val   loss: 0.9468


In [36]:

# sample from the model
for _ in range(20):
    
    out = []
    context = [0] * block_size # initialize with all ...
    while True:
      # forward pass the neural net
      logits = model(torch.tensor([context]))
      if logits.ndim == 3:
          logits = logits[:, -1, :]
      probs = F.softmax(logits, dim=1)
      # print(logits.shape)
      # print(Yb.shape)
      
      # sample from the distribution
      ix = torch.multinomial(probs, num_samples=1).item()
      # shift the context window and track the samples
      context = context[1:] + [ix]
      out.append(ix)
      # if we sample the special '.' token, break
      if ix == 0:
        break
    
    print(''.join(itos[i] for i in out))

CC1OCOC1C1NCC1=NOC(=O)CO.
C#CC1C2CCC3C2C13.
CN1CC(C)C1(CC1)C1COC1.
C#CCOC1COC2CC2C1.
COC1CC2(CC3C1)C2(C)O.
OC=NC1=O.
N=C(CC=O)O1.
OC12CC1C(C=O)N1CC2C([O-])=O.
CC1OC(CO1)C#N.
O=CC1CCC(O)C=C1C.
OCC12CC3C4C3N1C24O.
CCC(C#N)C12.
CC1OC(C)C1=NC=COC1=O.
CC1CC2(C)C(C)=O.
CC1CC(C)CO1.
CN(C=N)C(=O)C#C.
CC1(C)OC(C1)C=O.
O=CNC1=CON=C1N.
CC12CC(C1)O2.
CN1CC1C=CNC1=N.


In [37]:
for layer in model.layers:
    print(layer.__class__.__name__, ':', tuple(layer.out.shape))

Embedding : (1, 24, 24)
FlattenConsecutive : (1, 12, 48)
Linear : (1, 12, 512)
BatchNorm1d : (1, 12, 512)
Tanh : (1, 12, 512)
FlattenConsecutive : (1, 6, 1024)
Linear : (1, 6, 512)
BatchNorm1d : (1, 6, 512)
Tanh : (1, 6, 512)
FlattenConsecutive : (1, 3, 1024)
Linear : (1, 3, 512)
BatchNorm1d : (1, 3, 512)
Tanh : (1, 3, 512)
Linear : (1, 3, 19)


In [38]:
#pip install rdkit

In [39]:
from rdkit import Chem
from rdkit.Chem import Draw

def validate_smiles(smiles_str):
    """
    Attempts to parse a SMILES string. 
    Returns the Molecule object if valid, None if invalid.
    """
    # Remove the trailing stop token '.' if it's there
    smiles_str = smiles_str.replace('.', '')
    
    # MolFromSmiles returns None if the string is chemically impossible
    mol = Chem.MolFromSmiles(smiles_str)
    
    if mol:
        try:
            # Performs valency and aromaticity checks
            Chem.SanitizeMol(mol)
            return mol
        except:
            return None
    return None

In [41]:
n_samples = 500
valid_molecules = []
generated_smiles = []

# 1. Generate samples from your model
for _ in range(n_samples):
    out = []
    context = [0] * block_size
    while True:
        logits = model(torch.tensor([context]))
        if logits.ndim == 3:
            logits = logits[:, -1, :]
        probs = F.softmax(logits, dim=1)
        ix = torch.multinomial(probs, num_samples=1).item()
        context = context[1:] + [ix]
        if ix == 0: break
        out.append(itos[ix])
    
    smiles = "".join(out)
    generated_smiles.append(smiles)
    
    # 2. Validate with RDKit
    mol = validate_smiles(smiles)
    if mol:
        valid_molecules.append(mol)

# 3. Calculate Results
validity_rate = (len(valid_molecules) / n_samples) * 100
print(f"Total Samples: {n_samples}")
print(f"Valid Molecules: {len(valid_molecules)}")
print(f"Validity Rate: {validity_rate:.2f}%")

[02:12:31] SMILES Parse Error: unclosed ring for input: 'CN1C2CC2(O)C13'
[02:12:31] SMILES Parse Error: unclosed ring for input: 'CC1CCC1(O)C2NC2C1(C)C3O2'
[02:12:31] SMILES Parse Error: unclosed ring for input: 'O=C1CNC=NC1(C)COC=N1'
[02:12:31] SMILES Parse Error: unclosed ring for input: 'CC1(C)C2CC1(O)COC1CC1C'
[02:12:31] SMILES Parse Error: unclosed ring for input: 'CN=COCCOCOC1CCC2C1'
[02:12:32] SMILES Parse Error: unclosed ring for input: 'CC1(CC=O)C=O'
[02:12:32] SMILES Parse Error: unclosed ring for input: 'O=CC1=CN=CC(F)=CC(=O)C1OC2C(O)CC12O'
[02:12:32] SMILES Parse Error: unclosed ring for input: 'C1C2C=CC3C4C1C24C5'
[02:12:32] SMILES Parse Error: unclosed ring for input: 'CC12OC3CCC13C24'
[02:12:32] SMILES Parse Error: unclosed ring for input: 'OC1C2NC2C2OCC13'
[02:12:32] SMILES Parse Error: unclosed ring for input: 'C#CC1OC(=N)C2'
[02:12:32] SMILES Parse Error: unclosed ring for input: 'CC1C2CC(=O)C2CC12'
[02:12:32] SMILES Parse Error: unclosed ring for input: 'C1C2CC1C21CO

Total Samples: 500
Valid Molecules: 263
Validity Rate: 52.60%


[02:12:46] SMILES Parse Error: unclosed ring for input: 'O=C1CC2C(C)O1'
[02:12:47] SMILES Parse Error: unclosed ring for input: 'COC1C2C1C(C#N)C2CC1C2'


In [None]:
from rdkit.Chem import Draw

# Create the grid as a PNG image directly
img = Draw.MolsToGridImage(valid_molecules[:12], molsPerRow=4, subImgSize=(300,300))

if hasattr(img, 'data'):
    with open('generated_molecules.png', 'wb') as f:
        f.write(img.data)
    print("Saved using .data attribute!")
else:
    from rdkit.Chem.Draw import RdMolDrawing
    Draw.MolToFile(valid_molecules[0], 'single_molecule.png') # Test a single one
    print("Saved a single molecule to test.")

Saved using .data attribute!
