In [11]:
from alfabet.drawing import draw_mol_outlier
from alfabet.fragment import canonicalize_smiles
from alfabet.neighbors import find_neighbor_bonds
from alfabet.prediction import predict_bdes, check_input














In [5]:
import alfabet
alfabet.__version__

'0.2.2'

In [6]:
import rdkit

In [7]:
rdkit.__version__

'2024.03.5'

In [1]:
import pandas as pd
pd.set_option('display.max_columns', None)
pd.set_option('display.max_rows', None)

In [2]:
import networkx as nx
def create_bde_graph_selective_hs(smiles: str, bde_df) -> nx.Graph:
    """
    Build a NetworkX graph from the *original (heavy-atom)* RDKit Mol:
      - Keep all heavy-atom ring & skeleton bonds from the SMILES.
      - Add new H-X bonds (i.e., only the hydrogens needed) when a row in bde_df indicates
        a predicted bond that doesn't already exist in the heavy-atom Mol.
    
    bde_df is expected to have columns:
       - start_atom, end_atom: integer indexes or placeholders
       - bde_pred, bdfe_pred, etc.: predicted data for each bond
       - possibly bond_index (optional)
    
    Steps:
       1) Parse the SMILES without adding Hs (just once).
       2) Build a base Nx graph with all heavy-atom nodes & edges.
       3) Iterate over bde_df. If the row corresponds to an existing heavy–heavy bond,
          update the Nx edge with predicted data. If the row corresponds to an H–X bond,
          add the H node + edge and store the predictions.
    """

    # 1. Parse the SMILES into an RDKit Mol (no AddHs)
    base_mol = Chem.MolFromSmiles(smiles)
    if base_mol is None:
        # Handle parse error, e.g. return empty graph
        return nx.Graph()

    # 2. Create an Nx graph, optionally store the RDKit Mol for reference
    G = nx.Graph(mol=base_mol)

    # 3. Add heavy-atom nodes
    #    We'll store:
    #      - 'symbol': e.g. 'C', 'O', 'N', etc.
    #      - 'rdkit_idx': the integer index assigned by RDKit
    #    Feel free to store other attributes as well.
    for atom in base_mol.GetAtoms():
        atom_idx = atom.GetIdx()
        G.add_node(atom_idx, 
                   symbol=atom.GetSymbol(),
                   rdkit_idx=atom_idx)

    # 4. Add edges for all heavy-atom bonds in the original (no-H) Mol
    #    We won't attach any BDE predictions yet (set them to None).
    #    We'll also store a default bond_index=None if desired.
    for bond in base_mol.GetBonds():
        a1 = bond.GetBeginAtomIdx()
        a2 = bond.GetEndAtomIdx()
        G.add_edge(a1, a2,
                   bond_index=None,
                   bde_pred=None,
                   bdfe_pred=None)

    # 5. Iterate over bde_df.  We'll assume the columns are something like:
    #     start_atom, end_atom, bde_pred, bdfe_pred, bond_index, etc.
    #    - For heavy–heavy predictions, update the existing edge with predicted data.
    #    - For H–X predictions, add the new hydrogen node & edge if not present.
    #    - This approach assumes that for an H–X bond, either start_atom or end_atom
    #      is a placeholder for hydrogen or an integer representing "H" in your dataset.
    for _, row in bde_df.iterrows():
        s = row['start_atom']
        e = row['end_atom']
        
        # Attempt to interpret s and e in the context of the base mol
        # We'll use a simple rule:
        #  - If the index is >= base_mol.GetNumAtoms(), treat it as "this is a hydrogen"
        #  - Or you could have a special marker like -1 for hydrogen
        #    (depends on how your data is structured)
        
        # We also store predicted data
        bde_pred_value = row.get('bde_pred', None)
        bdfe_pred_value = row.get('bdfe_pred', None)
        bond_index_value = row.get('bond_index', None)
        
        # Convert them to integers if needed
        # (In practice, you may need to handle missing or invalid indexes carefully)
        
        # We'll define a helper function to check if an index is "heavy" or "hydrogen"
        def is_heavy(idx):
            return (0 <= idx < base_mol.GetNumAtoms())
        
        # Determine the "types" of s and e
        s_is_heavy = is_heavy(s)
        e_is_heavy = is_heavy(e)

        if s_is_heavy and e_is_heavy:
            # This is a heavy–heavy bond.
            # If it already exists in G, update attributes.
            if G.has_edge(s, e):
                # Just update the existing edge
                G[s][e]['bde_pred'] = bde_pred_value
                G[s][e]['bdfe_pred'] = bdfe_pred_value
                G[s][e]['bond_index'] = bond_index_value
            else:
                # Possibly -?> no, not possible the bond doesn't exist in the original skeleton 
                # (this can happen if the SMILES didn't have it).
                # Add it as a new edge. This is unusual, but let's handle it anyway.
                G.add_edge(s, e,
                           bond_index=bond_index_value,
                           bde_pred=bde_pred_value,
                           bdfe_pred=bdfe_pred_value)

        else:
            # At least one of them is a "hydrogen" or out-of-range index
            # We'll figure out which one is the heavy atom and which is the hydrogen.
            if s_is_heavy and not e_is_heavy:
                heavy_idx, hydrogen_idx = s, e
            elif e_is_heavy and not s_is_heavy:
                heavy_idx, hydrogen_idx = e, s
            else:
                # Both are hydrogens or out-of-range, which might be invalid.
                # For safety, just skip or handle error.
                # Could print a warning, raise an exception, etc.
                continue

            # Step 1: ensure the hydrogen node is present in G
            # We'll generate a unique node key for the H, e.g. "H_{hydrogen_idx}"
            # or something that won't collide with integer-based heavy nodes.
            # You could also store the actual integer if your system allows it.
            h_node = f"H_{hydrogen_idx}"
            if not G.has_node(h_node):
                # Add the hydrogen node with minimal attributes
                G.add_node(h_node,
                           symbol='H',
                           rdkit_idx=None)  # or some other placeholder

            # Step 2: add the H–X bond or update if it already exists
            # The heavy_idx is the integer from RDKit.
            if not G.has_edge(heavy_idx, h_node):
                G.add_edge(heavy_idx, h_node,
                           bond_index=bond_index_value,
                           bde_pred=bde_pred_value,
                           bdfe_pred=bdfe_pred_value)
            else:
                # If it somehow exists, just update attributes
                G[heavy_idx][h_node]['bde_pred'] = bde_pred_value
                G[heavy_idx][h_node]['bdfe_pred'] = bdfe_pred_value
                G[heavy_idx][h_node]['bond_index'] = bond_index_value

    return G


