In [58]:
!pip install git+https://github.com/NREL/alfabet.git@0.2.2


Collecting git+https://github.com/NREL/alfabet.git@0.2.2
  Cloning https://github.com/NREL/alfabet.git (to revision 0.2.2) to c:\users\80710\appdata\local\temp\pip-req-build-_y5u9hg8
  Resolved https://github.com/NREL/alfabet.git to commit 9942cbd6fceeed549e8126692b15bb135e103f5a
  Preparing metadata (setup.py): started
  Preparing metadata (setup.py): finished with status 'done'


  Running command git clone --filter=blob:none --quiet https://github.com/NREL/alfabet.git 'C:\Users\80710\AppData\Local\Temp\pip-req-build-_y5u9hg8'
  Running command git checkout -q 9942cbd6fceeed549e8126692b15bb135e103f5a


In [59]:
from alfabet.drawing import draw_mol_outlier
from alfabet.fragment import canonicalize_smiles
from alfabet.neighbors import find_neighbor_bonds
from alfabet.prediction import predict_bdes, check_input

In [60]:
import alfabet
alfabet.__version__

'0.2.2'

In [61]:
import rdkit

In [62]:
rdkit.__version__

'2024.03.5'

In [63]:
import networkx as nx
import numpy as np
from rdkit import Chem

In [64]:
import pandas as pd
pd.set_option('display.max_columns', None)
pd.set_option('display.max_rows', None)

