<a href="https://colab.research.google.com/github/OneFineStarstuff/OneFineStarstuff/blob/main/Molecular_Modeling_with_Graph_Neural_Networks_(GNNs).ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
pip install torch_geometric

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader
from torch_geometric.nn import GCNConv

# Define the GCN model
class MolecularGNN(nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super(MolecularGNN, self).__init__()
        self.conv1 = GCNConv(in_channels, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, out_channels)

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index).relu()
        x = self.conv2(x, edge_index)
        return x

# Example molecular data (randomly generated for illustration)
num_nodes = 10
in_channels = 5
hidden_channels = 16
out_channels = 2

# Random node features and edge indices
x = torch.randn((num_nodes, in_channels), dtype=torch.float)
edge_index = torch.tensor([
    [0, 1, 1, 2, 2, 3, 3, 4],
    [1, 0, 2, 1, 3, 2, 4, 3]
], dtype=torch.long)  # Example edge connections

# Labels for each node
y = torch.randint(0, out_channels, (num_nodes,), dtype=torch.long)

# Create a Data object
data = Data(x=x, edge_index=edge_index, y=y)
dataset = [data]
loader = DataLoader(dataset, batch_size=1, shuffle=True)

# Initialize model, optimizer, and loss function
model = MolecularGNN(in_channels, hidden_channels, out_channels)
optimizer = optim.Adam(model.parameters(), lr=0.01)
criterion = nn.CrossEntropyLoss()

# Training loop
model.train()
for epoch in range(100):  # Example: 100 epochs
    for batch in loader:
        optimizer.zero_grad()
        out = model(batch.x, batch.edge_index)
        loss = criterion(out, batch.y)
        loss.backward()
        optimizer.step()
    print(f'Epoch {epoch+1}, Loss: {loss.item()}')

print("Training complete!")