<a href="https://colab.research.google.com/github/MinaAzizii/Master-Thesis/blob/Generative-Model/VAE.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:


 !pip install torch==2.1.0 torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121
 !pip install torch-scatter -f https://data.pyg.org/whl/torch-2.1.0+cu121.html
 !pip install torch-sparse -f https://data.pyg.org/whl/torch-2.1.0+cu121.html
 !pip install torch-geometric


In [None]:
 !pip install pandas openpyxl seaborn matplotlib tqdm scikit-learn
 !pip install numpy==1.24.4

In [None]:
!pip install rdkit-pypi

In [None]:
# PYTHON SCRIPT INITIALIZATION:
# Imports required libraries for deep learning (PyTorch), data handling (pandas, NumPy), data preprocessing (scikit-learn), chemical structure parsing (RDKit), serialization (pickle), and reproducibility (random).

import torch
import torch.nn as nn
import torch.nn.functional as F
import pandas as pd
import numpy as np
import random
from torch.utils.data import Dataset, DataLoader
from sklearn.preprocessing import OneHotEncoder, StandardScaler
from rdkit import Chem
import pickle
from torch_geometric.data import Data
from torch_geometric.nn import NNConv, global_mean_pool
from rdkit.Chem import Descriptors


In [None]:
# Set seeds for reproducibility; select GPU device for computations
SEED = 42
torch.manual_seed(SEED)
np.random.seed(SEED)
random.seed(SEED)
device = torch.device("cuda")

In [None]:
# ==== LOAD DATA ====
df = pd.read_excel("DB2.xlsx")
df = df.dropna(subset=["smiles", "solvent", "abs"])
# Only keep rows with valid SMILES and solvent
is_valid_mol = df['smiles'].apply(lambda s: Chem.MolFromSmiles(s) is not None)
is_valid_sol = df['solvent'].apply(lambda s: Chem.MolFromSmiles(s) is not None)
df = df[is_valid_mol & is_valid_sol].reset_index(drop=True)


In [None]:
# BUILD SMILES VOCABULARY:
# - Constructs character-level vocabulary from SMILES strings.
# - Adds special tokens: padding, start/end of sequence, unknown character.
# - Provides mappings from characters to indices (stoi) and vice versa (itos).
# - Defines functions for encoding SMILES to padded tensors and decoding tensors back to SMILES.
# - Computes the maximum SMILES length (+2 for special tokens) for consistent tensor padding.
def build_vocab(smiles_list):
    tokens = set()
    for s in smiles_list:
        tokens.update(list(s))
    tokens = sorted(tokens)
    tokens = ['<PAD>', '<SOS>', '<EOS>', '<UNK>'] + tokens
    stoi = {s:i for i,s in enumerate(tokens)}
    itos = {i:s for i,s in enumerate(tokens)}
    return tokens, stoi, itos

def smiles_to_tensor(smiles, stoi, max_len):
    arr = [stoi['<SOS>']] + [stoi.get(c, stoi['<UNK>']) for c in smiles] + [stoi['<EOS>']]
    arr += [stoi['<PAD>']] * (max_len - len(arr))
    return arr[:max_len]

def tensor_to_smiles(tensor, itos):
    chars = []
    for idx in tensor:
        c = itos[int(idx)]
        if c == '<EOS>': break
        if c not in ['<PAD>', '<SOS>']:
            chars.append(c)
    return ''.join(chars)

smiles_list = df['smiles'].tolist()
tokens, stoi, itos = build_vocab(smiles_list)
lengths = [len(s) for s in smiles_list]
max_len = max(lengths) + 2  # +2 for <SOS> and <EOS>
print(f"Max SMILES length in data: {max(lengths)}. Using max_len={max_len} for padding.")

In [None]:
# ==== CONDITION: Lambda Absorption + Solvent One-Hot ====
solv_enc = OneHotEncoder(sparse_output=False)
solv_enc.fit(df[["solvent"]])
solv_oh = solv_enc.transform(df[["solvent"]])
scaler_abs = StandardScaler().fit(df[["abs"]])  # For target conditioning

cond_array = np.hstack([
    scaler_abs.transform(df[["abs"]]),  # 1D: scaled lambda absorption
    solv_oh                             # one-hot: solvent
])


