In [1]:
from  numba import njit
import numpy as np
# import pickle
from src.envs.utils import GraphDataset

In [2]:
@njit
def flatten_graph(graph):
    """
    Flatten a graph into matrices for adjacency, weights, start indices, and end indices.

    Parameters:
    - graph (adjacency matrix): The input graph to be flattened.

    Returns:
    - numpy.ndarray: Flattened adjacency matrix.
    - numpy.ndarray: Flattened weight matrix.
    - numpy.ndarray: Start indices for nodes in the flattened matrices.
    - numpy.ndarray: End indices for nodes in the flattened matrices.
    """
    flattened_adjacency = []
    flattened_weights = []
    num_nodes = graph.shape[0]
    
    node_start_indices = np.zeros(num_nodes,dtype=np.int64)
    node_end_indices = np.zeros(num_nodes,dtype=np.int64)
    
    for i in range(num_nodes):
        node_start_indices[i] = len(flattened_adjacency)
        for j in range(num_nodes):
            if graph[i, j] != 0:
                flattened_adjacency.append(j)
                flattened_weights.append(graph[i, j])
                
        node_end_indices[i] = len(flattened_adjacency)

    return (
        np.array(flattened_adjacency),
        np.array(flattened_weights),
        node_start_indices,
        node_end_indices
    )





In [3]:
@njit
def standard_greedy(graph):
    adj_matrix, weight_matrix, start_list, end_list=graph
    
    n=len(start_list)
    delta_local_cuts=np.zeros(n)
    spins=np.ones(n)
    
    
    curr_score=0
    for i in range(n):
        for j,weight in zip(adj_matrix[start_list[i]:end_list[i]],
                  weight_matrix[start_list[i]:end_list[i]]):
                
            delta_local_cuts[i]+=weight*(2*spins[i]-1)*(2*spins[j]-1)
            curr_score+=weight*(spins[i]+spins[j]-2*spins[i]*spins[j])

    curr_score/=2    
    # best_score=curr_score
    
    flag=True
    
    while flag:
        arg_gain=np.argsort(-delta_local_cuts)
        flag=False
        for v in arg_gain:
            if spins[v]:
                if delta_local_cuts[v]<0:
                    flag=False
                    break
                    
                curr_score+=delta_local_cuts[v]
                delta_local_cuts[v]=-delta_local_cuts[v]
                
                for u,weight in zip(adj_matrix[start_list[v]:end_list[v]],
                                     weight_matrix[start_list[v]:end_list[v]]):

                    delta_local_cuts[u]+=weight*(2*spins[u]-1)*(2-4*spins[v])

                spins[v] = 1-spins[v]
                flag=True
                break
                  
    return curr_score,spins

In [4]:
@njit
def mca(graph,spins):
    adj_matrix, weight_matrix, start_list, end_list=graph
    
    n=len(start_list)
    delta_local_cuts=np.zeros(n)
    
    
    
    curr_score=0
    for i in range(n):
        for j,weight in zip(adj_matrix[start_list[i]:end_list[i]],weight_matrix[start_list[i]:end_list[i]]):
                
            delta_local_cuts[i]+=weight*(2*spins[i]-1)*(2*spins[j]-1)
            curr_score+=weight*(spins[i]+spins[j]-2*spins[i]*spins[j])

    curr_score/=2   
    best_score=curr_score
    
    flag=True
    
    while flag:
        arg_gain=np.argsort(-delta_local_cuts)
        flag=False
        for v in arg_gain:
            
            if delta_local_cuts[v]<=0:
                flag=False
                break
                    
            curr_score+=delta_local_cuts[v]
            delta_local_cuts[v]=-delta_local_cuts[v]

            for u,weight in zip(adj_matrix[start_list[v]:end_list[v]],
                                 weight_matrix[start_list[v]:end_list[v]]):

                delta_local_cuts[u]+=weight*(2*spins[u]-1)*(2-4*spins[v])

            spins[v] =1-spins[v]
            flag=True
            break
                  
    return curr_score,spins

In [5]:
@njit
def tabu(graph,spins,tabu_tenure,max_steps):
    adj_matrix, weight_matrix, start_list, end_list=graph
    
    n=len(start_list)
    delta_local_cuts=np.zeros(n)
    
    tabu_list=np.ones(n)*-10000
    curr_score=0
    for i in range(n):
        for j,weight in zip(adj_matrix[start_list[i]:end_list[i]],
                  weight_matrix[start_list[i]:end_list[i]]):
                
            delta_local_cuts[i]+=weight*(2*spins[i]-1)*(2*spins[j]-1)
            curr_score+=weight*(spins[i]+spins[j]-2*spins[i]*spins[j])
            
