In [1]:
import torch
from model import TransformerVAE
from dataset import SMILESDataset
from utils import tokenize, pad_sequence, create_vocab, calculate_max_len
from cfg import Config

In [2]:
def infer(model, input_sequence, vocab, max_len, device):
    # Create index to char mapping
    index_to_char = {idx: char for char, idx in vocab.items()}
    
    # Tokenize and pad the input sequence
    tokens = tokenize(input_sequence, vocab)
    input_ids = pad_sequence(tokens, max_len, vocab["<pad>"])
    input_ids = torch.tensor(input_ids).unsqueeze(0).to(device)  # Add batch dimension
    
    # Create attention mask
    attention_mask = [1 if token != vocab["<pad>"] else 0 for token in input_ids[0]]
    attention_mask = torch.tensor(attention_mask).unsqueeze(0).to(device)  # Add batch dimension
    
    # Perform inference
    model.eval()
    with torch.no_grad():
        logits, _, _ = model(input_ids, attention_mask)
    
    # Get the predicted sequence
    predicted_ids = logits.argmax(dim=-1).squeeze(0).tolist()
    
    # Convert predicted IDs back to SMILES string
    predicted_sequence = ''.join([index_to_char.get(idx, '<unk>') for idx in predicted_ids if idx != vocab["<pad>"]])
    
    return predicted_sequence

In [None]:
def main():
    # Load configuration
    config = Config()
    
    # Load dataset to get vocab and max_len
    dataset = SMILESDataset(config.filepath)
    vocab = dataset.vocab
    max_len = dataset.max_len
    
    # Load the trained model
    model = TransformerVAE(
        vocab_size=len(vocab),
        embedding_dim=config.input_dim,
        hidden_dim=config.hidden_dim,
        latent_dim=config.latent_dim,
        num_heads=config.num_heads,
        num_layers=config.num_layers,
        dropout=config.dropout
    ).to(config.device)
    
    # Load the best model weights
    checkpoint_dir = "./checkpoints/20241126-0919"  # Replace with your checkpoint directory
    model.load_state_dict(torch.load(f"{checkpoint_dir}/best_model.pth", weights_only=True))
    
    # Example input sequence
    input_sequence = "C.CCCO ~ O=O > CC(=O)C(C)=O ~ [OH-]"  # Replace with your input SMILES string
    
    # Perform inference
    predicted_sequence = infer(model, input_sequence, vocab, max_len, config.device)
    
    print(f"Input Sequence: {input_sequence}")
    print(f"Predicted Sequence: {predicted_sequence}")

if __name__ == "__main__":
    main()

Input Sequence: C.CCCO ~ O=O > CC(=O)C(C)=O ~ [OH-]
Predicted Sequence: C.CCCO ~ O=O > CC(=O)C(C)=O ~ [OH-]
