In [2]:
import itertools as it

import networkx as nx
import numpy as np

import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
from torch_geometric.data import Data


In [3]:
class GCN(torch.nn.Module):
    def __init__(self, node_features):
        super().__init__()
        self.conv1 = GCNConv(node_features, 64)
        self.conv2 = GCNConv(64, 32)

    def forward(self, data):
        x, edge_index = data.x, data.edge_index

        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = self.conv2(x, edge_index)

        return torch.sigmoid(x)


In [4]:
# utils 

def read_sat(sat_path):
    with open(sat_path) as f:
        sat_lines = f.readlines()
        header = sat_lines[0]
        header_info = header.replace("\n", "").split(" ")
        num_vars = int(header_info[-2])
        num_clauses = int(header_info[-1])

        sat = [[int(x) for x in line.replace(' 0\n', '').split(' ')]
               for line in sat_lines[1:]]

        return sat, num_vars, num_clauses


def sat_to_lig_adjacency_matrix(sat, num_vars):
    get_literal_idx = lambda x: 2 * x - 2 if x > 0 else 2 * abs(x) - 1
    lig_adjacency_matrix = np.zeros([2*num_vars, 2*num_vars])
    lig_weighted_adjacency_matrix = np.zeros([2*num_vars, 2*num_vars])

    for clause in sat:
        pairs = it.combinations(clause, 2)
#         print(f'clause: {clause}')
        for pair in pairs:
            x_idx = get_literal_idx(pair[0])
            y_idx = get_literal_idx(pair[1])
#             print(f'pair: {(x_idx, y_idx)}')
            lig_adjacency_matrix[x_idx, y_idx] = 1
            lig_adjacency_matrix[y_idx, x_idx] = 1
            lig_weighted_adjacency_matrix[x_idx, y_idx] += 1
            lig_weighted_adjacency_matrix[y_idx, x_idx] += 1    
    return lig_adjacency_matrix, lig_weighted_adjacency_matrix

In [5]:
sat_path = './ssa2670-141.processed.cnf'
sat_instance, num_vars, num_clauses = read_sat(sat_path)

lig_adjacency_matrix, lig_weighted_adjacency_matrix = sat_to_lig_adjacency_matrix(sat_instance, num_vars)

# graph = nx.from_numpy_matrix(lig_adjacency_matrix)
# edges = nx.to_edgelist(graph)
# print(lig_adjacency_matrix.nonzero())

edge_index = torch.tensor(lig_adjacency_matrix.nonzero(), dtype=torch.long)
edge_value = lig_weighted_adjacency_matrix[lig_adjacency_matrix.nonzero()]
print(edge_value)
max_edge_value = max(edge_value)
norm_edge_value = edge_value/max_edge_value
print(norm_edge_value)
embeddings = torch.load('./embeddings.pt')
embeddings.requires_grad = False
# print(embeddings)
x = embeddings

data = Data(x=x, edge_index=edge_index, norm_edge_value=norm_edge_value)


[6. 6. 1. ... 2. 1. 3.]
[1.         1.         0.16666667 ... 0.33333333 0.16666667 0.5       ]


  edge_index = torch.tensor(lig_adjacency_matrix.nonzero(), dtype=torch.long)


In [6]:
# training
model = GCN(50)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)
# print(norm_edge_value)
model.train()
for epoch in range(200):
    optimizer.zero_grad()
    out = model(data)
    src, dst = edge_index
    score = (out[src] * out[dst]).sum(dim=-1)
    loss = F.mse_loss(score, torch.tensor(norm_edge_value, dtype=torch.float))
    print(f'epoch: {epoch}, loss: {loss.item()}')
    loss.backward()
    optimizer.step()


epoch: 0, loss: 61.42732620239258
epoch: 1, loss: 48.326412200927734
epoch: 2, loss: 37.22016906738281
epoch: 3, loss: 27.406429290771484
epoch: 4, loss: 18.952856063842773
epoch: 5, loss: 12.169828414916992
epoch: 6, loss: 7.223531723022461
epoch: 7, loss: 3.9763169288635254
epoch: 8, loss: 2.052755117416382
epoch: 9, loss: 1.0098166465759277
epoch: 10, loss: 0.483527809381485
epoch: 11, loss: 0.23297622799873352
epoch: 12, loss: 0.12008927762508392
epoch: 13, loss: 0.07266464084386826
epoch: 14, loss: 0.05514013394713402
epoch: 15, loss: 0.05061937868595123
epoch: 16, loss: 0.05136674642562866
epoch: 17, loss: 0.05394947528839111
epoch: 18, loss: 0.05691593512892723
epoch: 19, loss: 0.05969525873661041
epoch: 20, loss: 0.06210372969508171
epoch: 21, loss: 0.06411907821893692
epoch: 22, loss: 0.0657782182097435
epoch: 23, loss: 0.06713581085205078
epoch: 24, loss: 0.0682462528347969
epoch: 25, loss: 0.06915657967329025
epoch: 26, loss: 0.06990573555231094
epoch: 27, loss: 0.0705254077