In [12]:
import itertools as it
import time
from utils import *
from pysat.solvers import Glucose3

import numpy as np
from scipy import sparse

import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
# from torch_geometric.nn import GATConv
from torch_geometric.data import Data

In [13]:
# CELL
import warnings
warnings.filterwarnings('ignore')

import pickle
import numpy as np
import scipy.sparse as sp
from scipy.sparse import load_npz

import torch

from cell.utils import link_prediction_performance
from cell.cell import Cell, EdgeOverlapCriterion, LinkPredictionCriterion
from cell.graph_statistics import compute_graph_statistics

In [14]:
class GCN(torch.nn.Module):
    def __init__(self, node_features):
        super().__init__()
        # GCN initialization
        self.conv1 = GCNConv(node_features, 128)
        self.conv2 = GCNConv(128, 128)
        # self.conv1 = GATConv(node_features, 64, 5)
        # self.conv2 = GATConv(64 * 5, 128)
        # self.conv3 = GCNConv(128, 128)

    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        # x = F.elu(x)
        # x = F.dropout(x, training=self.training)
        x = self.conv2(x, edge_index)
        # x = F.tanh(x)
        # x = self.conv3(x, edge_index)

        return x

In [15]:
sat_name = 'ssa2670-141.processed.cnf'
sat_path = f'./dataset/formulas/{sat_name}'
num_vars, num_clauses, sat_instance = read_sat(sat_path)
max_len = max([len(clause) for clause in sat_instance])

lig_adjacency_matrix, lig_weighted_adjacency_matrix = sat_to_lig_adjacency_matrix(sat_instance, num_vars)
edge_index = torch.tensor(np.array(lig_adjacency_matrix.nonzero()), dtype=torch.long)
edge_value = lig_weighted_adjacency_matrix[lig_adjacency_matrix.nonzero()]

embeddings = torch.load(f'./model/embeddings/{sat_name}.pt')
embeddings.requires_grad = False
x = embeddings
data = Data(x=x, edge_index=edge_index)

In [16]:
ranks = [8, 10, 12]
graphs = []
metrics_table = []

for rank in ranks:
    print(f'rank: {rank}')
    model = GCN(50)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)
    model.train()
    for epoch in range(500):
        optimizer.zero_grad()
        out = model(data)
        src, dst = edge_index
        score = (out[src] * out[dst]).sum(dim=-1)
        loss = F.mse_loss(score, torch.tensor(edge_value, dtype=torch.float))
        loss.backward()
        optimizer.step()
    
    sparse_matrix = sparse.csr_matrix(lig_adjacency_matrix)
    cell_model = Cell(A=sparse_matrix,
                H=rank,
                callbacks=[EdgeOverlapCriterion(invoke_every=5, edge_overlap_limit=.80)])
    cell_model.train(steps=400,
                optimizer_fn=torch.optim.Adam,
                optimizer_args={'lr': 0.1,
                                'weight_decay': 1e-7})

    generated_graph = cell_model.sample_graph()
    graph_prime = generated_graph.A
    graph_prime = graph_post_process(graph_prime)

    edge_index_prime = torch.tensor(graph_prime.nonzero(), dtype=torch.long)
    x = embeddings
    data_prime = Data(x=x, edge_index = edge_index_prime)
    out = model(data_prime)
    src, dst = edge_index_prime
    score = (out[src] * out[dst]).sum(dim=-1)
    weight = score.detach().numpy()
    weight[weight <= 1] = 1
    weight = np.rint(weight).astype(int)

    weighted_graph_prime = np.copy(graph_prime)
    weighted_graph_prime[weighted_graph_prime.nonzero()] = weight

    graphs.append(weighted_graph_prime)

    clique_candidates = get_clique_candidates(graph_prime, max_len)
    current_cliques = lazy_clique_edge_cover(np.copy(weighted_graph_prime), clique_candidates, int(num_clauses/1.5))
    current_sat = cliques_to_sat(current_cliques)

    metrics = eval_solution(current_sat, num_vars)
    item = [f"rank = {rank}", num_vars, num_clauses]
    item.extend(metrics)
    metrics_table.append(item)




