In [1]:
# Regular EDA and plotting libraries
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import networkx as nx
import seaborn as sns
import warnings
import rdkit
import mols2grid
from rdkit import Chem

# PyTorch + model building
import torch
import torch.nn as nn
from torch.nn import functional as F
from torch.utils.data import Dataset, DataLoader

In [2]:
# Hyperparameters
block_size = 128
max_iters = 2000
eval_interval = 100
device = 'cuda' if torch.cuda.is_available() else 'cpu'
eval_iters = 300
n_embd = 384
n_head = 4
n_layer = 4
dropout = 0.2

torch.manual_seed(1337)

<torch._C.Generator at 0x18758b45610>

In [3]:
# Load and preprocess SMILES data
df = pd.read_csv("polymer_tg_tm.csv")
df = df.dropna(subset=["SMILES"]).copy()
df["SMILES"] = df["SMILES"].astype(str)
df = df.sample(n=1000, random_state=66).reset_index(drop=True)
smiles_list = df["SMILES"].tolist()

# Prepare data string and tokenizer
text = "\n".join(smiles_list)
chars = sorted(list(set(text)))
vocab_size = len(chars)
stoi = {ch: i for i, ch in enumerate(chars)}
itos = {i: ch for i, ch in enumerate(chars)}
encode = lambda s: [stoi[c] for c in s]
decode = lambda l: ''.join([itos[i] for i in l])

# Tokenize all text
data = torch.tensor(encode(text), dtype=torch.long)
n = int(0.9 * len(data))
train_data = data[:n]
val_data = data[n:]

In [4]:
df

