# CuteSmileyBert : a toy transformer for SMILES strings

Chemical data has such underlying complexity that it is quite complicated to find representations of molecules that machine learning models can work with. Indeed, we could try different types of human-engineered featurizations, especially for the protein structures. However, they will all inevitably lose some information, or become noisy. For this reason, I believe that the best encoding for this specific task will be a BERT-like transformer of protein structures. ESM2 demonstrated an understanding of protein structures despite only being trained on sequences. We can definitely expect similar results on SMILES strings.

CuteSmileyBert will have only 1 million parameters, which is absolutely **tiny**. I am not expecting it to work very well, but I am curious if it will work at all, and if we can demonstrate some form of scaling laws.

In [None]:
import sys
sys.path.append("..")

In [68]:
import re
from pathlib import Path
import random

import numpy as np
import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader

from src.download_dataset import download_datasets, extract_files

ROOT_DIR = Path("../")
DATA_DIR = Path(ROOT_DIR, "data")

load_dotenv()

# Define the SMILES dataset to download
smiles_datasets = {
    'SMILES_Big_Dataset.csv.zip': 'https://www.kaggle.com/api/v1/datasets/download/yanmaksi/big-molecules-smiles-dataset'
}

# Download the dataset
download_datasets(smiles_datasets)
extract_files(smiles_datasets)

filename = list(smiles_datasets.keys())[0]
file_path = Path(DATA_DIR, filename)
df = pd.read_csv(file_path)

SMILES_Big_Dataset.csv.zip est déjà présent.
Extraction de SMILES_Big_Dataset.csv.zip...


Extracting SMILES_Big_Dataset.csv.zip: 100%|██████████| 1/1 [00:00<00:00, 38.01it/s]

SMILES_Big_Dataset.csv.zip extrait dans /var/home/marcos/Code/Cheminformatics_molecule_property_project/data/SMILES_Big_Dataset.csv





In [63]:
# This is the column that we are interested in
smiles_list = df["SMILES"].to_list()

df["SMILES"].head()

0           O=S(=O)(Nc1cccc(-c2cnc3ccccc3n2)c1)c1cccs1
1    O=c1cc(-c2nc(-c3ccc(-c4cn(CCP(=O)(O)O)nn4)cc3)...
2               NC(=O)c1ccc2c(c1)nc(C1CCC(O)CC1)n2CCCO
3                  NCCCn1c(C2CCNCC2)nc2cc(C(N)=O)ccc21
4                    CNC(=S)Nc1cccc(-c2cnc3ccccc3n2)c1
Name: SMILES, dtype: object

In [None]:
def tokenize_smiles(smiles: str):
    return SMILES_REGEX.findall(smiles)

MASK_TOKEN = "<MASK>"
SPECIAL_TOKENS = [PAD_TOKEN, BOS_TOKEN, EOS_TOKEN, UNK_TOKEN, MASK_TOKEN]

def build_vocab(smiles_list):
    tokens = set(SPECIAL_TOKENS)
    for s in smiles_list:
        tokens.update(tokenize_smiles(s))
    vocab = {tok: i for i, tok in enumerate(sorted(tokens))}
    inv_vocab = {i: tok for tok, i in vocab.items()}
    return vocab, inv_vocab

MAX_LEN = 150

def encode_smiles(smiles, vocab, max_len=MAX_LEN):
    tokens = [BOS_TOKEN] + tokenize_smiles(smiles) + [EOS_TOKEN]
    token_ids = [vocab.get(t, vocab[UNK_TOKEN]) for t in tokens]
    token_ids = token_ids[:max_len] + [vocab[PAD_TOKEN]] * (max_len - len(token_ids))
    return token_ids

def mask_tokens(input_ids, vocab, mask_prob=0.15):
    input_ids = input_ids.clone()
    labels = input_ids.clone()

    mask_token_id = vocab[MASK_TOKEN]
    pad_token_id = vocab[PAD_TOKEN]
    vocab_size = len(vocab)

    # Do not mask padding tokens
    maskable = input_ids != pad_token_id
    masked_indices = torch.bernoulli(torch.full(input_ids.shape, mask_prob)).bool() & maskable
    labels[~masked_indices] = -100  # ignore loss on unmasked tokens

    # 80% replace with <MASK>
    replace_mask = torch.bernoulli(torch.full(input_ids.shape, 0.8)).bool() & masked_indices
    input_ids[replace_mask] = mask_token_id

    # 10% replace with random token
    random_mask = torch.bernoulli(torch.full(input_ids.shape, 0.1)).bool() & masked_indices & ~replace_mask
    random_tokens = torch.randint(vocab_size, input_ids.shape, dtype=torch.long)
    input_ids[random_mask] = random_tokens[random_mask]

    # 10% keep unchanged (no-op for the rest)
    return input_ids, labels

In [None]:
class SMILESMaskedDataset(torch.utils.data.Dataset):
    def __init__(self, smiles_list, vocab, max_len=150):
        self.vocab = vocab
        self.max_len = max_len
        self.data = [encode_smiles(s, vocab, max_len) for s in smiles_list]

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        input_ids = torch.tensor(self.data[idx], dtype=torch.long)
        masked_input, labels = mask_tokens(input_ids, self.vocab)
        return masked_input, labels
    
vocab, inv_vocab = build_vocab(smiles_list)

dataset = SMILESDataset(smiles_list, vocab, max_len=150)
loader = DataLoader(dataset, batch_size=4, shuffle=True)

for src, tgt in loader:
    print(src.shape, tgt.shape)
    break
    

torch.Size([4, 149]) torch.Size([4, 149])


In [None]:
import torch.nn as nn

class SMILESBERT(nn.Module):
    def __init__(self, vocab_size, d_model=256, nhead=8, num_layers=6, dim_feedforward=1024, dropout=0.1):
        super().__init__()
        self.embed = nn.Embedding(vocab_size, d_model)
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model, nhead=nhead,
            dim_feedforward=dim_feedforward, dropout=dropout, batch_first=True
        )
        self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
        self.lm_head = nn.Linear(d_model, vocab_size)

    def forward(self, input_ids):
        x = self.embed(input_ids)
        x = self.encoder(x)
        logits = self.lm_head(x)
        return logits


In [None]:
vocab, inv_vocab = build_vocab(smiles_list)

dataset = SMILESMaskedDataset(smiles_list, vocab, max_len=150)
loader = DataLoader(dataset, batch_size=32, shuffle=True)

model = SMILESBERT(len(vocab))
criterion = nn.CrossEntropyLoss(ignore_index=-100)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)

for epoch in range(3):
    for masked_input, labels in loader:
        optimizer.zero_grad()
        logits = model(masked_input)
        loss = criterion(logits.view(-1, len(vocab)), labels.view(-1))
        loss.backward()
        optimizer.step()
    print(f"Epoch {epoch+1}: loss = {loss.item():.4f}")