In [2]:
import warnings
warnings.filterwarnings("ignore", category=UserWarning, module="torch_geometric.warnings")

import pandas as pd
import numpy as np
from tqdm import tqdm
import time
import os
import joblib
from collections import OrderedDict
import copy
from dataclasses import dataclass

from sklearn.model_selection import KFold, train_test_split, StratifiedKFold
from sklearn.preprocessing import MinMaxScaler
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import Linear, ReLU, Dropout, Sequential, BatchNorm1d

from torch_geometric.nn import GINEConv, global_mean_pool, global_add_pool, global_max_pool
from torch_geometric.data import Data, Dataset
from torch_geometric.loader import DataLoader

from rdkit import Chem
from rdkit.Chem import AllChem
from rdkit import RDLogger
RDLogger.DisableLog('rdApp.*')

from skfp.fingerprints import PubChemFingerprint
import optuna

CSV_PATH = '../dataset/MT-thermal.csv'
SMILES_COL = 'smiles'
VALUE_COL = 'val'
PROP_COL = 'prop'

TEST_SPLIT_RATIO = 0.2
VALIDATION_SPLIT_OPTUNA = 0.2
RANDOM_STATE = 0
N_FOLDS = 5
N_TRIALS_OPTUNA = 30
N_EPOCHS_FOLD_TRAINING = 250
PATIENCE_FOLD_TRAINING = 30
BEST_MODEL_SAVE_DIR = './best_models_gine_mixfp_multitask_v3'
TEST_SET_RESULTS_FILE = 'results_gine_mixfp_multitask_v3.txt'
SCALERS_SAVE_FILE = 'multitask_minmax_scalers.joblib'
TARGET_COLS_FILE = 'target_cols_multitask.joblib'
PREGENERATED_DATA_FILE = 'pregenerated_data.joblib'

os.makedirs(BEST_MODEL_SAVE_DIR, exist_ok=True)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

element_names = [
    "C", "N", "O", "S", "F", "Si", "P", "Cl", "Br", "Mg", "Na",
    "Ca", "Fe", "As", "Al", "I", "B", "V", "K", "Tl", "Yb",
    "Sb", "Sn", "Ag", "Pd", "Co", "Se", "Ti", "Zn", "H",
    "Li", "Ge", "Cu", "Au", "Ni", "Cd", "In", "Mn", "Zr",
    "Cr", "Pt", "Hg", "Pb", "Unknown","*"
]

def one_of_k_encoding(x, allowable_set):
    return list(map(lambda s: x == s, allowable_set))

def one_of_k_encoding_unk(x, allowable_set):
    if x not in allowable_set:
        x = allowable_set[-1]
    return list(map(lambda s: x == s, allowable_set))

@dataclass
class BondConfig:
    bond_type: bool; conjugation: bool; ring: bool; stereo: bool = False; bond_dir: bool = False
    def __post_init__(self):
        self.n_features = 0; self.feat_names = []
        if self.bond_type: self.n_features += 4; self.feat_names += ["Single", "Double", "Triple", "Aromatic"]
        if self.conjugation: self.n_features += 1; self.feat_names += ["Conjugation"]
        if self.ring: self.n_features += 1; self.feat_names += ["inRing"]
        if self.stereo: self.n_features += 6; self.feat_names += ["any", "cis", "e", "none", "trans", "z"]
        if self.bond_dir: self.n_features += 7; self.feat_names += ["begin_dash", "begin_wedge", "either_double", "end_down_right", "end_up_right", "none", "unknown"]

@dataclass
class AtomConfig:
    element_type: bool; degree: bool; implicit_valence: bool; formal_charge: bool
    num_rad_e: bool; hybridization: bool; combo_hybrid: bool = False
    aromatic: bool = False; chirality: bool = False
    def __post_init__(self):
        self.n_features = 0; self.feat_names = []
        def update(names): self.feat_names += names; self.n_features += len(names)
        if self.element_type: update(element_names)
        if self.degree: update([f"degree{ind}" for ind in range(11)])
        if self.implicit_valence: update([f"implicitValence{ind}" for ind in range(7)])
        if self.formal_charge: update(["formalCharge"])
        if self.num_rad_e: update(["numRadElectons"])
        if self.hybridization:
            options = ["HybridizationType.SP", "HybridizationType.SP2or3", "HybridizationType.SP3D", "HybridizationType.SP3D2"] if self.combo_hybrid \
                 else ["HybridizationType.SP", "HybridizationType.SP2", "HybridizationType.SP3", "HybridizationType.SP3D", "HybridizationType.SP3D2"]
            update(options)
        if self.aromatic: update(["Aromatic"])
        if self.chirality: update(["Unspecified", "Tetrahedral_CW", "Tetrahedral_CCW", "CHI_OTHER", "Tetrahedral", "Allene", "Square_planar", "Trigonal_bipyramidal", "Octahedral"]) # Fixed CHI_OTHER typo and order

def bond_fp(bond, config: BondConfig):
    bond_feats = []
    if config.bond_type: bt = bond.GetBondType(); bond_feats += [bt == Chem.rdchem.BondType.SINGLE, bt == Chem.rdchem.BondType.DOUBLE, bt == Chem.rdchem.BondType.TRIPLE, bt == Chem.rdchem.BondType.AROMATIC]
    if config.conjugation: bond_feats.append(bond.GetIsConjugated())
    if config.ring: bond_feats.append(bond.IsInRing())
    if config.stereo: st = bond.GetStereo(); bond_feats += [st == Chem.rdchem.BondStereo.STEREOANY, st == Chem.rdchem.BondStereo.STEREOCIS, st == Chem.rdchem.BondStereo.STEREOE, st == Chem.rdchem.BondStereo.STEREONONE, st == Chem.rdchem.BondStereo.STEREOTRANS, st == Chem.rdchem.BondStereo.STEREOZ]
    if config.bond_dir: bd = bond.GetBondDir(); bond_feats += [bd == Chem.rdchem.BondDir.BEGINDASH, bd == Chem.rdchem.BondDir.BEGINWEDGE, bd == Chem.rdchem.BondDir.EITHERDOUBLE, bd == Chem.rdchem.BondDir.ENDDOWNRIGHT, bd == Chem.rdchem.BondDir.ENDUPRIGHT, bd == Chem.rdchem.BondDir.NONE, bd == Chem.rdchem.BondDir.UNKNOWN]
    return [float(f) for f in bond_feats]

def atom_fp(atom, config: AtomConfig):
    results = []
    if config.element_type: results += one_of_k_encoding_unk(atom.GetSymbol(), element_names)
    if config.degree: results += one_of_k_encoding(atom.GetDegree(), list(range(11)))
    if config.implicit_valence: results += one_of_k_encoding_unk(atom.GetImplicitValence(), list(range(7)))
    if config.formal_charge: results += [atom.GetFormalCharge()]
    if config.num_rad_e: results += [atom.GetNumRadicalElectrons()]
    if config.hybridization:
        feat = atom.GetHybridization()
        options = ["HybridizationType.SP", "HybridizationType.SP2or3", "HybridizationType.SP3D", "HybridizationType.SP3D2"] if config.combo_hybrid and (feat == Chem.rdchem.HybridizationType.SP2 or feat == Chem.rdchem.HybridizationType.SP3) \
             else ["HybridizationType.SP", "HybridizationType.SP2", "HybridizationType.SP3", "HybridizationType.SP3D", "HybridizationType.SP3D2"]
        current_feat = "SP2/3" if config.combo_hybrid and (feat == Chem.rdchem.HybridizationType.SP2 or feat == Chem.rdchem.HybridizationType.SP3) else feat
        results += one_of_k_encoding_unk(current_feat, options)
    if config.aromatic: results += [atom.GetIsAromatic()]
    if config.chirality: tag = str(atom.GetChiralTag()); results += [tag == "CHI_UNSPECIFIED", tag == "CHI_TETRAHEDRAL_CW", tag == "CHI_TETRAHEDRAL_CCW", tag == "CHI_OTHER", tag == "CHI_TETRAHEDRAL", tag == "CHI_ALLENE", tag == "CHI_SQUAREPLANAR", tag == "CHI_TRIGONALBIPYRAMIDAL", tag == "CHI_OCTAHEDRAL"]
    return [float(f) for f in results]
# --- END OF STRICTLY PROVIDED ATOM_FP FUNCTION ---

# --- Define Feature Configurations ---
polygnn_atom_config = AtomConfig(
    element_type=True, degree=True, implicit_valence=True, formal_charge=True,
    num_rad_e=True, hybridization=True, aromatic=True, chirality=True, combo_hybrid=False # combo_hybrid=False here
)
polygnn_bond_config = BondConfig(bond_type=True, conjugation=True, ring=True)