Unnamed: 0,SMILES,Tg,Tm
0,*OC(CCCCCCC=C)CC(*)=O,-50.0,35.0
1,*CCCCOC(=O)NCC(F)(F)C(F)(F)C(F)(F)C(F)(F)CNC(=...,33.0,140.0
2,*CNC(=O)OCc1ccc(COCc2ccc(COC(=O)NCc3ccc(C(C)(C...,39.0,173.0
3,*CCCCCCCN1C(=O)c2ccc(C(=O)Oc3ccc(-c4ccc(OC(=O)...,117.0,290.0
4,*c1ccc(-c2ccc(-c3ccc4nc(Oc5cc(-c6ccccc6)c6cc(*...,308.0,480.0
...,...,...,...
995,*CCOCCOCCOCCOCCOCCOC(=O)NCCCCCCNC(=O)O*,-31.0,65.0
996,*Oc1cc(OCCOC)c(OC(=O)c2ccc(C(*)=O)cc2OCCOCC)cc...,99.0,164.0
997,*CCCCCCOc1ccc(/C=C/c2ccc(O*)c3ccccc23)cc1,78.0,181.0
998,*CCCCCCCCCCOC(=O)CCCCCCCCCCCCCCCCC(=O)O*,-17.0,92.7


In [5]:
# Dataset class
class SMILESDataset(Dataset):
    def __init__(self, data, block_size):
        self.data = data
        self.block_size = block_size

    def __len__(self):
        return len(self.data) - self.block_size

    def __getitem__(self, idx):
        x = self.data[idx:idx + self.block_size]
        y = self.data[idx + 1:idx + self.block_size + 1]
        return x, y

In [6]:
# Create DataLoaders
train_dataset = SMILESDataset(train_data, block_size)
val_dataset = SMILESDataset(val_data, block_size)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)

In [7]:
# Model components
class Head(nn.Module):
    def __init__(self, head_size):
        super().__init__()
        self.key = nn.Linear(n_embd, head_size, bias=False)
        self.query = nn.Linear(n_embd, head_size, bias=False)
        self.value = nn.Linear(n_embd, head_size, bias=False)
        self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        B, T, C = x.shape
        k = self.key(x)
        q = self.query(x)
        wei = q @ k.transpose(-2, -1) * k.shape[-1] ** -0.5
        wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf'))
        wei = F.softmax(wei, dim=-1)
        wei = self.dropout(wei)
        v = self.value(x)
        return wei @ v

class MultiHeadAttention(nn.Module):
    def __init__(self, num_heads, head_size):
        super().__init__()
        self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)])
        self.proj = nn.Linear(head_size * num_heads, n_embd)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        out = torch.cat([h(x) for h in self.heads], dim=-1)
        out = self.dropout(self.proj(out))
        return out

class FeedFoward(nn.Module):
    def __init__(self, n_embd):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(n_embd, 4 * n_embd),
            nn.ReLU(),
            nn.Linear(4 * n_embd, n_embd),
            nn.Dropout(dropout),
        )

    def forward(self, x):
        return self.net(x)

class Block(nn.Module):
    def __init__(self, n_embd, n_head):
        super().__init__()
        head_size = n_embd // n_head
        self.sa = MultiHeadAttention(n_head, head_size)
        self.ffwd = FeedFoward(n_embd)
        self.ln1 = nn.LayerNorm(n_embd)
        self.ln2 = nn.LayerNorm(n_embd)

    def forward(self, x):
        x = x + self.sa(self.ln1(x))
        x = x + self.ffwd(self.ln2(x))
        return x

class GPTLanguageModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.token_embedding_table = nn.Embedding(vocab_size, n_embd)
        self.position_embedding_table = nn.Embedding(block_size, n_embd)
        self.blocks = nn.Sequential(*[Block(n_embd, n_head=n_head) for _ in range(n_layer)])
        self.ln_f = nn.LayerNorm(n_embd)
        self.lm_head = nn.Linear(n_embd, vocab_size)
        self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)

    def forward(self, idx, targets=None):
        B, T = idx.shape
        tok_emb = self.token_embedding_table(idx)
        pos_emb = self.position_embedding_table(torch.arange(T, device=device))
        x = tok_emb + pos_emb
        x = self.blocks(x)
        x = self.ln_f(x)
        logits = self.lm_head(x)

        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B * T, C)
            targets = targets.view(B * T)
            loss = F.cross_entropy(logits, targets)

        return logits, loss

    def generate(self, idx, max_new_tokens):
        for _ in range(max_new_tokens):
            idx_cond = idx[:, -block_size:]
            logits, _ = self(idx_cond)
            logits = logits[:, -1, :]
            probs = F.softmax(logits, dim=-1)
            idx_next = torch.multinomial(probs, num_samples=1)
            idx = torch.cat((idx, idx_next), dim=1)
        return idx

In [8]:
# Loss estimation with DataLoader
@torch.no_grad()
def estimate_loss():
    out = {}
    model.eval()
    for split, loader in [('train', train_loader), ('val', val_loader)]:
        losses = torch.zeros(eval_iters)
        for k, (X, Y) in enumerate(loader):
            if k >= eval_iters:
                break
            X, Y = X.to(device), Y.to(device)
            _, loss = model(X, Y)
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train()
    return out

In [9]:
# Initialize model
model = GPTLanguageModel().to(device)
print(sum(p.numel() for p in model.parameters()) / 1e6, 'M parameters')

# Optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4)

7.173159 M parameters


In [10]:
# Training loop
train_iterator = iter(train_loader)
for step in range(max_iters):
    if step % eval_interval == 0 or step == max_iters - 1:
        losses = estimate_loss()
        print(f"step {step}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")

    try:
        X, Y = next(train_iterator)
    except StopIteration:
        train_iterator = iter(train_loader)
        X, Y = next(train_iterator)

    X, Y = X.to(device), Y.to(device)
    logits, loss = model(X, Y)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

step 0: train loss 3.8420, val loss 1.8182
step 100: train loss 1.4135, val loss 0.6550
step 200: train loss 0.9949, val loss 0.4556
step 300: train loss 0.7758, val loss 0.3650
step 400: train loss 0.6535, val loss 0.3163
step 500: train loss 0.5790, val loss 0.2930
step 600: train loss 0.5124, val loss 0.2681
step 700: train loss 0.4663, val loss 0.2524
step 800: train loss 0.4261, val loss 0.2416
step 900: train loss 0.4002, val loss 0.2383
step 1000: train loss 0.3731, val loss 0.2307
step 1100: train loss 0.3519, val loss 0.2272
step 1200: train loss 0.3269, val loss 0.2267
step 1300: train loss 0.3080, val loss 0.2299
step 1400: train loss 0.2945, val loss 0.2311
step 1500: train loss 0.2814, val loss 0.2325
step 1600: train loss 0.2636, val loss 0.2358
step 1700: train loss 0.2510, val loss 0.2319
step 1800: train loss 0.2415, val loss 0.2433
step 1900: train loss 0.2235, val loss 0.2444
step 1999: train loss 0.2222, val loss 0.2488


In [11]:
# Generate monomers
context = torch.zeros((1, 1), dtype=torch.long, device=device)
generated = decode(model.generate(context, max_new_tokens=5000)[0].tolist())
smiles_list = generated.split("\n")
valid_smiles = [s for s in smiles_list if Chem.MolFromSmiles(s)]
print(f"{len(valid_smiles)} valid out of {len(smiles_list)}")

43 valid out of 121


[21:24:25] non-ring atom 4 marked aromatic
[21:24:25] SMILES Parse Error: unclosed ring for input: '*Oc1cc(C(C)c2cccc(OC(*)=O)c2)C(F)F'
[21:24:25] SMILES Parse Error: extra close parentheses while parsing: *CCCCCCC1SC(=O)(=O)c2ccc(C)c(OC(=O)CCCC(=O)N(*)C=O)cc2)cc1
[21:24:25] SMILES Parse Error: check for mistakes around position 55:
[21:24:25] )CCCC(=O)N(*)C=O)cc2)cc1
[21:24:25] ~~~~~~~~~~~~~~~~~~~~^
[21:24:25] SMILES Parse Error: Failed parsing SMILES '*CCCCCCC1SC(=O)(=O)c2ccc(C)c(OC(=O)CCCC(=O)N(*)C=O)cc2)cc1' for input: '*CCCCCCC1SC(=O)(=O)c2ccc(C)c(OC(=O)CCCC(=O)N(*)C=O)cc2)cc1'
[21:24:25] SMILES Parse Error: extra open parentheses while parsing: *CCCCCC(=O)Nc1ccc(Cc2ccc(N*)cc1
[21:24:25] SMILES Parse Error: check for mistakes around position 18:
[21:24:25] *CCCCCC(=O)Nc1ccc(Cc2ccc(N*)cc1
[21:24:25] ~~~~~~~~~~~~~~~~~^
[21:24:25] SMILES Parse Error: Failed parsing SMILES '*CCCCCC(=O)Nc1ccc(Cc2ccc(N*)cc1' for input: '*CCCCCC(=O)Nc1ccc(Cc2ccc(N*)cc1'
[21:24:25] SMILES Parse Error: ext

In [12]:
chems = [Chem.MolFromSmiles(x) for x in valid_smiles]
mols2grid.display(chems)

MolGridWidget()

In [13]:
valid_smiles

['',
 '*CC(*)c1ccc2ccccc2c1',
 '*c1ccc(OC(=O)c2cc(C(=O)Oc3ccc(C4(*)NC(=O)c5ccccc54)cc3)c2)cc1',
 '*CC(*)c1cccc(C)c1',
 '*CCCCCCCC(=O)N*',
 '*CCN(*)C(=O)c1ccccc1',
 '*C(=O)OCCCCCCCCCCCCCCCOC(=O)c1cccc(C(=O)OC(=O)c2ccc(O*)cc2)c1',
 '*CCCCCCC(=O)N*',
 '*CCC(*)(C)C(=O)OCCOC(C)(CC)C',
 '*Oc1ccc(C(C)c2ccc(OC(=O)c3ccc(C(*)=O)cc3)cc2)cc1',
 '*C(=O)Nc1ccc(Oc2ccccc(NC(=O)c3c(C(=O)c4ccc(Oc5ccc(Oc6ccc(N7C(=O)c8ccc(*)cc8C7=O)cc6)cc5)cc3)C4=O)cc2)cc1',
 '*CCCO*',
 '*C(=O)N*',
 '*Oc1ccc(NC(=O)CCCCC(=O)Nc2ccc(*)cc2)cc1',
 '*CCS*',
 '*CC(CSCC)OC(=O)OCCCCC',
 '*N(*)C(=O)CCCC',
 '*CCOC(=O)c2ccccc2C(=O)OCCCCCOc1ccc(-c2ccc(O*)cc2)cc1',
 '*Nc1ccc(CCC(=O)OCCCCOC(=O)c2ccc(*)cc2)cc1',
 '*CCCC(=O)O*',
 '*CC(*)CCCCCCCCCCCCCCCC',
 '*CC(*)CCCC1CCCCC1',
 '*CCCCCOC(=O)NCC(C)C(=O)O*',
 '*Oc1ccc(C(C)(C)c2ccc(OC(*)=O)cc2)cc1',
 '*CCS*',
 '*CCCCOC(=O)OCCCCCOC(=O)O*',
 '*CC(C)(C)CO*',
 '*CCCCCCC(=O)O*',
 '*COc1ccc(C(=O)OCCCCOC(=O)c2ccc(O*)cc2)c(C)c1',
 '*CC(*)(C)C(=O)OCCCCCCCCCCOc1ccc(C(=O)Oc2ccc(C(=O)/C=C/c3cc(O*)ccc3)c