In [2]:
# Cell 2: Imports
import torch
from torch_geometric.data import Data
from rdkit import Chem
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np


ModuleNotFoundError: No module named 'torch_geometric'

In [None]:
# Cell 3: Model + Graph Conversion
# Define model class (same as model.py)
from torch_geometric.nn import GCNConv, global_mean_pool
import torch.nn.functional as F
import torch.nn as nn

class GNNClassifier(nn.Module):
    def __init__(self, hidden_dim=64):
        super(GNNClassifier, self).__init__()
        self.conv1 = GCNConv(1, hidden_dim)
        self.conv2 = GCNConv(hidden_dim, hidden_dim)
        self.classifier = nn.Sequential(
            nn.Linear(hidden_dim, 32),
            nn.ReLU(),
            nn.Linear(32, 1)
        )

    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = self.conv2(x, edge_index)
        x = global_mean_pool(x, batch)
        return torch.sigmoid(self.classifier(x))

def smiles_to_graph(smiles):
    mol = Chem.MolFromSmiles(smiles)
    atoms = mol.GetAtoms()
    bonds = mol.GetBonds()

    x = torch.tensor([atom.GetAtomicNum() for atom in atoms], dtype=torch.float).unsqueeze(1)
    edge_index = []
    for bond in bonds:
        i = bond.GetBeginAtomIdx()
        j = bond.GetEndAtomIdx()
        edge_index.append([i, j])
        edge_index.append([j, i])
    edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous()

    return Data(x=x, edge_index=edge_index)


In [None]:
# Cell 4: Load model
model = GNNClassifier()
model.load_state_dict(torch.load("ddi_model.pt", map_location=torch.device('cpu')))
model.eval()


In [None]:
# Cell 5: Predict on user input
# Input your SMILES here
smiles1 = "CCO"
smiles2 = "CN"

g1 = smiles_to_graph(smiles1)
g2 = smiles_to_graph(smiles2)

x = torch.cat([g1.x, g2.x], dim=0)
edge_index = torch.cat([g1.edge_index, g2.edge_index + g1.x.size(0)], dim=1)
data = Data(x=x, edge_index=edge_index)
data.batch = torch.zeros(data.x.size(0), dtype=torch.long)

with torch.no_grad():
    confidence = model(data).item()

print(f"Prediction Confidence: {confidence:.2f}")


In [None]:
# Cell 6: Visual Dashboard
risk = "HIGH" if confidence > 0.75 else "MEDIUM" if confidence > 0.5 else "LOW"
evidence = "Strong" if confidence > 0.9 else "Moderate" if confidence > 0.6 else "Weak"

# Plot bars
plt.figure(figsize=(6, 2))
plt.barh(["Prediction Confidence"], [confidence], color='royalblue')
plt.xlim(0, 1)
plt.title(f"Confidence: {int(confidence*100)}% | Risk Level: {risk}")
plt.xlabel("Confidence Score")
plt.grid(True)
plt.show()

# Show Risk Box
print(f"🟠 Risk Level: {risk}")
print(f"📘 Evidence Strength: {evidence}")