In [12]:
def graph_to_df(bde_graph: nx.Graph) -> pd.DataFrame:
    """
    Convert the edges of bde_graph into a DataFrame with columns:
      ['u', 'v', 'bond_index', 'graph_bde_pred', 'graph_bdfe_pred'].
    """
    rows = []
    for u, v, data in bde_graph.edges(data=True):
        rows.append({
            'u': u,
            'v': v,
            'bond_index': data['bond_index'],
            'graph_bde_pred': data.get('bde_pred', None),
            'graph_bdfe_pred': data.get('bdfe_pred', None)
        })
    return pd.DataFrame(rows)

In [4]:
smiles_list = ['C1CC([C@H]3[C@@](C1)(C)[C@H]2CC[C@H](C)[C@H]([C@@]2(CC3)C)CCCC)(C)C',
       'C1CC([C@H]3[C@@](C1)(C)[C@H]2CC[C@H](C)[C@H]([C@@]2(CC3)C)CCC(C)C)(C)C',
       'C1CC([C@H]3[C@@](C1)(C)[C@H]2CC[C@H](C)[C@H]([C@@]2(CC3)C)CC[C@@H](C)CC)(C)C',
       'C1CC([C@H]3[C@@](C1)(C)[C@H]2CC[C@H](C)[C@H]([C@@]2(CC3)C)CC[C@H](CCC)C)(C)C',
       'C1CC([C@H]3[C@@](C1)(C)[C@H]2CC[C@H](C)[C@H]([C@]2(C)CC3)CC[C@H](C)CCCCC)(C)C',
       'C(CCC)C[C@H](C)CC[C@@H]1[C@H](CC[C@H]2[C@]1(CC[C@@H]3[C@@]2(CCCC3(C)C)C)C)C',
       'C1CC([C@H]3[C@@](C1)(C)[C@H]2CC[C@H](C)[C@H]([C@@]2(CC3)C)CC[C@@H](CCCC(C)C)C)(C)C',
       'C(C[C@@H](CC[C@H]1[C@]3([C@H](CC[C@@H]1C)[C@]2(CCCC(C)(C)[C@@H]2CC3)C)C)C)CC(C)C',
       '[C@]23(CC[C@@H]1[C@@](CCCC1(C)C)(C)[C@H]2CC[C@H]4[C@]3(CC[C@]5([C@@H]4CCC5)C)C)C',
       '[C@]12(CC[C@@H]5[C@@]([C@H]1CC[C@H]3[C@@]2(C)CC[C@H]4[C@@]3(CCC4)C)(CCCC5(C)C)C)C',
       'CC[C@@H]1CC[C@]2(C1CCC3(C2CCC4C3(CCC5C4(CCCC5(C)C)C)C)C)C',
       'CCC[C@@H]1CC[C@]2(C1CCC3(C2CCC4C3(CCC5C4(CCCC5(C)C)C)C)C)C',
       'CCC(C)[C@@H]1CC[C@]2(C1CCC3(C2CCC4C3(CCC5C4(CCCC5(C)C)C)C)C)C',
       'CC[C@@H](C)[C@@H]1CC[C@]2(C1CCC3(C2CCC4C3(CCC5C4(CCCC5(C)C)C)C)C)C',
       'CCCC(C)[C@@H]1CC[C@]2(C1CCC3(C2CCC4C3(CCC5C4(CCCC5(C)C)C)C)C)C',
       'CCC[C@@H](C)[C@@H]1CC[C@]2(C1CCC3(C2CCC4C3(CCC5C4(CCCC5(C)C)C)C)C)C',
       'CCCCC(C)[C@@H]1CC[C@]2(C1CCC3(C2CCC4C3(CCC5C4(CCCC5(C)C)C)C)C)C',
       'CCCC[C@@H](C)[C@@H]1CC[C@]2(C1CCC3(C2CCC4C3(CCC5C4(CCCC5(C)C)C)C)C)C',
       'CCCCCC(C)[C@@H]1CC[C@]2(C1CCC3(C2CCC4C3(CCC5C4(CCCC5(C)C)C)C)C)C',
       'CCCCC[C@@H](C)[C@@H]1CC[C@]2(C1CCC3(C2CCC4C3(CCC5C4(CCCC5(C)C)C)C)C)C']