In [65]:
def create_bde_graph_selective_hs(smiles: str, bde_df) -> nx.Graph:
    """
    Build a NetworkX graph from the *original (heavy-atom)* RDKit Mol:
      - Keep all heavy-atom ring & skeleton bonds from the SMILES.
      - Add new H-X bonds (i.e., only the hydrogens needed) when a row in bde_df indicates
        a predicted bond that doesn't already exist in the heavy-atom Mol.
    
    bde_df is expected to have columns:
       - start_atom, end_atom: integer indexes or placeholders
       - bde_pred, bdfe_pred, etc.: predicted data for each bond
       - possibly bond_index (optional)
    
    Steps:
       1) Parse the SMILES without adding Hs (just once).
       2) Build a base Nx graph with all heavy-atom nodes & edges.
       3) Iterate over bde_df. If the row corresponds to an existing heavy–heavy bond,
          update the Nx edge with predicted data. If the row corresponds to an H–X bond,
          add the H node + edge and store the predictions.
    """

    # 1. Parse the SMILES into an RDKit Mol (no AddHs)
    base_mol = Chem.MolFromSmiles(smiles)
    if base_mol is None:
        # Handle parse error, e.g. return empty graph
        return nx.Graph()

    # 2. Create an Nx graph, optionally store the RDKit Mol for reference
    G = nx.Graph(mol=base_mol)

    # 3. Add heavy-atom nodes
    #    We'll store:
    #      - 'symbol': e.g. 'C', 'O', 'N', etc.
    #      - 'rdkit_idx': the integer index assigned by RDKit
    #    Feel free to store other attributes as well.
    for atom in base_mol.GetAtoms():
        atom_idx = atom.GetIdx()
        G.add_node(atom_idx, 
                   symbol=atom.GetSymbol(),
                   rdkit_idx=atom_idx)

    # 4. Add edges for all heavy-atom bonds in the original (no-H) Mol
    #    We won't attach any BDE predictions yet (set them to None).
    #    We'll also store a default bond_index=None if desired.
    for bond in base_mol.GetBonds():
        a1 = bond.GetBeginAtomIdx()
        a2 = bond.GetEndAtomIdx()
        G.add_edge(a1, a2,
                   bond_index=None,
                   bde_pred=None,
                   bdfe_pred=None)

    # 5. Iterate over bde_df.  We'll assume the columns are something like:
    #     start_atom, end_atom, bde_pred, bdfe_pred, bond_index, etc.
    #    - For heavy–heavy predictions, update the existing edge with predicted data.
    #    - For H–X predictions, add the new hydrogen node & edge if not present.
    #    - This approach assumes that for an H–X bond, either start_atom or end_atom
    #      is a placeholder for hydrogen or an integer representing "H" in your dataset.
    for _, row in bde_df.iterrows():
        s = row['start_atom']
        e = row['end_atom']
        
        # Attempt to interpret s and e in the context of the base mol
        # We'll use a simple rule:
        #  - If the index is >= base_mol.GetNumAtoms(), treat it as "this is a hydrogen"
        #  - Or you could have a special marker like -1 for hydrogen
        #    (depends on how your data is structured)
        
        # We also store predicted data
        bde_pred_value = row.get('bde_pred', None)
        bdfe_pred_value = row.get('bdfe_pred', None)
        bond_index_value = row.get('bond_index', None)
        
        # Convert them to integers if needed
        # (In practice, you may need to handle missing or invalid indexes carefully)
        
        # We'll define a helper function to check if an index is "heavy" or "hydrogen"
        def is_heavy(idx):
            return (0 <= idx < base_mol.GetNumAtoms())
        
        # Determine the "types" of s and e
        s_is_heavy = is_heavy(s)
        e_is_heavy = is_heavy(e)

        if s_is_heavy and e_is_heavy:
            # This is a heavy–heavy bond.
            # If it already exists in G, update attributes.
            if G.has_edge(s, e):
                # Just update the existing edge
                G[s][e]['bde_pred'] = bde_pred_value
                G[s][e]['bdfe_pred'] = bdfe_pred_value
                G[s][e]['bond_index'] = bond_index_value
            else:
                # Possibly -?> no, not possible the bond doesn't exist in the original skeleton 
                # (this can happen if the SMILES didn't have it).
                # Add it as a new edge. This is unusual, but let's handle it anyway.
                G.add_edge(s, e,
                           bond_index=bond_index_value,
                           bde_pred=bde_pred_value,
                           bdfe_pred=bdfe_pred_value)

        else:
            # At least one of them is a "hydrogen" or out-of-range index
            # We'll figure out which one is the heavy atom and which is the hydrogen.
            if s_is_heavy and not e_is_heavy:
                heavy_idx, hydrogen_idx = s, e
            elif e_is_heavy and not s_is_heavy:
                heavy_idx, hydrogen_idx = e, s
            else:
                # Both are hydrogens or out-of-range, which might be invalid.
                # For safety, just skip or handle error.
                # Could print a warning, raise an exception, etc.
                continue

            # Step 1: ensure the hydrogen node is present in G
            # We'll generate a unique node key for the H, e.g. "H_{hydrogen_idx}"
            # or something that won't collide with integer-based heavy nodes.
            # You could also store the actual integer if your system allows it.
            h_node = f"H_{hydrogen_idx}"
            if not G.has_node(h_node):
                # Add the hydrogen node with minimal attributes
                G.add_node(h_node,
                           symbol='H',
                           rdkit_idx=None)  # or some other placeholder

            # Step 2: add the H–X bond or update if it already exists
            # The heavy_idx is the integer from RDKit.
            if not G.has_edge(heavy_idx, h_node):
                G.add_edge(heavy_idx, h_node,
                           bond_index=bond_index_value,
                           bde_pred=bde_pred_value,
                           bdfe_pred=bdfe_pred_value)
            else:
                # If it somehow exists, just update attributes
                G[heavy_idx][h_node]['bde_pred'] = bde_pred_value
                G[heavy_idx][h_node]['bdfe_pred'] = bdfe_pred_value
                G[heavy_idx][h_node]['bond_index'] = bond_index_value

    return G


In [66]:
def graph_to_df(bde_graph: nx.Graph) -> pd.DataFrame:
    """
    Convert the edges of bde_graph into a DataFrame with columns:
      ['u', 'v', 'bond_index', 'graph_bde_pred', 'graph_bdfe_pred'].
    """
    rows = []
    for u, v, data in bde_graph.edges(data=True):
        rows.append({
            'u': u,
            'v': v,
            'bond_index': data['bond_index'],
            'graph_bde_pred': data.get('bde_pred', None),
            'graph_bdfe_pred': data.get('bdfe_pred', None)
        })
    return pd.DataFrame(rows)

