In [1]:
import itertools as it
import time
import os
from utils import *
from pysat.solvers import Glucose3

import numpy as np
from scipy import sparse

import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv

# from torch_geometric.nn import GATConv
from torch_geometric.data import Data

# CELL
import warnings

warnings.filterwarnings("ignore")

import pickle
import scipy.sparse as sp
import torch

from cell.utils import link_prediction_performance
from cell.cell import Cell, EdgeOverlapCriterion, LinkPredictionCriterion
from cell.graph_statistics import compute_graph_statistics


class GCN(torch.nn.Module):
    def __init__(self, node_features):
        super().__init__()
        # GCN initialization
        self.conv1 = GCNConv(node_features, 128)
        self.conv2 = GCNConv(128, 128)
        # self.conv1 = GATConv(node_features, 64, 5)
        # self.conv2 = GATConv(64 * 5, 128)
        # self.conv3 = GCNConv(128, 128)

    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        # x = F.elu(x)
        # x = F.dropout(x, training=self.training)
        x = self.conv2(x, edge_index)
        # x = F.tanh(x)
        # x = self.conv3(x, edge_index)

        return x




In [2]:
metrics_table = []
owc_table = []


formulas_path = "./dataset/formulas/"
sat_names = os.listdir(formulas_path)

for sat_name in sat_names:
    print(sat_name)
    sat_path = f"./dataset/formulas/{sat_name}"
    num_vars, num_clauses, sat_instance = read_sat(sat_path)
    max_len = max([len(clause) for clause in sat_instance])

    # metric for original instance
    metrics = eval_solution(sat_instance, num_vars)
    item = [sat_name.split(".")[0], num_vars, num_clauses]
    item.extend(metrics)
    metrics_table.append(item)

    lig_adjacency_matrix, lig_weighted_adjacency_matrix = sat_to_lig_adjacency_matrix(
        sat_instance, num_vars
    )

    start_time = time.time()

    clique_candidates = get_clique_candidates(lig_adjacency_matrix, max_len)
    current_cliques = lazy_clique_edge_cover(
        np.copy(lig_weighted_adjacency_matrix),
        clique_candidates,
        int(num_clauses / 1.5),
    )

    # metric of owc incstance
    current_sat = cliques_to_sat(current_cliques)
    metrics = eval_solution(current_sat, num_vars)
    item = ["OWC for origin", num_vars, num_clauses]
    item.extend(metrics)
    metrics_table.append(item)

    owc_time = time.time() - start_time
    owc_table.append([sat_name.split(".")[0], num_vars, num_clauses, owc_time])

    edge_index = torch.tensor(
        np.array(lig_adjacency_matrix.nonzero()), dtype=torch.long
    )
    print(edge_index.shape)
    edge_value = lig_weighted_adjacency_matrix[lig_adjacency_matrix.nonzero()]

    embeddings = torch.load(f"./model/embeddings/{sat_name}.pt")
    embeddings.requires_grad = False
    # print(embeddings)
    x = embeddings
    data = Data(x=x, edge_index=edge_index)

    # training
    model = GCN(50)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)
    model.train()
    for epoch in range(500):
        optimizer.zero_grad()
        out = model(data)
        src, dst = edge_index
        score = (out[src] * out[dst]).sum(dim=-1)
        # score = torch.sigmoid(score)
        loss = F.mse_loss(score, torch.tensor(edge_value, dtype=torch.float))
        loss.backward()
        optimizer.step()
        # print(f'epoch: {epoch}, loss: {loss.item()}')

    out = model(data)
    src, dst = edge_index
    score = (out[src] * out[dst]).sum(dim=-1)

    sparse_matrix = sparse.csr_matrix(lig_adjacency_matrix)
    cell_model = Cell(
        A=sparse_matrix,
        H=12,
        callbacks=[EdgeOverlapCriterion(invoke_every=10, edge_overlap_limit=0.80)],
    )

    cell_model.train(
        steps=400,
        optimizer_fn=torch.optim.Adam,
        optimizer_args={"lr": 0.1, "weight_decay": 1e-7},
    )

    generated_graph = cell_model.sample_graph()
    graph_prime = generated_graph.A
    graph_prime = graph_post_process(graph_prime)
    edge_index_prime = torch.tensor(graph_prime.nonzero(), dtype=torch.long)
    x = embeddings
    data_prime = Data(x=x, edge_index=edge_index_prime)
    out = model(data_prime)
    src, dst = edge_index_prime
    score = (out[src] * out[dst]).sum(dim=-1)
    weight = score.detach().numpy()
    weight[weight <= 1] = 1
    weight = np.rint(weight).astype(int)

    weighted_graph_prime = np.copy(graph_prime)
    weighted_graph_prime[weighted_graph_prime.nonzero()] = weight

    clique_candidates = get_clique_candidates(graph_prime, max_len)
    current_cliques = lazy_clique_edge_cover(
        np.copy(weighted_graph_prime), clique_candidates, int(num_clauses / 1.5)
    )
    current_sat = cliques_to_sat(current_cliques)
    metrics = eval_solution(current_sat, num_vars)
    item = ["OWC for gen", num_vars, num_clauses]
    item.extend(metrics)
    metrics_table.append(item)