#     print(delta_local_cuts)

    curr_score/=2    
    best_score=curr_score

    for t in range(max_steps):
        arg_gain=np.argsort(-delta_local_cuts)
        for v in arg_gain:
            if (t-tabu_list[v]> tabu_tenure) or (best_score < curr_score + delta_local_cuts[v]):

                tabu_list[v] = t

                curr_score+=delta_local_cuts[v]
                delta_local_cuts[v]=-delta_local_cuts[v]
                
                for u,weight in zip(adj_matrix[start_list[v]:end_list[v]],
                                     weight_matrix[start_list[v]:end_list[v]]):

                    delta_local_cuts[u]+=weight*(2*spins[u]-1)*(2-4*spins[v])

                spins[v] = 1-spins[v]

                break

                
        best_score=max(curr_score,best_score)
    return best_score,None
    
    

In [19]:
# graph_save_loc=f'_graphs/testing/ER_{"15-20"}spin_p{0.15}_50graphs.pkl'
# graphs=load_graph_set(graph_save_loc)

In [6]:
# test_dataset=GraphDataset('../data/testing/wishart_100vertices_m50',ordered=True)
# test_dataset=GraphDataset('../data/validation/Uniform Random-3-SAT',ordered=True)
test_dataset=GraphDataset('../data/validation/dense_MC_100_200vertices_unweighted',ordered=True)



In [7]:
# test_dataset=GraphDataset('data/testing/Uniform Random-3-SAT',ordered=True)
tabu_cuts=[]
for i in range(len(test_dataset)):
# for i in range(100):
#     flatten_graph(graph)
    graph=test_dataset.get()
    g=flatten_graph(graph)
    spins= np.random.randint(2, size=graph.shape[0])
#     print(spins)
    cut,spins=tabu(g,spins,tabu_tenure=20,max_steps=graph.shape[0]*2)
    
    tabu_cuts.append(cut)
    
# print('Mean Tabu Cut:',sum(tabu_cuts)/len(tabu_cuts))
tabu_cuts=np.array(tabu_cuts)

In [8]:
# test_dataset=GraphDataset('data/testing/Uniform Random-3-SAT',ordered=True)
mca_cuts=[]
sg_cuts=[]
for i in range(len(test_dataset)):
# for i in range(100):
#     flatten_graph(graph)
    graph=test_dataset.get()
    g=flatten_graph(graph)
    spins= np.random.randint(2, size=graph.shape[0])
#     print(spins)
    cut,spins=mca(g,spins)
    sg_cut,sg_spins=standard_greedy(g)
    
    mca_cuts.append(cut)
    sg_cuts.append(sg_cut)

sg_cuts=np.array(sg_cuts)
mca_cuts=np.array(mca_cuts)

In [9]:
print('Tabu:',(tabu_cuts/sg_cuts).mean())
print('Max Cut Approx:',(mca_cuts/sg_cuts).mean())


Tabu: 1.0142646050644357
Max Cut Approx: 1.0009382285892885


In [10]:
tabu_cuts/mca_cuts

array([1.03253796, 1.02589928, 1.02325581, 1.00984529, 1.03078024,
       1.01203966, 1.01962209, 1.00209937, 1.00843289, 1.00489853,
       1.02794118, 1.0184136 , 1.00777385, 1.00070274, 1.01424501,
       1.01453488, 1.02074392, 1.02098408, 1.01506456, 1.01789549,
       1.01548205, 1.00694927, 1.01987225, 1.00292398, 1.00355114,
       1.0128041 , 1.00544959, 1.01239293, 1.01428833, 1.02175544,
       1.00127947, 1.00935094, 1.00765306, 1.02097642, 1.01022271,
       1.00657895, 1.00772485, 1.01053994, 1.01247478, 1.00902394,
       1.0088659 , 1.01438053, 1.00906618, 1.01532008, 1.01480933,
       1.01357051, 1.01285819, 1.0179099 , 1.00675306, 1.01551247])

In [11]:
tabu_cuts

array([1428., 1426., 1408., 1436., 1440., 1429., 1403., 1432., 1435.,
       1436., 1398., 1438., 1426., 1424., 1424., 1396., 1427., 1411.,
       1415., 1422., 1443., 1449., 1437., 1372., 1413., 5537., 5535.,
       5555., 5537., 5448., 5478., 5505., 5530., 5500., 5534., 5508.,
       5479., 5465., 5519., 5479., 5462., 5502., 5565., 5567., 5482.,
       5527., 5514., 5513., 5516., 5499.])

In [12]:
mca_cuts

array([1383., 1390., 1376., 1422., 1397., 1412., 1376., 1429., 1423.,
       1429., 1360., 1412., 1415., 1423., 1404., 1376., 1398., 1382.,
       1394., 1397., 1421., 1439., 1409., 1368., 1408., 5467., 5505.,
       5487., 5459., 5332., 5471., 5454., 5488., 5387., 5478., 5472.,
       5437., 5408., 5451., 5430., 5414., 5424., 5515., 5483., 5402.,
       5453., 5444., 5416., 5479., 5415.])