# ATOM_F_DIM is calculated based on AtomConfig.__post_init__
ATOM_F_DIM = polygnn_atom_config.n_features
EDGE_F_DIM = polygnn_bond_config.n_features
print(f"PolyGNN Atom feature dimension: {ATOM_F_DIM}")
print(f"PolyGNN Bond feature dimension: {EDGE_F_DIM}")

pubchem_fp_calculator = PubChemFingerprint(count=False, sparse=False)
temp_mol = Chem.MolFromSmiles('C')

def mixfp(mol):
    fp = []
    fp_maccs = AllChem.GetMACCSKeysFingerprint(mol)
    fp_pubcfp = pubchem_fp_calculator.transform([mol])[0]
    fp.extend(fp_maccs.ToList())
    fp.extend(fp_pubcfp)
    return fp

MIXFP_F_DIM = len(mixfp(temp_mol))
print(f"MixFP fingerprint dimension: {MIXFP_F_DIM}")

def smiles_to_pyg_data(smiles, atom_config=polygnn_atom_config, bond_config=polygnn_bond_config):
    mol = Chem.MolFromSmiles(smiles)
    if mol is None: return None

    atom_features_list = [atom_fp(atom, atom_config) for atom in mol.GetAtoms()]

    x = torch.tensor(atom_features_list, dtype=torch.float)

    edge_indices, edge_features_list = [], []
    if mol.GetNumBonds() > 0:
        for bond in mol.GetBonds():
            i, j = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()
            edge_indices.extend([(i, j), (j, i)])
            bond_feature = bond_fp(bond, bond_config)
            edge_features_list.extend([bond_feature, bond_feature])
        edge_index = torch.tensor(edge_indices, dtype=torch.long).t().contiguous()
        edge_attr = torch.tensor(edge_features_list, dtype=torch.float)
    else:
        edge_index = torch.empty((2, 0), dtype=torch.long)
        edge_attr = torch.empty((0, EDGE_F_DIM), dtype=torch.float)

    mixfp_features = mixfp(mol)
    mixfp_vec = torch.tensor(mixfp_features, dtype=torch.float)

    data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr,
                mixfp=mixfp_vec.unsqueeze(0), smiles=smiles)
    return data

class MultiTaskMoleculeDataset(Dataset):
    def __init__(self, dataframe, smiles_data_map, prop_map):
        super().__init__()
        self.dataframe = dataframe
        self.smiles_data_map = smiles_data_map
        self.prop_map = prop_map
        self.prop_selector_dim = len(prop_map)

        self.selectors = np.stack(self.dataframe[PROP_COL].apply(lambda p: np.eye(self.prop_selector_dim)[self.prop_map[p]]).values).astype(np.float32)
        self.targets_scaled = np.array(self.dataframe['scaled_value'].values, dtype=np.float32).reshape(-1, 1)
        self.smiles_list = self.dataframe[SMILES_COL].values

    def len(self): return len(self.dataframe)

    def get(self, idx):
        smiles = self.smiles_list[idx]
        data_obj = self.smiles_data_map[smiles]
        selector = torch.from_numpy(self.selectors[idx])
        target_scaled = torch.from_numpy(self.targets_scaled[idx])
        return data_obj, selector, target_scaled

class CrossModalInteraction(nn.Module):
    def __init__(self, gnn_dim, fp_dim, hidden_dim, dropout):
        super().__init__()
        self.mlp_fp_to_gnn = nn.Sequential(nn.Linear(fp_dim, hidden_dim), nn.ReLU(), Dropout(dropout), nn.Linear(hidden_dim, gnn_dim))
        self.mlp_gnn_to_fp = nn.Sequential(nn.Linear(gnn_dim, hidden_dim), nn.ReLU(), Dropout(dropout), nn.Linear(hidden_dim, fp_dim))
        self.gate_activation = nn.Sigmoid()
    def forward(self, gnn_embedding, fp_embedding):
        refined_gnn, refined_fp = gnn_embedding, fp_embedding
        embeddings_to_concat = []
        if gnn_embedding is not None: embeddings_to_concat.append(gnn_embedding)
        if fp_embedding is not None: embeddings_to_concat.append(fp_embedding)

        if gnn_embedding is not None and fp_embedding is not None:
            gnn_context = self.mlp_fp_to_gnn(fp_embedding); gnn_gate = self.gate_activation(gnn_context); refined_gnn = gnn_embedding * gnn_gate
            fp_context = self.mlp_gnn_to_fp(gnn_embedding); fp_gate = self.gate_activation(fp_context); refined_fp = fp_embedding * fp_gate
            embeddings_to_concat = [refined_gnn, refined_fp]

        fused_vector = embeddings_to_concat[0] if len(embeddings_to_concat)==1 else torch.cat(embeddings_to_concat, dim=-1) if len(embeddings_to_concat) > 1 else None
        return fused_vector, (None, None) # Gates not needed for standard forward