In [5]:
import urllib.parse
def quote(x):
    return urllib.parse.quote(x, safe='')

In [6]:
import networkx as nx
from rdkit import Chem


In [13]:
dfs = []
graphs = []  # Optionally keep a list of graphs if you want them separately

for smiles in smiles_list:
    # 1) Canonicalize and sanity-check input
    can_smiles = canonicalize_smiles(smiles)
    is_outlier, missing_atom, missing_bond = check_input(can_smiles)

    # 2) Get DataFrame of predicted BDE/BDFE for each bond
    bde_df = predict_bdes(can_smiles, draw=True)
    bde_df['raw_smiles'] = smiles

    # 3) Deduplicate and store any extra columns you like
    bde_df = bde_df.drop_duplicates(['fragment1', 'fragment2']).reset_index(drop=True)
    bde_df['smiles_link'] = bde_df.molecule.apply(quote)

    # 4) Build a NetworkX graph containing predicted BDE/BDFE
    bde_graph = create_bde_graph_selective_hs(can_smiles, bde_df)

    # 5) (Optional) store the graph in the DataFrame if you want
    #    the same graph for all rows (one per entire molecule)
    bde_df['nx_graph'] = [bde_graph] * len(bde_df)

    # 6) Append to your results
    dfs.append(bde_df)
    graphs.append(bde_graph)   # In case you want them in parallel




In [62]:
# Merge all DataFrame results
alfabet_results_022 = pd.concat(dfs, ignore_index=True)