rank: 8
Step:   5/400 Loss: 4.87521 Edge-Overlap: 0.113 Total-Time: 0
Step:  10/400 Loss: 4.05085 Edge-Overlap: 0.337 Total-Time: 0
Step:  15/400 Loss: 3.66544 Edge-Overlap: 0.480 Total-Time: 0
Step:  20/400 Loss: 3.47320 Edge-Overlap: 0.524 Total-Time: 0
Step:  25/400 Loss: 3.34359 Edge-Overlap: 0.572 Total-Time: 0
Step:  30/400 Loss: 3.24573 Edge-Overlap: 0.575 Total-Time: 0
Step:  35/400 Loss: 3.17591 Edge-Overlap: 0.609 Total-Time: 0
Step:  40/400 Loss: 3.12405 Edge-Overlap: 0.626 Total-Time: 0
Step:  45/400 Loss: 3.08345 Edge-Overlap: 0.647 Total-Time: 0
Step:  50/400 Loss: 3.05053 Edge-Overlap: 0.642 Total-Time: 0
Step:  55/400 Loss: 3.02281 Edge-Overlap: 0.656 Total-Time: 0
Step:  60/400 Loss: 2.99902 Edge-Overlap: 0.669 Total-Time: 0
Step:  65/400 Loss: 2.97830 Edge-Overlap: 0.691 Total-Time: 0
Step:  70/400 Loss: 2.95982 Edge-Overlap: 0.699 Total-Time: 0
Step:  75/400 Loss: 2.94291 Edge-Overlap: 0.707 Total-Time: 0
Step:  80/400 Loss: 2.92792 Edge-Overlap: 0.722 Total-Time: 0


In [18]:
print(graphs)
print(metrics_table)

import csv

graph_names = ['rank-8', 'rank-10', 'rank-12']

fileds = ['Source','Target','Type','Kind','Id','Label','Weight']
with open('./analysis/ssa2670-141-rank-8-10-12.csv', 'w') as csvfile:
    csvwriter = csv.writer(csvfile, delimiter=',')
    csvwriter.writerow(fileds)
    idx = 1
    for graph_name, graph in zip(graph_names, graphs):
        triu_adjacency_matrix = np.triu(graph)
        x, y = triu_adjacency_matrix.nonzero()
        for i, j in zip(x, y):
            csvwriter.writerow([i, j, 'Undirected', graph_name, idx, graph_name, triu_adjacency_matrix[i][j]])
            idx += 1

[array([[0., 0., 4., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [4., 0., 0., ..., 0., 0., 0.],
       ...,
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.]]), array([[0., 0., 4., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [4., 0., 0., ..., 0., 0., 0.],
       ...,
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.]]), array([[0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       ...,
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.]])]
[['rank = 8', 91, 377, 0.5480501602648356, 0.4607742946230321, 0.5192531112161312, 0.5573882153753698, 0.6514862989941034, 0.5000191184428087], ['rank = 10', 91, 377, 0.5519516442127347, 0.4237207780067638, 0.516029007338207, 0.5571179594654263, 0.6413166614915873, 0.50059

In [19]:
for line in metrics_table:
    print(f'{line[0]} & ' + ' & '.join([f'{x:.3f}' for x in line[1:]]) + '\\\\')

rank = 8 & 91.000 & 377.000 & 0.548 & 0.461 & 0.519 & 0.557 & 0.651 & 0.500\\
rank = 10 & 91.000 & 377.000 & 0.552 & 0.424 & 0.516 & 0.557 & 0.641 & 0.501\\
rank = 12 & 91.000 & 377.000 & 0.539 & 0.421 & 0.511 & 0.554 & 0.618 & 0.496\\