In [None]:
# MOLECULE DATASET AND DATALOADER:
# - Defines custom PyTorch dataset class to encode SMILES strings as padded tensor sequences.
# - Each dataset item consists of encoded SMILES tensor and associated conditional features.
# - Creates DataLoader for efficient batch training.
class MolDataset(Dataset):
    def __init__(self, smiles_list, cond_list, stoi, max_len):
        self.data = smiles_list
        self.cond = cond_list
        self.stoi = stoi
        self.max_len = max_len
    def __len__(self):
        return len(self.data)
    def __getitem__(self, idx):
        arr = smiles_to_tensor(self.data[idx], self.stoi, self.max_len)
        return torch.tensor(arr, dtype=torch.long), torch.tensor(self.cond[idx], dtype=torch.float)

dataset = MolDataset(smiles_list, cond_array, stoi, max_len)
loader = DataLoader(dataset, batch_size=128, shuffle=True)

In [None]:
# CONDITIONAL SMILES VAE MODEL:
# - Defines SMILES-based Variational Autoencoder (VAE) using GRU encoder-decoder.
# - Encodes SMILES sequences to latent space and decodes back to sequences.
# - Conditions latent vectors on external features (absorption & solvent).
# - Uses reparameterization trick for differentiable sampling.
# - Trains using CrossEntropy loss (reconstruction) plus KLD (latent regularization).
# - Implements teacher forcing (80%) for efficient training.
# - Trains for 40 epochs, saving trained model, vocabulary, and scalers.
class SMILESVAE(nn.Module):
    def __init__(self, vocab_size, emb_dim=128, hidden_dim=256, latent_dim=128, max_len=120, cond_dim=0):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, emb_dim)
        self.encoder = nn.GRU(emb_dim, hidden_dim, batch_first=True, bidirectional=True)
        self.fc_mu = nn.Linear(hidden_dim*2, latent_dim)
        self.fc_logvar = nn.Linear(hidden_dim*2, latent_dim)
        self.fc_z = nn.Linear(latent_dim+cond_dim, hidden_dim*2)
        self.decoder = nn.GRU(emb_dim, hidden_dim*2, batch_first=True)
        self.output = nn.Linear(hidden_dim*2, vocab_size)
        self.max_len = max_len

    def encode(self, x):
        x = self.embedding(x)
        _, h = self.encoder(x)
        h = h.transpose(0,1).reshape(x.size(0), -1)
        mu = self.fc_mu(h)
        logvar = self.fc_logvar(h)
        return mu, logvar

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        return mu + std * torch.randn_like(std)

    def decode(self, z, cond=None, x=None, teacher_forcing=0.5):
        if cond is not None:
            z = torch.cat([z, cond], dim=1)
        batch = z.size(0)
        hidden = self.fc_z(z).unsqueeze(0)
        input_token = torch.full((batch, 1), 1, device=z.device, dtype=torch.long)  # <SOS>
        outputs = []
        for t in range(self.max_len):
            emb = self.embedding(input_token)
            out, hidden = self.decoder(emb, hidden)
            logits = self.output(out.squeeze(1))
            outputs.append(logits)
            if x is not None and torch.rand(1).item() < teacher_forcing:
                input_token = x[:, t].unsqueeze(1)
            else:
                input_token = logits.argmax(dim=-1, keepdim=True)
        return torch.stack(outputs, dim=1)

    def forward(self, x, cond=None, teacher_forcing=0.5):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        out = self.decode(z, cond, x, teacher_forcing)
        return out, mu, logvar

vae_cond_dim = cond_array.shape[1]
vae = SMILESVAE(len(tokens), max_len=max_len, cond_dim=vae_cond_dim)
vae.to(device)
opt = torch.optim.Adam(vae.parameters(), lr=1e-3)
loss_fn = torch.nn.CrossEntropyLoss(ignore_index=stoi['<PAD>'])

print("Training Conditional SMILES VAE (lambda absorption + solvent)...")
for epoch in range(40):
    vae.train()
    total_loss = 0
    for batch, cond in loader:
        batch = batch.to(device)
        cond = cond.to(device)
        opt.zero_grad()
        out, mu, logvar = vae(batch, cond, teacher_forcing=0.8)
        loss_recon = loss_fn(out.view(-1, out.size(-1)), batch.view(-1))
        kld = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) / batch.size(0)
        loss = loss_recon + 0.1 * kld
        loss.backward()
        opt.step()
        total_loss += loss.item()
    print(f"Epoch {epoch+1}: Loss={total_loss/len(loader):.3f}")

