<a href="https://colab.research.google.com/github/DanielBautz/gnn4nmr/blob/main/src/gnn_correct.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# GNN

In [None]:
!pip install torch_geometric

Collecting torch_geometric
  Downloading torch_geometric-2.6.1-py3-none-any.whl.metadata (63 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/63.1 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m63.1/63.1 kB[0m [31m5.7 MB/s[0m eta [36m0:00:00[0m
Downloading torch_geometric-2.6.1-py3-none-any.whl (1.1 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.1/1.1 MB[0m [31m53.3 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: torch_geometric
Successfully installed torch_geometric-2.6.1


In [None]:
import torch
import torch.nn.functional as F
from torch_geometric.data import Data, DataLoader
from torch_geometric.nn import GCNConv
import networkx as nx
import pickle

In [None]:
# Funktion, um NetworkX-Graphen in PyTorch-Geometric Data zu konvertieren
def nx_to_pyg_data(G):
    # Knotenfeatures und Kanten auslesen
    x = []
    y = []
    edge_index = []

    for node, data in G.nodes(data=True):
        # die features
        features = features = [
            data.get('label', 0),
            data.get('atomic_num', 0),
            data.get('formal_charge', 0),
            data.get('hybridization', 'unknown'),
            data.get('aromatic', False),
            data.get('num_explicit_hs', 0),
            data.get('num_implicit_hs', 0),
            data.get('degree', 0),
            data.get('compound', 0.0),
            data.get('structure', 0.0),
            data.get('atom', 0.0),
            data.get('shift_high-low', 0.0),
            data.get('shift_low', 0.0),
            data.get('CN(X)', 0.0),
            data.get('no_CH', 0.0),
            data.get('no_CC', 0.0),
            data.get('no_CN', 0.0),
            data.get('no_CO', 0.0),
            data.get('no_CYH', 0.0),
            data.get('no_CYC', 0.0),
            data.get('no_CYN', 0.0),
            data.get('no_CYO', 0.0),
            data.get('shielding_dia', 0.0),
            data.get('shielding_para', 0.0),
            data.get('span', 0.0),
            data.get('skew', 0.0),
            data.get('asymmetry', 0.0),
            data.get('anisotropy', 0.0),
            data.get('at_charge_mull', 0.0),
            data.get('at_charge_loew', 0.0),
            data.get('orb_charge_mull_s', 0.0),
            data.get('orb_charge_mull_p', 0.0),
            data.get('orb_charge_mull_d', 0.0),
            data.get('orb_stdev_mull_p', 0.0),
            data.get('orb_charge_loew_s', 0.0),
            data.get('orb_charge_loew_p', 0.0),
            data.get('orb_charge_loew_d', 0.0),
            data.get('orb_stdev_loew_p', 0.0),
            data.get('BO_loew_sum', 0.0),
            data.get('BO_loew_av', 0.0),
            data.get('BO_mayer_sum', 0.0),
            data.get('BO_mayer_av', 0.0),
            data.get('mayer_VA', 0.0)
        ]

        x.append(features)

        # Zielattribut
        if 'shift_high-low' in data:
            y.append(data['shift_high-low'])
        else:
            y.append(0)  # Wenn kein Wert vorhanden, setze auf 0 oder ignoriere den Knoten.

    # Konvertiere Knotenfeatures in Tensoren
    x = torch.tensor(x, dtype=torch.float)
    y = torch.tensor(y, dtype=torch.float)

    for edge in G.edges(data=True):
        edge_index.append([edge[0], edge[1]])

    edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous()

    # Erstelle ein PyTorch-Geometric Data Objekt
    data = Data(x=x, edge_index=edge_index, y=y)
    return data

# Geladene Graphen in PyTorch-Geometric Data Objekte konvertieren
graph_file = "all_graphs.pkl"
with open(graph_file, 'rb') as f:
    all_graphs = pickle.load(f)

pyg_data_list = [nx_to_pyg_data(G) for G in all_graphs]

# Erstelle einen DataLoader für das Training
train_loader = DataLoader(pyg_data_list, batch_size=16, shuffle=True)




In [None]:
# Definiere GNN-Modell
class GNNModel(torch.nn.Module):
    def __init__(self):
        super(GNNModel, self).__init__()
        self.conv1 = GCNConv(in_channels=3, out_channels=16)
        self.conv2 = GCNConv(in_channels=16, out_channels=1)

    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = self.conv2(x, edge_index)
        return x

In [None]:
# Trainiere das Modell
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = GNNModel().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

model.train()
for epoch in range(100):  # Anzahl der Epochen anpassen
    total_loss = 0
    for batch in train_loader:
        batch = batch.to(device)
        optimizer.zero_grad()
        output = model(batch)
        # Verwende nur die Ausgabe für Knoten, die Kohlenstoff- oder Wasserstoffatome darstellen
        mask = (batch.x[:, 0] == 1) | (batch.x[:, 0] == 6)
        loss = F.mse_loss(output[mask].squeeze(), batch.y[mask])
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

    print(f'Epoch {epoch+1}, Loss: {total_loss:.4f}')

# Modell speichern
torch.save(model.state_dict(), "gnn_model.pth")

In [None]:
import torch
import torch.nn.functional as F
from torch_geometric.data import Data
import networkx as nx
import pandas as pd

# Funktion, um einen RDKit Molekül-Graphen in einen NetworkX Graphen zu konvertieren
def mol_to_nx(mol, properties, compound_id):
    G = nx.Graph()

    # Atome als Knoten hinzufügen
    for atom in mol.GetAtoms():
        atom_idx = atom.GetIdx()
        atom_symbol = atom.GetSymbol()

        # Basisattribute
        attributes = {
            'label': atom.GetAtomicNum(),  # Verwende Atomnummer als Knotenbeschriftung
            'atomic_num': atom.GetAtomicNum(),
            'symbol': atom_symbol,  # Füge das Elementsymbol hinzu, um es später leichter zu überprüfen
            'formal_charge': atom.GetFormalCharge(),
            'hybridization': str(atom.GetHybridization()),
            'aromatic': atom.GetIsAromatic(),
            'num_explicit_hs': atom.GetNumExplicitHs(),
            'num_implicit_hs': atom.GetNumImplicitHs(),
            'degree': atom.GetDegree()
        }

        # Zusätzliche Eigenschaften aus den Dateien hinzufügen, nur für das passende compound_id
        props = properties[(properties['atom'] == atom_idx + 1) & (properties['compound'] == compound_id)]
        if not props.empty:
            attributes.update(props.iloc[0].to_dict())
        else:
            # Debugging-Ausgabe für den Fall, dass keine Eigenschaften gefunden werden
            print(f"Debugging: Keine passenden Eigenschaften gefunden für Atom {atom_symbol} mit Index {atom_idx + 1} und compound_id {compound_id}")

        G.add_node(atom_idx, **attributes)

    # Bindungen als Kanten hinzufügen
    for bond in mol.GetBonds():
        G.add_edge(bond.GetBeginAtomIdx(),
                   bond.GetEndAtomIdx(),
                   bond_type=bond.GetBondType(),
                   is_conjugated=bond.GetIsConjugated(),
                   is_aromatic=bond.GetIsAromatic(),
                   bond_order=bond.GetBondTypeAsDouble())

    return G

# Funktion, um NetworkX-Graphen in PyTorch-Geometric Data zu konvertieren
def nx_to_pyg_data(G):
    # Knotenfeatures und Kanten auslesen
    x = []
    y = []
    edge_index = []

    for node, data in G.nodes(data=True):
        # Wähle die Features, die du verwenden möchtest (z.B. atomic_num, formal_charge, etc.)
        features = [
            data.get('atomic_num', 0),
            data.get('formal_charge', 0),
            data.get('degree', 0)
        ]
        x.append(features)

        # Das Zielattribut (`shift_high-low`) hinzufügen, falls vorhanden
        if 'shift_high-low' in data and not pd.isna(data['shift_high-low']):
            y.append(data['shift_high-low'])
        else:
            y.append(float('nan'))  # Falls der Wert nicht vorhanden ist, NaN setzen

    # Konvertiere Knotenfeatures in Tensoren
    x = torch.tensor(x, dtype=torch.float)
    y = torch.tensor(y, dtype=torch.float)

    for edge in G.edges(data=True):
        edge_index.append([edge[0], edge[1]])

    edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous()

    # Erstelle ein PyTorch-Geometric Data Objekt
    data = Data(x=x, edge_index=edge_index, y=y)
    return data

# Beispiel: Vorhersage für ein neues Molekül
h_file = "/data/ml_pbe0_pcSseg-2_h.dat"
c_file = "/data/ml_pbe0_pcSseg-2_c.dat"

# Lese die Eigenschaften aus beiden Dateien ein und kombiniere sie
h_properties, c_properties = load_properties(h_file, c_file)
properties = pd.concat([h_properties, c_properties])

# Verwende eine gültige Compound-ID (z.B. eine aus den verfügbaren `compound`-IDs)
valid_compound_id = properties['compound'].unique()[0]

new_mol_file = "012_00.sdf"
new_mol = read_sdf_file(new_mol_file)[0]
new_graph = mol_to_nx(new_mol, properties, compound_id=valid_compound_id)

# Konvertiere NetworkX-Graph zu PyTorch-Geometric Data
new_data = nx_to_pyg_data(new_graph)

# Lade das trainierte Modell
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = GNNModel().to(device)

# Lade die Gewichte in den Modellzustand
checkpoint = torch.load("gnn_model.pth")
model.load_state_dict(checkpoint)
model.eval()

# Vorhersage machen
new_data = new_data.to(device)
with torch.no_grad():
    prediction = model(new_data)

# Ausgabe der vorhergesagten shift_high-low Werte, Ground Truth und Element
for idx, (pred, true) in enumerate(zip(prediction, new_data.y)):
    # Hole das Element aus den Knotendaten im Graphen
    element = new_graph.nodes[idx].get('symbol', None)  # Nutze 'symbol', um das Elementsymbol korrekt zu bekommen
    element_symbol = "Unbekannt"  # Fallback-Wert

    if element:
        element_symbol = element

    # Überprüfen, ob der Ground Truth Wert NaN ist, dann überspringen
    if torch.isnan(true):
        print(f"Atom {idx} ({element_symbol}): keine Ground Truth verfügbar, vorhergesagtes shift_high-low = {pred.item()}")
    else:
        print(f"Atom {idx} ({element_symbol}): vorhergesagtes shift_high-low = {pred.item()}, Ground Truth = {true.item()}")

Atom 0 (C): vorhergesagtes shift_high-low = -2.882751226425171, Ground Truth = -2.1809728145599365
Atom 1 (N): vorhergesagtes shift_high-low = -3.8705499172210693, Ground Truth = -16.9322566986084
Atom 2 (C): vorhergesagtes shift_high-low = -3.2838070392608643, Ground Truth = -0.17985659837722778
Atom 3 (C): vorhergesagtes shift_high-low = -2.8302741050720215, Ground Truth = -0.14272424578666687
Atom 4 (N): vorhergesagtes shift_high-low = -4.03288459777832, Ground Truth = -0.18340036273002625


  checkpoint = torch.load("gnn_model.pth")
