In [1]:
!mkdir -p ~/.kaggle
!mv kaggle.json ~/.kaggle/
!chmod 600 ~/.kaggle/kaggle.json
!kaggle datasets download -d shay2030/processed-graphs-torch-filefor-evaluating-models --force

Dataset URL: https://www.kaggle.com/datasets/shay2030/processed-graphs-torch-filefor-evaluating-models
License(s): unknown
Downloading processed-graphs-torch-filefor-evaluating-models.zip to /content
100% 427M/428M [00:04<00:00, 91.9MB/s]
100% 428M/428M [00:04<00:00, 95.2MB/s]


In [2]:
import os
import zipfile

zip_path = "/content/processed-graphs-torch-filefor-evaluating-models.zip"
extract_path = "/content/processed_graphs"
os.makedirs(extract_path, exist_ok=True)
with zipfile.ZipFile(zip_path, "r") as zip_ref:
    zip_ref.extractall(extract_path)
print(f"Files extracted to {extract_path}")

Files extracted to /content/processed_graphs


In [3]:
!pip install torch_geometric
import torch
file_path = os.path.join(extract_path, "processed_graphs.pt")
graphs = torch.load(file_path)
print(f"Loaded {len(graphs)} graphs successfully!")

Collecting torch_geometric
  Downloading torch_geometric-2.6.1-py3-none-any.whl.metadata (63 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/63.1 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m63.1/63.1 kB[0m [31m2.9 MB/s[0m eta [36m0:00:00[0m
Downloading torch_geometric-2.6.1-py3-none-any.whl (1.1 MB)
[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/1.1 MB[0m [31m?[0m eta [36m-:--:--[0m[2K   [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m [32m1.1/1.1 MB[0m [31m48.4 MB/s[0m eta [36m0:00:01[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.1/1.1 MB[0m [31m31.8 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: torch_geometric
Successfully installed torch_geometric-2.6.1


  graphs = torch.load(file_path)


Loaded 679269 graphs successfully!


#EGNN

In [27]:
from torch_geometric.loader import DataLoader
import torch

torch.manual_seed(42)
num_samples = len(graphs)

train_ratio, val_ratio, test_ratio = 0.8, 0.1, 0.1

train_size = int(train_ratio * num_samples)
val_size = int(val_ratio * num_samples)
test_size = num_samples - train_size - val_size

train_graphs = graphs[:train_size]
val_graphs = graphs[train_size:train_size + val_size]
test_graphs = graphs[train_size + val_size:]

train_loader = DataLoader(train_graphs, batch_size=32, shuffle=True)
val_loader = DataLoader(val_graphs, batch_size=32, shuffle=False)
test_loader = DataLoader(test_graphs, batch_size=32, shuffle=False)

In [17]:
print("Node Feature Shape:", graphs[0].x.shape)
print("Edge Feature Shape:", graphs[0].edge_attr.shape if graphs[0].edge_attr is not None else "No edge attributes")
print("Graph Target Shape:", graphs[0].y.shape)

Node Feature Shape: torch.Size([25, 6])
Edge Feature Shape: torch.Size([54, 4])
Graph Target Shape: torch.Size([1])


In [22]:
!pip install e3nn

Collecting e3nn
  Using cached e3nn-0.5.5-py3-none-any.whl.metadata (5.4 kB)
Collecting opt_einsum_fx>=0.1.4 (from e3nn)
  Using cached opt_einsum_fx-0.1.4-py3-none-any.whl.metadata (3.3 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch>=1.8.0->e3nn)
  Using cached nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch>=1.8.0->e3nn)
  Using cached nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch>=1.8.0->e3nn)
  Using cached nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch>=1.8.0->e3nn)
  Using cached nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch>=1.8.0->e3nn)
  Using cached nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.

In [23]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import MessagePassing, global_mean_pool

class EGNNLayer(MessagePassing):
    def __init__(self, in_channels, out_channels, edge_attr_dim):
        super(EGNNLayer, self).__init__(aggr="add")
        self.node_mlp = nn.Sequential(nn.Linear(in_channels, out_channels),nn.ReLU(),nn.Linear(out_channels, out_channels),)
        self.edge_mlp = nn.Sequential(nn.Linear(edge_attr_dim, out_channels),nn.ReLU(),nn.Linear(out_channels, out_channels),)

    def forward(self, x, edge_index, edge_attr):
        return self.propagate(edge_index, x=x, edge_attr=edge_attr)

    def message(self, x_j, edge_attr):
        return self.node_mlp(x_j) + self.edge_mlp(edge_attr)

class EGNN(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, num_layers=3):
        super(EGNN, self).__init__()
        self.layers = nn.ModuleList()
        self.layers.append(EGNNLayer(input_dim, hidden_dim, edge_attr_dim=4))
        for _ in range(num_layers - 1):
            self.layers.append(EGNNLayer(hidden_dim, hidden_dim, edge_attr_dim=4))
        self.fc_out = nn.Linear(hidden_dim, output_dim)

    def forward(self, x, edge_index, batch, edge_attr):
        for layer in self.layers:
            x = layer(x, edge_index, edge_attr)
            x = F.relu(x)
        x = global_mean_pool(x, batch)
        return self.fc_out(x)


In [29]:
@torch.no_grad()
def validate(model, loader, loss_fn, device):
    model.eval()
    total_loss = 0
    for batch in loader:
        batch = batch.to(device)
        output = model(batch.x, batch.edge_index, batch.batch, batch.edge_attr)
        loss = loss_fn(output, batch.y)
        total_loss += loss.item()
    return total_loss / len(loader)
best_val_loss = float("inf")
for epoch in range(epochs):
    train_loss = train(model, train_loader, optimizer, loss_fn, device)
    val_loss = validate(model, val_loader, loss_fn, device)

    print(f"Epoch {epoch+1}/{epochs}: Train Loss = {train_loss:.4f}, Val Loss = {val_loss:.4f}")

    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save(model.state_dict(), "egnn_model_final.pth")



Epoch 1/10: Train Loss = 2.1989, Val Loss = 2.0766




Epoch 2/10: Train Loss = 2.1923, Val Loss = 2.0791




Epoch 3/10: Train Loss = 2.1881, Val Loss = 2.0406




Epoch 4/10: Train Loss = 2.1851, Val Loss = 2.0258




Epoch 5/10: Train Loss = 2.1820, Val Loss = 2.0364




Epoch 6/10: Train Loss = 2.1799, Val Loss = 2.0298




Epoch 7/10: Train Loss = 2.1760, Val Loss = 2.0363




Epoch 8/10: Train Loss = 2.1743, Val Loss = 2.0225




Epoch 9/10: Train Loss = 2.1726, Val Loss = 2.0359




Epoch 10/10: Train Loss = 2.1705, Val Loss = 2.0458
