<a href="https://colab.research.google.com/github/RaymndH/Parallelizing_Sparse_Matrix_Annealing/blob/main/Sparse_Matrix.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import time
import networkx as nx
import random

matrixpath = "N-1e3_c-18.0_q-5"

In [None]:


def read_graph_txt(path, device):
    with open(path, 'r') as f:
        lines = f.readlines()

    N = int(lines[0])
    edges = [tuple(map(int, line.strip().split())) for line in lines[1:]]

    # Build edge index lists
    row_list, col_list = zip(*edges)

    # Duplicate for symmetric edges
    row = torch.tensor(list(row_list) + list(col_list), dtype=torch.long)
    col = torch.tensor(list(col_list) + list(row_list), dtype=torch.long)

    values = torch.ones(len(row), dtype=torch.float32)
    coords = torch.stack([row, col], dim=0).to(device)
    values = values.to(device)

    return N, coords, values, edges

def iterative_mis_peeling(G, target_frac=0.1):
    total_nodes = G.number_of_nodes()
    remaining = set(G.nodes())
    mis_list = []
    G_sub = G.copy()

    while len(remaining) > total_nodes * target_frac:
        mis = nx.algorithms.mis.maximal_independent_set(G_sub)
        mis_set = set(mis)
        mis_list.append(mis_set)
        G_sub.remove_nodes_from(mis_set)
        remaining -= mis_set
    return mis_list, remaining

def compute_energy(q_idx, Aq):
    row_idx = torch.arange(q_idx.shape[0], device=q_idx.device)
    return Aq[row_idx, q_idx].sum().item()

In [None]:

# === Parameters ===
num_classes = 5
T = 1.5
dT = 1e-5
min_T = 1e-4
max_epochs = 100000000
file_path = 'N-1e3_c-18.0_q-5'  # Update with your actual path

# === Setup ===
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
N, coords, values, edge_list = read_graph_txt(file_path, device)
A = torch.sparse_coo_tensor(coords, values, (N, N)).coalesce()

# Build NetworkX graph
G = nx.Graph()
G.add_edges_from(edge_list)



In [None]:
# Compute MIS sets once
mislists=[]
for i in range(10):
  mis_list, leftover = iterative_mis_peeling(G, target_frac=0.0)
  mislists.append(mis_list)
print(f"Computed {len(mislists)} MIS sets")

row_idx = torch.arange(N, device=device)
q_idx = torch.randint(0, num_classes, (N,), device=device)



# Initial energy
q_onehot = torch.nn.functional.one_hot(q_idx, num_classes=num_classes).float().to(device)
Aq = torch.sparse.mm(A, q_onehot)
initial_energy = compute_energy(q_idx, Aq)
print(f"Initial system energy: {initial_energy:.4f}")

Computed 10 MIS sets
Initial system energy: 3564.0000


In [None]:
print(set())

set()


In [None]:
order = []
for i in range(len(G)):
  v = set()


In [None]:
print(compute_energy(q_idx, Aq))
torch.cuda.synchronize()
start_evt = torch.cuda.Event(enable_timing=True)
end_evt = torch.cuda.Event(enable_timing=True)

start_evt.record()

for i, node in enumerate(q_onehot):
  color = q_idx[i]
  altcolor = (color + torch.randint(1, num_classes, (1,), device=device)) % num_classes
  de = Aq[i][altcolor] - Aq[i][q_idx[i]]
  #if (i == 0): print(de)
  if de < 0:
    #if (i==0): print("yes")
    #if (i==0): print("qidx before:", q_idx[i])
    Aq[rows, color] -= 1
    Aq[rows, altcolor] += 1
    q_idx[i] = altcolor
    #if (i==0): print("qidx after:", q_idx[i])
end_evt.record()
torch.cuda.synchronize()
elapsed_ms = start_evt.elapsed_time(end_evt)
print(elapsed_ms)

print(compute_energy(q_idx, Aq))


360504.0


NameError: name 'rows' is not defined