sat_prob_83.processed.cnf
torch.Size([2, 37774])
Step:  10/400 Loss: 5.64842 Edge-Overlap: 0.126 Total-Time: 2
Step:  20/400 Loss: 3.88866 Edge-Overlap: 0.489 Total-Time: 4
Step:  30/400 Loss: 3.20061 Edge-Overlap: 0.599 Total-Time: 7
Step:  40/400 Loss: 2.90125 Edge-Overlap: 0.688 Total-Time: 9
Step:  50/400 Loss: 2.75813 Edge-Overlap: 0.757 Total-Time: 11
Step:  60/400 Loss: 2.67814 Edge-Overlap: 0.805 Total-Time: 14
mrpp_4x4#4_5.processed.cnf
torch.Size([2, 14472])
Step:  10/400 Loss: 4.66522 Edge-Overlap: 0.402 Total-Time: 0
Step:  20/400 Loss: 4.02191 Edge-Overlap: 0.613 Total-Time: 0
Step:  30/400 Loss: 3.78049 Edge-Overlap: 0.665 Total-Time: 0
Step:  40/400 Loss: 3.65811 Edge-Overlap: 0.716 Total-Time: 1
Step:  50/400 Loss: 3.59050 Edge-Overlap: 0.749 Total-Time: 1
Step:  60/400 Loss: 3.54844 Edge-Overlap: 0.768 Total-Time: 1
Step:  70/400 Loss: 3.52029 Edge-Overlap: 0.785 Total-Time: 1
Step:  80/400 Loss: 3.50037 Edge-Overlap: 0.796 Total-Time: 2
Step:  90/400 Loss: 3.48556 Edg

In [None]:
metrics_table
for line in metrics_table:
    # print(line)
    print(f'{line[0]} & ' + ' & '.join([f'{x:.3f}' for x in line[1:]]) + '\\\\')

sat_prob_83 & 1759.000 & 8012.000 & 0.441 & 0.388 & 0.760 & 0.796 & 0.871 & 0.757\\
OWC for origin & 1759.000 & 8012.000 & 0.505 & 0.440 & 0.773 & 0.812 & 0.893 & 0.717\\
OWC for gen & 1759.000 & 8012.000 & 0.497 & 0.423 & 0.739 & 0.810 & 0.887 & 0.671\\
mrpp_4x4#4_5 & 309.000 & 2517.000 & 0.428 & 0.357 & 0.469 & 0.538 & 0.784 & 0.719\\
OWC for origin & 309.000 & 2517.000 & 0.427 & 0.359 & 0.472 & 0.553 & 0.754 & 0.603\\
OWC for gen & 309.000 & 2517.000 & 0.429 & 0.356 & 0.458 & 0.543 & 0.718 & 0.558\\
aes_64_1_keyfind_1 & 320.000 & 2088.000 & 0.457 & 0.393 & 0.655 & 0.661 & 0.754 & 0.672\\
OWC for origin & 320.000 & 2088.000 & 0.789 & 0.654 & 0.789 & 0.782 & 0.820 & 0.666\\
OWC for gen & 320.000 & 2088.000 & 0.551 & 0.437 & 0.710 & 0.714 & 0.765 & 0.587\\
bmc-ibm-2 & 119.000 & 573.000 & 0.627 & 0.357 & 0.617 & 0.625 & 0.661 & 0.651\\
OWC for origin & 119.000 & 573.000 & 0.646 & 0.455 & 0.613 & 0.631 & 0.644 & 0.537\\
OWC for gen & 119.000 & 573.000 & 0.597 & 0.496 & 0.599 & 0.631 & 0.