In [None]:
import os
import pandas as pd
import numpy as np
from tqdm import tqdm
from rdkit import Chem, DataStructs, RDLogger
from rdkit.Chem import Descriptors, AllChem, rdMolDescriptors
from rdkit.Chem.Scaffolds import MurckoScaffold
from sklearn.cluster import KMeans
from sklearn.preprocessing import StandardScaler
import torch
import pickle

RDLogger.DisableLog('rdApp.*')

# Input CSV file path
input_csv = "INPUT_CSV_PATH"
# Directory to save label CSV files
label_dir = "LABEL_DIR_PATH"
# Directory to save molecular graph features
feature_dir = "FEATURE_DIR_PATH"
n_bins = 10
n_clusters = 10
common_elements = {"C","H","O","N"}

os.makedirs(label_dir, exist_ok=True)
os.makedirs(feature_dir, exist_ok=True)
tqdm.pandas()

df = pd.read_csv(input_csv)
assert 'freq' in df.columns, "Input CSV must contain 'freq' column!"
base = os.path.splitext(os.path.basename(input_csv))[0]

df1 = df.copy()
df1['label'] = pd.qcut(df1['y_true'], q=n_bins, labels=False, duplicates='drop')
df1 = df1[['NO','SMILES','y_true','freq','label']]
f1 = os.path.join(label_dir, f"{base}_qcut_{n_bins}_label.csv")
df1.to_csv(f1, index=False)

def element_label(smi):
    m = Chem.MolFromSmiles(smi)
    if not m: return "Invalid"
    els = {a.GetSymbol() for a in m.GetAtoms()}
    un = sorted(els - common_elements)
    return "_" if not un else ",".join(un)

df2 = df.copy()
df2['label'] = df2['SMILES'].progress_map(element_label)
df2 = df2[['NO','SMILES','y_true','freq','label']]
f2 = os.path.join(label_dir, f"{base}_elem_label.csv")
df2.to_csv(f2, index=False)

def molwt(smi):
    m = Chem.MolFromSmiles(smi)
    return Descriptors.MolWt(m) if m else np.nan

df3 = df.copy()
df3['MolWt'] = df3['SMILES'].progress_map(molwt)
df3 = df3.dropna(subset=['MolWt'])
km3 = KMeans(n_clusters=n_clusters, random_state=42).fit(df3[['MolWt']])
df3['label'] = km3.labels_
df3 = df3[['NO','SMILES','y_true','freq','label']]
f3 = os.path.join(label_dir, f"{base}_molwt_{n_clusters}_label.csv")
df3.to_csv(f3, index=False)

def smiles_to_fp(smi):
    m = Chem.MolFromSmiles(smi)
    if not m: return None
    fp = AllChem.GetMorganFingerprintAsBitVect(m, radius=2, nBits=2048)
    arr = np.zeros((2048,), dtype=int)
    DataStructs.ConvertToNumpyArray(fp, arr)
    return arr

valid = df['SMILES'].map(lambda s: Chem.MolFromSmiles(s) is not None)
df4 = df[valid].copy()
fps4 = df4['SMILES'].progress_map(smiles_to_fp).dropna()
X4 = np.stack(fps4.values)
km4 = KMeans(n_clusters=n_clusters, random_state=42).fit(X4)
df4['label'] = km4.labels_
df4 = df4[['NO','SMILES','y_true','freq','label']]
f4 = os.path.join(label_dir, f"{base}_fp_{n_clusters}_label.csv")
df4.to_csv(f4, index=False)

def scaffold_smiles(smi):
    m = Chem.MolFromSmiles(smi)
    if not m: return None
    sc = MurckoScaffold.GetScaffoldForMol(m)
    return Chem.MolToSmiles(sc) if sc else None