class MultiTaskPredictor(nn.Module):
    def __init__(self, node_feature_dim, mixfp_feature_dim, prop_selector_dim, gnn_hidden_dim, gnn_layers, fp_hidden_dim, fp_layers, interaction_hidden_dim, fusion_hidden_dim, dropout=0.3):
        super().__init__()
        self.dropout_rate = dropout; self.node_feature_dim = node_feature_dim; self.mixfp_feature_dim = mixfp_feature_dim; self.prop_selector_dim = prop_selector_dim
        self.gnn_readout_dim, self.final_fp_dim = 0, 0
        self.gnn_input_mlp, self.gine_convs, self.gnn_batch_norms = None, nn.ModuleList(), nn.ModuleList()
        if node_feature_dim > 0 and gnn_layers > 0:
            self.gnn_input_mlp = Linear(node_feature_dim, gnn_hidden_dim)
            for _ in range(gnn_layers):
                mlp = Sequential(Linear(gnn_hidden_dim, gnn_hidden_dim * 2), BatchNorm1d(gnn_hidden_dim * 2), ReLU(), Linear(gnn_hidden_dim * 2, gnn_hidden_dim))
                self.gine_convs.append(GINEConv(nn=mlp, edge_dim=EDGE_F_DIM))
                self.gnn_batch_norms.append(BatchNorm1d(gnn_hidden_dim))
            self.gnn_readout_dim = gnn_hidden_dim * 2 # Global Add + Global Max
        self.fp_layers_list, self.fp_batch_norms = nn.ModuleList(), nn.ModuleList()
        if mixfp_feature_dim > 0:
            if fp_layers > 0:
                self.fp_layers_list.append(Linear(mixfp_feature_dim, fp_hidden_dim)); self.fp_batch_norms.append(BatchNorm1d(fp_hidden_dim))
                for _ in range(fp_layers - 1): self.fp_layers_list.append(Linear(fp_hidden_dim, fp_hidden_dim)); self.fp_batch_norms.append(BatchNorm1d(fp_hidden_dim))
                self.final_fp_dim = fp_hidden_dim
            else: self.final_fp_dim = mixfp_feature_dim
        self.interaction_module = None;
        fusion_base_dim = 0
        if self.gnn_readout_dim > 0 and self.final_fp_dim > 0 and interaction_hidden_dim > 0:
             self.interaction_module = CrossModalInteraction(self.gnn_readout_dim, self.final_fp_dim, interaction_hidden_dim, dropout)
             fusion_base_dim = self.gnn_readout_dim + self.final_fp_dim
        elif self.gnn_readout_dim > 0: fusion_base_dim = self.gnn_readout_dim
        elif self.final_fp_dim > 0: fusion_base_dim = self.final_fp_dim

        fusion_input_dim = fusion_base_dim + prop_selector_dim

        self.fusion_mlp = Sequential(Linear(fusion_input_dim, fusion_hidden_dim), ReLU(), BatchNorm1d(fusion_hidden_dim), Dropout(dropout),
                                   Linear(fusion_hidden_dim, fusion_hidden_dim // 2), ReLU(), Dropout(dropout), Linear(fusion_hidden_dim // 2, 1))

    def forward(self, data, selector):
        x, edge_index, edge_attr, batch, mixfp = data.x, data.edge_index, data.edge_attr, data.batch, data.mixfp
        graph_embedding, fp_embedding = None, None

        if self.gnn_input_mlp is not None:
            gnn_x = self.gnn_input_mlp(x)
            for i in range(len(self.gine_convs)):
                gnn_x = self.gine_convs[i](gnn_x, edge_index, edge_attr=edge_attr); bn = self.gnn_batch_norms[i]
                gnn_x = F.relu(bn(gnn_x)); gnn_x = F.dropout(gnn_x, p=self.dropout_rate, training=self.training)
            graph_embedding = torch.cat([global_add_pool(gnn_x, batch), global_max_pool(gnn_x, batch)], dim=-1)

        if self.final_fp_dim > 0:
             fp = mixfp.squeeze(1)
             if len(self.fp_layers_list) > 0:
                 for i in range(len(self.fp_layers_list)): bn_fp = self.fp_batch_norms[i]; fp = F.relu(bn_fp(self.fp_layers_list[i](fp))); fp = F.dropout(fp, p=self.dropout_rate, training=self.training)
             fp_embedding = fp

        fused_vector = None
        if self.interaction_module is not None: fused_vector, _ = self.interaction_module(graph_embedding, fp_embedding)
        else:
            embeddings_to_use = []
            if graph_embedding is not None: embeddings_to_use.append(graph_embedding)
            if fp_embedding is not None: embeddings_to_use.append(fp_embedding)
            fused_vector = embeddings_to_use[0] if len(embeddings_to_use) == 1 else torch.cat(embeddings_to_use, dim=-1) if len(embeddings_to_use) > 1 else None

        if fused_vector is not None:
             final_input = torch.cat([fused_vector, selector], dim=-1)
        else:
             final_input = selector

        output = self.fusion_mlp(final_input)
        return output

def train_epoch(model, loader, optimizer, loss_fn, device):
    model.train()
    total_loss = 0.0
    num_samples = 0
    for data_batch, selector_batch, target_batch in loader:
        data_batch = data_batch.to(device)
        selector_batch = selector_batch.to(device)
        target_batch = target_batch.to(device)

        optimizer.zero_grad()
        out = model(data_batch, selector_batch)
        loss = loss_fn(out, target_batch)
        loss.backward()
        optimizer.step()
        total_loss += loss.item() * data_batch.num_graphs
        num_samples += data_batch.num_graphs
    return total_loss / num_samples

@torch.no_grad()
def evaluate(model, loader, loss_fn, device, scalers, prop_names):
    model.eval()
    total_loss = 0.0
    all_preds_scaled, all_targets_scaled, all_props = [], [], []

    for data_batch, selector_batch, target_batch in loader:
        data_batch = data_batch.to(device)
        selector_batch = selector_batch.to(device)
        target_batch = target_batch.to(device)

        out_scaled = model(data_batch, selector_batch)

        total_loss += loss_fn(out_scaled, target_batch).item() * data_batch.num_graphs

        all_preds_scaled.append(out_scaled.cpu().numpy())
        all_targets_scaled.append(target_batch.cpu().numpy())

        prop_indices = torch.argmax(selector_batch.cpu(), dim=1).numpy()
        batch_props = [prop_names[i] for i in prop_indices]
        all_props.extend(batch_props)

    avg_loss_scaled = total_loss / sum(len(batch) for batch in all_preds_scaled)

    preds_scaled_np = np.concatenate(all_preds_scaled, axis=0)
    targets_scaled_np = np.concatenate(all_targets_scaled, axis=0)

    metrics = {}
    preds_orig_dict = {}
    targets_orig_dict = {}

    for prop_idx, prop in enumerate(prop_names):
        prop_mask = np.array(all_props) == prop
        if np.sum(prop_mask) > 0:
            prop_preds_scaled = preds_scaled_np[prop_mask]
            prop_targets_scaled = targets_scaled_np[prop_mask]

            # --- FIX: Access scaler by index from the list ---
            scaler = scalers[prop_idx]

            if scaler is not None and hasattr(scaler, 'scale_') and scaler.scale_ is not None:
                 prop_preds_orig = scaler.inverse_transform(prop_preds_scaled)
                 prop_targets_orig = scaler.inverse_transform(prop_targets_scaled)

                 preds_orig_dict[prop] = prop_preds_orig.flatten().tolist()
                 targets_orig_dict[prop] = prop_targets_orig.flatten().tolist()

                 mae_orig = mean_absolute_error(prop_targets_orig, prop_preds_orig)
                 rmse_orig = np.sqrt(mean_squared_error(prop_targets_orig, prop_preds_orig))
                 r2 = r2_score(prop_targets_orig, prop_preds_orig)
                 metrics[prop] = {'MAE': mae_orig, 'RMSE': rmse_orig, 'R2': r2, 'count': int(np.sum(prop_mask))}
            else:
                 metrics[prop] = {'MAE': np.nan, 'RMSE': np.nan, 'R2': np.nan, 'count': int(np.sum(prop_mask))}
        else:
             metrics[prop] = {'MAE': np.nan, 'RMSE': np.nan, 'R2': np.nan, 'count': 0}
             preds_orig_dict[prop] = []
             targets_orig_dict[prop] = []


    return avg_loss_scaled, metrics, targets_orig_dict, preds_orig_dict

# --- Data Loading and Preparation ---
print("Loading data...")
df = pd.read_csv(CSV_PATH)
df.dropna(subset=[SMILES_COL, VALUE_COL, PROP_COL], inplace=True)
print(f"Data loaded and cleaned, shape: {df.shape}")

TARGET_COLS = sorted(df[PROP_COL].unique().tolist())
PROP_MAP = {prop: i for i, prop in enumerate(TARGET_COLS)}
NUM_TASKS = len(TARGET_COLS)
PROP_SELECTOR_DIM = NUM_TASKS
print(f"Detected target properties: {TARGET_COLS}")
print(f"Number of tasks (properties): {NUM_TASKS}")

print("\nPre-generating PyG Data objects for unique SMILES...")
unique_smiles = df[SMILES_COL].unique().tolist()
smiles_data_map = {}
pregenerated_data_path = os.path.join(BEST_MODEL_SAVE_DIR, PREGENERATED_DATA_FILE)

if os.path.exists(pregenerated_data_path):
    print(f"Loading pre-generated data from {pregenerated_data_path}")
    smiles_data_map = joblib.load(pregenerated_data_path)
    smiles_data_map = {s: d for s, d in smiles_data_map.items() if d is not None}
    print(f"Loaded data for {len(smiles_data_map)} unique SMILES.")
else:
    print("Pre-generated data not found. Generating...")
    for smiles in tqdm(unique_smiles, desc="Generating Data objects"):
        data_obj = smiles_to_pyg_data(smiles)
        if data_obj is not None:
            smiles_data_map[smiles] = data_obj

    print(f"Generated data for {len(smiles_data_map)} unique SMILES.")
    print(f"Saving pre-generated data to {pregenerated_data_path}")
    joblib.dump(smiles_data_map, pregenerated_data_path)

df = df[df[SMILES_COL].isin(smiles_data_map.keys())].reset_index(drop=True)
print(f"Filtered dataframe based on successful data generation, shape: {df.shape}")

print("\nSplitting data (stratified by property)...")
train_val_df, test_df = train_test_split(df, test_size=TEST_SPLIT_RATIO, stratify=df[PROP_COL], random_state=RANDOM_STATE)

train_val_df = train_val_df.copy().reset_index(drop=True)
test_df = test_df.copy().reset_index(drop=True)

print(f"Train/Val set size: {len(train_val_df)}, Test set size: {len(test_df)}")
print("Train/Val property distribution:\n", train_val_df[PROP_COL].value_counts(normalize=True).sort_index())
print("Test property distribution:\n", test_df[PROP_COL].value_counts(normalize=True).sort_index())

print("\nFitting main scalers (MinMaxScaler) per property...")
main_scalers = [] # Store scalers in a list
for prop in TARGET_COLS:
    scaler = MinMaxScaler()
    prop_train_val_data = train_val_df[train_val_df[PROP_COL] == prop][VALUE_COL].values.reshape(-1, 1)
    if len(prop_train_val_data) > 0:
        scaler.fit(prop_train_val_data)
    main_scalers.append(scaler)

main_scalers_path = os.path.join(BEST_MODEL_SAVE_DIR, SCALERS_SAVE_FILE)
target_cols_path = os.path.join(BEST_MODEL_SAVE_DIR, TARGET_COLS_FILE)
joblib.dump(main_scalers, main_scalers_path)
joblib.dump(TARGET_COLS, target_cols_path)
print(f"Main scalers and target columns saved to: {main_scalers_path}, {target_cols_path}")

print(f"\n--- Starting {N_FOLDS}-Fold Stratified CV ---")
kf_mtl = StratifiedKFold(n_splits=N_FOLDS, shuffle=True, random_state=RANDOM_STATE + 1)

fold_model_paths = []
fold_best_hyperparams = []

train_val_indices = train_val_df.index.tolist()
train_val_stratify_target = train_val_df[PROP_COL]

for fold, (fold_train_indices, fold_val_indices) in enumerate(kf_mtl.split(train_val_indices, train_val_stratify_target)):
    print(f"\n===== Fold {fold + 1}/{N_FOLDS} =====")
    train_df_fold = train_val_df.iloc[fold_train_indices].copy().reset_index(drop=True)
    val_df_fold = train_val_df.iloc[fold_val_indices].copy().reset_index(drop=True)

    print(f"Fold Train size: {len(train_df_fold)}, Fold Val size: {len(val_df_fold)}")

    print(f"\n--- Fold {fold + 1}: Fitting fold scalers per property ---")
    fold_scalers = [] 
    for prop in TARGET_COLS:
        scaler = MinMaxScaler()
        prop_train_data = train_df_fold[train_df_fold[PROP_COL] == prop][VALUE_COL].values.reshape(-1, 1)
        if len(prop_train_data) > 0:
            scaler.fit(prop_train_data)
        fold_scalers.append(scaler)

    train_df_fold['scaled_value'] = np.nan
    val_df_fold['scaled_value'] = np.nan

    for prop_idx, prop in enumerate(TARGET_COLS):
         scaler = fold_scalers[prop_idx]
         train_prop_mask = train_df_fold[PROP_COL] == prop
         train_prop_data = train_df_fold.loc[train_prop_mask, VALUE_COL].values.reshape(-1, 1)
         if hasattr(scaler, 'scale_') and scaler.scale_ is not None and len(train_prop_data) > 0:
             train_df_fold.loc[train_prop_mask, 'scaled_value'] = scaler.transform(train_prop_data).flatten()

         val_prop_mask = val_df_fold[PROP_COL] == prop
         val_prop_data = val_df_fold.loc[val_prop_mask, VALUE_COL].values.reshape(-1, 1)
         if hasattr(scaler, 'scale_') and scaler.scale_ is not None and len(val_prop_data) > 0:
              val_df_fold.loc[val_prop_mask, 'scaled_value'] = scaler.transform(val_prop_data).flatten()


    train_df_fold.dropna(subset=['scaled_value'], inplace=True)
    val_df_fold.dropna(subset=['scaled_value'], inplace=True)
    print(f"Fold Train size after dropping unscalable samples: {len(train_df_fold)}")
    print(f"Fold Val size after dropping unscalable samples: {len(val_df_fold)}")


    # --- Optuna Split (from train_df_fold) ---
    print(f"\n--- Fold {fold + 1}: Running Optuna ---")
    if len(train_df_fold) == 0:
        print(f"  Fold {fold + 1}: No training data after scaling filtering. Skipping Optuna.")
        fold_best_hyperparams.append({})
        continue

    optuna_train_df, optuna_val_df = train_test_split(train_df_fold, test_size=VALIDATION_SPLIT_OPTUNA, stratify=train_df_fold[PROP_COL], random_state=RANDOM_STATE + fold + 2)

    optuna_train_df = optuna_train_df.copy().reset_index(drop=True)
    optuna_val_df = optuna_val_df.copy().reset_index(drop=True)

    optuna_train_dataset = MultiTaskMoleculeDataset(optuna_train_df, smiles_data_map, PROP_MAP)
    optuna_val_dataset = MultiTaskMoleculeDataset(optuna_val_df, smiles_data_map, PROP_MAP)

    def objective(trial):
         lr = trial.suggest_float("lr", 1e-5, 1e-3, log=True)
         weight_decay = trial.suggest_float("weight_decay", 1e-6, 1e-4, log=True)
         dropout = trial.suggest_float("dropout", 0.1, 0.5)
         batch_size = trial.suggest_categorical("batch_size", [32, 64, 128])
         fusion_hidden_dim = trial.suggest_categorical("fusion_hidden_dim", [64, 128, 256, 512])
         gnn_hidden_dim, gnn_layers = (trial.suggest_categorical("gnn_hidden_dim", [32, 64, 128, 256]), trial.suggest_int("gnn_layers", 1, 5)) if ATOM_F_DIM > 0 else (0, 0)
         fp_hidden_dim, fp_layers = (trial.suggest_categorical("fp_hidden_dim", [128, 256, 512, 1024]), trial.suggest_int("fp_layers", 1, 4)) if MIXFP_F_DIM > 0 else (0, 0)
         interaction_hidden_dim = trial.suggest_categorical("interaction_hidden_dim", [32, 64, 128]) if (gnn_layers > 0 and fp_layers >= 0) and (ATOM_F_DIM > 0 and MIXFP_F_DIM > 0) else 0


         model = MultiTaskPredictor(
             node_feature_dim=ATOM_F_DIM, mixfp_feature_dim=MIXFP_F_DIM, prop_selector_dim=PROP_SELECTOR_DIM,
             gnn_hidden_dim=gnn_hidden_dim, gnn_layers=gnn_layers,
             fp_hidden_dim=fp_hidden_dim, fp_layers=fp_layers,
             interaction_hidden_dim=interaction_hidden_dim,
             fusion_hidden_dim=fusion_hidden_dim,
             dropout=dropout
         ).to(device)

         optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
         scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor=0.5, patience=5)
         loss_fn = torch.nn.MSELoss()

         train_drop_last_optuna = len(optuna_train_dataset) > 1 and len(optuna_train_dataset) % batch_size == 1
         temp_train_loader = DataLoader(optuna_train_dataset, batch_size=batch_size, shuffle=True, drop_last=train_drop_last_optuna)

         temp_val_loader = DataLoader(optuna_val_dataset, batch_size=256, shuffle=False) if len(optuna_val_dataset) > 0 else None

         optuna_epochs, optuna_patience = 150, 15
         best_optuna_val_loss, epochs_no_improve = float('inf'), 0

         for epoch in range(1, optuna_epochs + 1):
             train_loss_scaled = train_epoch(model, temp_train_loader, optimizer, loss_fn, device)
             current_val_loss = float('inf')

             if temp_val_loader is not None and len(temp_val_loader) > 0:
                 val_loss_scaled, _, _, _ = evaluate(model, temp_val_loader, loss_fn, device, fold_scalers, TARGET_COLS)
                 current_val_loss = val_loss_scaled

                 scheduler.step(current_val_loss)

                 if current_val_loss < best_optuna_val_loss:
                     best_optuna_val_loss = current_val_loss; epochs_no_improve = 0
                 else:
                     epochs_no_improve += 1
                     if epochs_no_improve >= optuna_patience: break
             elif len(optuna_val_dataset) == 0:
                  current_val_loss = train_loss_scaled
                  best_optuna_val_loss = current_val_loss
                  epochs_no_improve = 0

             trial.report(current_val_loss, epoch)
             if trial.should_prune(): raise optuna.exceptions.TrialPruned()

         return best_optuna_val_loss


    start_time_optuna = time.time()
    study = optuna.create_study(direction='minimize', pruner=optuna.pruners.MedianPruner())
    study.optimize(objective, n_trials=N_TRIALS_OPTUNA, n_jobs=1)
    end_time_optuna = time.time()
    print(f"Fold {fold + 1}: Optuna finished. Duration: {end_time_optuna - start_time_optuna:.2f} sec")

    current_fold_best_params = study.best_trial.params
    if ATOM_F_DIM == 0 or current_fold_best_params.get('gnn_layers', 0) == 0: current_fold_best_params['gnn_hidden_dim'], current_fold_best_params['gnn_layers'] = 0, 0
    if MIXFP_F_DIM == 0: current_fold_best_params['fp_hidden_dim'], current_fold_best_params['fp_layers'] = 0, 0
    if not (ATOM_F_DIM > 0 and MIXFP_F_DIM > 0 and current_fold_best_params.get('gnn_layers', 0) > 0 and current_fold_best_params.get('fp_layers', -1) >= 0):
         current_fold_best_params['interaction_hidden_dim'] = 0

    fold_best_hyperparams.append(current_fold_best_params)
    print(f"Fold {fold + 1}: Optuna best hyperparameters:\n {current_fold_best_params}")
    print(f"Fold {fold + 1}: Best validation loss (Average MSE scaled): {study.best_trial.value:.6f}")

    # --- Final Fold Training ---
    print(f"\n--- Fold {fold + 1}: Training final model ---")

    if len(train_df_fold) == 0:
         print(f"  Fold {fold + 1}: No training data available for final training. Skipping model save.")
         continue

    train_dataset_fold = MultiTaskMoleculeDataset(train_df_fold, smiles_data_map, PROP_MAP)
    val_dataset_fold = MultiTaskMoleculeDataset(val_df_fold, smiles_data_map, PROP_MAP)

    fold_batch_size = current_fold_best_params.get('batch_size', 64)

    train_drop_last_fold = len(train_dataset_fold) > 1 and len(train_dataset_fold) % fold_batch_size == 1
    train_loader_fold = DataLoader(train_dataset_fold, batch_size=fold_batch_size, shuffle=True, drop_last=train_drop_last_fold)

    val_loader_fold = DataLoader(val_dataset_fold, batch_size=256, shuffle=False) if len(val_dataset_fold) > 0 else None

    final_model = MultiTaskPredictor(
        node_feature_dim=ATOM_F_DIM, mixfp_feature_dim=MIXFP_F_DIM, prop_selector_dim=PROP_SELECTOR_DIM,
        gnn_hidden_dim=current_fold_best_params['gnn_hidden_dim'],
        gnn_layers=current_fold_best_params['gnn_layers'],
        fp_hidden_dim=current_fold_best_params['fp_hidden_dim'],
        fp_layers=current_fold_best_params['fp_layers'],
        interaction_hidden_dim=current_fold_best_params['interaction_hidden_dim'],
        fusion_hidden_dim=current_fold_best_params['fusion_hidden_dim'],
        dropout=current_fold_best_params['dropout']
    ).to(device)

    optimizer_final = torch.optim.Adam(final_model.parameters(), lr=current_fold_best_params['lr'], weight_decay=current_fold_best_params['weight_decay'])
    scheduler_final = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer_final, 'min', factor=0.5, patience=max(1, PATIENCE_FOLD_TRAINING // 2))
    loss_fn_final = torch.nn.MSELoss()

    fold_best_final_val_loss, fold_final_epochs_no_improve, best_epoch_fold = float('inf'), 0, 0
    best_model_fold_state = None
    train_start_time_fold = time.time()

    for epoch in range(1, N_EPOCHS_FOLD_TRAINING + 1):
        train_loss_scaled = train_epoch(final_model, train_loader_fold, optimizer_final, loss_fn_final, device)
        epoch_summary_str = f" F{fold + 1} E{epoch:03d}/{N_EPOCHS_FOLD_TRAINING} | Tr L(S): {train_loss_scaled:.5f}"

        if val_loader_fold is not None and len(val_loader_fold) > 0:
            val_loss_scaled, val_metrics, _, _ = evaluate(final_model, val_loader_fold, loss_fn_final, device, fold_scalers, TARGET_COLS)
            epoch_summary_str += f" | V L(S): {val_loss_scaled:.5f}"
            if TARGET_COLS and len(val_metrics) > 0:
                 first_prop = TARGET_COLS[0]
                 if first_prop in val_metrics and 'MAE' in val_metrics[first_prop] and not np.isnan(val_metrics[first_prop]['MAE']):
                     epoch_summary_str += f" | V {first_prop} MAE: {val_metrics[first_prop]['MAE']:.4f}"
                 else:
                      for prop in TARGET_COLS:
                           if prop in val_metrics and val_metrics[prop]['count'] > 0 and not np.isnan(val_metrics[prop]['MAE']):
                                epoch_summary_str += f" | V {prop} MAE: {val_metrics[prop]['MAE']:.4f}"
                                break

            scheduler_final.step(val_loss_scaled)
            if val_loss_scaled < fold_best_final_val_loss:
                 fold_best_final_val_loss = val_loss_scaled; fold_final_epochs_no_improve = 0; best_epoch_fold = epoch
                 best_model_fold_state = copy.deepcopy(final_model.state_dict())
            else:
                 fold_final_epochs_no_improve += 1
                 if fold_final_epochs_no_improve >= PATIENCE_FOLD_TRAINING:
                     print(epoch_summary_str); print(f"  Early stopping. Best Epoch: {best_epoch_fold}"); break
        else:
             best_model_fold_state = copy.deepcopy(final_model.state_dict()); best_epoch_fold = epoch

        if epoch % 10 == 0 or epoch == N_EPOCHS_FOLD_TRAINING: print(epoch_summary_str)

    train_end_time_fold = time.time()
    print(f"  Fold {fold + 1} training finished. Duration: {train_end_time_fold - train_start_time_fold:.2f} sec.")

    current_fold_model_path = os.path.join(BEST_MODEL_SAVE_DIR, f'model_fold_{fold + 1}_best.pth')
    if best_model_fold_state is not None:
        torch.save(best_model_fold_state, current_fold_model_path)
        print(f"  Fold {fold + 1} best model saved to: {current_fold_model_path} (from epoch {best_epoch_fold})")
        fold_model_paths.append(current_fold_model_path)
    else:
        print(f"  Fold {fold + 1}: No model state saved, skipping model path.")

Using device: cuda
PolyGNN Atom feature dimension: 80
PolyGNN Bond feature dimension: 6
MixFP fingerprint dimension: 1048
Loading data...
Data loaded and cleaned, shape: (16420, 3)
Detected target properties: ['SP', 'Tc', 'Td', 'Tg', 'Tm']
Number of tasks (properties): 5

Pre-generating PyG Data objects for unique SMILES...
Pre-generated data not found. Generating...


Ge


Generated data for 10755 unique SMILES.
Saving pre-generated data to ./best_models_gine_mixfp_multitask_v3\pregenerated_data.joblib


[I 2025-05-12 10:31:31,100] A new study created in memory with name: no-name-6f9dfebf-911b-42a7-8789-fb3d683a0397


Filtered dataframe based on successful data generation, shape: (16420, 3)

Splitting data (stratified by property)...
Train/Val set size: 13136, Test set size: 3284
Train/Val property distribution:
 prop
SP    0.042707
Tc    0.026035
Td    0.311663
Tg    0.442296
Tm    0.177299
Name: proportion, dtype: float64
Test property distribution:
 prop
SP    0.042935
Tc    0.025883
Td    0.311815
Tg    0.442144
Tm    0.177223
Name: proportion, dtype: float64

Fitting main scalers (MinMaxScaler) per property...
Main scalers and target columns saved to: ./best_models_gine_mixfp_multitask_v3\multitask_minmax_scalers.joblib, ./best_models_gine_mixfp_multitask_v3\target_cols_multitask.joblib

--- Starting 5-Fold Stratified CV ---

===== Fold 1/5 =====
Fold Train size: 10508, Fold Val size: 2628

--- Fold 1: Fitting fold scalers per property ---
Fold Train size after dropping unscalable samples: 10508
Fold Val size after dropping unscalable samples: 2628

--- Fold 1: Running Optuna ---


[I 2025-05-12 10:33:13,273] Trial 0 finished with value: 0.017327703500111594 and parameters: {'lr': 4.3764485735513555e-05, 'weight_decay': 7.809930704286869e-06, 'dropout': 0.4084202758591543, 'batch_size': 128, 'fusion_hidden_dim': 512, 'gnn_hidden_dim': 32, 'gnn_layers': 2, 'fp_hidden_dim': 512, 'fp_layers': 3, 'interaction_hidden_dim': 64}. Best is trial 0 with value: 0.017327703500111594.
[I 2025-05-12 10:37:35,307] Trial 1 finished with value: 0.0068787906051291725 and parameters: {'lr': 0.0004228989678120472, 'weight_decay': 5.763601079050454e-06, 'dropout': 0.30211311728566836, 'batch_size': 128, 'fusion_hidden_dim': 256, 'gnn_hidden_dim': 32, 'gnn_layers': 3, 'fp_hidden_dim': 256, 'fp_layers': 4, 'interaction_hidden_dim': 32}. Best is trial 1 with value: 0.0068787906051291725.
[I 2025-05-12 10:42:16,493] Trial 2 finished with value: 0.00701896073677702 and parameters: {'lr': 0.00036531980776648355, 'weight_decay': 1.5525776127379146e-06, 'dropout': 0.34974794485206195, 'batch

Fold 1: Optuna finished. Duration: 2802.22 sec
Fold 1: Optuna best hyperparameters:
 {'lr': 0.0009733776929225491, 'weight_decay': 4.230581572624207e-06, 'dropout': 0.15422077302580298, 'batch_size': 128, 'fusion_hidden_dim': 256, 'gnn_hidden_dim': 32, 'gnn_layers': 4, 'fp_hidden_dim': 256, 'fp_layers': 4, 'interaction_hidden_dim': 32}
Fold 1: Best validation loss (Average MSE scaled): 0.005348

--- Fold 1: Training final model ---
 F1 E010/250 | Tr L(S): 0.01063 | V L(S): 0.01083 | V SP MAE: 51.3563
 F1 E020/250 | Tr L(S): 0.00723 | V L(S): 0.00701 | V SP MAE: 40.6567
 F1 E030/250 | Tr L(S): 0.00583 | V L(S): 0.00677 | V SP MAE: 40.5349
 F1 E040/250 | Tr L(S): 0.00504 | V L(S): 0.00562 | V SP MAE: 36.9266
 F1 E050/250 | Tr L(S): 0.00470 | V L(S): 0.00620 | V SP MAE: 38.1552
 F1 E060/250 | Tr L(S): 0.00412 | V L(S): 0.00546 | V SP MAE: 37.5301
 F1 E070/250 | Tr L(S): 0.00394 | V L(S): 0.00627 | V SP MAE: 36.5227
 F1 E080/250 | Tr L(S): 0.00434 | V L(S): 0.00559 | V SP MAE: 35.3901
 F1 

[I 2025-05-12 11:26:51,039] A new study created in memory with name: no-name-19511b5b-294c-4c07-a352-e2fccb032d75


 F1 E191/250 | Tr L(S): 0.00145 | V L(S): 0.00420 | V SP MAE: 30.5450
  Early stopping. Best Epoch: 161
  Fold 1 training finished. Duration: 517.56 sec.
  Fold 1 best model saved to: ./best_models_gine_mixfp_multitask_v3\model_fold_1_best.pth (from epoch 161)

===== Fold 2/5 =====
Fold Train size: 10509, Fold Val size: 2627

--- Fold 2: Fitting fold scalers per property ---
Fold Train size after dropping unscalable samples: 10509
Fold Val size after dropping unscalable samples: 2627

--- Fold 2: Running Optuna ---


[I 2025-05-12 11:33:48,306] Trial 0 finished with value: 0.01763689745976089 and parameters: {'lr': 2.6026002493402733e-05, 'weight_decay': 1.4459507379353897e-05, 'dropout': 0.38997659152464814, 'batch_size': 128, 'fusion_hidden_dim': 256, 'gnn_hidden_dim': 256, 'gnn_layers': 4, 'fp_hidden_dim': 256, 'fp_layers': 1, 'interaction_hidden_dim': 128}. Best is trial 0 with value: 0.01763689745976089.
[I 2025-05-12 11:36:27,766] Trial 1 finished with value: 0.023089348362469876 and parameters: {'lr': 2.5824444261579055e-05, 'weight_decay': 7.389741765055624e-05, 'dropout': 0.4500882884233117, 'batch_size': 64, 'fusion_hidden_dim': 128, 'gnn_hidden_dim': 32, 'gnn_layers': 3, 'fp_hidden_dim': 1024, 'fp_layers': 2, 'interaction_hidden_dim': 32}. Best is trial 0 with value: 0.01763689745976089.
[I 2025-05-12 11:41:06,828] Trial 2 finished with value: 0.005692866564170944 and parameters: {'lr': 0.0004351188867563945, 'weight_decay': 2.999758373587731e-06, 'dropout': 0.2771296810111603, 'batch_si

Fold 2: Optuna finished. Duration: 5403.45 sec
Fold 2: Optuna best hyperparameters:
 {'lr': 0.0005520319369083573, 'weight_decay': 7.960133733548129e-06, 'dropout': 0.14451703122058876, 'batch_size': 128, 'fusion_hidden_dim': 128, 'gnn_hidden_dim': 128, 'gnn_layers': 4, 'fp_hidden_dim': 512, 'fp_layers': 2, 'interaction_hidden_dim': 32}
Fold 2: Best validation loss (Average MSE scaled): 0.005099

--- Fold 2: Training final model ---
 F2 E010/250 | Tr L(S): 0.01426 | V L(S): 0.01228 | V SP MAE: 48.1482
 F2 E020/250 | Tr L(S): 0.00902 | V L(S): 0.00920 | V SP MAE: 43.1212
 F2 E030/250 | Tr L(S): 0.00615 | V L(S): 0.00851 | V SP MAE: 40.7687
 F2 E040/250 | Tr L(S): 0.00576 | V L(S): 0.00665 | V SP MAE: 41.1809
 F2 E050/250 | Tr L(S): 0.00471 | V L(S): 0.00604 | V SP MAE: 39.4237
 F2 E060/250 | Tr L(S): 0.00439 | V L(S): 0.00617 | V SP MAE: 37.6305
 F2 E070/250 | Tr L(S): 0.00391 | V L(S): 0.00611 | V SP MAE: 36.2376
 F2 E080/250 | Tr L(S): 0.00336 | V L(S): 0.00554 | V SP MAE: 36.1335
 F2

[I 2025-05-12 13:08:01,117] A new study created in memory with name: no-name-209b15d8-5ef5-49f1-ab5a-f0a03e8abf62


 F2 E210/250 | Tr L(S): 0.00123 | V L(S): 0.00468 | V SP MAE: 36.1258
  Early stopping. Best Epoch: 180
  Fold 2 training finished. Duration: 666.50 sec.
  Fold 2 best model saved to: ./best_models_gine_mixfp_multitask_v3\model_fold_2_best.pth (from epoch 180)

===== Fold 3/5 =====
Fold Train size: 10509, Fold Val size: 2627

--- Fold 3: Fitting fold scalers per property ---
Fold Train size after dropping unscalable samples: 10509
Fold Val size after dropping unscalable samples: 2627

--- Fold 3: Running Optuna ---


[I 2025-05-12 13:11:50,248] Trial 0 finished with value: 0.02514918297741892 and parameters: {'lr': 3.56600614763924e-05, 'weight_decay': 5.263335943985726e-06, 'dropout': 0.36800556082648483, 'batch_size': 64, 'fusion_hidden_dim': 64, 'gnn_hidden_dim': 128, 'gnn_layers': 4, 'fp_hidden_dim': 256, 'fp_layers': 3, 'interaction_hidden_dim': 128}. Best is trial 0 with value: 0.02514918297741892.
[I 2025-05-12 13:17:34,378] Trial 1 finished with value: 0.00550929355664581 and parameters: {'lr': 8.60737798853748e-05, 'weight_decay': 6.903532698361291e-05, 'dropout': 0.11712187600411372, 'batch_size': 32, 'fusion_hidden_dim': 256, 'gnn_hidden_dim': 32, 'gnn_layers': 1, 'fp_hidden_dim': 1024, 'fp_layers': 2, 'interaction_hidden_dim': 64}. Best is trial 1 with value: 0.00550929355664581.
[I 2025-05-12 13:19:41,820] Trial 2 finished with value: 0.010440456307892795 and parameters: {'lr': 4.9955111691236275e-05, 'weight_decay': 5.402478911933109e-05, 'dropout': 0.1995128935121068, 'batch_size': 1

Fold 3: Optuna finished. Duration: 3819.14 sec
Fold 3: Optuna best hyperparameters:
 {'lr': 8.60737798853748e-05, 'weight_decay': 6.903532698361291e-05, 'dropout': 0.11712187600411372, 'batch_size': 32, 'fusion_hidden_dim': 256, 'gnn_hidden_dim': 32, 'gnn_layers': 1, 'fp_hidden_dim': 1024, 'fp_layers': 2, 'interaction_hidden_dim': 64}
Fold 3: Best validation loss (Average MSE scaled): 0.005509

--- Fold 3: Training final model ---
 F3 E010/250 | Tr L(S): 0.01228 | V L(S): 0.01144 | V SP MAE: 44.0870
 F3 E020/250 | Tr L(S): 0.00767 | V L(S): 0.00827 | V SP MAE: 40.8398
 F3 E030/250 | Tr L(S): 0.00564 | V L(S): 0.01565 | V SP MAE: 42.3285
 F3 E040/250 | Tr L(S): 0.00482 | V L(S): 0.00684 | V SP MAE: 37.8271
 F3 E050/250 | Tr L(S): 0.00409 | V L(S): 0.00568 | V SP MAE: 37.0381
 F3 E060/250 | Tr L(S): 0.00388 | V L(S): 0.00658 | V SP MAE: 33.1386
 F3 E070/250 | Tr L(S): 0.00362 | V L(S): 0.00647 | V SP MAE: 37.0322
 F3 E080/250 | Tr L(S): 0.00326 | V L(S): 0.00552 | V SP MAE: 31.0699
 F3 E

[I 2025-05-12 14:30:48,242] A new study created in memory with name: no-name-90ad06b7-f77f-4746-a175-13ba0061054e


 F3 E229/250 | Tr L(S): 0.00139 | V L(S): 0.00458 | V SP MAE: 29.2085
  Early stopping. Best Epoch: 199
  Fold 3 training finished. Duration: 1147.81 sec.
  Fold 3 best model saved to: ./best_models_gine_mixfp_multitask_v3\model_fold_3_best.pth (from epoch 199)

===== Fold 4/5 =====
Fold Train size: 10509, Fold Val size: 2627

--- Fold 4: Fitting fold scalers per property ---
Fold Train size after dropping unscalable samples: 10509
Fold Val size after dropping unscalable samples: 2627

--- Fold 4: Running Optuna ---


[I 2025-05-12 14:35:32,061] Trial 0 finished with value: 0.010883994106306109 and parameters: {'lr': 1.53529146628892e-05, 'weight_decay': 6.177681746502137e-06, 'dropout': 0.15963447435887332, 'batch_size': 128, 'fusion_hidden_dim': 64, 'gnn_hidden_dim': 32, 'gnn_layers': 4, 'fp_hidden_dim': 512, 'fp_layers': 3, 'interaction_hidden_dim': 64}. Best is trial 0 with value: 0.010883994106306109.
[I 2025-05-12 14:41:47,486] Trial 1 finished with value: 0.005009107056430955 and parameters: {'lr': 0.0005916049746876524, 'weight_decay': 6.9141419954794405e-06, 'dropout': 0.14756607171305475, 'batch_size': 64, 'fusion_hidden_dim': 512, 'gnn_hidden_dim': 64, 'gnn_layers': 5, 'fp_hidden_dim': 128, 'fp_layers': 2, 'interaction_hidden_dim': 128}. Best is trial 1 with value: 0.005009107056430955.
[I 2025-05-12 14:52:15,293] Trial 2 finished with value: 0.008135188152068972 and parameters: {'lr': 0.00017640254686812477, 'weight_decay': 4.376744164109263e-05, 'dropout': 0.23999763764610363, 'batch_si

Fold 4: Optuna finished. Duration: 4863.18 sec
Fold 4: Optuna best hyperparameters:
 {'lr': 0.00042154979130198494, 'weight_decay': 1.659471617105183e-06, 'dropout': 0.13855465762222516, 'batch_size': 64, 'fusion_hidden_dim': 128, 'gnn_hidden_dim': 64, 'gnn_layers': 1, 'fp_hidden_dim': 128, 'fp_layers': 2, 'interaction_hidden_dim': 128}
Fold 4: Best validation loss (Average MSE scaled): 0.004785

--- Fold 4: Training final model ---
 F4 E010/250 | Tr L(S): 0.01365 | V L(S): 0.01063 | V SP MAE: 50.6282
 F4 E020/250 | Tr L(S): 0.00866 | V L(S): 0.00991 | V SP MAE: 47.2972
 F4 E030/250 | Tr L(S): 0.00634 | V L(S): 0.00676 | V SP MAE: 40.7468
 F4 E040/250 | Tr L(S): 0.00540 | V L(S): 0.00606 | V SP MAE: 41.6347
 F4 E050/250 | Tr L(S): 0.00448 | V L(S): 0.00632 | V SP MAE: 37.7267
 F4 E060/250 | Tr L(S): 0.00400 | V L(S): 0.00542 | V SP MAE: 36.8854
 F4 E070/250 | Tr L(S): 0.00377 | V L(S): 0.00561 | V SP MAE: 34.5121
 F4 E080/250 | Tr L(S): 0.00337 | V L(S): 0.00531 | V SP MAE: 33.9822
 F4

[I 2025-05-12 16:00:29,158] A new study created in memory with name: no-name-858d2a23-348a-4594-a22d-9d541ca5f8ee


 F4 E195/250 | Tr L(S): 0.00187 | V L(S): 0.00475 | V SP MAE: 33.0476
  Early stopping. Best Epoch: 165
  Fold 4 training finished. Duration: 517.63 sec.
  Fold 4 best model saved to: ./best_models_gine_mixfp_multitask_v3\model_fold_4_best.pth (from epoch 165)

===== Fold 5/5 =====
Fold Train size: 10509, Fold Val size: 2627

--- Fold 5: Fitting fold scalers per property ---
Fold Train size after dropping unscalable samples: 10509
Fold Val size after dropping unscalable samples: 2627

--- Fold 5: Running Optuna ---


[I 2025-05-12 16:07:22,199] Trial 0 finished with value: 0.005956766324084106 and parameters: {'lr': 0.00031559294561399104, 'weight_decay': 9.450062288052563e-05, 'dropout': 0.40468435373710976, 'batch_size': 64, 'fusion_hidden_dim': 128, 'gnn_hidden_dim': 32, 'gnn_layers': 3, 'fp_hidden_dim': 512, 'fp_layers': 1, 'interaction_hidden_dim': 32}. Best is trial 0 with value: 0.005956766324084106.
[I 2025-05-12 16:08:56,017] Trial 1 finished with value: 0.021631675119291136 and parameters: {'lr': 3.5216853698199856e-05, 'weight_decay': 2.264600647087025e-06, 'dropout': 0.3553321198958326, 'batch_size': 64, 'fusion_hidden_dim': 512, 'gnn_hidden_dim': 128, 'gnn_layers': 5, 'fp_hidden_dim': 1024, 'fp_layers': 3, 'interaction_hidden_dim': 32}. Best is trial 0 with value: 0.005956766324084106.
[I 2025-05-12 16:09:32,131] Trial 2 finished with value: 0.02182381589862418 and parameters: {'lr': 3.959557718503262e-05, 'weight_decay': 5.095304580030697e-05, 'dropout': 0.45913993407846143, 'batch_si

Fold 5: Optuna finished. Duration: 2099.11 sec
Fold 5: Optuna best hyperparameters:
 {'lr': 0.0009263862243706357, 'weight_decay': 9.453274739860927e-05, 'dropout': 0.23233595885106373, 'batch_size': 128, 'fusion_hidden_dim': 64, 'gnn_hidden_dim': 32, 'gnn_layers': 3, 'fp_hidden_dim': 128, 'fp_layers': 2, 'interaction_hidden_dim': 64}
Fold 5: Best validation loss (Average MSE scaled): 0.005777

--- Fold 5: Training final model ---
 F5 E010/250 | Tr L(S): 0.01620 | V L(S): 0.01543 | V SP MAE: 49.8122
 F5 E020/250 | Tr L(S): 0.01147 | V L(S): 0.00918 | V SP MAE: 45.9640
 F5 E030/250 | Tr L(S): 0.00956 | V L(S): 0.00820 | V SP MAE: 42.0860
 F5 E040/250 | Tr L(S): 0.00845 | V L(S): 0.00692 | V SP MAE: 44.0542
 F5 E050/250 | Tr L(S): 0.00767 | V L(S): 0.00664 | V SP MAE: 43.7780
 F5 E060/250 | Tr L(S): 0.00730 | V L(S): 0.00714 | V SP MAE: 41.4413
 F5 E070/250 | Tr L(S): 0.00684 | V L(S): 0.00654 | V SP MAE: 40.0093
 F5 E080/250 | Tr L(S): 0.00690 | V L(S): 0.00659 | V SP MAE: 42.4421
 F5 E

In [3]:
# --- Test Set Evaluation ---
print("\n\n--- Test Set Evaluation ---")
num_valid_models = len(fold_model_paths)
if num_valid_models == 0:
    print("No valid models saved from CV folds. Cannot perform test set evaluation.")
else:
    print(f"Using {num_valid_models} models for ensemble prediction on the test set.")

    main_scalers_loaded = joblib.load(os.path.join(BEST_MODEL_SAVE_DIR, SCALERS_SAVE_FILE))
    loaded_target_cols = joblib.load(os.path.join(BEST_MODEL_SAVE_DIR, TARGET_COLS_FILE))
    print("Main scalers and target columns loaded for test evaluation.")

    print("\nScaling test data using main scalers...")

    test_df_scaled = test_df.copy() 
    test_df_scaled['scaled_value'] = np.nan

    main_scalers_list = main_scalers_loaded 
    for prop_idx, prop in enumerate(loaded_target_cols):
         scaler = main_scalers_list[prop_idx]
         test_prop_mask = test_df_scaled[PROP_COL] == prop
         test_prop_data = test_df_scaled.loc[test_prop_mask, VALUE_COL].values.reshape(-1, 1)

         if hasattr(scaler, 'scale_') and scaler.scale_ is not None and len(test_prop_data) > 0:
             test_df_scaled.loc[test_prop_mask, 'scaled_value'] = scaler.transform(test_prop_data).flatten()

    test_df_scaled.dropna(subset=['scaled_value'], inplace=True) 
    print(f"Test size after scaling and filtering: {len(test_df_scaled)}")

    if len(test_df_scaled) == 0:
        print("No test data available after scaling. Skipping test set evaluation.")
    else:
        test_dataset = MultiTaskMoleculeDataset(test_df_scaled, smiles_data_map, PROP_MAP)
        test_loader = DataLoader(test_dataset, batch_size=256, shuffle=False)
        print(f"Test dataset size (from loader): {len(test_dataset)}")

        models = []
        print("Loading models for ensemble...")

        loaded_fold_hps = {}
        for saved_path in fold_model_paths:
            try:
                fold_num_str = os.path.basename(saved_path).split('_')[-2]
                fold_num = int(fold_num_str)
                if fold_num -1 >= 0 and fold_num - 1 < len(fold_best_hyperparams): 
                    loaded_fold_hps[fold_num] = fold_best_hyperparams[fold_num - 1]
                else:
                    print(f"Warning: Hyperparameters for fold {fold_num} (path: {saved_path}) not found or index out of bounds. Skipping this model.")
            except (ValueError, IndexError) as e:
                print(f"Warning: Could not parse fold number from path {saved_path}: {e}. Skipping this model.")


        loaded_model_paths_filtered = []
        for fold_num, hps in loaded_fold_hps.items():
             model_path = os.path.join(BEST_MODEL_SAVE_DIR, f'model_fold_{fold_num}_best.pth')
             if os.path.exists(model_path):
                  gnn_hidden_dim = hps.get('gnn_hidden_dim', 0)
                  gnn_layers = hps.get('gnn_layers', 0)
                  fp_hidden_dim = hps.get('fp_hidden_dim', 0)
                  fp_layers = hps.get('fp_layers', 0)
                  interaction_hidden_dim = hps.get('interaction_hidden_dim', 0)
                  fusion_hidden_dim = hps.get('fusion_hidden_dim', 128) 
                  dropout = hps.get('dropout', 0.3) 

                  if ATOM_F_DIM == 0 or gnn_layers == 0: gnn_hidden_dim, gnn_layers = 0,0
                  if MIXFP_F_DIM == 0 or fp_layers == 0: fp_hidden_dim, fp_layers = 0,0 
                  if not (ATOM_F_DIM > 0 and MIXFP_F_DIM > 0 and gnn_layers > 0 and fp_layers >=0):
                       interaction_hidden_dim = 0

                  model = MultiTaskPredictor(
                      node_feature_dim=ATOM_F_DIM, mixfp_feature_dim=MIXFP_F_DIM, prop_selector_dim=PROP_SELECTOR_DIM,
                      gnn_hidden_dim=gnn_hidden_dim, gnn_layers=gnn_layers,
                      fp_hidden_dim=fp_hidden_dim, fp_layers=fp_layers,
                      interaction_hidden_dim=interaction_hidden_dim,
                      fusion_hidden_dim=fusion_hidden_dim,
                      dropout=dropout
                  ).to(device)
                  try:
                      model.load_state_dict(torch.load(model_path, map_location=device))
                      model.eval(); models.append(model)
                      loaded_model_paths_filtered.append(model_path)
                  except Exception as e:
                      print(f"Error loading state dict for model {model_path}: {e}. Skipping this model.")
             else:
                 print(f"Warning: Model path {model_path} does not exist. Skipping this model.")


        num_valid_models_after_load = len(models)
        if num_valid_models_after_load == 0:
            print("No models could be loaded. Cannot perform ensemble prediction.")
        else:
            print(f"\nStarting ensemble prediction on test set ({num_valid_models_after_load} models)...")
            all_ensemble_preds_scaled = []
            all_test_targets_scaled = [] 
            all_test_props = []

            with torch.no_grad():
                for data_batch, selector_batch, target_batch_scaled in tqdm(test_loader, desc="Test Prediction"): # MODIFIED: Renamed variable
                    data_batch = data_batch.to(device)
                    selector_batch = selector_batch.to(device)
                    target_batch_scaled = target_batch_scaled.cpu()


                    batch_preds_list_scaled = [model(data_batch, selector_batch) for model in models]
                    stacked_preds_scaled = torch.stack(batch_preds_list_scaled)
                    mean_preds_scaled = torch.mean(stacked_preds_scaled, dim=0)

                    all_ensemble_preds_scaled.append(mean_preds_scaled.cpu().numpy())
                    all_test_targets_scaled.append(target_batch_scaled.numpy()) 

                    prop_indices = torch.argmax(selector_batch.cpu(), dim=1).numpy()
                    batch_props = [loaded_target_cols[i] for i in prop_indices]
                    all_test_props.extend(batch_props)

            final_ensemble_preds_scaled_np = np.concatenate(all_ensemble_preds_scaled, axis=0)
            final_test_targets_scaled_np = np.concatenate(all_test_targets_scaled, axis=0) 

            test_metrics = {}
            main_scalers_list = main_scalers_loaded 

            print("\n--- Test Set Results (Original Scale) ---")
            print(f"(Based on {len(test_dataset)} samples in test_loader)")
            print(f"Ensemble prediction used {num_valid_models_after_load} models.\n")

            for prop_idx, prop in enumerate(loaded_target_cols):
                prop_mask = np.array(all_test_props) == prop
                if np.sum(prop_mask) > 0:
                    current_prop_preds_scaled = final_ensemble_preds_scaled_np[prop_mask]
                    current_prop_targets_scaled = final_test_targets_scaled_np[prop_mask] 

                    scaler = main_scalers_list[prop_idx] 

                    if hasattr(scaler, 'scale_') and scaler.scale_ is not None and scaler.scale_ !=0 : 
                         prop_preds_orig = scaler.inverse_transform(current_prop_preds_scaled)
                         prop_targets_orig = scaler.inverse_transform(current_prop_targets_scaled) 

                         mae_orig = mean_absolute_error(prop_targets_orig, prop_preds_orig)
                         rmse_orig = np.sqrt(mean_squared_error(prop_targets_orig, prop_preds_orig))
                         r2 = r2_score(prop_targets_orig, prop_preds_orig)
                         test_metrics[prop] = {'MAE': mae_orig, 'RMSE': rmse_orig, 'R2': r2, 'count': int(np.sum(prop_mask))}
                    else:
                         print(f"    Warning: Scaler for property '{prop}' is not properly fitted. Cannot compute original scale metrics.")
                         test_metrics[prop] = {'MAE': np.nan, 'RMSE': np.nan, 'R2': np.nan, 'count': int(np.sum(prop_mask))}
                else:
                     test_metrics[prop] = {'MAE': np.nan, 'RMSE': np.nan, 'R2': np.nan, 'count': 0}

                # --- MODIFIED: Print results directly ---
                metrics_prop = test_metrics.get(prop, {})
                print(f"  Property: {prop} (N={metrics_prop.get('count', 0)})")
                if metrics_prop.get('count', 0) > 0:
                    print(f"    R² Score: {metrics_prop.get('R2', float('nan')):.4f}")
                    print(f"    MAE:      {metrics_prop.get('MAE', float('nan')):.4f}")
                    print(f"    RMSE:     {metrics_prop.get('RMSE', float('nan')):.4f}")
                else:
                    print("    No samples for this property in the test set.")
                print("-" * 20)

    HYPERPARAMS_SAVE_PATH = os.path.join(BEST_MODEL_SAVE_DIR, 'fold_best_hyperparams_multitask.joblib')
    joblib.dump(fold_best_hyperparams, HYPERPARAMS_SAVE_PATH)
    print(f"Hyperparameters saved to: {HYPERPARAMS_SAVE_PATH}")

print("\nScript execution finished.")



--- Test Set Evaluation ---
Using 5 models for ensemble prediction on the test set.
Main scalers and target columns loaded for test evaluation.

Scaling test data using main scalers...
Test size after scaling and filtering: 3284
Test dataset size (from loader): 3284
Loading models for ensemble...

Starting ensemble prediction on test set (5 models)...


Test Prediction: 100%|█████████████████████████████████████████████████████████████████| 13/13 [00:00<00:00, 20.15it/s]


--- Test Set Results (Original Scale) ---
(Based on 3284 samples in test_loader)
Ensemble prediction used 5 models.

  Property: SP (N=141)
    R² Score: 0.7708
    MAE:      30.0630
    RMSE:     44.9146
--------------------
  Property: Tc (N=85)
    R² Score: 0.7970
    MAE:      31.1843
    RMSE:     40.2944
--------------------
  Property: Td (N=1024)
    R² Score: 0.8519
    MAE:      26.1224
    RMSE:     38.4488
--------------------
  Property: Tg (N=1452)
    R² Score: 0.9347
    MAE:      19.4773
    RMSE:     27.7733
--------------------
  Property: Tm (N=582)
    R² Score: 0.8713
    MAE:      24.5684
    RMSE:     34.4425
--------------------
Hyperparameters saved to: ./best_models_gine_mixfp_multitask_v3\fold_best_hyperparams_multitask.joblib

Script execution finished.