torch.save(vae.state_dict(), "smiles_vae_lambdaabs.pt")
with open("vocab_lambdaabs.pkl", "wb") as f:
    pickle.dump((tokens, stoi, itos, max_len), f)
pickle.dump({"solv_enc": solv_enc, "scaler_abs": scaler_abs}, open("cond_scalers_lambdaabs.pkl","wb"))
print("Saved VAE and vocab/scalers.")

In [None]:
# SMILES GENERATION USING TRAINED VAE:
# - Loads trained Conditional SMILES VAE, vocab, and scalers.
# - Defines target absorption wavelength ranges and solvents.
# - Generates specified number of molecules per solvent and wavelength.
# - Saves generated molecules (SMILES) with associated solvent and target λ-absorption values to CSV.


with open("vocab_lambdaabs.pkl", "rb") as f:
    tokens, stoi, itos, max_len = pickle.load(f)
vae = SMILESVAE(len(tokens), max_len=max_len, cond_dim=vae_cond_dim)
vae.load_state_dict(torch.load("smiles_vae_lambdaabs.pt", map_location="cpu"))
vae = vae.to("cuda")
vae.eval()

scalers = pickle.load(open("cond_scalers_lambdaabs.pkl","rb"))
solv_enc = scalers["solv_enc"]
scaler_abs = scalers["scaler_abs"]

# --- User-defined generation targets ---
solvents_list = df['solvent'].unique()  # Make sure 'df' contains your solvents

ranges = [
    (400, 410, "400_410"),
    (490, 500, "490_500"),
    (550, 560, "550_560"),
    (640, 650, "640_650"),
]

num_per_condition = 100  # molecules per solvent per absorption
batch_size = 20

all_generated = []

for (low, high, range_label) in ranges:
    lambda_abs_range = np.linspace(low, high, 20)
    print(f"\nGenerating for range: {low}-{high} nm")
    for i, solvent in enumerate(solvents_list):
        print(f"  Solvent {i+1}/{len(solvents_list)}: {solvent}")
        solv_oh = solv_enc.transform(pd.DataFrame([[solvent]], columns=["solvent"]))[0]
        for j, target_abs in enumerate(lambda_abs_range):
            if j % 10 == 0:
                print(f"    Target abs {j+1}/{len(lambda_abs_range)}: {target_abs:.2f}")
            target_cond_np = np.hstack([
                scaler_abs.transform(pd.DataFrame([[target_abs]], columns=["abs"]))[0],
                solv_oh
            ])
            target_cond = torch.tensor(target_cond_np, dtype=torch.float, device=vae.fc_mu.weight.device).unsqueeze(0)
            batches = num_per_condition // batch_size
            for _ in range(batches):
                z = torch.randn(batch_size, vae.fc_mu.out_features, device=vae.fc_mu.weight.device)
                cond_batch = target_cond.repeat(batch_size, 1)
                with torch.no_grad():
                    out = vae.decode(z, cond=cond_batch)
                    preds = out.argmax(dim=-1).cpu().numpy()
                    for pred in preds:
                        smi = tensor_to_smiles(pred, itos)
                        all_generated.append((solvent, float(target_abs), smi, range_label))
            # Handle remainder if num_per_condition is not divisible by batch_size
            remainder = num_per_condition % batch_size
            if remainder:
                z = torch.randn(remainder, vae.fc_mu.out_features, device=vae.fc_mu.weight.device)
                cond_batch = target_cond.repeat(remainder, 1)
                with torch.no_grad():
                    out = vae.decode(z, cond=cond_batch)
                    preds = out.argmax(dim=-1).cpu().numpy()
                    for pred in preds:
                        smi = tensor_to_smiles(pred, itos)
                        all_generated.append((solvent, float(target_abs), smi, range_label))

print(f"\nGeneration complete! Total molecules: {len(all_generated)}")