In [67]:
smiles_list = ['C1CC([C@H]3[C@@](C1)(C)[C@H]2CC[C@H](C)[C@H]([C@@]2(CC3)C)CCCC)(C)C',
       'C1CC([C@H]3[C@@](C1)(C)[C@H]2CC[C@H](C)[C@H]([C@@]2(CC3)C)CCC(C)C)(C)C',
       'C1CC([C@H]3[C@@](C1)(C)[C@H]2CC[C@H](C)[C@H]([C@@]2(CC3)C)CC[C@@H](C)CC)(C)C',
       'C1CC([C@H]3[C@@](C1)(C)[C@H]2CC[C@H](C)[C@H]([C@@]2(CC3)C)CC[C@H](CCC)C)(C)C',
       'C1CC([C@H]3[C@@](C1)(C)[C@H]2CC[C@H](C)[C@H]([C@]2(C)CC3)CC[C@H](C)CCCCC)(C)C',
       'C(CCC)C[C@H](C)CC[C@@H]1[C@H](CC[C@H]2[C@]1(CC[C@@H]3[C@@]2(CCCC3(C)C)C)C)C',
       'C1CC([C@H]3[C@@](C1)(C)[C@H]2CC[C@H](C)[C@H]([C@@]2(CC3)C)CC[C@@H](CCCC(C)C)C)(C)C',
       'C(C[C@@H](CC[C@H]1[C@]3([C@H](CC[C@@H]1C)[C@]2(CCCC(C)(C)[C@@H]2CC3)C)C)C)CC(C)C',
       '[C@]23(CC[C@@H]1[C@@](CCCC1(C)C)(C)[C@H]2CC[C@H]4[C@]3(CC[C@]5([C@@H]4CCC5)C)C)C',
       '[C@]12(CC[C@@H]5[C@@]([C@H]1CC[C@H]3[C@@]2(C)CC[C@H]4[C@@]3(CCC4)C)(CCCC5(C)C)C)C',
       'CC[C@@H]1CC[C@]2(C1CCC3(C2CCC4C3(CCC5C4(CCCC5(C)C)C)C)C)C',
       'CCC[C@@H]1CC[C@]2(C1CCC3(C2CCC4C3(CCC5C4(CCCC5(C)C)C)C)C)C',
       'CCC(C)[C@@H]1CC[C@]2(C1CCC3(C2CCC4C3(CCC5C4(CCCC5(C)C)C)C)C)C',
       'CC[C@@H](C)[C@@H]1CC[C@]2(C1CCC3(C2CCC4C3(CCC5C4(CCCC5(C)C)C)C)C)C',
       'CCCC(C)[C@@H]1CC[C@]2(C1CCC3(C2CCC4C3(CCC5C4(CCCC5(C)C)C)C)C)C',
       'CCC[C@@H](C)[C@@H]1CC[C@]2(C1CCC3(C2CCC4C3(CCC5C4(CCCC5(C)C)C)C)C)C',
       'CCCCC(C)[C@@H]1CC[C@]2(C1CCC3(C2CCC4C3(CCC5C4(CCCC5(C)C)C)C)C)C',
       'CCCC[C@@H](C)[C@@H]1CC[C@]2(C1CCC3(C2CCC4C3(CCC5C4(CCCC5(C)C)C)C)C)C',
       'CCCCCC(C)[C@@H]1CC[C@]2(C1CCC3(C2CCC4C3(CCC5C4(CCCC5(C)C)C)C)C)C',
       'CCCCC[C@@H](C)[C@@H]1CC[C@]2(C1CCC3(C2CCC4C3(CCC5C4(CCCC5(C)C)C)C)C)C']

In [68]:
import urllib.parse
def quote(x):
    return urllib.parse.quote(x, safe='')

In [69]:
dfs = []
graphs = []  # Optionally keep a list of graphs if you want them separately

for smiles in smiles_list:
    # 1) Canonicalize and sanity-check input
    can_smiles = canonicalize_smiles(smiles)
    is_outlier, missing_atom, missing_bond = check_input(can_smiles)

    # 2) Get DataFrame of predicted BDE/BDFE for each bond
    bde_df = predict_bdes(can_smiles, draw=True)
    bde_df['raw_smiles'] = smiles

    # 3) Deduplicate and store any extra columns you like
    bde_df = bde_df.drop_duplicates(['fragment1', 'fragment2']).reset_index(drop=True)
    bde_df['smiles_link'] = bde_df.molecule.apply(quote)

    # 4) Build a NetworkX graph containing predicted BDE/BDFE
    bde_graph = create_bde_graph_selective_hs(can_smiles, bde_df)

    # 5) (Optional) store the graph in the DataFrame if you want
    #    the same graph for all rows (one per entire molecule)
    bde_df['nx_graph'] = [bde_graph] * len(bde_df)

    # 6) Append to your results
    dfs.append(bde_df)
    graphs.append(bde_graph)   # In case you want them in parallel




