In [7]:
import itertools as it

import networkx as nx
import numpy as np

import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
from torch_geometric.data import Data


In [8]:
class GCN(torch.nn.Module):
    def __init__(self, node_features):
        super().__init__()
        self.conv1 = GCNConv(node_features, 64)
        self.conv2 = GCNConv(64, 32)

    def forward(self, data):
        x, edge_index = data.x, data.edge_index

        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = self.conv2(x, edge_index)

        return x


In [9]:
# utils 

def read_sat(sat_path):
    with open(sat_path) as f:
        sat_lines = f.readlines()
        header = sat_lines[0]
        header_info = header.replace("\n", "").split(" ")
        num_vars = int(header_info[-2])
        num_clauses = int(header_info[-1])

        sat = [[int(x) for x in line.replace(' 0\n', '').split(' ')]
               for line in sat_lines[1:]]

        return sat, num_vars, num_clauses


def sat_to_lig_adjacency_matrix(sat, num_vars):
    get_literal_idx = lambda x: 2 * x - 2 if x > 0 else 2 * abs(x) - 1
    lig_adjacency_matrix = np.zeros([2*num_vars, 2*num_vars])
    lig_weighted_adjacency_matrix = np.zeros([2*num_vars, 2*num_vars])

    for clause in sat:
        pairs = it.combinations(clause, 2)
#         print(f'clause: {clause}')
        for pair in pairs:
            x_idx = get_literal_idx(pair[0])
            y_idx = get_literal_idx(pair[1])
#             print(f'pair: {(x_idx, y_idx)}')
            lig_adjacency_matrix[x_idx, y_idx] = 1
            lig_adjacency_matrix[y_idx, x_idx] = 1
            lig_weighted_adjacency_matrix[x_idx, y_idx] += 1
            lig_weighted_adjacency_matrix[y_idx, x_idx] += 1    
    return lig_adjacency_matrix, lig_weighted_adjacency_matrix

In [10]:
sat_path = './ssa2670-141.processed.cnf'
sat_instance, num_vars, num_clauses = read_sat(sat_path)

lig_adjacency_matrix, lig_weighted_adjacency_matrix = sat_to_lig_adjacency_matrix(sat_instance, num_vars)

# graph = nx.from_numpy_matrix(lig_adjacency_matrix)
# edges = nx.to_edgelist(graph)
# print(lig_adjacency_matrix.nonzero())

edge_index = torch.tensor(lig_adjacency_matrix.nonzero(), dtype=torch.long)
edge_value = lig_weighted_adjacency_matrix[lig_adjacency_matrix.nonzero()]
print(edge_value)
max_edge_value = max(edge_value)
norm_edge_value = edge_value/max_edge_value
print(norm_edge_value)
embeddings = torch.load('./embeddings.pt')
embeddings.requires_grad = False
# print(embeddings)
x = embeddings

data = Data(x=x, edge_index=edge_index, norm_edge_value=norm_edge_value)


[6. 6. 1. ... 2. 1. 3.]
[1.         1.         0.16666667 ... 0.33333333 0.16666667 0.5       ]


In [11]:
# training
model = GCN(50)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)
# print(norm_edge_value)
model.train()
for epoch in range(200):
    optimizer.zero_grad()
    out = model(data)
    src, dst = edge_index
    score = (out[src] * out[dst]).sum(dim=-1)
    score = torch.sigmoid(score)
    loss = F.mse_loss(score, torch.tensor(norm_edge_value, dtype=torch.float))
    print(f'epoch: {epoch}, loss: {loss.item()}')
    loss.backward()
    optimizer.step()


epoch: 0, loss: 0.3352663815021515
epoch: 1, loss: 0.16129252314567566
epoch: 2, loss: 0.11396269500255585
epoch: 3, loss: 0.11215722560882568
epoch: 4, loss: 0.11726173013448715
epoch: 5, loss: 0.118002749979496
epoch: 6, loss: 0.11455845087766647
epoch: 7, loss: 0.10962137579917908
epoch: 8, loss: 0.10485920310020447
epoch: 9, loss: 0.10081926733255386
epoch: 10, loss: 0.09754770994186401
epoch: 11, loss: 0.09503047913312912
epoch: 12, loss: 0.09319749474525452
epoch: 13, loss: 0.09193914383649826
epoch: 14, loss: 0.09112832695245743
epoch: 15, loss: 0.09065461158752441
epoch: 16, loss: 0.09039942920207977
epoch: 17, loss: 0.09026653319597244
epoch: 18, loss: 0.09018757194280624
epoch: 19, loss: 0.09012112766504288
epoch: 20, loss: 0.09004461020231247
epoch: 21, loss: 0.08994879573583603
epoch: 22, loss: 0.08982842415571213
epoch: 23, loss: 0.08969057351350784
epoch: 24, loss: 0.08954218775033951
epoch: 25, loss: 0.08939193189144135
epoch: 26, loss: 0.0892476886510849
epoch: 27, loss