In [None]:
total_time_ms = 0.0
dT = 1e-3
T = 1.4
min_T = 0
for epoch in range(max_epochs):
    epoch_updates = 0
    epoch_time_ms = 0.0

    for i, mis_nodes in enumerate(random.choice(mislists)):
        mis_mask = torch.zeros(N, dtype=torch.bool, device=device) # which nodes to use
        mis_mask[list(mis_nodes)] = True

        offsets = torch.randint(1, num_classes, (N,), device=device) # altcolor
        q_hat_idx = (q_idx + offsets) % num_classes
        q_hat_idx_sparse = q_hat_idx.clone()
        q_hat_idx_sparse[~mis_mask] = -1

        torch.cuda.synchronize()
        start_evt = torch.cuda.Event(enable_timing=True)
        end_evt = torch.cuda.Event(enable_timing=True)

        start_evt.record()

        q_onehot = torch.nn.functional.one_hot(q_idx, num_classes=num_classes).float().to(device) # create q vector
        Aq = torch.sparse.mm(A, q_onehot) #calc Aq

        a_q = Aq[row_idx, q_idx] # energy with current color
        a_q_hat = torch.zeros(N, device=device)
        valid = q_hat_idx_sparse != -1
        a_q_hat[valid] = Aq[row_idx[valid], q_hat_idx_sparse[valid]]

        v = a_q_hat - a_q # de
        candidate_mask =  valid & mis_mask # which to flip (this is probably wrong)
        num_candidates = candidate_mask.sum().item() # how many flipping?

        if num_candidates > 0:
            v_candidates = v[candidate_mask] #get all energies
            r = torch.rand(num_candidates, device=device) # prn
            acceptance_prob = torch.exp(-v_candidates / T) #
            accept_mask = acceptance_prob > r # which to flip

            full_accept_mask = torch.zeros_like(candidate_mask, dtype=torch.bool) #
            full_accept_mask[candidate_mask] = accept_mask #the candidates which are accepted go yes.

            q_idx[full_accept_mask] = q_hat_idx_sparse[full_accept_mask] # qs are flipped
            updated_count = full_accept_mask.sum().item()
        else:
            updated_count = 0

        #for site in remaining:
          # if thing is < ln(r):
            #


        end_evt.record()
        torch.cuda.synchronize()
        elapsed_ms = start_evt.elapsed_time(end_evt)

        epoch_time_ms += elapsed_ms
        epoch_updates += updated_count

        #print(f"  MIS {i+1}: {elapsed_ms:.3f} ms, Updates: {updated_count}")

    if epoch%10 == 0:
      print(f"\r epoch {epoch} of {int(1.4/dT)} done:",
            f"Temperature = {T:.6f},",
            f"Time = {epoch_time_ms:.3f} ms,",
            f"Total updates = {epoch_updates},",
            f"energy = {compute_energy(q_idx, Aq)}",
            "                     ",end="", sep=" ")
    #print()
    total_time_ms += epoch_time_ms

    T = max(min_T, T - dT)
    if T <= min_T:
        print("Minimum temperature reached.")
        break

# Final energy
q_onehot = torch.nn.functional.one_hot(q_idx, num_classes=num_classes).float().to(device)
Aq = torch.sparse.mm(A, q_onehot)
final_energy = compute_energy(q_idx, Aq)

print(f"\nFinal system energy: {final_energy:.4f}")
print(f"Initial energy:       {initial_energy:.4f}")
print(f"Energy change:        {final_energy - initial_energy:.4f}")
print(f"Total annealing time: {total_time_ms:.3f} ms over {epoch+1} epochs")


 epoch 1400 of 1399 done: Temperature = 0.000000, Time = 10.165 ms, Total updates = 14, energy = 370.0                      Minimum temperature reached.

Final system energy: 370.0000
Initial energy:       3564.0000
Energy change:        -3194.0000
Total annealing time: 15371.228 ms over 1401 epochs


In [None]:
for i in range(10):
  print(f"\rthis one is {i}",end="")
  time.sleep(1)

this one is 9

In [None]:
v[mylist[0]]

tensor(-1., device='cuda:0')

In [None]:
a_q_hat

tensor([ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  6.,  5.,  0.,  5.,  0.,  0.,  0.,
         0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
         0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  4.,
         0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  5.,  0.,
         0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
         0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
         0.,  0.,  5.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
         0.,  0.,  8.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
         0.,  0.,  0.,  0.,  3.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
         0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
         0.,  0.,  0.,  0.,  0.,  6.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
         0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
         5.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  

In [None]:
a_q[0]

tensor(1., device='cuda:0')