# Save as a DataFrame and CSV
gen_df = pd.DataFrame(all_generated, columns=['solvent', 'target_abs', 'smiles', 'abs_range_label'])
gen_df.to_csv("generated_molecules_multiple_ranges.csv", index=False)
print("Saved all generated molecules to generated_molecules_multiple_ranges.csv")

In [None]:
# VALIDATE AND FILTER GENERATED MOLECULES:
def is_valid(smi):
    return Chem.MolFromSmiles(smi) is not None

def canonicalize(smi):
    mol = Chem.MolFromSmiles(smi)
    return Chem.MolToSmiles(mol) if mol else None

# --- Load DB2 and prepare canonical set ---
df_db2 = pd.read_excel("DB2.xlsx")
db2_smiles = df_db2["smiles"].dropna()
db2_cano_set = set(filter(None, (canonicalize(s) for s in db2_smiles)))

# --- Load generated molecules ---
df_gen = pd.read_csv("generated_molecules_multiple_ranges.csv")

# ✅ Step 1: Keep only valid SMILES
df_gen["is_valid"] = df_gen["smiles"].apply(is_valid)
df_gen_valid = df_gen[df_gen["is_valid"]].copy()

# ✅ Step 2: Canonicalize valid SMILES
df_gen_valid["canonical"] = df_gen_valid["smiles"].apply(canonicalize)

# ✅ Step 3: Remove duplicates within generated molecules
df_gen_unique = df_gen_valid.drop_duplicates(subset=["canonical"]).reset_index(drop=True)

# ✅ Step 4: Remove SMILES that already exist in training set
df_gen_final = df_gen_unique[~df_gen_unique["canonical"].isin(db2_cano_set)].copy()

# --- Clean and save ---
df_gen_final = df_gen_final.drop(columns=["canonical", "is_valid"])
df_gen_final.to_csv("generated_molecules_multiple_ranges_valid_novel.csv", index=False)

print(f"Saved {len(df_gen_final)} valid, novel, unique molecules to generated_molecules_multiple_ranges_valid_novel.csv")


In [None]:
 #PREDICT λ-ABSORPTION FOR GENERATED MOLECULES:
mol_funcs = [
    ("Mol_MolWt", Descriptors.MolWt),
    ("Mol_TPSA", Descriptors.TPSA),
    ("Mol_NumRotatableBonds", Descriptors.NumRotatableBonds),
    ("Mol_LogP", Descriptors.MolLogP),
    ("Mol_Aromaticity", Descriptors.NumAromaticRings),
    ("Mol_NumHDonors", Descriptors.NumHDonors),
    ("Mol_NumHAcceptors", Descriptors.NumHAcceptors),
    ("Mol_FractionCSP3", Descriptors.FractionCSP3),
    ("Mol_HeteroatomCount", Descriptors.HeavyAtomCount),
]
solvent_funcs = [
    ("Solv_MolWt", Descriptors.MolWt),
    ("Solv_TPSA", Descriptors.TPSA),
    ("Solv_MolLogP", Descriptors.MolLogP),
    ("Solv_NumHDonors", Descriptors.NumHDonors),
]

def compute_descriptors(smiles, func_list):
    m = Chem.MolFromSmiles(smiles)
    vals = []
    for _, fn in func_list:
        try:
            v = fn(m)
            vals.append(v if np.isfinite(v) else 0.0)
        except:
            vals.append(0.0)
    return np.array(vals, dtype=float)

def mol_to_graph(smiles):
    mol = Chem.MolFromSmiles(smiles)
    atoms = mol.GetAtoms()
    heavy = [i for i, a in enumerate(atoms) if a.GetAtomicNum() > 1] or list(range(len(atoms)))
    idx_map = {old: i for i, old in enumerate(heavy)}
    x = torch.tensor([[atoms[i].GetAtomicNum(), atoms[i].GetFormalCharge(), atoms[i].GetNumExplicitHs()]
                      for i in heavy], dtype=torch.float)
    edges, attrs = [], []
    for b in mol.GetBonds():
        i, j = b.GetBeginAtomIdx(), b.GetEndAtomIdx()
        if i in heavy and j in heavy:
            ei, ej = idx_map[i], idx_map[j]
            onehot = [int(b.GetBondType() == t) for t in
                      (Chem.rdchem.BondType.SINGLE, Chem.rdchem.BondType.DOUBLE,
                       Chem.rdchem.BondType.TRIPLE, Chem.rdchem.BondType.AROMATIC)]
            edges += [[ei, ej], [ej, ei]]
            attrs += [onehot, onehot]
    if not edges:
        edge_index = torch.zeros((2, 0), dtype=torch.long)
        edge_attr  = torch.zeros((0, 4), dtype=torch.float)
    else:
        edge_index = torch.tensor(edges, dtype=torch.long).t().contiguous()
        edge_attr  = torch.tensor(attrs, dtype=torch.float)
    return Data(x=x, edge_index=edge_index, edge_attr=edge_attr)