In [70]:
# Merge all DataFrame results
alfabet_results_022 = pd.concat(dfs, ignore_index=True)


In [71]:
graph_to_df(graphs[0])

Unnamed: 0,u,v,bond_index,graph_bde_pred,graph_bdfe_pred
0,0,1,0.0,89.382645,75.711853
1,0,H_23,25.0,100.077187,91.049133
2,1,2,1.0,85.872467,71.412849
3,1,H_27,29.0,97.163109,87.689636
4,2,3,2.0,85.041306,70.000275
5,2,H_28,30.0,95.392189,86.257256
6,3,4,3.0,83.115479,66.99527
7,3,H_30,32.0,94.518456,84.748627
8,4,5,,,
9,4,10,,,


In [72]:
!pip install torch torchvision torchaudio




In [73]:
!pip install torch-geometric




In [77]:
# Assuming the environment variable

# 1) Add random environment columns for demonstration
num_rows = len(alfabet_results_022)

# Temperatures between 10°C and 40°C
alfabet_results_022['temperature'] = np.random.uniform(10, 40, size=num_rows)

# Concentration in mg/L, random 1–100
alfabet_results_022['Concentration'] = np.random.uniform(1, 100, size=num_rows)

# Time in hours, random 0–120
alfabet_results_022['Time'] = np.random.uniform(0, 120, size=num_rows)

# Categorical 'Seawater' vs 'fresh' environment
alfabet_results_022['Seawater'] = np.random.choice(['sea', 'fresh'], size=num_rows)

# And a random target: 'degradation_rate' (arbitrary range)
alfabet_results_022['degradation_rate'] = np.random.uniform(0.1, 1.0, size=num_rows)

# 2) Inspect the updated DataFrame
alfabet_results_022.head(5)

