In [None]:
import os
import numpy as np
import pandas as pd
import torch
from rdkit import Chem
from rdkit.Chem import AllChem, Descriptors, rdMolDescriptors
from sklearn.preprocessing import StandardScaler
from tqdm import tqdm
import pickle

from rdkit import RDLogger
RDLogger.DisableLog('rdApp.*')

def one_of_k_encoding_unk(x, allowable_set):
    if x not in allowable_set:
        x = allowable_set[-1]
    return [x == s for s in allowable_set]

def atom_feature(atom):
    idx = int(atom.GetIdx())
    allowable_elements = ['C','N','O','S','F','H','Si','P','Cl','Br',
                          'Li','Na','K','Mg','Ca','Fe','As','Al','I','B',
                          'V','Tl','Sb','Sn','Ag','Pd','Co','Se','Ti','Zn',
                          'Ge','Cu','Au','Ni','Cd','Mn','Cr','Pt','Hg','Pb']
    base = one_of_k_encoding_unk(atom.GetSymbol(), allowable_elements) + \
           one_of_k_encoding_unk(atom.GetDegree(), [0,1,2,3,4,5]) + \
           one_of_k_encoding_unk(atom.GetTotalNumHs(), [0,1,2,3,4]) + \
           one_of_k_encoding_unk(atom.GetImplicitValence(), [0,1,2,3,4,5]) + \
           [atom.GetIsAromatic()]
    
    extra = [
        atom.GetAtomicNum() / 100.0,
        atom.GetMass() / 100.0,
        atom.GetFormalCharge(),
        atom.GetNumRadicalElectrons(),
        atom.GetNumExplicitHs(),
    ]
    if atom.HasProp('_GasteigerCharge'):
        extra.append(float(atom.GetProp('_GasteigerCharge')))
    else:
        extra.append(0.0)
    
    extra += one_of_k_encoding_unk(atom.GetHybridization(), list(Chem.rdchem.HybridizationType.values))
    extra += one_of_k_encoding_unk(atom.GetChiralTag(), list(Chem.rdchem.ChiralType.values))
    
    ring_features = [atom.IsInRing()]
    for ring_size in range(3, 8):
        ring_features.append(atom.IsInRingSize(ring_size))
    extra += ring_features
    
    return np.array(base + extra, dtype=float)

def bond_feature(bond):
    bt = bond.GetBondType()
    bf = [
        bt == Chem.rdchem.BondType.SINGLE,
        bt == Chem.rdchem.BondType.DOUBLE,
        bt == Chem.rdchem.BondType.TRIPLE,
        bt == Chem.rdchem.BondType.AROMATIC,
        bond.IsInRing(),
        bond.GetIsConjugated(),
        bond.GetStereo() != Chem.rdchem.BondStereo.STEREONONE
    ]
    return np.array(bf, dtype=float)

def mol_global_features(mol):
    base_features = [
        Descriptors.MolWt(mol),
        Descriptors.MolLogP(mol),
        Descriptors.TPSA(mol),
        Descriptors.NumHAcceptors(mol),
        Descriptors.NumHDonors(mol),
        rdMolDescriptors.CalcNumRotatableBonds(mol),
        rdMolDescriptors.CalcNumHeteroatoms(mol),
        rdMolDescriptors.CalcFractionCSP3(mol),
    ]
    fp = AllChem.GetMorganFingerprintAsBitVect(mol, 2, nBits=64)
    fp_features = list(fp)
    return base_features + fp_features

class MolGraphBuilder:
    def __init__(self, atom_scaler=None, edge_scaler=None):
        self.atom_scaler = atom_scaler
        self.edge_scaler = edge_scaler
    
    def build_graph(self, smiles, max_atoms, use_edge_features=True, use_global_features=True):
        mol = Chem.AddHs(Chem.MolFromSmiles(smiles))
        if mol is None: 
            return None
        
        atom_feats = [atom_feature(atom) for atom in mol.GetAtoms()]
        n = len(atom_feats)
        x = np.zeros((max_atoms, len(atom_feats[0])), dtype=float)
        x[:n] = np.vstack(atom_feats)
        
        edge_index = []
        edge_attr = []
        for bond in mol.GetBonds():
            i = int(bond.GetBeginAtomIdx())
            j = int(bond.GetEndAtomIdx())
            edge_index.append([i, j])
            edge_index.append([j, i])
            if use_edge_features:
                bf = bond_feature(bond)
                edge_attr.append(bf)
                edge_attr.append(bf)
        
        for i in range(n):
            edge_index.append([i, i])
            if use_edge_features:
                feat = np.zeros(len(edge_attr[0])) if edge_attr else np.zeros(7)
                edge_attr.append(feat)
        
        edge_index = torch.tensor(edge_index).t().contiguous()
        if use_edge_features:
            edge_attr = torch.tensor(edge_attr, dtype=torch.float)
        
        u = None
        if use_global_features:
            try:
                global_feats = mol_global_features(mol)
                u = torch.tensor(global_feats, dtype=torch.float).view(1, -1)
            except:
                u = None
        
        return {
            'x': torch.tensor(x, dtype=torch.float),
            'edge_index': edge_index,
            'edge_attr': edge_attr,
            'u': u,
            'smiles': smiles,
            'num_atoms': n
        }

