In [None]:
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
import torch.nn.functional as F
import torch.nn as nn
import pandas as pd
import random
import torch

In [2]:
class cVAE(nn.Module):
    def __init__(self, encoder, decoder, device,
                 latent_dim, gru_dim, vocab_size, embedding_dim,
                 teacher_forcing_ratio=0.5):
        super(cVAE, self).__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.device = device
        self.latent_dim = latent_dim
        self.gru_dim = gru_dim
        self.vocab_size = vocab_size
        self.embedding_dim = embedding_dim
        self.teacher_forcing_ratio = teacher_forcing_ratio

        # Token embedding (for decoder)
        self.embedding = nn.Embedding(vocab_size, embedding_dim)

        # Latent space projections
        self.encoder_mu = nn.Linear(gru_dim, latent_dim)
        self.encoder_logvar = nn.Linear(gru_dim, latent_dim)

    def _sample_latent(self, hidden_encoder):
        """
        Reparameterization trick: z ~ N(mu, sigma^2)
        """
        mu = self.encoder_mu(hidden_encoder)
        logvar = self.encoder_logvar(hidden_encoder)
        sigma = torch.exp(0.5 * logvar)
        eps = torch.randn_like(sigma).to(self.device)
        z = mu + sigma * eps

        # Save for loss calculation
        self.z_mean = mu
        self.z_logvar = logvar

        return z

    def forward_decoder(self, z, x, y):
        """
        Autoregressive decoding
        z: [batch, latent_dim]
        x: [batch, seq_len] token indices
        y: [batch, 1] property vector
        """
        batch_size, target_len = x.size()
        device = x.device

        outputs = torch.zeros(batch_size, target_len, self.vocab_size).to(device)

        # Initialize first token as <STR> (index 2)
        input_token = torch.ones(batch_size, dtype=torch.long).to(device) * 2
        outputs[:,0,2] = 1

        # Initialize hidden state (zeros)
        hidden = torch.zeros(self.decoder.n_layers, batch_size, self.decoder.gru_size).to(device)

        for t in range(1, target_len):
            output, hidden = self.decoder(input_token, z, y, hidden)
            outputs[:, t, :] = output

            # Get predicted token
            top1 = output.argmax(1)

            # Teacher forcing
            if random.random() < self.teacher_forcing_ratio:
                input_token = x[:, t]  # ground truth
            else:
                input_token = top1.detach()

        return outputs

    def forward(self, x, y):
        """
        Full forward pass
        x: [batch, seq_len] token indices
        y: [batch, 1] property vector
        """
        # Encode graph to hidden representation
        hidden_encoder = self.encoder(x)  # should output [batch, gru_dim]

        # Sample latent vector
        z = self._sample_latent(hidden_encoder)

        # Decode sequence
        recon_x = self.forward_decoder(z, x, y)

        return recon_x

In [3]:
class GRU_Decoder(nn.Module):
    def __init__(self, vocab_size, latent_dim, gru_size, n_layers, embedding_dim):
        super(GRU_Decoder, self).__init__()
        self.vocab_size = vocab_size
        self.latent_dim = latent_dim
        self.gru_size = gru_size
        self.n_layers = n_layers
        self.embedding_dim = embedding_dim

        # Embedding layer for input tokens
        self.embedding = nn.Embedding(vocab_size, embedding_dim)

        # GRU input size: token embedding + latent vector + property
        self.gru = nn.GRU(embedding_dim + latent_dim + 1, gru_size, n_layers, batch_first=True)

        # Output layer: project GRU hidden state to vocab size
        self.fc = nn.Linear(gru_size, vocab_size)

    def forward(self, input_token, z, y, hidden):
        """
        input_token: [batch] token indices
        z: [batch, latent_dim]
        y: [batch, 1] property vector
        hidden: [n_layers, batch, gru_size]
        """
        # Embed token
        token_embed = self.embedding(input_token)  # [batch, embed_dim]

        # Concatenate token embedding + latent + property
        decoder_input = torch.cat([token_embed, z, y], dim=1).unsqueeze(1)  # [batch, 1, embed+latent+1]

        # GRU forward
        output, hidden = self.gru(decoder_input, hidden)  # output: [batch, 1, gru_size]

        # Project to vocab
        output = self.fc(output.squeeze(1))  # [batch, vocab_size]

        return output, hidden