# --- Model definition (must match training!) ---
class GCNEncoder(nn.Module):
    def __init__(self, in_dim, hid, out_dim):
        super().__init__()
        self.e1 = nn.Sequential(nn.Linear(4, hid * in_dim), nn.ReLU(), nn.Linear(hid * in_dim, hid * in_dim))
        self.e2 = nn.Sequential(nn.Linear(4, hid * hid), nn.ReLU(), nn.Linear(hid * hid, hid * hid))
        self.e3 = nn.Sequential(nn.Linear(4, out_dim * hid), nn.ReLU(), nn.Linear(out_dim * hid, out_dim * hid))
        self.c1 = NNConv(in_dim, hid, self.e1, aggr='mean')
        self.c2 = NNConv(hid, hid, self.e2, aggr='mean')
        self.c3 = NNConv(hid, out_dim, self.e3, aggr='mean')
    def forward(self, x, ei, batch, ea):
        x = self.c1(x, ei, ea).relu()
        x = self.c2(x, ei, ea).relu()
        x = self.c3(x, ei, ea).relu()
        return global_mean_pool(x, batch)

class SolvationPredictor(nn.Module):
    def __init__(self):
        super().__init__()
        self.gcn_mol = GCNEncoder(3, 64, 64)
        self.gcn_sol = GCNEncoder(3, 32, 64)
        self.mlp_sol = nn.Sequential(nn.Linear(4, 64), nn.ReLU(), nn.Linear(64, 64))
        self.mlp_mol = nn.Sequential(nn.Linear(9, 64), nn.ReLU(), nn.Linear(64, 64))
        self.mlp_gcn = nn.Sequential(
            nn.Linear(64 + 64, 128),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(128, 64),
            nn.ReLU()
        )
        self.fuse = nn.Sequential(
            nn.Linear(64 + 64 + 64, 128),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(128, 128),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(128, 1)
        )
    def forward(self, mg, sg, sdesc, mdesc):
        me = self.gcn_mol(mg.x, mg.edge_index, mg.batch, mg.edge_attr)
        se = self.gcn_sol(sg.x, sg.edge_index, sg.batch, sg.edge_attr)
        sf = self.mlp_sol(sdesc)
        mf = self.mlp_mol(mdesc)
        gcn_cat = torch.cat([me, se], dim=-1)
        gcn_out = self.mlp_gcn(gcn_cat)
        cat = torch.cat([gcn_out, sf, mf], dim=-1)
        return self.fuse(cat).squeeze(-1)

# --- Load model and scalers ---
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = SolvationPredictor().to(device)
model.load_state_dict(torch.load('model_abs (1).pt', map_location=device))
model.eval()

with open("solv_scaler.pkl", "rb") as f:
    solv_scaler = pickle.load(f)
with open("mol_scaler.pkl", "rb") as f:
    mol_scaler = pickle.load(f)
with open("target_scaler.pkl", "rb") as f:
    target_scaler = pickle.load(f)

# --- Load input dataset ---
generated_df = pd.read_csv("generated_molecules_multiple_ranges_valid_novel.csv")  # All columns preserved!

