In [1]:
import itertools as it
import time
from utils import *
# from pysat.solvers import Glucose3

import numpy as np
from scipy import sparse

import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
from torch_geometric.data import Data

In [2]:
# CELL
import warnings
warnings.filterwarnings('ignore')

import pickle
import numpy as np
import scipy.sparse as sp
from scipy.sparse import load_npz

import torch

from cell.utils import link_prediction_performance
from cell.cell import Cell, EdgeOverlapCriterion, LinkPredictionCriterion
from cell.graph_statistics import compute_graph_statistics


In [3]:
class GCN(torch.nn.Module):
    def __init__(self, node_features):
        super().__init__()
        # GCN initialization
        self.conv1 = GCNConv(node_features, 128)
        self.conv2 = GCNConv(128, 128)

    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = self.conv2(x, edge_index)

        return x


In [4]:
sat_name = 'ssa2670-141.processed.cnf'
# sat_name = 'mrpp_4x4#4_5.processed.cnf'
# sat_name = 'countbitsrotate016.processed.cnf'

sat_path = f'./dataset/formulas/{sat_name}'
num_vars, num_clauses, sat_instance = read_sat(sat_path)
max_len = max([len(clause) for clause in sat_instance])

lig_adjacency_matrix, lig_weighted_adjacency_matrix = sat_to_lig_adjacency_matrix(sat_instance, num_vars)
edge_index = torch.tensor(np.array(lig_adjacency_matrix.nonzero()), dtype=torch.long)
edge_value = lig_weighted_adjacency_matrix[lig_adjacency_matrix.nonzero()]

embeddings = torch.load(f'./model/embeddings/{sat_name}.pt')
embeddings.requires_grad = False
x = embeddings
data = Data(x=x, edge_index=edge_index)


In [5]:
# training for GNN
model = GCN(50)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)
model.train()
for epoch in range(500):
    optimizer.zero_grad()
    out = model(data)
    src, dst = edge_index
    score = (out[src] * out[dst]).sum(dim=-1)
    loss = F.mse_loss(score, torch.tensor(edge_value, dtype=torch.float))
    loss.backward()
    optimizer.step()
    # print(f'epoch: {epoch}, loss: {loss.item()}')


In [7]:
# training for CELL
sparse_matrix = sparse.csr_matrix(lig_adjacency_matrix)
cell_model = Cell(A=sparse_matrix,
             H=9,
             callbacks=[EdgeOverlapCriterion(invoke_every=5, edge_overlap_limit=.80)])
cell_model.train(steps=400,
            optimizer_fn=torch.optim.Adam,
            optimizer_args={'lr': 0.1,
                            'weight_decay': 1e-7})


Step:   5/400 Loss: 4.79573 Edge-Overlap: 0.142 Total-Time: 0
Step:  10/400 Loss: 3.92074 Edge-Overlap: 0.400 Total-Time: 0
Step:  15/400 Loss: 3.58263 Edge-Overlap: 0.508 Total-Time: 0
Step:  20/400 Loss: 3.38414 Edge-Overlap: 0.567 Total-Time: 0
Step:  25/400 Loss: 3.25506 Edge-Overlap: 0.599 Total-Time: 0
Step:  30/400 Loss: 3.15880 Edge-Overlap: 0.630 Total-Time: 0
Step:  35/400 Loss: 3.08748 Edge-Overlap: 0.644 Total-Time: 0
Step:  40/400 Loss: 3.03414 Edge-Overlap: 0.667 Total-Time: 0
Step:  45/400 Loss: 2.98964 Edge-Overlap: 0.688 Total-Time: 0
Step:  50/400 Loss: 2.95379 Edge-Overlap: 0.705 Total-Time: 0
Step:  55/400 Loss: 2.92406 Edge-Overlap: 0.729 Total-Time: 0
Step:  60/400 Loss: 2.89797 Edge-Overlap: 0.732 Total-Time: 0
Step:  65/400 Loss: 2.87539 Edge-Overlap: 0.728 Total-Time: 0
Step:  70/400 Loss: 2.85603 Edge-Overlap: 0.741 Total-Time: 0
Step:  75/400 Loss: 2.83946 Edge-Overlap: 0.732 Total-Time: 0
Step:  80/400 Loss: 2.82511 Edge-Overlap: 0.768 Total-Time: 0
Step:  8

In [8]:
# generate WLIG
generated_graph = cell_model.sample_graph()
graph_prime = generated_graph.A
graph_prime = graph_post_process(graph_prime)
edge_index_prime = torch.tensor(graph_prime.nonzero(), dtype=torch.long)
x = embeddings
data_prime = Data(x=x, edge_index = edge_index_prime)
out = model(data_prime)
src, dst = edge_index_prime
score = (out[src] * out[dst]).sum(dim=-1)
weight = score.detach().numpy()
weight[weight <= 1] = 1
weight = np.rint(weight).astype(int)
weighted_graph_prime = np.copy(graph_prime)
weighted_graph_prime[weighted_graph_prime.nonzero()] = weight

# decode formulas from WLIG
clique_candidates = get_clique_candidates(graph_prime, max_len)
current_cliques = lazy_clique_edge_cover(np.copy(weighted_graph_prime), clique_candidates, int(num_clauses/1.5))
current_sat = cliques_to_sat(current_cliques)

# g = Glucose3(bootstrap_with=current_sat)
# g = Glucose3(bootstrap_with=sat_instance)
# %time print(g.solve())

In [9]:
# evaluate the graph metrics for the generation instance
features = [
        "clu. VIG",
        "clu. LIG",
        "mod. VIG",
        "mod. LIG",
        "mod. VCG",
        "mod. LCG"
]

metrics = eval_solution(current_sat, num_vars)
for feature, value in zip(features, metrics):
    print(f'{feature}: {value}')

clu. VIG: 0.5516900582687099
clu. LIG: 0.4281033441155783
mod. VIG: 0.5400413223140497
mod. LIG: 0.5879060545758404
mod. VCG: 0.6727841963231573
mod. LCG: 0.49821268090759574


In [10]:
# evaluate the graph metrics for the original instance
features = [
        "clu. VIG",
        "clu. LIG",
        "mod. VIG",
        "mod. LIG",
        "mod. VCG",
        "mod. LCG"
]

metrics = eval_solution(sat_instance, num_vars)
for feature, value in zip(features, metrics):
    print(f'{feature}: {value}')

clu. VIG: 0.5821450231177787
clu. LIG: 0.35092989061568536
mod. VIG: 0.5199306960289521
mod. LIG: 0.5607388965140568
mod. VCG: 0.643919413674837
mod. LCG: 0.5846741230677764