df5 = df.copy()
df5['Scaffold'] = df5['SMILES'].progress_map(scaffold_smiles)
df5 = df5.dropna(subset=['Scaffold'])
uniq5 = df5['Scaffold'].unique()
fps5 = []
for sc in tqdm(uniq5, desc="Scaffold FP"):
    m = Chem.MolFromSmiles(sc)
    fp = AllChem.GetMorganFingerprintAsBitVect(m, 2, nBits=2048)
    arr = np.zeros((2048,), dtype=int)
    DataStructs.ConvertToNumpyArray(fp, arr)
    fps5.append(arr)
X5 = np.stack(fps5)
km5 = KMeans(n_clusters=n_clusters, random_state=42).fit(X5)
map5 = dict(zip(uniq5, km5.labels_))
df5['label'] = df5['Scaffold'].map(map5)
df5 = df5[['NO','SMILES','y_true','freq','label']]
f5 = os.path.join(label_dir, f"{base}_scaffold_{n_clusters}_label.csv")
df5.to_csv(f5, index=False)

print("âœ… Label CSVs generated:")
print(f1, f2, f3, f4, f5, sep="\n")

def one_of_k(x, choices):
    return [x==c for c in choices[:-1]] + [x not in choices[:-1]]

def atom_feature(atom):
    elems = ['C','N','O','S','F','H','Si','P','Cl','Br','Li','Na','K','Mg','Ca','Fe','As','Al','I','B','V','Tl','Sb','Sn','Ag','Pd','Co','Se','Ti','Zn','Ge','Cu','Au','Ni','Cd','Mn','Cr','Pt','Hg','Pb']
    base = one_of_k(atom.GetSymbol(), elems) \
         + one_of_k(atom.GetDegree(), [0,1,2,3,4,5]) \
         + one_of_k(atom.GetTotalNumHs(), [0,1,2,3,4]) \
         + one_of_k(atom.GetImplicitValence(), [0,1,2,3,4,5]) \
         + [atom.GetIsAromatic()]
    extra = [
        atom.GetAtomicNum()/100.0,
        atom.GetMass()/100.0,
        atom.GetFormalCharge(),
        atom.GetNumRadicalElectrons(),
        atom.GetNumExplicitHs(),
    ]
    chg = float(atom.GetProp('_GasteigerCharge')) if atom.HasProp('_GasteigerCharge') else 0.0
    extra.append(chg)
    extra += one_of_k(atom.GetHybridization(), list(Chem.rdchem.HybridizationType.values)) \
          + one_of_k(atom.GetChiralTag(), list(Chem.rdchem.ChiralType.values)) \
          + [atom.IsInRing()] + [atom.IsInRingSize(r) for r in range(3,8)]
    return np.array(base+extra, dtype=float)

def bond_feature(bond):
    bt = bond.GetBondType()
    feats = [
        bt==Chem.rdchem.BondType.SINGLE,
        bt==Chem.rdchem.BondType.DOUBLE,
        bt==Chem.rdchem.BondType.TRIPLE,
        bt==Chem.rdchem.BondType.AROMATIC,
        bond.IsInRing(),
        bond.GetIsConjugated(),
        bond.GetStereo()!=Chem.rdchem.BondStereo.STEREONONE
    ]
    return np.array(feats, dtype=float)

def mol_global_features(mol, freq_value=None):
    desc = [
        Descriptors.MolWt(mol),
        Descriptors.MolLogP(mol),
        Descriptors.TPSA(mol),
        Descriptors.NumHAcceptors(mol),
        Descriptors.NumHDonors(mol),
        rdMolDescriptors.CalcNumRotatableBonds(mol),
        rdMolDescriptors.CalcNumHeteroatoms(mol),
        rdMolDescriptors.CalcFractionCSP3(mol)
    ]
    fp = AllChem.GetMorganFingerprintAsBitVect(mol, 2, nBits=64)
    desc += list(fp)
    if freq_value is not None:
        desc.append(freq_value)
    return desc

