In [9]:
!pip install torch torch-geometric

Collecting torch-geometric
  Using cached torch_geometric-2.6.1-py3-none-any.whl.metadata (63 kB)
Collecting aiohttp (from torch-geometric)
  Using cached aiohttp-3.11.11-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (7.7 kB)
Collecting tqdm (from torch-geometric)
  Using cached tqdm-4.67.1-py3-none-any.whl.metadata (57 kB)
Collecting aiohappyeyeballs>=2.3.0 (from aiohttp->torch-geometric)
  Using cached aiohappyeyeballs-2.4.4-py3-none-any.whl.metadata (6.1 kB)
Collecting aiosignal>=1.1.2 (from aiohttp->torch-geometric)
  Using cached aiosignal-1.3.2-py2.py3-none-any.whl.metadata (3.8 kB)
Collecting frozenlist>=1.1.1 (from aiohttp->torch-geometric)
  Using cached frozenlist-1.5.0-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (13 kB)
Collecting multidict<7.0,>=4.5 (from aiohttp->torch-geometric)
  Using cached multidict-6.1.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (5.0 kB)
Col

In [66]:
import torch
from torch_geometric.nn import GCNConv
from torch.nn import Linear
from torch.nn.functional import relu, sigmoid, binary_cross_entropy
import numpy as np
import json

In [110]:
class GNN(torch.nn.Module):
    def __init__(self,  layers):
        super(GNN, self).__init__()
        convs = []
        Bs = []
        for l_in, l_out in zip(layers[:-1], layers[1:]):
            convs.append(GCNConv(l_in, l_out, bias=False))
            torch.nn.init.normal_(convs[-1].lin.weight,mean=0.01, std=0.3)
            Bs.append(torch.nn.Linear(l_in, l_out, bias=False))
            torch.nn.init.normal_(Bs[-1].weight, mean=0.5, std=0.3)
        self.convs = torch.nn.ModuleList(convs)
        self.Bs = torch.nn.ModuleList(Bs)

    
    def forward(self, x: torch.Tensor, edge_index: torch.Tensor) -> torch.Tensor:
        for conv, B in zip(self.convs, self.Bs):
            x = conv(x, edge_index) -  B(x)
            x = relu(x)
        return x

class EdgesMLP(torch.nn.Module):
    def __init__(self, l3):
        super(EdgesMLP, self).__init__()
        self.linear = Linear(2*l3, 1, bias=False)
        torch.nn.init.normal_(self.linear.weight, mean=0.5, std=0.3)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.linear(x)
        return torch.squeeze(sigmoid(x), 1)

def get_models(params):
    layers = params["count_neuron_layers"]
    node_gnn = GNN(layers)
    edge_linear = EdgesMLP(layers[-1])
    return node_gnn, edge_linear

def list_batchs(dataset, batch_size):
    for i in range(0, len(dataset), batch_size):
        yield dataset[i:i+batch_size]

def get_tensor_from_graph(graph):
    i = graph["A"]
    v_in = [rev_dist(e) for e in graph["edges_feature"]]
    v_true = graph["true_edges"]
    x = graph["nodes_feature"]
    N = len(x)
    
    X = torch.tensor(data=x, dtype=torch.float32)
    sp_A = torch.sparse_coo_tensor(indices=i, values=v_in, size=(N, N), dtype=torch.float32)
    E_true = torch.tensor(data=v_true, dtype=torch.float32)
    return X, sp_A, E_true, i

def validation(models, dataset, criterion):
    my_loss_list = []
    for j, graph in enumerate(dataset):
        X, sp_A, E_true, i = get_tensor_from_graph(graph)
        H_end = models[0](X, sp_A)
        Omega = torch.cat([H_end[i[0]], H_end[i[1]]],dim=1)
        E_pred = models[1](Omega)
        loss = criterion(E_pred, E_true)
        my_loss_list.append(loss.item())
        print(f"{(j+1)/len(dataset)*100:.2f} % loss = {my_loss_list[-1]:.5f} {' '*30}", end='\r')
    return np.mean(my_loss_list)