In [63]:
graph_to_df(graphs[0])

Unnamed: 0,u,v,bond_index,graph_bde_pred,graph_bdfe_pred
0,0,1,0.0,89.382645,75.711853
1,0,H_23,25.0,100.077187,91.049133
2,1,2,1.0,85.872467,71.412849
3,1,H_27,29.0,97.163109,87.689636
4,2,3,2.0,85.041306,70.000275
5,2,H_28,30.0,95.392189,86.257256
6,3,4,3.0,83.115479,66.99527
7,3,H_30,32.0,94.518456,84.748627
8,4,5,,,
9,4,10,,,


In [67]:
import pandas as pd

# Load the environmental data from Excel
env_file = r"C:\Users\80710\OneDrive - Imperial College London\2025 engineering\GNN molecules\graph_pickles\dataset02.xlsx"
env_df = pd.read_excel(env_file, engine='openpyxl')

# Select only the relevant columns for the environment
env_columns = ["temperature", "seawater", "time", "component","concentration", "degradation_rate"]

# Ensure all columns exist in the dataset
env_var = env_df[env_columns].copy()

# Convert categorical "seawater" to numerical (if needed)
env_var["seawater"] = env_var["seawater"].map({"sea": 1, "art": 0})  # Map "sea" → 1, "art" → 0

# Drop rows with missing values
env_var = env_var.dropna().reset_index(drop=True)

# Check if it matches the number of graphs
print(f"Loaded {len(env_var)} environment rows")
print(env_var.head())


Loaded 1023 environment rows
   temperature  seawater  time component  concentration  degradation_rate
0         35.6         1    30       C23             70          0.670914
1         35.6         1    30       C24             70          0.680071
2         35.6         1    30       C25             70          0.655230
3         35.6         1    30       C26             70          0.625193
4         35.6         1    30      C28a             70          0.605853


---

In [113]:
import pandas as pd

# 1) 先检查长度
num_rows = len(alfabet_results_022)
if len(env_var) < num_rows:
    raise ValueError("环境数据行数不足，无法覆盖全部 alfabet_results_022 ！")

# 2) 合并环境数据到 alfabet_results_022
alfabet_results_022["temperature"]       = env_var["temperature"][:num_rows].values
alfabet_results_022["seawater"]          = env_var["seawater"][:num_rows].values
alfabet_results_022["time"]              = env_var["time"][:num_rows].values
alfabet_results_022["concentration"]     = env_var["concentration"][:num_rows].values
alfabet_results_022["degradation_rate"]  = env_var["degradation_rate"][:num_rows].values

print("合并后 DataFrame 列：", alfabet_results_022.columns.tolist())
print("合并后 DataFrame 大小：", alfabet_results_022.shape)


合并后 DataFrame 列： ['molecule', 'bond_index', 'bond_type', 'start_atom', 'end_atom', 'fragment1', 'fragment2', 'is_valid_stereo', 'bde_pred', 'bdfe_pred', 'bde', 'bdfe', 'set', 'svg', 'has_dft_bde', 'raw_smiles', 'smiles_link', 'nx_graph', 'temperature', 'Concentration', 'Time', 'Seawater', 'degradation_rate', 'concentration', 'time', 'seawater']
合并后 DataFrame 大小： (676, 26)


In [114]:
import torch
from torch.utils.data import Dataset

class MoleculeEnvDataset(Dataset):
    """
    返回 (nx_graph, env_features, target) 形式，用于 GNN 训练。
    """
    def __init__(self, df):
        """
        df: 需包含以下列：
          - 'nx_graph': 你的 NetworkX 图
          - 'temperature', 'seawater', 'time', 'concentration'
          - 'degradation_rate'
        """
        self.df = df.reset_index(drop=True)

        # 可根据实际列名决定哪些环境变量放进 env_features
        self.env_cols = ["temperature", "seawater", "time", "concentration"]
        self.target_col = "degradation_rate"

        # 如果需要，可以强制转换为数值型
        for col in self.env_cols + [self.target_col]:
            self.df[col] = pd.to_numeric(self.df[col], errors='coerce')
        self.df = self.df.dropna(subset=self.env_cols + [self.target_col]).reset_index(drop=True)

    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        
        # (1) 取出 NetworkX 图
        nx_graph = row["nx_graph"]
        
        # (2) 环境变量拼成一个张量
        env_features = torch.tensor([
            row["temperature"],
            row["seawater"],
            row["time"],
            row["concentration"]
        ], dtype=torch.float32)

        # (3) target: degradation_rate
        target = torch.tensor(row[self.target_col], dtype=torch.float32)
        
        return nx_graph, env_features, target


