In [None]:
import numpy as np
import random
import os
import time
import torch
import networkx as nx

from classical import get_gw_result
from classical import get_rank2_result
from rqaoa import get_rqaoa_result
from qrao import QRAO
from rqrao import RQRAO
from utils import get_graph

seed = 42
random.seed(seed)
seed_numpy = random.randint(1, 4294967295)
seed_os_py = random.randint(1, 4294967295)
seed_torch = random.randint(1, 4294967295)
seed_torch_cuda = random.randint(1, 4294967295)
np.random.seed(seed_numpy)
os.environ['PYTHONHASHSEED'] = str(seed_os_py)
torch.manual_seed(seed_torch)
torch.cuda.manual_seed_all(seed_torch_cuda)
    
nb_threads = 1
torch.set_num_threads(nb_threads)
torch.set_num_interop_threads(nb_threads)
os.environ["OMP_NUM_THREADS"] = str(nb_threads)
os.environ["MKL_NUM_THREADS"] = str(nb_threads)

In [None]:
nb_sampling = 10000

edge_energy_thresh = 1e-5
nb_ensemble = 20
sigma = 2.
brute_force_threshold = 10
bond_dim = 2
ldf = True
device = 'cpu'
thresh = 1e-2

edge_noise = 1e-5
    
nb_div, div_level = 50, 2

sdpt3_path = '<SDPT3_INSTALL_DIRECTORY>/SDPT3-4.0/'
matlab_path = '<MATLAB_INSTALL_DIRECTORY>/bin/matlab'

In [None]:
nb_nodes_list = [s * 10**exponent for exponent in range(1, 6) for s in range(1, 10)]
nb_nodes_list = nb_nodes_list[1:19]
nb_nodes_list

In [None]:
# method = 'gw'
# method = 'rank2' # CirCut
# method = 'rqaoa'
# method = 'qrao'
method = 'rqrao'

nb_sampling = 10000
npert = 10
bond_dim = 2
thresh = 1e-2

cut_dict = {}
time_dict = {}
for nb_nodes in nb_nodes_list:
    
    cut_list = []
    time_list = []
    
    for iseed in range(10):
        
        edges, edge_weights = get_graph(
            nb_nodes=nb_nodes,
            weight_type='pm1',
            degree=3,
            seed=iseed,
        )
        
        #====================================================================================================
        
        if method == 'gw':

            cut, bits, times = get_gw_result(
                edges=edges,
                edge_weights=edge_weights,
                sdpt3_path=sdpt3_path,
                matlab_path=matlab_path,
                nb_sampling=nb_sampling,
            )
            times = np.sum(list(times.values()))
            cut_list += [cut]
            time_list += [times]
        
        #====================================================================================================
        
        elif method == 'rank2': # CirCut

            start = time.time()
            cut, bits = get_rank2_result(
                edges=edges,
                edge_weights=edge_weights,
                npert=npert,
            )
            cut_list += [cut]
            end = time.time()
            time_list += [end - start]

        #====================================================================================================
        
        elif method == 'rqaoa':

            start = time.time()
            bits, cut, times = get_rqaoa_result(
                edges=edges,
                edge_weights=edge_weights,
                brute_force_threshold=brute_force_threshold,
                edge_noise=edge_noise,
                nb_div=nb_div,
                div_level=div_level,
                fninp='graphdata_scal.txt',
                fnhyp='hyps_scal.txt',
                fnout='result_scal.txt',
            )
            cut_list += [cut]
            end = time.time()
            time_list += [end - start]
        
        #====================================================================================================
        
        elif method == 'qrao':

            start = time.time()
            qrao_ins = QRAO(
                edges=edges,
                edge_weights=edge_weights,
                bond_dim=bond_dim,
                device='cpu',
                thresh=thresh,
                mode='31',
                paulis='XYZ',
                ldf=True,
            )
            cut, bits = qrao_ins.get_qrao_result()
            end = time.time()
            cut_list += [cut]
            time_list += [end - start]

        #====================================================================================================
        
        elif method == 'rqrao':

            start = time.time()
            rqrao_ins = RQRAO(
                edges=edges,
                edge_weights=edge_weights,
                bond_dim=bond_dim,
                device=device,
                edge_energy_thresh=edge_energy_thresh,
                edge_noise=edge_noise,
                thresh=thresh,
                brute_force_threshold=brute_force_threshold,
                nb_ensemble=nb_ensemble,
                sigma=sigma,
                ldf=ldf,
            )
            cut, bits, time_dict = rqrao_ins.get_rqrao_result()
            end = time.time()
            cut_list += [cut]
            time_list += [end - start]
            
        print(nb_nodes, iseed, cut_list[-1], time_list[-1])
        
    cut_dict.update({nb_nodes: cut_list})
    time_dict.update({nb_nodes: time_list})
    
    np.savez('scalability_3regular_'+method+'.npz', cut_dict=cut_dict, time_dict=time_dict)