In [None]:
## INCOMPLETE

In [None]:
!nvidia-smi -L

GPU 0: Tesla T4 (UUID: GPU-aed0fc74-9b8d-2199-78f0-2024639666fe)


In [None]:
# | export
import jax
import jax.numpy as jnp
import numpy as np
import networkx as nx  # for making graphs
import optax  # for optimizing GNN with Adam

from flax import linen as nn  # for defining the GNN
from flax.training import train_state  # utility for training
# from pyqubo import Array  # for defining the QUBO
from tqdm.notebook import trange, tqdm  # visualizing notebook progress

In [None]:
# | export
def qubo_approx_cost(probs, Q):
    cost = jnp.sum(jnp.matmul(jnp.matmul(jnp.transpose(probs), Q), probs))
    return cost

In [None]:
# | export
def compute_metrics(*, probs, q_matrix):
    energy = qubo_approx_cost(probs=probs, Q=q_matrix)
    metrics = {
        "energy": energy,
    }
    return metrics

In [None]:
# | export
class GraphConvLayer(nn.Module):
    c_out: int  # Output feature size

    @nn.compact
    def __call__(self, node_feats, adj_matrix):
        node_feats_w1 = nn.Dense(features=self.c_out, use_bias=False)(node_feats)
        node_feats_w2 = jax.lax.batch_matmul(adj_matrix, node_feats)
        node_feats_w2 = nn.Dense(features=self.c_out)(node_feats_w2)
        return node_feats_w1 + node_feats_w2

In [None]:
# | export
class CombGNN(nn.Module):
    hidden_size: int
    num_classes: int
    dropout_frac: float

    @nn.compact
    def __call__(self, node_feats, adj_matrix, train=False):
        # First convolution
        h = GraphConvLayer(c_out=self.hidden_size)(
            node_feats=node_feats, adj_matrix=adj_matrix
        )
        h = nn.relu(h)
        h = nn.Dropout(rate=self.dropout_frac, deterministic=not train)(h)
        # Second convolution
        # h = GraphConvLayer(c_out=self.hidden_size//2)(
        #     node_feats=h, adj_matrix=adj_matrix
        # )
        # h = nn.relu(h)
        # h = nn.Dropout(rate=self.dropout_frac, deterministic=not train)(h)
        # Third convolution
        h = GraphConvLayer(c_out=self.num_classes)(node_feats=h, adj_matrix=adj_matrix)
        probs = nn.sigmoid(h)

        return probs

In [None]:
# | export
def create_train_state(
    n_vertices, embedding_size, hidden_size, rng, learning_rate, dropout_frac=0.0
):
    gnn = CombGNN(hidden_size=hidden_size, num_classes=1, dropout_frac=dropout_frac)
    dropout_rng = jax.random.PRNGKey(0)
    params = gnn.init(
        rngs={"params": rng, "dropout": dropout_rng},
        node_feats=jnp.ones([n_vertices, embedding_size]),
        adj_matrix=jnp.ones([n_vertices, n_vertices]),
        train=True,
    )["params"]
    tx = optax.adam(learning_rate)
    return train_state.TrainState.create(apply_fn=gnn.apply, params=params, tx=tx)

In [None]:
# | export
@jax.jit
def train_step(state, node_embeddings, adj_matrix, q_matrix, dropout_rng):
    """Train for a single step."""

    def cost_fn(params):
        probs = state.apply_fn(
            {"params": params},
            node_embeddings,
            adj_matrix,
            rngs={"dropout": dropout_rng},
            train=True,
        )
        cost = qubo_approx_cost(probs=probs, Q=q_matrix)
        return cost, probs

    grad_fn = jax.grad(cost_fn, has_aux=True)
    grads, probs = grad_fn(state.params)
    state = state.apply_gradients(grads=grads)
    metrics = compute_metrics(probs=probs, q_matrix=q_matrix)
    return state, metrics

In [None]:
# | export
def get_classification(apply_fn, params, node_embeddings, adj_matrix):
    pred_probs = apply_fn({"params": params}, node_embeddings, adj_matrix)
    classification = jnp.where(pred_probs >= 0.5, 1, 0)
    return np.ravel(classification)

In [None]:
import scipy.io
mat = scipy.io.loadmat('/content/fault_resolution_set_TE_50fault.mat')
# mat2 = scipy.io.loadmat('/content/fault_resolution_set_TE_combination.mat')

In [None]:
mat['B'].shape

(561, 50)

In [None]:
from scipy import sparse
matrix = sparse.csr_matrix(np.array(mat['B']))
#;matrix2 = mat2['faultnumbering']

In [None]:
from networkx.algorithms import bipartite