In [115]:
from torch.utils.data.dataloader import default_collate

def my_collate(batch):
    """
    batch 是 [(nx_graph1, env1, tgt1), (nx_graph2, env2, tgt2), ...]
    需要把每个元素合并到一起。
    返回： (list_of_nx_graphs, env_tensors, target_tensors)
    """
    nx_graph_list  = [b[0] for b in batch]
    env_list       = [b[1] for b in batch]
    tgt_list       = [b[2] for b in batch]

    # env_list / tgt_list 可以用 default_collate 合并成 shape = (batch_size, ...)
    env_batch = default_collate(env_list)  # (batch_size, 4) 如果有4个env变量
    tgt_batch = default_collate(tgt_list)  # (batch_size,)

    return nx_graph_list, env_batch, tgt_batch


In [119]:
from torch.utils.data import DataLoader

# 1) 构建 Dataset
dataset = MoleculeEnvDataset(alfabet_results_022)

# 2) DataLoader
dataloader = DataLoader(dataset, batch_size=8, shuffle=True, collate_fn=my_collate)

# 3) 初始化你的 GNN 模型

num_node_features = 1  
model = SimpleGraphModel(
    num_node_features=num_node_features,
    env_input_dim=4,       # temperature, seawater, time, concentration
    hidden_dim=128,
    output_dim=1
)
model.to(device)


SimpleGraphModel(
  (env_encoder): EnvPositionalEncoder(
    (mlp): Sequential(
      (0): Linear(in_features=4, out_features=64, bias=True)
      (1): ReLU()
      (2): Linear(in_features=64, out_features=128, bias=True)
    )
  )
  (node_encoder): Linear(in_features=1, out_features=128, bias=True)
  (conv1): GCNConv(128, 128)
  (conv2): GCNConv(128, 128)
  (fc_out): Sequential(
    (0): Linear(in_features=128, out_features=128, bias=True)
    (1): ReLU()
    (2): Linear(in_features=128, out_features=1, bias=True)
  )
)

In [129]:
import torch
from torch_geometric.data import Data
import networkx as nx

def convert_nx_to_pyg(nx_graph):
    """
    直接基于 `create_bde_graph_selective_hs()` 生成的 NetworkX 图转换为 PyG Data 对象。

    参数:
        nx_graph: NetworkX Graph (包含 'symbol' 作为节点属性, 'bde_pred', 'bdfe_pred' 作为边属性)

    返回:
        PyG `Data` 对象:
            - x: (num_nodes, num_features)  -> 节点特征 (原子类型 One-hot)
            - edge_index: (2, num_edges)    -> 边索引
            - edge_attr: (num_edges, 2)     -> 边特征 (BDE, BDFE)
    """
    if not isinstance(nx_graph, nx.Graph):
        raise ValueError("Input must be a valid NetworkX Graph!")

    # **1️⃣ 统一节点索引排序**
    int_nodes = [n for n in nx_graph.nodes if isinstance(n, int)]
    str_nodes = sorted([n for n in nx_graph.nodes if isinstance(n, str) and n.startswith("H_")],
                       key=lambda x: int(x.split("_")[1]))  # "H_5" -> 5
    nodes = sorted(int_nodes) + str_nodes  # 先排序整数节点，再拼接字符串节点

    # **2️⃣ 提取节点特征**
    node_features = []
    for n in nodes:
        symbol = nx_graph.nodes[n].get('symbol', 'C')  # 默认 C
        node_features.append([1 if symbol == 'C' else 0])  # 这里用 1 维特征 (是否是C)

    x = torch.tensor(node_features, dtype=torch.float)  # (num_nodes, 1)

    # **3️⃣ 提取边索引**
    edges = list(nx_graph.edges())
    edge_index_list = []

    for u, v in edges:
        if isinstance(u, str):  # 确保 "H_5" 这种节点转换为整数索引
            u = nodes.index(u)
        if isinstance(v, str):
            v = nodes.index(v)
        edge_index_list.append([u, v])

    if len(edge_index_list) > 0:
        edge_index = torch.tensor(edge_index_list, dtype=torch.long).T  # (2, num_edges)
    else:
        edge_index = torch.empty((2, 0), dtype=torch.long)  # 处理空图

    # **4️⃣ 提取边特征，并转换 `None` 为 `0.0`**
    edge_features = []
    for u, v in edges:
        bde_pred = nx_graph[u][v].get("bde_pred", 0.0)  # 确保数值
        bdfe_pred = nx_graph[u][v].get("bdfe_pred", 0.0)  # 确保数值

        # **确保所有边特征都是 `float` 类型**
        bde_pred = float(bde_pred) if bde_pred is not None else 0.0
        bdfe_pred = float(bdfe_pred) if bdfe_pred is not None else 0.0

        edge_features.append([bde_pred, bdfe_pred])

    edge_attr = torch.tensor(edge_features, dtype=torch.float) if edge_features else None

    # **5️⃣ 构建 PyG Data**
    data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr)

    return data