In [None]:
# Load QM9 SMILES
df_qm9 = pd.read_pickle('data/RDKit/rdkit_only_valid_smiles_qm9.pkl')
smiles_list = df_qm9["SMILES"].to_list()

In [None]:

# Collect all unique characters
charset = set()
for smi in smiles_list:
    for ch in smi:
        charset.add(ch)

# Sort for consistency
charset = sorted(list(charset))

# Add special tokens
special_tokens = ['<PAD>', '<END>', '<STR>']
vocab_list = special_tokens + charset

# Create token -> index mapping
token2idx = {tok: idx for idx, tok in enumerate(vocab_list)}
idx2token = {idx: tok for tok, idx in token2idx.items()}

print("Vocabulary size:", len(vocab_list))
print("Example tokens:", vocab_list)

['#', '(', ')', '+', '-', '/', '1', '2', '3', '4', '5', '=', '@', 'C', 'F', 'H', 'N', 'O', '[', '\\', ']']
Vocabulary size: 24
Example tokens: ['<PAD>', '<END>', '<STR>', '#', '(', ')', '+', '-', '/', '1', '2', '3', '4', '5', '=', '@', 'C', 'F', 'H', 'N', 'O', '[', '\\', ']']


In [23]:
# Settings
batch_size = 4
seq_len = 15                     # sequence length for SMILES
vocab_size = len(vocab_list)     # small example vocab
latent_dim = 16
gru_dim = 32
embedding_dim = 8
n_layers = 1

device = 'cpu'      # or 'cuda' if GPU is available

# Dummy input tokens (batch of sequences)
x = torch.randint(0, vocab_size, (batch_size, seq_len)).to(device)

# Dummy target property (HOMO-LUMO gap, for example)
y = torch.rand(batch_size, 1).to(device)

In [24]:
print(x)
print(y)

tensor([[10,  1,  2, 18, 19,  0, 18,  6,  5,  2,  6,  1, 13, 23, 22],
        [17, 17, 17, 13,  0, 13, 17,  4,  0, 17,  5,  3, 17, 22, 14],
        [ 3,  9,  1, 17,  0, 21, 13, 16, 20, 21, 16,  2, 11,  0,  8],
        [ 9,  5, 22, 14,  8,  9, 10,  8, 14,  1, 23, 21, 10, 19, 20]])
tensor([[0.0325],
        [0.5752],
        [0.3310],
        [0.5041]])


In [25]:
class DummyEncoder(nn.Module):
    def __init__(self, gru_dim):
        super().__init__()
        self.gru_dim = gru_dim

    def forward(self, x):
        batch_size = x.size(0)
        # just return random vector as hidden representation
        return torch.rand(batch_size, self.gru_dim)

In [26]:
decoder = GRU_Decoder(
    vocab_size=vocab_size,
    latent_dim=latent_dim,
    gru_size=gru_dim,
    n_layers=n_layers,
    embedding_dim=embedding_dim
).to(device)

encoder = DummyEncoder(gru_dim=gru_dim).to(device)

model = cVAE(
    encoder=encoder,
    decoder=decoder,
    device=device,
    latent_dim=latent_dim,
    gru_dim=gru_dim,
    vocab_size=vocab_size,
    embedding_dim=embedding_dim,
    teacher_forcing_ratio=0.5
).to(device)

In [27]:
with torch.no_grad():
    outputs = model(x, y)  # [batch, seq_len, vocab_size]
    
print("Output shape:", outputs.shape)

Output shape: torch.Size([4, 15, 24])


In [28]:
# pick the most probable token at each step
pred_tokens = outputs.argmax(-1)
print("Predicted token indices:\n", pred_tokens)

Predicted token indices:
 tensor([[ 2,  3,  3,  3,  3,  3, 13,  3,  3, 13,  3,  3,  3,  3, 13],
        [ 2,  9,  9,  9,  9, 11, 11,  9,  9, 11,  9,  9,  9,  9, 11],
        [ 2,  4,  4,  4,  4,  4,  4,  4,  4,  4,  4,  4,  4,  4,  4],
        [ 2, 12,  3,  6,  6,  3, 18,  6,  3, 18,  3,  3,  3, 18,  3]])


In [51]:
"".join([idx2token[i] for i in pred_tokens[3].numpy()][1:])

'4#++#H+#H###H#'

In [52]:
"".join([idx2token[i] for i in x[3].numpy()][1:])

')\\=/12/=<END>][2NO'