In [None]:
# === 1. Imports ===
import torch
import torch.nn as nn
import torch.nn.functional as F
from datasets import load_dataset
import selfies as sf
from sklearn.preprocessing import MinMaxScaler
from rdkit import Chem
from rdkit.Chem import Descriptors
from transformers import AutoTokenizer, AutoModel, AutoModelForSeq2SeqLM
import random

# === 2. Dataset Preparation ===
data = load_dataset("maykcaldas/smiles-transformers", split="train")
smiles_list = data['smiles']
logp_list = data['logP']

# Convert to SELFIES & tokenize
selfies_list = [sf.encoder(s) for s in smiles_list]

# Normalize property (LogP for now)
scaler = MinMaxScaler()
logp_scaled = scaler.fit_transform([[p] for p in logp_list])

# === 3. SELF-BART Molecule Encoder ===
tokenizer = AutoTokenizer.from_pretrained("pschwllr/selfies-bart")
self_bart = AutoModelForSeq2SeqLM.from_pretrained("pschwllr/selfies-bart")
self_bart.eval()

class SelfBARTEncoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = self_bart
        self.tokenizer = tokenizer

    def forward(self, selfies_batch):
        inputs = self.tokenizer(selfies_batch, return_tensors="pt", padding=True, truncation=True)
        with torch.no_grad():
            encoder_outputs = self.model.model.encoder(**inputs)
        return encoder_outputs.last_hidden_state[:, 0, :]  # CLS-like representation

# === 4. Property Encoder ===
class PropertyMLP(nn.Module):
    def __init__(self, input_dim=1, latent_dim=768):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(input_dim, 128),
            nn.ReLU(),
            nn.Linear(128, latent_dim)
        )

    def forward(self, x):
        return self.model(x)

# === 5. Contrastive Loss ===
def contrastive_loss(z_mol, z_prop, temp=0.07):
    sim = F.cosine_similarity(z_mol.unsqueeze(1), z_prop.unsqueeze(0), dim=2)
    logits = sim / temp
    labels = torch.arange(len(z_mol)).to(z_mol.device)
    loss_i = F.cross_entropy(logits, labels)
    loss_j = F.cross_entropy(logits.T, labels)
    return (loss_i + loss_j) / 2

# === 6. Training Loop ===
mol_enc = SelfBARTEncoder()
prop_enc = PropertyMLP()
opt = torch.optim.Adam(prop_enc.parameters(), lr=1e-3)

for epoch in range(5):
    for i in range(0, len(selfies_list), 32):
        selfies_batch = selfies_list[i:i+32]
        props = torch.tensor(logp_scaled[i:i+32], dtype=torch.float)
        zmol = mol_enc(selfies_batch).detach()
        zprop = prop_enc(props)

        loss = contrastive_loss(zmol, zprop)
        opt.zero_grad()
        loss.backward()
        opt.step()

        print(f"Epoch {epoch}, Step {i}, Loss {loss.item():.4f}")

# === 7. Inference ===
def decode_from_latent(z):
    # Project latent to decoder-compatible embedding (optional)
    # Use random decoder tokens and force decode from z (needs custom logic if not exposed)
    prompt = "[C][C][O]"  # temporary prompt
    inputs = tokenizer(prompt, return_tensors="pt")
    with torch.no_grad():
        generated_ids = self_bart.generate(**inputs, max_length=40)
    decoded = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
    return decoded

test_prop = torch.tensor([[0.6]])
z_test = prop_enc(test_prop)
generated_selfies = decode_from_latent(z_test)
print("Generated SELFIES:", generated_selfies)
print("Decoded SMILES:", sf.decoder(generated_selfies))

# === 8. Evaluation ===
def is_valid(smiles):
    return Chem.MolFromSmiles(smiles) is not None

def compute_logp(smiles):
    mol = Chem.MolFromSmiles(smiles)
    return Descriptors.MolLogP(mol) if mol else None