In [130]:
# 批量转换所有 NetworkX Graphs -> PyG Data
pytorch_graphs = [convert_nx_to_pyg(g) for g in graphs]

# 打印转换结果
print(f"Total converted PyG graphs: {len(pytorch_graphs)}")
print(pytorch_graphs[0])  # 查看第一个 PyG Graph 数据


Total converted PyG graphs: 20
Data(x=[42, 1], edge_index=[2, 44], edge_attr=[44, 2])


In [134]:
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.data import Data, Batch
import torch.optim as optim
from sklearn.metrics import r2_score

def train_loop(model, dataloader, device='cpu', epochs=5):
    optimizer = optim.Adam(model.parameters(), lr=1e-3)
    criterion = nn.MSELoss()

    for epoch in range(1, epochs+1):
        model.train()
        total_loss = 0.0
        total_samples = 0

        for nx_graph_list, env_batch, tgt_batch in dataloader:
            # **1️⃣ 将 NetworkX 列表转换为 PyG Batch**
            pyg_data_list = [convert_nx_to_pyg(nxg) for nxg in nx_graph_list]
            data_batch = Batch.from_data_list(pyg_data_list).to(device)

            # **2️⃣ 送入 GPU**
            env_batch = env_batch.to(device)
            tgt_batch = tgt_batch.to(device)

            optimizer.zero_grad()

            # **3️⃣ 计算前向传播 & 损失**
            out = model(data_batch, env_batch)  # (batch_size, )
            loss = criterion(out, tgt_batch)
            loss.backward()
            optimizer.step()

            total_loss += loss.item() * tgt_batch.size(0)
            total_samples += tgt_batch.size(0)

        # **4️⃣ 计算 MSE & RMSE**
        avg_loss = total_loss / total_samples
        rmse = math.sqrt(avg_loss)

        # **5️⃣ 计算 R²**
        r2_val = evaluate_r2(model, dataloader, device)

        print(f"Epoch {epoch}/{epochs} - MSE: {avg_loss:.4f}, RMSE: {rmse:.4f}, R²: {r2_val:.4f}")

def evaluate_r2(model, dataloader, device):
    """
    计算 $R^2$ 评分来衡量 GNN 预测结果的拟合度。

    参数:
    - model: 训练好的 GNN 模型
    - dataloader: PyTorch DataLoader (用于验证或测试集)
    - device: 设备 ('cpu' 或 'cuda')

    返回:
    - r2_val: R² 分数
    """
    model.eval()
    preds = []
    gts = []

    with torch.no_grad():
        for nx_graph_list, env_batch, tgt_batch in dataloader:
            # **转换 NetworkX → PyG**
            pyg_data_list = [convert_nx_to_pyg(nxg) for nxg in nx_graph_list]
            data_batch = Batch.from_data_list(pyg_data_list).to(device)

            env_batch = env_batch.to(device)
            tgt_batch = tgt_batch.to(device)

            # **前向传播**
            out = model(data_batch, env_batch)

            # **存储预测值 & 真实值**
            preds.append(out.cpu())
            gts.append(tgt_batch.cpu())

    # **拼接所有 batch 数据**
    preds = torch.cat(preds).numpy()
    gts = torch.cat(gts).numpy()

    # **计算 R²**
    r2_val = r2_score(gts, preds)
    return r2_val


