## QQA

In [3]:
import torch
import networkx as nx
from main import utils
from main import qqa
from main import instance

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

In [4]:
SEED = 0
utils.fix_seed(SEED)
torch_type = torch.float32

## Maximum Independent Set

### Step 1: Load Problem and Set Hyperparameters

In [5]:
# Graph Parameters
N, d, p = 10000, 100, None
nx_graph = nx.random_regular_graph(d=d, n=N, seed=SEED)
Q_mat = utils.qubo_dict_to_torch(nx_graph, instance.gen_q_dict_mis_sym(nx_graph, penalty=2)).to(device)

# QQA Parameters
parallel_size=100
num_epoch=int(1e4)
lr=1
temp=1e-3
min_bg=-2
max_bg=0.1
curve_rate=4
div_param=0.1
check_interval=1000

### Step 2: Define Loss Function

In [6]:
def loss_func(x):
    return torch.einsum('bi,ij,bj->b', x, Q_mat, x)

### Step 3: Run QQA

In [7]:
best_bit_string, runtime = qqa.optimize(N,
                                    loss_func,
                                    parallel_size=parallel_size,
                                    num_epoch=num_epoch,
                                    lr=lr,
                                    temp=temp,
                                    device=device,
                                    min_bg=min_bg,
                                    max_bg=max_bg,
                                    curve_rate=curve_rate,
                                    div_param=div_param,
                                    check_interval=check_interval
                                    )

EPOCH:0, LOSS:49721208.0, PENALTY:799894.125, PARAM:-2.0
EPOCH:1000, LOSS:169569.015625, PENALTY:183611.0, PARAM:-1.79
EPOCH:2000, LOSS:98238.3125, PENALTY:149255.53125, PARAM:-1.58
EPOCH:3000, LOSS:33151.60546875, PENALTY:110212.265625, PARAM:-1.37
EPOCH:4000, LOSS:-16084.150390625, PENALTY:75748.8125, PARAM:-1.1600000000000001
EPOCH:5000, LOSS:-30861.75390625, PENALTY:62788.78125, PARAM:-0.95
EPOCH:6000, LOSS:-40765.265625, PENALTY:51797.40625, PARAM:-0.74
EPOCH:7000, LOSS:-49729.90234375, PENALTY:38191.6875, PARAM:-0.53
EPOCH:8000, LOSS:-62459.09375, PENALTY:6802.5908203125, PARAM:-0.32000000000000006
EPOCH:9000, LOSS:-64448.0, PENALTY:0.0, PARAM:-0.1100000000000001
EPOCH:9999, LOSS:-64448.0, PENALTY:0.0, PARAM:0.09979000000000005


In [8]:
size_mis, _, number_violation = utils.postprocess_mis(best_bit_string, nx_graph)
print(f"Independent set size: {size_mis} Violation: {number_violation}")

Independent set size: 656 Violation: 0


## Graph Coloring

### Step 1: Load Problem and Set Hyperparameters

In [9]:
# Graph Parameters
N, d, p = 1000, 10, None
nx_graph = nx.random_regular_graph(d=d, n=N, seed=SEED)
adj_maxtrix = torch.tensor(nx.adjacency_matrix(nx_graph, nodelist=[i for i in range(N)]).toarray(), device=device, dtype=torch_type)
num_color = 5

# QQA Parameters
parallel_size=100
num_epoch=int(1e4)
lr=0.1
temp=1e-3
min_bg=-2
max_bg=0.1
curve_rate=4
div_param=0.1
check_interval=1000

### Step 2: Define Loss Function

In [10]:
## Define Loss Function
def loss_func(x):
    return torch.sum(torch.einsum('bis,ij,bjs->bs', x, adj_maxtrix, x)/2, dim=1)

### Step 3: Run QQA

In [11]:
best_string, runtime = qqa.optimize_categorical(N,
                                                num_color,
                                                loss_func,
                                                parallel_size=parallel_size,
                                                num_epoch=num_epoch,
                                                lr=lr,
                                                temp=temp,
                                                device=device,
                                                min_bg=min_bg,
                                                max_bg=max_bg,
                                                curve_rate=curve_rate,
                                                div_param=div_param,
                                                check_interval=check_interval
                                                )

EPOCH:0, LOSS:100018.640625, PENALTY:99416.59375, PARAM:-2.0
EPOCH:1000, LOSS:34982.125, PENALTY:93113.625, PARAM:-1.79
EPOCH:2000, LOSS:34066.03515625, PENALTY:92589.15625, PARAM:-1.58
EPOCH:3000, LOSS:32909.99609375, PENALTY:91831.6796875, PARAM:-1.37
EPOCH:4000, LOSS:31343.701171875, PENALTY:90639.6484375, PARAM:-1.1600000000000001
EPOCH:5000, LOSS:29045.580078125, PENALTY:88571.625, PARAM:-0.95
EPOCH:6000, LOSS:25075.625, PENALTY:84250.0078125, PARAM:-0.74
EPOCH:7000, LOSS:16940.814453125, PENALTY:72897.4375, PARAM:-0.53
EPOCH:8000, LOSS:4910.4140625, PENALTY:45458.3828125, PARAM:-0.32000000000000006
EPOCH:9000, LOSS:374.71246337890625, PENALTY:26472.13671875, PARAM:-0.1100000000000001
EPOCH:9999, LOSS:87.61386108398438, PENALTY:35.251853942871094, PARAM:0.09979000000000005


In [12]:
number_violation = utils.postprocess_coloring(best_string, num_color, nx_graph)
print(f"Violation: {number_violation}")

Violation: 0