def split_train_val(dataset, val_split=0.2, shuffle=True, seed=1234):
    if shuffle:
        rng = np.random.default_rng(seed)
        rng.shuffle(dataset)
    train_size = int(len(dataset) * (1 - val_split))
    train_dataset = dataset[:train_size]
    val_dataset = dataset[train_size:]
    return train_dataset, val_dataset

def train_step(models, batch, optimizer, criterion):
    optimizer.zero_grad()
    my_loss_list = []
   
    for j, graph in enumerate(batch):
        X, sp_A, E_true, i = get_tensor_from_graph(graph)
        H_end = models[0](X, sp_A)
        Omega = torch.cat([H_end[i[0]], H_end[i[1]]],dim=1)
        E_pred = models[1](Omega)
        loss = criterion(E_pred, E_true)
        my_loss_list.append(loss.item())
        print(f"Batch loss={my_loss_list[-1]:.4f}" + " "*40, end="\r")
        loss.backward()
    optimizer.step()
    return np.mean(my_loss_list)

def train_model(params, models, dataset, path_save, save_frequency=5):  
    optimizer = torch.optim.Adam(
    list(models[0].parameters()) + list(models[1].parameters()),
    lr=learning_rate,
    )
    criterion = torch.nn.BCELoss()
    loss_list = []
    train_dataset, val_dataset = split_train_val(dataset, val_split=0.1)
    for k in range(params["epochs"]):
        my_loss_list = []
        
        for l, batch in enumerate(list_batchs(train_dataset, params["batch_size"])):
            batch_loss = train_step(models, batch, optimizer, criterion)
            my_loss_list.append(batch_loss)
            print(f"Batch # {l+1} loss={my_loss_list[-1]:.4f}" + " "*40)
        train_val = np.mean(my_loss_list)
        validation_val = validation(models, val_dataset, criterion)
        print("="*10, f"EPOCH #{k+1}","="*10, f"({train_val:.4f}/{validation_val:.4f})")
        with open('log.txt', 'a') as f:
            f.write(f"EPOCH #{k}\t {train_val:.8f} (VAL: {validation_val:.8f})\n")  
        if (k+1) % save_frequency == 0:
            num = k//save_frequency
            torch.save(models[0].state_dict(), path_save+f"_node_gnn_{num}")
            torch.save(models[1].state_dict(), path_save+f"_edge_linear_{num}")
    torch.save(models[0].state_dict(), path_save+f"_node_gnn_end")
    torch.save(models[1].state_dict(), path_save+f"_edge_linear_end")


In [161]:
import json
with open("../dataset.json", "r") as f:
    dataset = json.load(f)['dataset']
# with open("../delaunay_seg.json", "r") as f:
#     dataset = json.load(f)['dataset']

print("DATASET INFO:")
print("count row:", len(dataset))
print("first:", dataset[0].keys())
print(f"\t A:", np.shape(dataset[0]["A"]))
print(f"\t nodes_feature:", np.shape(dataset[0]["nodes_feature"]))
print(f"\t edges_feature:", np.shape(dataset[0]["edges_feature"]))
print(f"\t true_edges:", np.shape(dataset[0]["true_edges"]))
print("end:", dataset[-1].keys())
print(f"\t A:", np.shape(dataset[-1]["A"]))
print(f"\t nodes_feature:", np.shape(dataset[-1]["nodes_feature"]))
print(f"\t edges_feature:", np.shape(dataset[-1]["edges_feature"]))
print(f"\t true_edges:", np.shape(dataset[-1]["true_edges"]))


DATASET INFO:
count row: 1557
first: dict_keys(['A', 'nodes_feature', 'edges_feature', 'true_edges'])
	 A: (2, 779)
	 nodes_feature: (385, 9)
	 edges_feature: (779,)
	 true_edges: (779,)
end: dict_keys(['A', 'nodes_feature', 'edges_feature', 'true_edges'])
	 A: (2, 2142)
	 nodes_feature: (1039, 9)
	 edges_feature: (2142,)
	 true_edges: (2142,)


In [162]:
def rev_dist(a):
    if a==0:
        return 0
    else:
        return 1/a
        
i = dataset[0]["A"]
v_in = [rev_dist(e) for e in dataset[0]["edges_feature"]]
v_true = dataset[0]["true_edges"]
x = dataset[0]["nodes_feature"]
N = len(x)