# --- Predict lambda absorption for each row ---
preds = []
for idx, row in generated_df.iterrows():
    smi = row['smiles']
    solvent = row['solvent']
    try:
        # Prepare molecule graph and descriptors
        mol_graph = mol_to_graph(smi)
        mol_desc = compute_descriptors(smi, mol_funcs)
        mol_desc_scaled = torch.tensor(mol_scaler.transform([mol_desc])[0], dtype=torch.float).unsqueeze(0).to(device)
        # Prepare solvent graph and descriptors
        solvent_graph = mol_to_graph(solvent)
        solvent_desc = compute_descriptors(solvent, solvent_funcs)
        solvent_desc_scaled = torch.tensor(solv_scaler.transform([solvent_desc])[0], dtype=torch.float).unsqueeze(0).to(device)
        # Add batch info (single molecule = batch of 0s)
        mol_graph = mol_graph.to(device)
        solvent_graph = solvent_graph.to(device)
        mol_graph.batch = torch.zeros(mol_graph.num_nodes, dtype=torch.long, device=device)
        solvent_graph.batch = torch.zeros(solvent_graph.num_nodes, dtype=torch.long, device=device)
        # Predict (scaled and real)
        with torch.no_grad():
            pred_scaled = model(mol_graph, solvent_graph, solvent_desc_scaled, mol_desc_scaled).cpu().numpy().item()
            pred_real = target_scaler.inverse_transform([[pred_scaled]])[0, 0]
            preds.append(pred_real)
    except Exception as e:
        print(f"Failed for SMILES: {smi}, Solvent: {solvent}, error: {e}")
        preds.append(np.nan)

# --- Add predictions as new column and save ---
generated_df["pred_lambda_abs"] = preds
generated_df.to_csv("generated_molecules_with_pred_lambda_abs.csv", index=False)
print("Saved predictions to generated_molecules_with_pred_lambda_abs.csv")
print(generated_df.head())

In [None]:


#Predict other optical properties
mol_funcs = [
    ("Mol_MolWt", Descriptors.MolWt),
    ("Mol_TPSA", Descriptors.TPSA),
    ("Mol_NumRotatableBonds", Descriptors.NumRotatableBonds),
    ("Mol_LogP", Descriptors.MolLogP),
    ("Mol_Aromaticity", Descriptors.NumAromaticRings),
    ("Mol_NumHDonors", Descriptors.NumHDonors),
    ("Mol_NumHAcceptors", Descriptors.NumHAcceptors),
    ("Mol_FractionCSP3", Descriptors.FractionCSP3),
    ("Mol_HeteroatomCount", Descriptors.HeavyAtomCount),
]
solvent_funcs = [
    ("Solv_MolWt", Descriptors.MolWt),
    ("Solv_TPSA", Descriptors.TPSA),
    ("Solv_MolLogP", Descriptors.MolLogP),
    ("Solv_NumHDonors", Descriptors.NumHDonors),
]

def compute_descriptors(smiles, func_list):
    m = Chem.MolFromSmiles(smiles)
    vals = []
    for _, fn in func_list:
        try:
            v = fn(m)
            vals.append(v if np.isfinite(v) else 0.0)
        except:
            vals.append(0.0)
    return np.array(vals, dtype=float)

def mol_to_graph(smiles):
    mol = Chem.MolFromSmiles(smiles)
    atoms = mol.GetAtoms()
    heavy = [i for i, a in enumerate(atoms) if a.GetAtomicNum() > 1] or list(range(len(atoms)))
    idx_map = {old: i for i, old in enumerate(heavy)}
    x = torch.tensor([[atoms[i].GetAtomicNum(), atoms[i].GetFormalCharge(), atoms[i].GetNumExplicitHs()]
                      for i in heavy], dtype=torch.float)
    edges, attrs = [], []
    for b in mol.GetBonds():
        i, j = b.GetBeginAtomIdx(), b.GetEndAtomIdx()
        if i in heavy and j in heavy:
            ei, ej = idx_map[i], idx_map[j]
            onehot = [int(b.GetBondType() == t) for t in
                      (Chem.rdchem.BondType.SINGLE, Chem.rdchem.BondType.DOUBLE,
                       Chem.rdchem.BondType.TRIPLE, Chem.rdchem.BondType.AROMATIC)]
            edges += [[ei, ej], [ej, ei]]
            attrs += [onehot, onehot]
    if not edges:
        edge_index = torch.zeros((2, 0), dtype=torch.long)
        edge_attr  = torch.zeros((0, 4), dtype=torch.float)
    else:
        edge_index = torch.tensor(edges, dtype=torch.long).t().contiguous()
        edge_attr  = torch.tensor(attrs, dtype=torch.float)
    return Data(x=x, edge_index=edge_index, edge_attr=edge_attr)

