In [None]:
import os
import torch
import numpy as np
import pandas as pd
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader
from tqdm import tqdm
import hashlib

from torch_geometric.nn import GCNConv, global_mean_pool

class SimpleGNN(torch.nn.Module):
    def __init__(self, node_dim, edge_dim, global_dim, hidden_dims, dropout=0.2):
        super().__init__()
        self.node_norm = torch.nn.BatchNorm1d(node_dim)
        if edge_dim:
            self.edge_norm = torch.nn.BatchNorm1d(edge_dim)
        if global_dim:
            self.global_norm = torch.nn.BatchNorm1d(global_dim)
            self.global_mlp = torch.nn.Sequential(
                torch.nn.Linear(global_dim, hidden_dims[-1]),
                torch.nn.ReLU(),
                torch.nn.Dropout(dropout)
            )
        self.convs = torch.nn.ModuleList()
        in_dim = node_dim
        for h_dim in hidden_dims:
            self.convs.append(GCNConv(in_dim, h_dim))
            in_dim = h_dim
        self.dropout = torch.nn.Dropout(dropout)
        self.final_dim = hidden_dims[-1] * (2 if global_dim else 1)
        self.output_mlp = torch.nn.Sequential(
            torch.nn.Linear(self.final_dim, self.final_dim // 2),
            torch.nn.ReLU(),
            torch.nn.Dropout(dropout),
            torch.nn.Linear(self.final_dim // 2, 1)
        )

    def forward(self, data, return_feat=False):
        x = self.node_norm(data.x)
        if hasattr(data, 'edge_attr') and data.edge_attr is not None:
            _ = self.edge_norm(data.edge_attr)
        u = data.u if hasattr(data, 'u') else None
        if u is not None:
            u = self.global_norm(u)
        
        for conv in self.convs:
            x = torch.nn.functional.relu(conv(x, data.edge_index))
            x = self.dropout(x)
        
        node_pool = global_mean_pool(x, data.batch)
        h = torch.cat([node_pool, self.global_mlp(u)], dim=1) if u is not None else node_pool
        out = self.output_mlp(h).squeeze(-1)
        return (out, h) if return_feat else out

def get_model_hash(model):
    params = []
    for k, v in model.state_dict().items():
        if "num_batches_tracked" in k:
            continue
        params.append(v.detach().cpu().float().view(-1))
    params = torch.cat(params).numpy()
    return hashlib.md5(params.tobytes()).hexdigest()

def create_loader(graph_data, batch_size=32):
    data_list = []
    for graph in graph_data:
        x = graph['x'].float()
        edge_index = graph['edge_index'].long()
        
        edge_attr = graph.get('edge_attr', None)
        if edge_attr is not None:
            edge_attr = edge_attr.float()
        u = graph.get('u', None)
        if u is not None:
            u = u.float()
        
        data_list.append(Data(x=x, edge_index=edge_index, edge_attr=edge_attr, u=u))
    return DataLoader(data_list, batch_size=batch_size, shuffle=False, num_workers=0)

def predict_single_model(model_path, data_dir):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model_name = os.path.basename(model_path).replace('.pt', '')
    print(f"===== Processing Model: {model_name} =====")
    print(f"Using Device: {device}")

    checkpoint = torch.load(model_path, map_location=device)
    batch_size = checkpoint.get('batch_size', 64)
    print(f"Using Batch Size: {batch_size}")
    
    print(f"Checkpoint Keys: {list(checkpoint.keys())}")

    model = SimpleGNN(
        node_dim=checkpoint['node_dim'],
        edge_dim=checkpoint['edge_dim'],
        global_dim=checkpoint['global_dim'],
        hidden_dims=checkpoint['hidden_dims'],
        dropout=checkpoint['dropout']
    ).to(device)
    
    if 'student_state_dict' in checkpoint:
        model.load_state_dict(checkpoint['student_state_dict'], strict=True)
    elif 'state_dict' in checkpoint:
        model.load_state_dict(checkpoint['state_dict'], strict=True)
    elif 'model_state_dict' in checkpoint:
        model.load_state_dict(checkpoint['model_state_dict'], strict=True)
    else:
        for key in checkpoint.keys():
            if 'state' in key.lower() and isinstance(checkpoint[key], dict):
                model.load_state_dict(checkpoint[key], strict=True)
                print(f"Loaded model parameters using key: '{key}'")
                break
    
    model.eval()

    current_hash = get_model_hash(model)
    param_hash_key = 'student_param_hash' if 'student_param_hash' in checkpoint else 'param_hash' if 'param_hash' in checkpoint else None
    if param_hash_key and param_hash_key in checkpoint:
        print(f"✅ Model Parameter Hash: {current_hash[:8]}... (Compared with training hash {checkpoint[param_hash_key][:8]}...)")
    else:
        print(f"✅ Model Parameter Hash: {current_hash[:8]}...")

    data_path = os.path.join(data_dir, "graph_data.pt")
    assert os.path.exists(data_path), f"Feature data not found: {data_path}"
    graph_data = torch.load(data_path, map_location='cpu')
    print(f"Loaded data from {data_dir}, total {len(graph_data)} samples")

    data_loader = create_loader(graph_data, batch_size=batch_size)

    all_preds = []
    with torch.no_grad():
        for batch in tqdm(data_loader, desc="Predicting"):
            batch = batch.to(device)
            pred_norm = model(batch)
            all_preds.append(pred_norm.cpu().numpy())

    preds_norm = np.concatenate(all_preds)
    y_mean, y_std = checkpoint['y_mean'], checkpoint['y_std']
    preds_raw = preds_norm * y_std + y_mean
    print(f"✅ Prediction Completed! Generated {len(preds_raw)} prediction values\n")

    smiles_list = [graph['smiles'] for graph in graph_data]
    
    return model_name, preds_raw, smiles_list

if __name__ == "__main__":
    # Input/Output Path Configuration (modify according to actual needs)
    model_dir = "PATH/TO/MODEL_DIRECTORY"
    data_dir = "PATH/TO/DATA_DIRECTORY"
    output_dir = "PATH/TO/OUTPUT_DIRECTORY"
    
    model_paths = [os.path.join(model_dir, f) for f in os.listdir(model_dir) if f.endswith('.pt')]
    print(f"Found {len(model_paths)} model files, prediction results will be saved to: {output_dir}\n")

    all_predictions = {}
    smiles_list = None

    for i, model_path in enumerate(model_paths, 1):
        try:
            model_name, preds_raw, slist = predict_single_model(model_path, data_dir)
            all_predictions[model_name] = preds_raw
            if smiles_list is None:
                smiles_list = slist
        except Exception as e:
            print(f"❌ Error processing model {os.path.basename(model_path)}: {str(e)}")
            import traceback
            traceback.print_exc()
            print()
        
        if i < len(model_paths):
            print("-" * 80 + "\n")

    if smiles_list is not None and all_predictions:
        pred_df = pd.DataFrame({
            'Sample_Index': np.arange(len(smiles_list)),
            'SMILES': smiles_list
        })
        for model_name, preds in all_predictions.items():
            pred_df[model_name] = preds
        
        os.makedirs(output_dir, exist_ok=True)
        save_path = os.path.join(output_dir, "all_model_predictions_summary_base_P-SMILES_rev(freq3+)_ultra_high_temp_PI_multi_freq_permittivity_20_feat.csv")
        pred_df.to_csv(save_path, index=False, encoding='utf-8-sig')
        print(f"✅ All model prediction results have been summarized and saved to: {save_path}")
    else:
        print("❌ No valid prediction results obtained, cannot generate summary file")