X = torch.Tensor(x)
sp_A = torch.sparse_coo_tensor(i, v_in, (N, N))
E_true = torch.Tensor(v_true)

In [165]:
params = {
    "count_neuron_layers": [9, 27, 18],
    "epochs": 30,
    "batch_size": 60,
}

learning_rate = 0.02

node_gnn, edge_linear = get_models(params)


optimizer = torch.optim.Adam(
    list(node_gnn.parameters()) + list(edge_linear.parameters()),
    lr=learning_rate,
)
criterion = torch.nn.BCELoss()


H_end = node_gnn(X, sp_A)
Omega = torch.cat([H_end[i[0]], H_end[i[1]]],dim=1)
E_pred = edge_linear(Omega)
print(f"E_pred:\n{E_pred}", f"\nE_true:\n{E_true}")
print("Loss = ", criterion(E_pred, E_true))

del optimizer, criterion

E_pred:
tensor([0.5000, 0.9523, 1.0000, 0.9966, 0.9421, 1.0000, 0.5447, 0.5212, 1.0000,
        0.5000, 0.5000, 1.0000, 0.5000, 0.5000, 1.0000, 0.5000, 0.5000, 1.0000,
        0.5000, 0.5000, 1.0000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.7894,
        0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000,
        1.0000, 1.0000, 1.0000, 0.5000, 0.5000, 0.7894, 0.5000, 0.5000, 0.5000,
        0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000,
        0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 0.9994, 1.0000, 0.7894, 0.9997, 0.7666, 0.9999, 0.5000,
        0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000,
        0.5000, 0.5000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 0.9993, 1.0000, 0.5000, 0.9739,
        0.5000, 0.5000, 0.5000, 

In [167]:
# models =  get_models(params)
train_model(params,models, dataset, "z_torch", save_frequency=10)

Batch # 1 loss=0.3822                                        
Batch # 2 loss=0.7642                                        
Batch # 3 loss=0.4809                                        
Batch # 4 loss=0.3877                                        
Batch # 5 loss=0.4833                                        
Batch # 6 loss=0.5117                                        
Batch # 7 loss=0.4665                                        
Batch # 8 loss=0.4245                                        
Batch # 9 loss=0.4286                                        
Batch # 10 loss=0.4377                                        
Batch # 11 loss=0.4206                                        
Batch # 12 loss=0.3638                                        
Batch # 13 loss=0.4228                                        
Batch # 14 loss=0.3739                                        
Batch # 15 loss=0.4099                                        
Batch # 16 loss=0.3874                                        
B

In [116]:
for i,j in enumerate(list(models[0].parameters()) + list(models[1].parameters())):
    print(i, j)

0 Parameter containing:
tensor([[-0.1759, -0.0254,  0.2722,  0.0184,  0.1136, -0.2754,  0.2124,  0.1935,
          0.5147],
        [-0.1288, -0.4236, -0.1713, -0.2819, -0.1806, -0.0876,  0.2570,  0.0384,
          0.6041],
        [ 0.0236, -0.6083,  0.1191, -0.2275,  0.1009,  0.3368,  0.0306, -0.0403,
         -0.2145],
        [-0.3864, -0.2112, -0.0350,  0.4570, -0.1351,  0.7280, -0.0679, -0.1624,
          0.1483],
        [-0.9543, -0.0974, -0.1931,  0.2187,  0.0159,  0.0479,  0.0109, -0.0153,
          0.5601],
        [-0.1470, -0.1628, -0.2523, -0.6196, -0.3178,  0.0788,  0.0615, -0.3432,
          0.1215],
        [-0.4620,  0.2739, -0.5049, -0.0629, -0.6254,  0.0990, -0.1985, -0.0373,
         -0.0076],
        [ 0.4061, -0.1382,  0.4136,  0.3187,  0.3797,  0.0730,  0.4111,  0.6153,
          0.3442],
        [ 0.0625, -0.2081, -0.0831,  0.3625,  0.2356, -0.1471,  0.1947,  0.3956,
         -0.0047],
        [ 0.4516,  0.8717,  0.9890,  0.7712,  0.4401,  0.6299,  0.2625,  0.5