# Use MolDecod

This notebook is for generating molecules using MolDecod. You need the model checkpoint and the tokenizer in the /models folder.

## Setup

In [3]:
# pip install torch sentencepiece

In [4]:
import torch
import torch.nn as nn
import torch.optim as optim
import math
import sentencepiece as spm

In [5]:
# Use GPU if available, CPU if not
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cpu


## Load tokenizer

In [7]:
sp = spm.SentencePieceProcessor()
sp.load('models/moldecod_tokenizer.model')

True

## Define the model architecture and load MolDecod

In [8]:
class RotaryPositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout=0.1, max_len=5000):
        super(RotaryPositionalEncoding, self).__init__()
        self.d_model = d_model
        self.dropout = nn.Dropout(p=dropout)
        
        position = torch.arange(0, max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * -(math.log(10000.0) / d_model))
        self.register_buffer('sin_pos', torch.sin(position * div_term))
        self.register_buffer('cos_pos', torch.cos(position * div_term))

    def forward(self, x):
        seq_len = x.size(1)
        x1 = x[..., ::2]
        x2 = x[..., 1::2]
        x = torch.cat([
            x1 * self.cos_pos[:seq_len] - x2 * self.sin_pos[:seq_len],
            x1 * self.sin_pos[:seq_len] + x2 * self.cos_pos[:seq_len]
        ], dim=-1)
        return self.dropout(x)

class DecoderOnlyTransformer(nn.Module):
    def __init__(self, vocab_size, d_model, nhead, num_layers, dropout=0.1):
        super(DecoderOnlyTransformer, self).__init__()
        self.d_model = d_model
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.pos_encoder = RotaryPositionalEncoding(d_model)
        encoder_layer = nn.TransformerEncoderLayer(d_model, nhead, dim_feedforward=4*d_model, dropout=dropout, activation='gelu', batch_first=True, norm_first=True)
        self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers)
        self.fc_out = nn.Linear(d_model, vocab_size)
        self.dropout = nn.Dropout(p=dropout)
        
    def forward(self, src, src_mask=None):
        src = self.embedding(src) * math.sqrt(self.d_model)
        src = self.pos_encoder(src)
        output = self.transformer_encoder(src, mask=src_mask)
        output = self.fc_out(output)
        return self.dropout(output)

In [9]:
vocab_size = sp.get_piece_size()  # Use the tokenizer to get the vocabulary size
d_model = 256
nhead = 4
num_encoder_layers = 4
dropout = 0.25

model = DecoderOnlyTransformer(vocab_size, d_model, nhead, num_layers=num_encoder_layers, dropout=dropout)
model = model.to(device)
criterion = nn.CrossEntropyLoss().to(device)
optimizer = optim.AdamW(model.parameters(), lr=1e-4, weight_decay=0.01)



In [10]:
# Load the model checkpoint
if torch.cuda.is_available():
    checkpoint = torch.load('models/moldecod_transformer.pth')
else:
    checkpoint = torch.load('models/moldecod_transformer.pth', map_location=torch.device('cpu'))

model.load_state_dict(checkpoint['model_state_dict'])
model = model.to(device)
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

## Use the model to generate molecules

In [11]:
def create_mask(size):
    mask = torch.triu(torch.ones(size, size) * float('-inf'), diagonal=1)
    return mask

def generate_molecule(model, start_seq, sp_model, max_length=150, temperature=0.7):
    model.eval()
    with torch.no_grad():
        current_seq = start_seq.to(device).unsqueeze(0)  # Add batch dimension
        for _ in range(max_length):
            src_mask = create_mask(current_seq.size(1)).to(device)
            output = model(current_seq, src_mask)
            logits = output[0, -1, :] / temperature  # Select last time step
            next_token_idx = torch.multinomial(torch.softmax(logits, dim=-1), 1).item()
            
            if next_token_idx == sp_model.piece_to_id('<EOS>'):
                break

            next_token_tensor = torch.tensor([[next_token_idx]], device=device)
            current_seq = torch.cat([current_seq, next_token_tensor], dim=1)
    
    # Decode using the tokenizer
    generated_sequence = sp_model.decode_ids(current_seq[0].cpu().tolist())
    return generated_sequence.replace('<SOS>', '', 1)

In [12]:
# Example: Generate a molecule starting with a carbon atom
start_seq = torch.tensor([sp.piece_to_id('<SOS>'), sp.piece_to_id('C')], device=device)  # Start with <SOS> and a carbon atom
generated_molecule = generate_molecule(model, start_seq, sp)
print("Generated molecule:", generated_molecule)

Generated molecule: C(=O)N1CCCc2ccc(OCC(=O)Nc3ccccc3C)cc21


You can also stream the generation

In [15]:
from IPython.display import clear_output, display

def generate_molecule_streaming(model, start_seq, sp_model, max_length=150, temperature=0.7):
    model.eval()
    with torch.no_grad():
        current_seq = start_seq.to(device).unsqueeze(0)  # Add batch dimension
        generated_tokens = []
        
        for _ in range(max_length):
            src_mask = create_mask(current_seq.size(1)).to(device)
            output = model(current_seq, src_mask)
            logits = output[0, -1, :] / temperature  # Select last time step
            next_token_idx = torch.multinomial(torch.softmax(logits, dim=-1), 1).item()
            
            if next_token_idx == sp_model.piece_to_id('<EOS>'):
                break

            # Add the token to the list
            generated_tokens.append(next_token_idx)

            # Display the current sequence
            current_seq_display = sp_model.decode_ids(generated_tokens)
            clear_output(wait=True)
            display(f"Generated molecule: {current_seq_display}")

            next_token_tensor = torch.tensor([[next_token_idx]], device=device)
            current_seq = torch.cat([current_seq, next_token_tensor], dim=1)
    
    # Decode using the tokenizer
    generated_sequence = sp_model.decode_ids(current_seq[0].cpu().tolist())
    return generated_sequence.replace('<SOS>', '', 1)

In [16]:
# Example: Generate a molecule starting with a carbon atom
start_seq = torch.tensor([sp.piece_to_id('<SOS>'), sp.piece_to_id('C')], device=device)  # Start with <SOS> and a carbon atom
generated_molecule = generate_molecule_streaming(model, start_seq, sp)

'Generated molecule: (C)C(CC)NC(=O)Nc1ccc2c(c1)COC2'