class MolGraphBuilder:
    def __init__(self, atom_scaler, edge_scaler):
        self.atom_scaler = atom_scaler
        self.edge_scaler = edge_scaler

    def build_graph(self, smiles, max_atoms, freq_value=None, use_edge=True, use_global=True):
        mol = Chem.AddHs(Chem.MolFromSmiles(smiles))
        if not mol: return None
        atom_feats = [atom_feature(a) for a in mol.GetAtoms()]
        n = len(atom_feats)
        x = np.zeros((max_atoms, len(atom_feats[0])), dtype=float)
        x[:n] = np.vstack(atom_feats)
        edge_idx, edge_attr = [], []
        for b in mol.GetBonds():
            i,j = b.GetBeginAtomIdx(), b.GetEndAtomIdx()
            edge_idx += [[i,j],[j,i]]
            if use_edge:
                bf = bond_feature(b)
                edge_attr += [bf, bf]
        for i in range(n):
            edge_idx.append([i,i])
            if use_edge:
                edge_attr.append(np.zeros_like(edge_attr[0],dtype=float))
        edge_index = torch.tensor(edge_idx).t().contiguous()
        edge_attr  = torch.tensor(edge_attr,dtype=torch.float) if use_edge else None
        u = None
        if use_global:
            try:
                u = torch.tensor(mol_global_features(mol, freq_value),dtype=torch.float).view(1,-1)
            except:
                u = None
        return {'x':torch.tensor(x,dtype=torch.float), 'edge_index':edge_index, 'edge_attr':edge_attr, 'u':u, 'smiles':smiles}

counts = [Chem.AddHs(Chem.MolFromSmiles(s)).GetNumAtoms() for s in df['SMILES'] if Chem.MolFromSmiles(s)]
max_atoms = max(counts) + 5
print(f"max_atoms = {max_atoms}")

all_atom_feats, all_edge_feats = [], []
for s in tqdm(df['SMILES'], desc='Collect atoms'):
    m = Chem.AddHs(Chem.MolFromSmiles(s))
    if m: all_atom_feats += [atom_feature(a) for a in m.GetAtoms()]
for s in tqdm(df['SMILES'], desc='Collect bonds'):
    m = Chem.AddHs(Chem.MolFromSmiles(s))
    if m: all_edge_feats += [bond_feature(b) for b in m.GetBonds()]
atom_scaler = StandardScaler().fit(np.vstack(all_atom_feats))
edge_scaler = StandardScaler().fit(np.vstack(all_edge_feats))

builder = MolGraphBuilder(atom_scaler, edge_scaler)

for csv_path in [f1, f2, f3, f4, f5]:
    sub = pd.read_csv(csv_path, usecols=['NO','SMILES','y_true','freq','label'])
    
    if sub['label'].dtype == object:
        label_encoder = {k: v for v, k in enumerate(sub['label'].unique())}
        sub['label'] = sub['label'].map(label_encoder)
        
    name = os.path.splitext(os.path.basename(csv_path))[0]
    out_subdir = os.path.join(feature_dir, name + '_feat')
    os.makedirs(out_subdir, exist_ok=True)

    graphs = []
    for _, row in tqdm(sub.iterrows(), total=len(sub), desc=f'Build {name}'):
        g = builder.build_graph(row['SMILES'], max_atoms, freq_value=row['freq'])
        if g:
            g['y'] = torch.tensor([row['y_true']], dtype=torch.float)
            g['label'] = torch.tensor([row['label']], dtype=torch.long)
            g['y_soft'] = torch.tensor([row['label']], dtype=torch.float)
            graphs.append(g)
    
    torch.save(graphs, os.path.join(out_subdir, 'graph_data.pt'))
    with open(os.path.join(out_subdir, 'scalers.pkl'), 'wb') as f:
        pickle.dump({
            'atom_scaler': atom_scaler, 
            'edge_scaler': edge_scaler, 
            'max_atoms': max_atoms,
            'label_encoder': label_encoder if 'label_encoder' in locals() else None
        }, f)

    print(f"Features saved in {out_subdir}")