Unnamed: 0,molecule,bond_index,bond_type,start_atom,end_atom,fragment1,fragment2,is_valid_stereo,bde_pred,bdfe_pred,bde,bdfe,set,svg,has_dft_bde,raw_smiles,smiles_link,nx_graph,temperature,Concentration,Time,Seawater,degradation_rate
0,CCCC[C@@H]1[C@@H](C)CC[C@H]2[C@@]1(C)CC[C@H]1C...,10,C-C,10,11,CCCC[C@H]1[C]2CC[C@H]3C(C)(C)CCC[C@]3(C)[C@H]2...,[CH3],True,79.460541,64.25386,,,,<?xml version='1.0' encoding='iso-8859-1'?>\n<...,False,C1CC([C@H]3[C@@](C1)(C)[C@H]2CC[C@H](C)[C@H]([...,CCCC%5BC%40%40H%5D1%5BC%40%40H%5D%28C%29CC%5BC...,"(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,...",17.1004,74.768569,105.82965,sea,0.56574
1,CCCC[C@@H]1[C@@H](C)CC[C@H]2[C@@]1(C)CC[C@H]1C...,21,C-C,21,22,CCCC[C@@H]1[C@@H](C)CC[C@@H]2[C]3CCCC(C)(C)[C@...,[CH3],True,79.644073,64.361122,,,,<?xml version='1.0' encoding='iso-8859-1'?>\n<...,False,C1CC([C@H]3[C@@](C1)(C)[C@H]2CC[C@H](C)[C@H]([...,CCCC%5BC%40%40H%5D1%5BC%40%40H%5D%28C%29CC%5BC...,"(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,...",37.02122,14.677846,47.242317,fresh,0.756477
2,CCCC[C@@H]1[C@@H](C)CC[C@H]2[C@@]1(C)CC[C@H]1C...,15,C-C,15,16,CCCC[C@@H]1[C@@H](C)CC[C@H]2[C@@]1(C)CC[C@H]1[...,[CH3],True,82.40863,67.338509,,,,<?xml version='1.0' encoding='iso-8859-1'?>\n<...,False,C1CC([C@H]3[C@@](C1)(C)[C@H]2CC[C@H](C)[C@H]([...,CCCC%5BC%40%40H%5D1%5BC%40%40H%5D%28C%29CC%5BC...,"(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,...",19.830141,55.215225,4.104363,sea,0.650062
3,CCCC[C@@H]1[C@@H](C)CC[C@H]2[C@@]1(C)CC[C@H]1C...,3,C-C,3,4,[CH2]CCC,C[C@@H]1[CH][C@]2(C)CC[C@H]3C(C)(C)CCC[C@]3(C)...,True,83.115479,66.99527,,,,<?xml version='1.0' encoding='iso-8859-1'?>\n<...,False,C1CC([C@H]3[C@@](C1)(C)[C@H]2CC[C@H](C)[C@H]([...,CCCC%5BC%40%40H%5D1%5BC%40%40H%5D%28C%29CC%5BC...,"(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,...",18.2693,77.855803,70.923872,sea,0.887023
4,CCCC[C@@H]1[C@@H](C)CC[C@H]2[C@@]1(C)CC[C@H]1C...,2,C-C,2,3,[CH2]CC,[CH2][C@@H]1[C@@H](C)CC[C@H]2[C@@]1(C)CC[C@H]1...,True,85.041306,70.000275,,,,<?xml version='1.0' encoding='iso-8859-1'?>\n<...,False,C1CC([C@H]3[C@@](C1)(C)[C@H]2CC[C@H](C)[C@H]([...,CCCC%5BC%40%40H%5D1%5BC%40%40H%5D%28C%29CC%5BC...,"(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,...",27.142742,76.564347,22.651805,sea,0.748602


In [98]:
import numpy as np
import torch

def encode_environment(df):
    """
    Convert 'Seawater' column to numeric (sea=1, fresh=0).
    Return DataFrame with a new column 'sea_numeric'.
    """
    df = df.copy()
    df['sea_numeric'] = df['Seawater'].map({'sea': 1.0, 'fresh': 0.0})
    return df

# Encode 'Seawater' as numeric
alfabet_results_022 = encode_environment(alfabet_results_022)

# Define the environment columns we want to use as numeric
env_columns = ['temperature', 'Concentration', 'Time', 'sea_numeric']


2. Custom Environmental “Positional” Encoding
Below is a small module EnvPositionalEncoding that converts the environment variables into an embedding of size d_model. We then add that embedding to each node embedding in a batch (similar to how standard Transformers add a sinusoidal vector to each token).

In [92]:
import torch.nn as nn

class EnvPositionalEncoding(nn.Module):
    """
    Map environment variables to a d_model-sized embedding that will be added
    to each node (token) embedding in the Transformer input.
    """
    def __init__(self, env_dim, d_model):
        super().__init__()
        # A simple 2-layer MLP, for example
        self.linear = nn.Sequential(
            nn.Linear(env_dim, d_model),
            nn.ReLU(),
            nn.Linear(d_model, d_model)
        )
        
    def forward(self, x, env):
        """
        x: shape (batch_size, seq_len, d_model)
        env: shape (batch_size, env_dim) [the environment variables]
        Return: (batch_size, seq_len, d_model) with env-based offsets added.
        """
        # Convert env to (batch_size, d_model)
        env_enc = self.linear(env)  # shape: (batch_size, d_model)
        
        # Unsqueeze to broadcast across seq_len
        env_enc = env_enc.unsqueeze(1)  # (batch_size, 1, d_model)
        
        # Add to each node embedding
        out = x + env_enc  # shape: (batch_size, seq_len, d_model)
        return out


3. Building the Model (Transformer + Environmental Encoding)
Below is a toy example of a Transformer-based model in plain PyTorch (not PyTorch Geometric) that:

Converts node symbols into embeddings (nn.Embedding).
Adds your environment-based encoding via EnvPositionalEncoding.
Passes it through a standard TransformerEncoder.
Pools (e.g., average) or takes the [CLS]-like node as the molecule representation.
Predicts the degradation rate via a small MLP.

In [93]:
import torch.nn.functional as F
from torch.nn import TransformerEncoder, TransformerEncoderLayer

class MoleculeTransformerModel(nn.Module):
    def __init__(self, 
                 num_atom_types,    # e.g. size of your symbol_to_index dict
                 d_model=128, 
                 nhead=4, 
                 dim_feedforward=256, 
                 num_encoder_layers=3,
                 env_dim=4,         # e.g. [temp, time, concentration, is_seawater?]
                 dropout=0.1):
        super().__init__()
        
        # (1) Node embedding: convert each atom index to an embedding
        self.atom_emb = nn.Embedding(num_atom_types, d_model)
        
        # (2) Environmental encoding
        self.env_pos_enc = EnvPositionalEncoding(env_dim, d_model)
        
        # (3) Transformer encoder
        encoder_layers = TransformerEncoderLayer(d_model, 
                                                 nhead, 
                                                 dim_feedforward,
                                                 dropout,
                                                 batch_first=True)
        self.transformer_encoder = TransformerEncoder(encoder_layers, 
                                                      num_layers=num_encoder_layers)
        
        # (4) A readout for regression
        self.mlp = nn.Sequential(
            nn.Linear(d_model, d_model),
            nn.ReLU(),
            nn.Linear(d_model, 1)  # predict a single value: degradation rate
        )
        
    def forward(self, node_features, env_data):
        """
        node_features: (batch_size, seq_len) of atom indices
        env_data: (batch_size, env_dim)
        """
        # 1) Convert node indices -> embeddings
        #    shape => (batch_size, seq_len, d_model)
        x = self.atom_emb(node_features)
        
        # 2) Add environment-based "positional" encoding
        x = self.env_pos_enc(x, env_data)  # shape => (batch_size, seq_len, d_model)
        
        # 3) Pass through Transformer
        x = self.transformer_encoder(x)  # shape => (batch_size, seq_len, d_model)
        
        # 4) Pool or take the first token as a "molecule representation"
        #    Let's do a simple mean pool:
        mol_repr = x.mean(dim=1)  # shape => (batch_size, d_model)
        
        # 5) Regress to a single value
        out = self.mlp(mol_repr)  # shape => (batch_size, 1)
        return out.squeeze(-1)    # shape => (batch_size,)


4. Training Loop (MSE Loss, RMSE Metric)
Below is an example of how you might train this model on your dataset. We’ll define a small function to compute RMSE for logging/evaluation.

In [94]:
def rmse(pred, target):
    return torch.sqrt(torch.mean((pred - target)**2))

def train_one_epoch(model, dataloader, optimizer, device='cpu'):
    model.train()
    total_loss = 0.0
    for node_feats, env_data, y in dataloader:
        node_feats = node_feats.to(device)
        env_data = env_data.to(device)
        y = y.to(device)
        
        optimizer.zero_grad()
        # Forward
        preds = model(node_feats, env_data)
        loss = F.mse_loss(preds, y)  # MSE
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item() * len(y)
        
    return total_loss / len(dataloader.dataset)

@torch.no_grad()
def evaluate(model, dataloader, device='cpu'):
    model.eval()
    total_mse = 0.0
    total_count = 0
    for node_feats, env_data, y in dataloader:
        node_feats = node_feats.to(device)
        env_data = env_data.to(device)
        y = y.to(device)
        
        preds = model(node_feats, env_data)
        # MSE
        mse_val = F.mse_loss(preds, y, reduction='sum').item()
        total_mse += mse_val
        total_count += len(y)
        
    mse_score = total_mse / total_count
    rmse_score = np.sqrt(mse_score)
    return mse_score, rmse_score


5. training

In [95]:
# Create dataset
dataset = MoleculeEnvDataset(alfabet_results_022)

# Create DataLoader
batch_size = 16
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)


In [96]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


In [99]:
model = MoleculeTransformerModel(num_atom_types=10, d_model=64, env_dim=len(env_columns))
model.to(device)
    
# Define optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
    
# Train for 20 epochs
for epoch in range(20):
    train_loss = train_one_epoch(model, dataloader, optimizer, device)
    mse_score, rmse_score = evaluate(model, dataloader, device)
        
    print(f"Epoch {epoch:02d} | Train MSE: {train_loss:.4f} | Eval MSE: {mse_score:.4f} | Eval RMSE: {rmse_score:.4f}")



TypeError: default_collate: batch must contain tensors, numpy arrays, numbers, dicts or lists; found <class 'networkx.classes.graph.Graph'>