# --- 2. MODEL DEFINITION (as used in training) ---
class GCNEncoder(nn.Module):
    def __init__(self, in_dim, hid, out_dim):
        super().__init__()
        self.e1 = nn.Sequential(nn.Linear(4, hid * in_dim), nn.ReLU(), nn.Linear(hid * in_dim, hid * in_dim))
        self.e2 = nn.Sequential(nn.Linear(4, hid * hid), nn.ReLU(), nn.Linear(hid * hid, hid * hid))
        self.e3 = nn.Sequential(nn.Linear(4, out_dim * hid), nn.ReLU(), nn.Linear(out_dim * hid, out_dim * hid))
        self.c1 = NNConv(in_dim, hid, self.e1, aggr='mean')
        self.c2 = NNConv(hid, hid, self.e2, aggr='mean')
        self.c3 = NNConv(hid, out_dim, self.e3, aggr='mean')
    def forward(self, x, ei, batch, ea):
        x = self.c1(x, ei, ea).relu()
        x = self.c2(x, ei, ea).relu()
        x = self.c3(x, ei, ea).relu()
        return global_mean_pool(x, batch)

class SolvationPredictor(nn.Module):
    def __init__(self):
        super().__init__()
        self.gcn_mol = GCNEncoder(3, 64, 64)
        self.gcn_sol = GCNEncoder(3, 32, 64)
        self.mlp_sol = nn.Sequential(nn.Linear(4, 64), nn.ReLU(), nn.Linear(64, 64))
        self.mlp_mol = nn.Sequential(nn.Linear(9, 64), nn.ReLU(), nn.Linear(64, 64))
        self.mlp_gcn = nn.Sequential(
            nn.Linear(64 + 64, 128),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(128, 64),
            nn.ReLU()
        )
        self.fuse = nn.Sequential(
            nn.Linear(64 + 64 + 64, 128),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(128, 128),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(128, 1)
        )
    def forward(self, mg, sg, sdesc, mdesc):
        me = self.gcn_mol(mg.x, mg.edge_index, mg.batch, mg.edge_attr)
        se = self.gcn_sol(sg.x, sg.edge_index, sg.batch, sg.edge_attr)
        sf = self.mlp_sol(sdesc)
        mf = self.mlp_mol(mdesc)
        gcn_cat = torch.cat([me, se], dim=-1)
        gcn_out = self.mlp_gcn(gcn_cat)
        cat = torch.cat([gcn_out, sf, mf], dim=-1)
        return self.fuse(cat).squeeze(-1)

# --- 3. LOAD INPUT DATASET ---
input_file = "generated_molecules_with_pred_lambda_abs.csv"
df = pd.read_csv(input_file)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# --- 4. LOAD ALL MODELS AND SCALERS ---
# Emission
model_em = SolvationPredictor().to(device)
model_em.load_state_dict(torch.load('model_emi.pt', map_location=device))
model_em.eval()
with open("mol_scaler_emission.pkl", "rb") as f: mol_scaler_em = pickle.load(f)
with open("solv_scaler_emission.pkl", "rb") as f: solv_scaler_em = pickle.load(f)
with open("target_scaler_emission.pkl", "rb") as f: target_scaler_em = pickle.load(f)

# Epsilon
model_eps = SolvationPredictor().to(device)
model_eps.load_state_dict(torch.load('model_epsilon.pt', map_location=device))
model_eps.eval()
with open("mol_scaler_Epsilon.pkl", "rb") as f: mol_scaler_eps = pickle.load(f)
with open("solv_scaler_Epsilon.pkl", "rb") as f: solv_scaler_eps = pickle.load(f)


# QY
model_qy = SolvationPredictor().to(device)
model_qy.load_state_dict(torch.load('model_QY.pt', map_location=device))
model_qy.eval()
with open("mol_scaler_QY.pkl", "rb") as f: mol_scaler_qy = pickle.load(f)
with open("solv_scaler_QY.pkl", "rb") as f: solv_scaler_qy = pickle.load(f)
# QY target is NOT scaled

