In [42]:
import torch
import torch.nn.functional as F
from torch_geometric.nn import RGCNConv, global_mean_pool
from torch_geometric.data import Data
#from graph_builder import GraphBuilder  # <-- External builder
import pandas as pd
from torch.nn import Linear, ReLU, Sequential
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader
from torch_geometric.nn import GCNConv, global_add_pool
from sklearn.model_selection import train_test_split
import ast
from torch_geometric.utils import degree


In [None]:
def parse_edge_line(line):
    # Split by comma only *outside* curly braces
    edges_str = line.strip().split('","')  # Handles entries like "{1, 2}","{1, 3}",...
    cleaned = [edge.replace('{', '(').replace('}', ')').replace('"', '') for edge in edges_str]
    return [tuple(ast.literal_eval(edge)) for edge in cleaned]

# Load the file manually
edges_per_graph = []
with open("/home/rigers/Documents/GitHub/ML-correlator/Rigers/GNN/data_8_Loop/edges8Loop.csv", 'r') as file:
    for line in file:
        edges = parse_edge_line(line)
        edges_per_graph.append(edges)

labels = []
with open("/home/rigers/Documents/GitHub/ML-correlator/Rigers/GNN/data_8_Loop/coeffs8Loop.csv", 'r') as f:
    for line in f:
        labels.append(int(line.strip()))

In [None]:
data8Loop=pd.read_csv("../Graph_Edge_Data/den_graph_data_8.csv")

In [36]:
class GraphBuilder:
    def __init__(self, solid_edges, node_labels=None):
        # Auto-infer node labels if not provided
        if node_labels is None:
            node_labels = sorted(set(u for e in solid_edges for u in e))
        self.node_labels = node_labels
        self.label2idx = {label: i for i, label in enumerate(node_labels)}

        self.solid_edges = solid_edges
        self.num_nodes = len(self.node_labels)

    def build(self, extra_node_features=None):
        edge_list = []

        for u, v in self.solid_edges:
            i, j = self.label2idx[u], self.label2idx[v]
            edge_list += [[i, j], [j, i]]  # bidirectional

        edge_index = torch.tensor(edge_list, dtype=torch.long).t().contiguous()

        # Basic node feature: degree
        degree_feat = degree(edge_index[0], num_nodes=self.num_nodes).view(-1, 1)

        # Combine degree with extra features if provided
        if extra_node_features is not None:
            assert extra_node_features.shape[0] == self.num_nodes, \
                "extra_node_features must match number of nodes"
            x = torch.cat([degree_feat, extra_node_features], dim=1)
        else:
            x = degree_feat

        return Data(x=x, edge_index=edge_index, num_nodes=self.num_nodes)

In [49]:
graph_list=[GraphBuilder(x).build() for x in edges_per_graph]

In [43]:
class SimpleGNN(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels):
        super().__init__()
        self.conv1 = GCNConv(in_channels, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, hidden_channels)
        self.lin = torch.nn.Linear(hidden_channels, 1)

    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        batch = torch.zeros(data.num_nodes, dtype=torch.long)  # single graph
        x = F.relu(self.conv1(x, edge_index))
        x = F.relu(self.conv2(x, edge_index))
        x = global_mean_pool(x, batch)
        return self.lin(x)


In [53]:
in_channels = graph_list[0].x.shape[1]
S_gnn=SimpleGNN(in_channels = in_channels,hidden_channels=32)

In [None]:
S_gnn.eval()
[S_gnn(x) for x in graph_list]

[tensor([[-0.4003]], grad_fn=<AddmmBackward0>),
 tensor([[-0.4291]], grad_fn=<AddmmBackward0>),
 tensor([[-0.4145]], grad_fn=<AddmmBackward0>),
 tensor([[-0.4144]], grad_fn=<AddmmBackward0>),
 tensor([[-0.4145]], grad_fn=<AddmmBackward0>),
 tensor([[-0.4145]], grad_fn=<AddmmBackward0>),
 tensor([[-0.4145]], grad_fn=<AddmmBackward0>),
 tensor([[-0.4159]], grad_fn=<AddmmBackward0>),
 tensor([[-0.4291]], grad_fn=<AddmmBackward0>),
 tensor([[-0.4003]], grad_fn=<AddmmBackward0>),
 tensor([[-0.4145]], grad_fn=<AddmmBackward0>),
 tensor([[-0.4145]], grad_fn=<AddmmBackward0>),
 tensor([[-0.4145]], grad_fn=<AddmmBackward0>),
 tensor([[-0.4003]], grad_fn=<AddmmBackward0>),
 tensor([[-0.4003]], grad_fn=<AddmmBackward0>),
 tensor([[-0.4003]], grad_fn=<AddmmBackward0>),
 tensor([[-0.4145]], grad_fn=<AddmmBackward0>),
 tensor([[-0.4159]], grad_fn=<AddmmBackward0>),
 tensor([[-0.4429]], grad_fn=<AddmmBackward0>),
 tensor([[-0.4291]], grad_fn=<AddmmBackward0>),
 tensor([[-0.4435]], grad_fn=<AddmmBackw