In [None]:
gg = bipartite.from_biadjacency_matrix(matrix)
node_1 = set(range(0, 561))
node_2 = list(range(561,611))

In [None]:
sub_set = []
for n in node_2:
  sub_set.append(list(gg.neighbors(n)))

In [None]:
my_set = set((tuple(x)) for x in sub_set)
print(my_set)

{(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 24, 25, 26, 27, 28, 29, 53, 54, 55, 62, 63, 64, 84, 85, 86, 93, 94, 95, 114, 115, 116, 123, 124, 125, 143, 144, 145, 152, 153, 154, 171, 172, 173, 180, 181, 182, 198, 199, 200, 207, 208, 209, 217, 224, 225, 226, 233, 234, 235, 249, 250, 251, 258, 259, 260, 273, 274, 275, 282, 283, 284, 296, 297, 298, 305, 306, 307, 318, 319, 320, 327, 328, 329, 339, 340, 341, 348, 349, 350, 359, 360, 361, 368, 369, 370, 378, 379, 380, 387, 388, 389, 396, 397, 398, 405, 406, 407, 413, 414, 415, 422, 423, 424, 429, 430, 431, 438, 439, 440, 444, 445, 446, 453, 454, 455, 458, 459, 460, 467, 468, 469, 471, 472, 473, 480, 481, 482, 483, 484, 485, 492, 493, 494, 497, 498, 499, 500, 501, 502, 507, 508, 509, 510, 511, 512, 516, 517, 518, 519, 520, 521, 530, 531, 532, 537, 538, 539, 543, 544, 545, 548, 549, 550, 552, 553, 554, 555, 556, 557), (0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 24, 25, 26, 27, 28, 2

In [None]:
!pip install qubovert

In [None]:
from qubovert.problems import SetCover
import qubovert

In [None]:
problem = SetCover(node_1, sub_set)

In [None]:
Q_dic = problem.to_qubo()

In [None]:
from qubovert.utils import qubo_to_matrix
Q_n = qubo_to_matrix(Q_dic.Q)

In [None]:
Q_n.shape

(3977, 3977)

In [None]:
M = max(
        sum(int(alpha in v) for v in sub_set)
        for alpha in node_1
       )


In [None]:
M

50

In [None]:
A = np.zeros((3977,3977))

In [None]:
B = Q_n!=0
A = B.astype(np.float32)

In [None]:
A_graph = jnp.array(nx.to_numpy_array(gg))

In [None]:
for i in range(0,611):
  for j in range(0,611):
    A[i][j] = A_graph[i][j]

In [None]:
learning_rate = 0.001

In [None]:
embedding_d0 = int(np.cbrt(3977))
embedding_d1 = embedding_d0 // 2

In [None]:
rng = jax.random.PRNGKey(2023)
rng, init_rng = jax.random.split(rng)
rng, embed_rng = jax.random.split(rng)
state = create_train_state(
    3977,
    embedding_d0,
    embedding_d1,
    init_rng,
    learning_rate,
    dropout_frac=0.01,
)

In [None]:
node_embeddings = jax.random.uniform(embed_rng, [3977, embedding_d0])

In [None]:
num_epochs = 1000

In [None]:
rng, dropout_rng = jax.random.split(rng)

state, metrics = train_step(
        state=state,
        node_embeddings=node_embeddings,
        adj_matrix=A,
        q_matrix=Q_n,
        dropout_rng=dropout_rng,
    )
loss_ = metrics["energy"]
prev_loss = 0

In [None]:
i = 1
count = 0
while (count < 100):
  # rng, dropout_rng = jax.random.split(rng)
  prev_loss = metrics["energy"]
  state, metrics = train_step(
    state=state,
    node_embeddings=node_embeddings,
    adj_matrix=A,
    q_matrix=Q_n,
    dropout_rng=dropout_rng,
  )
  loss_ = metrics["energy"]
  if (abs(loss_ - prev_loss) <= 0.001) | ((loss_ - prev_loss) > 0):
    count += 1
  else:
    count = 0
  i += 1
  if i % 1000 == 0:
    print("train epoch: %d, cost: %.2f" % (i, metrics["energy"]))


In [None]:
# | export
def get_classification(apply_fn, params, node_embeddings, adj_matrix):
    pred_probs = apply_fn({"params": params}, node_embeddings, adj_matrix)
    classification = jnp.where(pred_probs >= 0.5, 1, 0)
    return np.ravel(classification)

In [None]:
classification = get_classification(state.apply_fn, state.params, node_embeddings, A)

In [None]:
set_1 = np.array([node for node, entry in enumerate(classification) if entry == 1])

In [None]:
set_1