In [131]:
# 1) 准备 dataset & dataloader
dataset = MoleculeEnvDataset(alfabet_results_022)
dataloader = DataLoader(dataset, batch_size=8, shuffle=True, collate_fn=my_collate)

# 2) 初始化模型
num_node_features = 1  
model = SimpleGraphModel(
    num_node_features=num_node_features,
    env_input_dim=4,
    hidden_dim=128,
    output_dim=1
).to(device)

# 3) 训练
train_loop(model, dataloader, device=device, epochs=20)


Epoch 1/20 - MSE Loss: 0.4695
Epoch 2/20 - MSE Loss: 0.0753
Epoch 3/20 - MSE Loss: 0.0610
Epoch 4/20 - MSE Loss: 0.0471
Epoch 5/20 - MSE Loss: 0.0637
Epoch 6/20 - MSE Loss: 0.0509
Epoch 7/20 - MSE Loss: 0.0400
Epoch 8/20 - MSE Loss: 0.0423
Epoch 9/20 - MSE Loss: 0.0542
Epoch 10/20 - MSE Loss: 0.0449
Epoch 11/20 - MSE Loss: 0.0471
Epoch 12/20 - MSE Loss: 0.0434
Epoch 13/20 - MSE Loss: 0.0486
Epoch 14/20 - MSE Loss: 0.0557
Epoch 15/20 - MSE Loss: 0.0651
Epoch 16/20 - MSE Loss: 0.0574
Epoch 17/20 - MSE Loss: 0.0425
Epoch 18/20 - MSE Loss: 0.0472
Epoch 19/20 - MSE Loss: 0.0410
Epoch 20/20 - MSE Loss: 0.0599


In [137]:
train_loop(model, dataloader, device=device, epochs=20)


Epoch 1/20 - MSE: 0.0618, RMSE: 0.2485, R²: 0.2448
Epoch 2/20 - MSE: 0.0464, RMSE: 0.2154, R²: -0.1347
Epoch 3/20 - MSE: 0.0826, RMSE: 0.2874, R²: -0.7995
Epoch 4/20 - MSE: 0.0571, RMSE: 0.2390, R²: -0.3639
Epoch 5/20 - MSE: 0.0452, RMSE: 0.2126, R²: 0.0303
Epoch 6/20 - MSE: 0.0540, RMSE: 0.2324, R²: 0.2870
Epoch 7/20 - MSE: 0.0424, RMSE: 0.2060, R²: 0.3037
Epoch 8/20 - MSE: 0.0394, RMSE: 0.1984, R²: 0.1955
Epoch 9/20 - MSE: 0.0437, RMSE: 0.2091, R²: 0.2520
Epoch 10/20 - MSE: 0.0403, RMSE: 0.2008, R²: 0.0830
Epoch 11/20 - MSE: 0.0433, RMSE: 0.2080, R²: 0.1108
Epoch 12/20 - MSE: 0.0401, RMSE: 0.2002, R²: 0.3046
Epoch 13/20 - MSE: 0.0386, RMSE: 0.1963, R²: -1.0961
Epoch 14/20 - MSE: 0.0492, RMSE: 0.2217, R²: -0.1278
Epoch 15/20 - MSE: 0.0417, RMSE: 0.2042, R²: 0.2549
Epoch 16/20 - MSE: 0.0400, RMSE: 0.1999, R²: 0.2663
Epoch 17/20 - MSE: 0.0366, RMSE: 0.1913, R²: 0.2773
Epoch 18/20 - MSE: 0.0420, RMSE: 0.2050, R²: -0.0209
Epoch 19/20 - MSE: 0.0396, RMSE: 0.1989, R²: 0.2160
Epoch 20/20 - M