In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, global_mean_pool
import os
os.environ['TORCH_COMPILE'] = '0'
os.environ['CUDA_VISIBLE_DEVICES'] = '-1'


class MoleBlendModel(nn.Module):
    def __init__(self, node_feat_dim, edge_feat_dim, hidden_dim, num_classes):
        super(MoleBlendModel, self).__init__()
        self.node_encoder = nn.Linear(node_feat_dim, hidden_dim)
        self.edge_encoder = nn.Linear(edge_feat_dim, hidden_dim)
        self.conv1 = GCNConv(hidden_dim, hidden_dim)
        self.conv2 = GCNConv(hidden_dim, hidden_dim)
        self.fc = nn.Linear(hidden_dim, num_classes)

    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch
        edge_attr = data.edge_attr

        # Encode node and edge features
        x = self.node_encoder(x)
        edge_attr = self.edge_encoder(edge_attr)

        # GCN Layers
        x = F.relu(self.conv1(x, edge_index))
        x = F.relu(self.conv2(x, edge_index))

        # Global pooling and classification
        x = global_mean_pool(x, batch)
        out = self.fc(x)

        return out


# Example usage
if __name__ == "__main__":
    from torch_geometric.data import Data

    # Dummy data for demonstration
    node_features = torch.rand(10, 4)  # 10 nodes with 4 features each
    edge_index = torch.tensor([[0, 1, 2], [1, 2, 0]])  # Edges between nodes
    edge_features = torch.rand(3, 4)  # 3 edges with 4 features each
    batch = torch.tensor([0, 0, 0, 0, 0, 1, 1, 1, 1, 1])  # Batch info

    data = Data(x=node_features, edge_index=edge_index, edge_attr=edge_features, batch=batch)

    # Model setup
    model = MoleBlendModel(node_feat_dim=4, edge_feat_dim=4, hidden_dim=128, num_classes=2)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    criterion = nn.CrossEntropyLoss()

    # Training step
    model.train()
    optimizer.zero_grad()
    output = model(data)
    labels = torch.tensor([0, 1])  # Example labels
    loss = criterion(output, labels)
    loss.backward()
    optimizer.step()

    print(f"Loss: {loss.item()}")

Loss: 0.6969464421272278