def extract_features(
    df,
    soft_cols,
    output_dir="feature_data",
    max_atoms=None,
    use_edge_features=True,
    use_global_features=True
):
    os.makedirs(output_dir, exist_ok=True)
    
    if max_atoms is None:
        counts = []
        for s in df.SMILES:
            m = Chem.AddHs(Chem.MolFromSmiles(s))
            counts.append(m.GetNumAtoms() if m is not None else 0)
        max_atoms = max(counts) + 5
        print(f"Auto-calculated max atoms: {max_atoms}")
    
    all_atom_feats = []
    for smiles in tqdm(df["SMILES"], desc="Atom Features"):
        mol = Chem.AddHs(Chem.MolFromSmiles(smiles))
        if mol:
            for atom in mol.GetAtoms():
                all_atom_feats.append(atom_feature(atom))
    
    atom_scaler = StandardScaler()
    atom_scaler.fit(np.array(all_atom_feats))
    
    edge_scaler = None
    if use_edge_features:
        all_edge_feats = []
        for smiles in tqdm(df["SMILES"], desc="Bond Features"):
            mol = Chem.AddHs(Chem.MolFromSmiles(smiles))
            if mol and mol.GetNumBonds() > 0:
                for bond in mol.GetBonds():
                    all_edge_feats.append(bond_feature(bond))
        if all_edge_feats:
            edge_scaler = StandardScaler()
            edge_scaler.fit(np.array(all_edge_feats))
    
    if "freq" not in df.columns:
        raise ValueError("'freq' column is missing in the dataframe")
    freq_scaler = StandardScaler()
    freq_scaler.fit(df[["freq"]].values)
    
    graph_builder = MolGraphBuilder(atom_scaler, edge_scaler)
    graph_data = []
    error_count = 0
    max_errors = 20
    
    for idx, row in tqdm(df.iterrows(), total=len(df), desc="Building Graphs"):
        try:
            data = graph_builder.build_graph(
                row["SMILES"],
                max_atoms,
                use_edge_features,
                use_global_features
            )
            if data is None:
                continue
            
            if 'y_true' in row:
                data['y'] = torch.tensor([row['y_true']], dtype=torch.float)
            else:
                raise ValueError("'y_true' target column is missing in the dataframe")
            
            soft_labels = []
            for col in soft_cols:
                if col in row:
                    soft_labels.append(row[col])
                else:
                    raise ValueError(f"Soft label column '{col}' is missing in the dataframe")
            data['y_soft'] = torch.tensor(soft_labels, dtype=torch.float).view(1, -1)
            
            if pd.notna(row["freq"]):
                freq_val = freq_scaler.transform([[row["freq"]]])
                freq_tensor = torch.tensor(freq_val, dtype=torch.float)
                if data['u'] is not None:
                    data['u'] = torch.cat([data['u'], freq_tensor], dim=1)
                else:
                    data['u'] = freq_tensor
            else:
                raise ValueError("Missing values exist in 'freq' column")
            
            graph_data.append(data)
        
        except Exception as e:
            error_count += 1
            if error_count <= max_errors:
                print(f"Error processing SMILES: {row['SMILES']}, Error: {str(e)}")
            elif error_count == max_errors + 1:
                print(f"Exceeded maximum error display count ({max_errors}), subsequent errors will not be shown...")
    
    print(f"Processing completed, {error_count} errors encountered, valid samples: {len(graph_data)}")
    
    torch.save(graph_data, os.path.join(output_dir, "graph_data.pt"))
    
    with open(os.path.join(output_dir, "scalers.pkl"), 'wb') as f:
        pickle.dump({
            'atom_scaler': atom_scaler,
            'edge_scaler': edge_scaler,
            'freq_scaler': freq_scaler,
            'max_atoms': max_atoms
        }, f)
    
    with open(os.path.join(output_dir, "metadata.txt"), 'w') as f:
        f.write(f"max_atoms={max_atoms}\n")
        f.write(f"use_edge_features={use_edge_features}\n")
        f.write(f"use_global_features={use_global_features}\n")
        f.write(f"num_samples={len(graph_data)}\n")
        f.write(f"extra_feature=freq (standardized)\n")
    
    print("Feature extraction completed!")
    return output_dir


if __name__ == "__main__":
    # Input CSV file path
    input_csv = "INPUT_CSV_PATH"
    # Output directory for feature data
    output_feature_dir = "OUTPUT_FEATURE_DIR_PATH"

    df = pd.read_csv(input_csv)

    soft_cols = [
    ]

    feature_dir = extract_features(
        df=df,
        soft_cols=soft_cols,
        output_dir=output_feature_dir,
        use_edge_features=True,
        use_global_features=True
    )
    print(f"Features saved to: {feature_dir}")