# --- 5. PREDICTION FUNCTIONS ---
def predict_emission(row):
    smi, solv = row['smiles'], row['solvent']
    try:
        mol_graph = mol_to_graph(smi)
        mol_desc = compute_descriptors(smi, mol_funcs)
        mol_desc_scaled = torch.tensor(mol_scaler_em.transform([mol_desc])[0], dtype=torch.float).unsqueeze(0).to(device)
        solv_graph = mol_to_graph(solv)
        solv_desc = compute_descriptors(solv, solvent_funcs)
        solv_desc_scaled = torch.tensor(solv_scaler_em.transform([solv_desc])[0], dtype=torch.float).unsqueeze(0).to(device)
        mol_graph = mol_graph.to(device)
        solv_graph = solv_graph.to(device)
        mol_graph.batch = torch.zeros(mol_graph.num_nodes, dtype=torch.long, device=device)
        solv_graph.batch = torch.zeros(solv_graph.num_nodes, dtype=torch.long, device=device)
        with torch.no_grad():
            pred_scaled = model_em(mol_graph, solv_graph, solv_desc_scaled, mol_desc_scaled).cpu().numpy().item()
            pred_real = target_scaler_em.inverse_transform([[pred_scaled]])[0, 0]
        return pred_real
    except Exception as e:
        print(f"Emission prediction failed for {smi} in {solv}: {e}")
        return None

def predict_epsilon(row):
    smi, solv = row['smiles'], row['solvent']
    try:
        mol_graph = mol_to_graph(smi)
        mol_desc = compute_descriptors(smi, mol_funcs)
        mol_desc_scaled = torch.tensor(mol_scaler_eps.transform([mol_desc])[0], dtype=torch.float).unsqueeze(0).to(device)
        solv_graph = mol_to_graph(solv)
        solv_desc = compute_descriptors(solv, solvent_funcs)
        solv_desc_scaled = torch.tensor(solv_scaler_eps.transform([solv_desc])[0], dtype=torch.float).unsqueeze(0).to(device)
        mol_graph = mol_graph.to(device)
        solv_graph = solv_graph.to(device)
        mol_graph.batch = torch.zeros(mol_graph.num_nodes, dtype=torch.long, device=device)
        solv_graph.batch = torch.zeros(solv_graph.num_nodes, dtype=torch.long, device=device)
        with torch.no_grad():
            pred_logeps = model_eps(mol_graph, solv_graph, solv_desc_scaled, mol_desc_scaled).cpu().numpy().item()
        return pred_logeps
    except Exception as e:
        print(f"Epsilon prediction failed for {smi} in {solv}: {e}")
        return None

def predict_QY(row):
    smi, solv = row['smiles'], row['solvent']
    try:
        mol_graph = mol_to_graph(smi)
        mol_desc = compute_descriptors(smi, mol_funcs)
        mol_desc_scaled = torch.tensor(mol_scaler_qy.transform([mol_desc])[0], dtype=torch.float).unsqueeze(0).to(device)
        solv_graph = mol_to_graph(solv)
        solv_desc = compute_descriptors(solv, solvent_funcs)
        solv_desc_scaled = torch.tensor(solv_scaler_qy.transform([solv_desc])[0], dtype=torch.float).unsqueeze(0).to(device)
        mol_graph = mol_graph.to(device)
        solv_graph = solv_graph.to(device)
        mol_graph.batch = torch.zeros(mol_graph.num_nodes, dtype=torch.long, device=device)
        solv_graph.batch = torch.zeros(solv_graph.num_nodes, dtype=torch.long, device=device)
        with torch.no_grad():
            pred_real = model_qy(mol_graph, solv_graph, solv_desc_scaled, mol_desc_scaled).cpu().numpy().item()
        return pred_real
    except Exception as e:
        print(f"QY prediction failed for {smi} in {solv}: {e}")
        return None

# --- 6. APPLY PREDICTIONS ---
print("Predicting emission...")
df['pred_lambda_em'] = df.apply(predict_emission, axis=1)
print("Predicting epsilon...")
df['pred_epsilon'] = df.apply(predict_epsilon, axis=1)
print("Predicting QY...")
df['pred_QY'] = df.apply(predict_QY, axis=1)
df['brightness'] = df['pred_epsilon'] * df['pred_QY']

# --- 7. SAVE FINAL PREDICTIONS ---
df.to_csv("generated_molecules_with_all_predictions.csv", index=False)
print("Saved all predictions to generated_molecules_with_all_predictions.csv")
print(df.head())