In [1]:

from commons import *
from dgl.nn.pytorch import GATConv, EdgeConv

# Utils code

In [2]:
TORCH_DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
TORCH_DTYPE = torch.float32
def get_gnn(n_nodes, gnn_hypers, opt_params, torch_device, torch_dtype):
    """
    Generate GNN instance with specified structure. Creates GNN, retrieves embedding layer,
    and instantiates ADAM optimizer given those.

    Input:
        n_nodes: Problem size (number of nodes in graph)
        gnn_hypers: Hyperparameters relevant to GNN structure
        opt_params: Hyperparameters relevant to ADAM optimizer
        torch_device: Whether to load pytorch variables onto CPU or GPU
        torch_dtype: Datatype to use for pytorch variables
    Output:
        net: GNN instance
        embed: Embedding layer to use as input to GNN
        optimizer: ADAM optimizer instance
    """
    dim_embedding = gnn_hypers['dim_embedding']
    hidden_dim = gnn_hypers['hidden_dim']
    dropout = gnn_hypers['dropout']
    number_classes = gnn_hypers['number_classes']

    # instantiate the GNN
    net = GCNSoftmax(dim_embedding, hidden_dim, number_classes, dropout, torch_device)
    net = net.type(torch_dtype).to(torch_device)
    embed = nn.Embedding(n_nodes, dim_embedding)
    embed = embed.type(torch_dtype).to(torch_device)

    # set up Adam optimizer
    params = chain(net.parameters(), embed.parameters())
    optimizer = torch.optim.Adam(params, **opt_params)
    return net, embed, optimizer

def partition_weight(adj, s):
    """
    Calculates the sum of weights of edges that are in different partitions.

    :param adj: Adjacency matrix of the graph.
    :param s: List indicating the partition of each edge (0 or 1).
    :return: Sum of weights of edges in different partitions.
    """
    s = np.array(s)
    partition_matrix = np.not_equal.outer(s, s).astype(int)
    weight = (adj * partition_matrix).sum() / 2
    return weight

import torch

def partition_weight2(adj, s):
    """
    Calculates the sum of weights of edges that are in different partitions.

    :param adj: Adjacency matrix of the graph as a PyTorch tensor.
    :param s: Tensor indicating the partition of each node (0 or 1).
    :return: Sum of weights of edges in different partitions.
    """
    # Ensure s is a tensor
    # s = torch.tensor(s, dtype=torch.float32)

    # Compute outer difference to create partition matrix
    s = s.unsqueeze(0)  # Convert s to a row vector
    t = s.t()           # Transpose s to a column vector
    partition_matrix = (s != t).float()  # Compute outer product and convert boolean to float

    # Calculate the weight of edges between different partitions
    weight = (adj * partition_matrix).sum() / 2

    return weight

def calculateAllCut(q_torch, s):
    '''

    :param q_torch: The adjacent matrix of the graph
    :param s: The binary output from the neural network. s will be in form of [[prob1, prob2, ..., prob n], ...]
    :return: The calculated cut loss value
    '''
    if len(s) > 0:
        totalCuts = len(s[0])
        CutValue = 0
        for i in range(totalCuts):
            CutValue += partition_weight2(q_torch, s[:,i])
        return CutValue/2
    return 0

def hyperParameters(n = 80, d = 3, p = None, graph_type = 'reg', number_epochs = int(1e5),
                    learning_rate = 1e-4, PROB_THRESHOLD = 0.5, tol = 1e-4, patience = 100):
    dim_embedding = n #int(np.sqrt(4096))    # e.g. 10, used to be the one before # used to be n
    hidden_dim = int(dim_embedding/2)

    return n, d, p, graph_type, number_epochs, learning_rate, PROB_THRESHOLD, tol, patience, dim_embedding, hidden_dim
def FIndAC(graph):
    max_degree = max(dict(graph.degree()).values())
    A_initial = max_degree + 1  # A is set to be one more than the maximum degree
    C_initial = max_degree / 2  # C is set to half the maximum degree

    return A_initial, C_initial

def generate_unique_random_numbers(N):
    if N < 2:
        raise ValueError("N must be at least 2 to generate 3 unique random numbers.")

    # Generate 3 unique random numbers
    random_numbers = random.sample(range(N + 1), 3)

    return random_numbers

N = 10  # Set the value of N
unique_random_numbers = generate_unique_random_numbers(N)
print("Three unique random numbers between 0 and", N, ":", unique_random_numbers)

def extend_matrix_to_N(matrix, N):
    original_size = matrix.shape[0]

    if N < original_size:
        raise ValueError("N should be greater than or equal to the original matrix size.")

    extended_matrix = np.zeros((N, N))
    extended_matrix[:original_size, :original_size] = matrix

    return extended_matrix




Three unique random numbers between 0 and 10 : [0, 5, 10]


## HyperParameters initialization and related functions


In [3]:



def printCombo(orig):
    # Original dictionary
    input_dict = orig

    # Generate all permutations of the dictionary values
    value_permutations = list(permutations(input_dict.values()))

    # Create a list of dictionaries from the permutations
    permuted_dicts = [{key: value for key, value in zip(input_dict.keys(), perm)} for perm in value_permutations]

    return permuted_dicts

def GetOptimalNetValue(net, dgl_graph, inp, q_torch, terminal_dict):
    net.eval()
    best_loss = float('inf')

    if (dgl_graph.number_of_nodes() < 30):
        inp = torch.ones((dgl_graph.number_of_nodes(), 30))

    # find all potential combination of terminal nodes with respective indices

    perm_items = printCombo(terminal_dict)
    for i in perm_items:
        probs = net(dgl_graph, inp, i)
        binary_partitions = (probs >= 0.5).float()
        cut_value_item = calculateAllCut(q_torch, binary_partitions)
        if cut_value_item < best_loss:
            best_loss = cut_value_item
    return best_loss



# Hamiltonian loss function


In [4]:
def terminal_independence_penalty(s, terminal_nodes):
    """
    Calculate a penalty that enforces each terminal node to be in a distinct partition.
    :param s: A probability matrix of size |V| x |K| where s[i][j] is the probability of vertex i being in partition j.
    :param terminal_nodes: A list of indices for terminal nodes.
    :return: The penalty term.
    """
    penalty = 0
    num_terminals = len(terminal_nodes)
    # Compare each pair of terminal nodes
    for i in range(num_terminals):
        for j in range(i + 1, num_terminals):
            # Calculate the dot product of the probability vectors for the two terminals
            dot_product = torch.dot(s[terminal_nodes[i]], s[terminal_nodes[j]])
            # Penalize the similarity in their partition assignments (dot product should be close to 0)
            penalty += dot_product
    return penalty

In [37]:
def calculate_HA_vectorized(s):
    """
    Vectorized calculation of HA.
    :param s: A binary matrix of size |V| x |K| where s[i][j] is 1 if vertex i is in partition j.
    :return: The HA value.
    """
    # HA = ∑v∈V(∑k∈K(sv,k)−1)^2
    HA = torch.sum((torch.sum(s, axis=1) - 1) ** 2)
    return HA

def calculate_HC_min_cut_intra_inter(s, adjacency_matrix):
    """
    Vectorized calculation of HC to minimize cut size.
    :param s: A probability matrix of size |V| x |K| where s[i][j] is the probability of vertex i being in partition j.
    :param adjacency_matrix: A matrix representing the graph where the value at [i][j] is the weight of the edge between i and j.
    :return: The HC value focusing on minimizing edge weights between partitions.
    """
    HC = 0
    K = s.shape[1]
    for k in range(K):
        for l in range(k + 1, K):
            partition_k = s[:, k].unsqueeze(1) * s[:, k].unsqueeze(0)  # Probability node pair both in partition k
            partition_l = s[:, l].unsqueeze(1) * s[:, l].unsqueeze(0)  # Probability node pair both in partition l
            # Edges between partitions k and l
            inter_partition_edges = adjacency_matrix * (partition_k + partition_l)
            HC += torch.sum(inter_partition_edges)

    return HC

def calculate_HC_min_cut_intra_inter2(s, adjacency_matrix):
    """
    Vectorized calculation of HC to minimize cut size.
    :param s: A probability matrix of size |V| x |K| where s[i][j] is the probability of vertex i being in partition j.
    :param adjacency_matrix: A matrix representing the graph where the value at [i][j] is the weight of the edge between i and j.
    :return: The HC value focusing on minimizing edge weights between partitions.
    """
    HC = 0
    K = s.shape[1]
    for k in range(K):
        for l in range(k + 1, K):
            partition_k = s[:, k].unsqueeze(1) * s[:, k].unsqueeze(0)  # Probability node pair both in partition k
            partition_l = s[:, l].unsqueeze(1) * s[:, l].unsqueeze(0)  # Probability node pair both in partition l
            # Edges between partitions k and l
            inter_partition_edges = adjacency_matrix * (partition_k + partition_l)
            HC += torch.sum(inter_partition_edges)

    return HC

def calculate_HC_min_cut_new(s, adjacency_matrix):
    """
    Differentiable calculation of HC for minimizing edge weights between different partitions.
    :param s: A probability matrix of size |V| x |K| where s[i][j] is the probability of vertex i being in partition j.
    :param adjacency_matrix: A matrix representing the graph where the value at [i][j] is the weight of the edge between i and j.
    :return: The HC value, focusing on minimizing edge weights between partitions.
    """
    K = s.shape[1]
    V = s.shape[0]

    # Create a full partition matrix indicating the likelihood of each node pair being in the same partition
    partition_matrix = torch.matmul(s, s.T)

    # Calculate the complement matrix, which indicates the likelihood of node pairs being in different partitions
    complement_matrix = 1 - partition_matrix

    # Apply adjacency matrix to only consider actual edges and their weights
    inter_partition_edges = adjacency_matrix * complement_matrix

    # Summing up all contributions for edges between different partitions
    HC = torch.sum(inter_partition_edges)

    return HC

def calculate_HC_vectorized_old(s, adjacency_matrix):
    """
    Vectorized calculation of HC.
    :param s: A binary matrix of size |V| x |K|.
    :param adjacency_matrix: A matrix representing the graph where the value at [i][j] is the weight of the edge between i and j.
    :return: The HC value.
    """
    # HC = ∑(u,v)∈E(1−∑k∈K(su,k*sv,k))*adjacency_matrix[u,v]
    K = s.shape[1]
    # Outer product to find pairs of vertices in the same partition and then weight by the adjacency matrix
    prod = adjacency_matrix * (1 - s @ s.T)
    HC = torch.sum(prod)
    return HC
import torch

def min_cut_loss(s, adjacency_matrix):
    """
    Compute a differentiable min-cut loss for a graph given node partition probabilities.

    :param s: A probability matrix of size |V| x |K| where s[i][j] is the probability of vertex i being in partition j.
    :param adjacency_matrix: A matrix representing the graph where the value at [i][j] is the weight of the edge between i and j.
    :return: The expected min-cut value, computed as a differentiable loss.
    """
    V = s.size(0)  # Number of nodes
    K = s.size(1)  # Number of partitions

    # Ensure the partition matrix s sums to 1 over partitions
    s = torch.softmax(s, dim=1)

    # Compute the expected weight of edges within each partition
    intra_partition_cut = torch.zeros((K, K), dtype=torch.float32)
    for k in range(K):
        for l in range(k + 1, K):
            # Probability that a node pair (i, j) is split between partitions k and l
            partition_k = s[:, k].unsqueeze(1)  # Shape: V x 1
            partition_l = s[:, l].unsqueeze(0)  # Shape: 1 x V

            # Compute the expected weight of the cut edges between partitions k and l
            cut_weight = adjacency_matrix * (partition_k @ partition_l)
            intra_partition_cut[k, l] = torch.sum(cut_weight)

    # Sum up all contributions to get the total expected min-cut value
    total_cut_weight = torch.sum(intra_partition_cut)

    return total_cut_weight

import torch

# def min_cut_loss(s, adjacency_matrix):
#     """
#     Compute a differentiable min-cut loss for a graph given node partition probabilities.
#
#     :param s: A probability matrix of size |V| x |K| where s[i][j] is the probability of vertex i being in partition j.
#     :param adjacency_matrix: A matrix representing the graph where the value at [i][j] is the weight of the edge between i and j.
#     :return: The expected min-cut value, computed as a differentiable loss.
#     """
#     V = s.size(0)  # Number of nodes
#     K = s.size(1)  # Number of partitions
#
#     # Ensure the partition matrix s sums to 1 over partitions
#     # s = torch.softmax(s, dim=1)
#
#     # Compute the expected weight of cut edges between each pair of partitions
#     total_cut_weight = 0
#     for k in range(K):
#         for l in range(k + 1, K):
#             # Probability that a node pair (i, j) is split between partitions k and l
#             partition_k = s[:, k].unsqueeze(1)  # Shape: V x 1
#             partition_l = s[:, l].unsqueeze(0)  # Shape: 1 x V
#
#             # Compute the expected weight of the cut edges between partitions k and l
#             cut_weight = adjacency_matrix * (partition_k @ partition_l)
#             total_cut_weight += torch.sum(cut_weight)
#
#     return total_cut_weight


def calculate_HC_vectorized(s, adjacency_matrix):
    """
    Vectorized calculation of HC for soft partitioning.
    :param s: A probability matrix of size |V| x |K| where s[i][j] is the probability of vertex i being in partition j.
    :param adjacency_matrix: A matrix representing the graph where the value at [i][j] is the weight of the edge between i and j.
    :return: The HC value.
    """
    # Initialize HC to 0
    HC = 0

    # Iterate over each partition to calculate its contribution to HC
    for k in range(s.shape[1]):
        # Compute the probability matrix for partition k
        partition_prob_matrix = s[:, k].unsqueeze(1) * s[:, k].unsqueeze(0)

        # Compute the contribution to HC for partition k
        HC_k =adjacency_matrix * (1 - partition_prob_matrix)
        # Sum up the contributions for partition k
        HC += torch.sum(HC_k, dim=(0, 1))

    # Since we've summed up the partition contributions twice (due to symmetry), divide by 2
    HC = HC / 2

    return HC

TORCH_DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
TORCH_DTYPE = torch.float32
def swap_graph_nodes(graph, mapping):
    """
    Swap the nodes of a NetworkX graph according to a given mapping dictionary, ensuring mutual swaps.

    :param graph: NetworkX graph
    :param mapping: dictionary where keys and values represent nodes to be swapped
    :return: None; the graph is modified in place
    """
    # Create a temporary mapping to intermediate labels
    temp_mapping = {old_label: max(graph.nodes) + 1 + i for i, old_label in enumerate(mapping)}

    # First step: move all nodes to temporary labels
    nx.relabel_nodes(graph, mapping=temp_mapping, copy=False)

    # Create the reverse of the initial mapping to complete the swap
    reverse_mapping = {value: key for key, value in mapping.items()}

    # Second step: from temporary labels to final labels
    nx.relabel_nodes(graph, mapping={temp_mapping[old]: reverse_mapping[old] for old in mapping}, copy=False)

def swap_all_terminal_nodes(graph_list):
    graph_listV2 = {}
    for key, (dgl_graph, adjacency_matrix,graph, terminals) in graph_list.items():
        graph_listV2[key] = (dgl_graph, adjacency_matrix,graph, terminals)

    return graph_listV2

def extend_matrix(matrix, N):
    original_size = matrix.shape[0]

    if N < original_size:
        raise ValueError("N should be greater than or equal to the original matrix size.")

    extended_matrix = np.zeros((N, N))
    extended_matrix[:original_size, :original_size] = matrix

    return extended_matrix

def extend_matrix_torch(matrix, N, torch_dtype=None, torch_device=None):
    original_size = matrix.shape[0]

    if N < original_size:
        raise ValueError("N should be greater than or equal to the original matrix size.")

    extended_matrix = torch.zeros(original_size, N)
    extended_matrix[:original_size, :original_size] = matrix

    if torch_dtype is not None:
        extended_matrix = extended_matrix.type(torch_dtype)

    if torch_device is not None:
        extended_matrix = extended_matrix.to(torch_device)

    return extended_matrix

def extend_matrix_torch_2(matrix, N, torch_dtype=None, torch_device=None):
    original_size = matrix.shape[0]

    if N < original_size:
        raise ValueError("N should be greater than or equal to the original matrix size.")

    # Use the dtype and device from the original matrix if not provided
    torch_dtype = torch_dtype or matrix.dtype
    torch_device = torch_device or matrix.device

    # Initialize an empty matrix with the specified size, dtype, and device
    extended_matrix = torch.empty((original_size, N), dtype=torch_dtype, device=torch_device)

    # Copy the original matrix into the extended matrix
    extended_matrix[:, :original_size] = matrix

    # Zero out the remaining columns
    if N > original_size:
        extended_matrix[:, original_size:] = 0

    return extended_matrix

def createGraphFromFolder(all_graphs, all_terminals, max_nodes):

    # Example usage
    # directory =  dst # Replace this with your directory path
    # all_graphs, all_terminals = process_all_files(directory)

    datasetItem = {}
    i = 0
    skipped = 0
    try:
        # Print out some details about the graphs (optional)
        for filename, graph in all_graphs.items():
            # print(f"Graph for {filename}: Nodes = {graph.nodes()}, Edges = {graph.edges(data=True)}")
            # print(f"Terminals for {filename}: {all_terminals[filename]}")
            terminals = all_terminals[filename]
            if 0 not in terminals and 1 not in terminals and 2 not in terminals:
                swap_graph_nodes(graph, {terminals[0]:0, terminals[1]:1, terminals[2]:2, 0:terminals[0], 1:terminals[1], 2:terminals[2]})
            elif 0 not in terminals and 1 not in terminals and 2 in terminals:
                terminals.sort()
                swap_graph_nodes(graph, {terminals[1]:0, terminals[2]:1,  0:terminals[1], 1:terminals[2]})
            elif 0 not in terminals and 1  in terminals and 2 not in terminals:
                terminals.sort()
                swap_graph_nodes(graph, {terminals[1]:0, terminals[2]:2,  0:terminals[1], 2:terminals[2]})
            elif 0  in terminals and 1 not in terminals and 2 not in terminals:
                terminals.sort()
                swap_graph_nodes(graph, {terminals[1]:1, terminals[2]:2,  1:terminals[1], 2:terminals[2]})
            else:
                skipped +=1
                continue
            graph_dgl = dgl.from_networkx(nx_graph=graph)
            graph_dgl = graph_dgl.to(TORCH_DEVICE)

            q_torch = qubo_dict_to_torch(graph, gen_adj_matrix(graph), torch_dtype=TORCH_DTYPE, torch_device=TORCH_DEVICE)

            # full_matrix = extend_matrix_torch(q_torch, max_nodes, torch_dtype=TORCH_DTYPE, torch_device=TORCH_DEVICE)

            # datasetItem[i] = [graph_dgl, q_torch, graph, all_terminals[filename]]
            datasetItem[i] = [graph_dgl, q_torch, graph, [0,1,2]]
            i+=1

    except:
        print(i, filename, terminals, graph.edges())

    print(skipped)
    return datasetItem

def createGraphFromFolder_full(all_graphs, all_terminals, max_nodes):

    # Example usage
    # directory =  dst # Replace this with your directory path
    # all_graphs, all_terminals = process_all_files(directory)

    datasetItem = {}
    i = 0
    skipped = 0
    try:
        # Print out some details about the graphs (optional)
        for filename, graph in all_graphs.items():
            # print(f"Graph for {filename}: Nodes = {graph.nodes()}, Edges = {graph.edges(data=True)}")
            # print(f"Terminals for {filename}: {all_terminals[filename]}")
            terminals = all_terminals[filename]
            if 0 not in terminals and 1 not in terminals and 2 not in terminals:
                swap_graph_nodes(graph, {terminals[0]:0, terminals[1]:1, terminals[2]:2, 0:terminals[0], 1:terminals[1], 2:terminals[2]})
            elif 0 not in terminals and 1 not in terminals and 2 in terminals:
                terminals.sort()
                swap_graph_nodes(graph, {terminals[1]:0, terminals[2]:1,  0:terminals[1], 1:terminals[2]})
            elif 0 not in terminals and 1  in terminals and 2 not in terminals:
                terminals.sort()
                swap_graph_nodes(graph, {terminals[1]:0, terminals[2]:2,  0:terminals[1], 2:terminals[2]})
            elif 0  in terminals and 1 not in terminals and 2 not in terminals:
                terminals.sort()
                swap_graph_nodes(graph, {terminals[1]:1, terminals[2]:2,  1:terminals[1], 2:terminals[2]})
            else:
                skipped +=1
                continue
            print("Terminal swapped ", i)
            graph_dgl = dgl.from_networkx(nx_graph=graph)
            graph_dgl = graph_dgl.to(TORCH_DEVICE)

            q_torch = qubo_dict_to_torch(graph, gen_adj_matrix(graph), torch_dtype=TORCH_DTYPE, torch_device=TORCH_DEVICE)

            full_matrix = extend_matrix_torch_2(q_torch, max_nodes, torch_dtype=TORCH_DTYPE, torch_device=TORCH_DEVICE)

            # datasetItem[i] = [graph_dgl, q_torch, graph, all_terminals[filename]]
            datasetItem[i] = [graph_dgl, full_matrix, graph, [0,1,2]]
            i+=1

            print("graph finished: ", i)

    except:
        print("Exception Occured ", i, filename, terminals)

    print("skipped Items:", skipped)
    return datasetItem

# def graph_with_max_edges(graph_list):
#     # Initialize variables to keep track of the graph with the max number of edges
#     max_nodes = -1
#     max_graph = None
#
#     # Iterate over all graphs in the list
#     for key, graph in graph_list.items():
#         num_nodes = graph.number_of_nodes()
#         if num_nodes > max_nodes:
#             max_nodes = num_nodes
#             max_graph = graph
#
#     return max_graph, max_nodes

# N = 250  # Desired size of the extended matrix
# original_matrix = np.random.randint(1, 10, size=(200, 200))  # Creating a random 200 x 200 matrix
#
# extended_matrix = extend_matrix(original_matrix, N)
# print("Original Matrix:")
# print(original_matrix)
# print("\nExtended Matrix:")
# print(extended_matrix)

In [6]:
s = torch.Tensor([[0,1,0],[0,1,0],[0,0,1]])
# print(calculate_HA_vectorized(s))
# print(calculate_HA_vectorized(torch.Tensor([[0,0.9,0.9],[0.9,0.9,0],[0,0,0.9]])))
terminal_loss = torch.abs(s[0] - s[1]-s[2])
# print(terminal_loss)
# print(10 * (1 - terminal_loss))
# print(torch.sum(10 * (1 - terminal_loss)))
print(torch.abs(s[0] - s[1]))
print(torch.abs(s[0] - s[2]))
print(torch.abs(s[2] - s[1]))

print(torch.sum(10 * (1-torch.abs(s[0] - s[1]))))
print(torch.sum(10 * (1-torch.abs(s[0] - s[2]))))
print(torch.sum(10 * (1-torch.abs(s[2] - s[1]))))
print(terminal_independence_penalty(s, [0,1,2]))

tensor([0., 0., 0.])
tensor([0., 1., 1.])
tensor([0., 1., 1.])
tensor(30.)
tensor(10.)
tensor(10.)
tensor(1.)


In [7]:
def generate_graph(n, d=None, p=None, graph_type='reg', random_seed=0):
    """
    Helper function to generate a NetworkX random graph of specified type,
    given specified parameters (e.g. d-regular, d=3). Must provide one of
    d or p, d with graph_type='reg', and p with graph_type in ['prob', 'erdos'].

    Input:
        n: Problem size
        d: [Optional] Degree of each node in graph
        p: [Optional] Probability of edge between two nodes
        graph_type: Specifies graph type to generate
        random_seed: Seed value for random generator
    Output:
        nx_graph: NetworkX OrderedGraph of specified type and parameters
    """
    if graph_type == 'reg':
        print(f'Generating d-regular graph with n={n}, d={d}, seed={random_seed}')
        nx_temp = nx.random_regular_graph(d=d, n=n, seed=random_seed)
    elif graph_type == 'reg_random':
        print(f'Generating d-regular random graph with n={n}, d={d}')
        nx_temp = nx.random_regular_graph(d=d, n=n)
    elif graph_type == 'prob':
        print(f'Generating p-probabilistic graph with n={n}, p={p}, seed={random_seed}')
        nx_temp = nx.fast_gnp_random_graph(n, p, seed=random_seed)
    elif graph_type == 'erdos':
        print(f'Generating erdos-renyi graph with n={n}, p={p}, seed={random_seed}')
        nx_temp = nx.erdos_renyi_graph(n, p, seed=random_seed)
    else:
        raise NotImplementedError(f'!! Graph type {graph_type} not handled !!')

    # Networkx does not enforce node order by default
    nx_temp = nx.relabel.convert_node_labels_to_integers(nx_temp)
    # Need to pull nx graph into OrderedGraph so training will work properly
    nx_graph = nx.Graph()
    nx_graph.add_nodes_from(sorted(nx_temp.nodes()))
    nx_graph.add_edges_from(nx_temp.edges)
    nx_graph.order()
    return nx_graph

# Generating Graphs

### Generating graph 200 training graph

- Nodes = 80
- Degree  = 3

In [8]:
# nx_generated_graph = {}
#
# for i in range (200):
#     nx_graph = generate_graph(n=80, d=3, p=None, graph_type='reg', random_seed=i)
#
#     for u, v, d in nx_graph.edges(data=True):
#         d['weight'] = 1
#         d['capacity'] = 1
#
#     graph_dgl = dgl.from_networkx(nx_graph=nx_graph)
#     graph_dgl = graph_dgl.to(TORCH_DEVICE)
#     q_torch = qubo_dict_to_torch(nx_graph, gen_adj_matrix(nx_graph), torch_dtype=TORCH_DTYPE, torch_device=TORCH_DEVICE)
#     terminals = [10,40,70]
#     nx_generated_graph[i] = [graph_dgl, q_torch, nx_graph, terminals]
#


In [9]:
# save_object(nx_generated_graph, './testData/nx_generated_graph_n80_d3_t200.pkl')

### Generating graph 200 training graph

- Nodes = 500
- Degree  = 3

In [10]:
# nx_generated_graph = {}
#
# for i in range (200):
#     nx_graph = generate_graph(n=500, d=3, p=None, graph_type='reg', random_seed=i)
#
#     for u, v, d in nx_graph.edges(data=True):
#         d['weight'] = 1
#         d['capacity'] = 1
#
#     graph_dgl = dgl.from_networkx(nx_graph=nx_graph)
#     graph_dgl = graph_dgl.to(TORCH_DEVICE)
#     q_torch = qubo_dict_to_torch(nx_graph, gen_adj_matrix(nx_graph), torch_dtype=TORCH_DTYPE, torch_device=TORCH_DEVICE)
#     terminals = [100,450,700]
#     nx_generated_graph[i] = [graph_dgl, q_torch, nx_graph, terminals]



In [11]:
# save_object(nx_generated_graph, './testData/nx_generated_graph_n500_d3_t200.pkl')

### Generating graph 200 training graph

- Nodes = 1000
- Degree  = 3

In [12]:
# nx_generated_graph = {}
#
# for i in range (200):
#     nx_graph = generate_graph(n=1000, d=3, p=None, graph_type='reg', random_seed=i)
#
#     for u, v, d in nx_graph.edges(data=True):
#         d['weight'] = 1
#         d['capacity'] = 1
#
#     graph_dgl = dgl.from_networkx(nx_graph=nx_graph)
#     graph_dgl = graph_dgl.to(TORCH_DEVICE)
#     q_torch = qubo_dict_to_torch(nx_graph, gen_adj_matrix(nx_graph), torch_dtype=TORCH_DTYPE, torch_device=TORCH_DEVICE)
#     terminals = [200,400,700]
#     nx_generated_graph[i] = [graph_dgl, q_torch, nx_graph, terminals]



In [13]:
# save_object(nx_generated_graph, './testData/nx_generated_graph_n1000_d3_t200.pkl')

# Generating random graph with nodes 200-500 and edges 7-8

In [8]:
nx_generated_graph = {}
terminals = {}

for i in range (400,420,1):
    nodes = random.randint(600,800)
    degree = random.randint(6,8)
    if (nodes * degree) % 2 != 0:
        i-=1
        continue
    nx_graph = generate_graph(n=nodes, d=degree, p=None, graph_type='reg', random_seed=i)

    for u, v, d in nx_graph.edges(data=True):
        d['weight'] = 1
        d['capacity'] = 1

    # graph_dgl = dgl.from_networkx(nx_graph=nx_graph)
    # graph_dgl = graph_dgl.to(TORCH_DEVICE)
    # q_torch = qubo_dict_to_torch(nx_graph, gen_adj_matrix(nx_graph), torch_dtype=TORCH_DTYPE, torch_device=TORCH_DEVICE)
    unique_random_numbers = generate_unique_random_numbers(nodes)

    # nx_generated_graph[i] = [graph_dgl, q_torch, nx_graph, terminals]
    nx_generated_graph[i] = nx_graph
    terminals[i] = unique_random_numbers
# #


Generating d-regular graph with n=671, d=6, seed=400
Generating d-regular graph with n=751, d=8, seed=401
Generating d-regular graph with n=666, d=7, seed=402
Generating d-regular graph with n=646, d=7, seed=403
Generating d-regular graph with n=780, d=6, seed=405
Generating d-regular graph with n=641, d=6, seed=406
Generating d-regular graph with n=694, d=8, seed=407
Generating d-regular graph with n=769, d=8, seed=408
Generating d-regular graph with n=661, d=8, seed=409
Generating d-regular graph with n=650, d=8, seed=410
Generating d-regular graph with n=636, d=8, seed=411
Generating d-regular graph with n=608, d=7, seed=412
Generating d-regular graph with n=624, d=8, seed=413
Generating d-regular graph with n=764, d=7, seed=414
Generating d-regular graph with n=756, d=6, seed=415
Generating d-regular graph with n=610, d=6, seed=416
Generating d-regular graph with n=681, d=8, seed=417
Generating d-regular graph with n=733, d=6, seed=418
Generating d-regular graph with n=645, d=8, se

In [9]:
# graph_with_max_edges(nx_generated_graph)

In [10]:
ds = createGraphFromFolder_full(nx_generated_graph, terminals, 800)
ds

graph finished:  1
graph finished:  2
graph finished:  3
graph finished:  4
graph finished:  5
graph finished:  6
graph finished:  7
graph finished:  8
graph finished:  9
graph finished:  10
graph finished:  11
graph finished:  12
graph finished:  13
graph finished:  14
graph finished:  15
graph finished:  16
graph finished:  17
graph finished:  18
graph finished:  19
0


{0: [Graph(num_nodes=671, num_edges=4026,
        ndata_schemes={}
        edata_schemes={}),
  tensor([[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.]]),
  <networkx.classes.graph.Graph at 0x105828640>,
  [0, 1, 2]],
 1: [Graph(num_nodes=751, num_edges=6008,
        ndata_schemes={}
        edata_schemes={}),
  tensor([[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.]]),
  <networkx.classes.graph.Graph at 0x14b7459c0>,
  [0, 1, 2]],
 2: [Graph(num_nodes=666, num_edges=4662,
        ndata_schemes={}
        edata_schemes={}),
  tensor([[0., 0., 0.,  ..., 0., 0., 0.],
          [0

In [11]:
len(nx_generated_graph)

19

In [16]:
# import os
# mypath = './testData/'
# if not os.path.isdir(mypath):
#     print("creates")
#     os.makedirs(mypath)

In [17]:
save_object(ds, 'testData/NX_testingGraph_800.pkl')

In [21]:
import os
cwd = os.getcwd()
cwd

'/Users/javaad/Documents/research/COP'

# Generating random graph with nodes 500-800 and edges 6-10


In [83]:
nx_generated_graph = {}
terminals = {}

for i in range (400):
    nodes = random.randint(500,800)
    degree = random.randint(6,8)
    if (nodes * degree) % 2 != 0:
        i-=1
        continue
    nx_graph = generate_graph(n=nodes, d=degree, p=None, graph_type='reg', random_seed=i)

    for u, v, d in nx_graph.edges(data=True):
        d['weight'] = 1
        d['capacity'] = 1

    unique_random_numbers = generate_unique_random_numbers(nodes)

    nx_generated_graph[i] = nx_graph
    terminals[i] = unique_random_numbers
# #


Generating d-regular graph with n=619, d=6, seed=0
Generating d-regular graph with n=784, d=6, seed=1
Generating d-regular graph with n=766, d=8, seed=6
Generating d-regular graph with n=530, d=7, seed=7
Generating d-regular graph with n=626, d=8, seed=8
Generating d-regular graph with n=796, d=7, seed=9
Generating d-regular graph with n=553, d=8, seed=10
Generating d-regular graph with n=659, d=6, seed=11
Generating d-regular graph with n=699, d=6, seed=12
Generating d-regular graph with n=712, d=6, seed=13
Generating d-regular graph with n=705, d=8, seed=14
Generating d-regular graph with n=569, d=6, seed=15
Generating d-regular graph with n=640, d=7, seed=16
Generating d-regular graph with n=618, d=7, seed=18
Generating d-regular graph with n=792, d=8, seed=19
Generating d-regular graph with n=569, d=8, seed=22
Generating d-regular graph with n=652, d=6, seed=23
Generating d-regular graph with n=763, d=8, seed=24
Generating d-regular graph with n=705, d=8, seed=25
Generating d-regul

In [91]:
ds = createGraphFromFolder_full(nx_generated_graph, terminals, 800)


17 24 [654, 264, 763] [(3, 626), (3, 633), (3, 118), (3, 94), (3, 5), (3, 23), (3, 763), (3, 1), (4, 5), (4, 258), (4, 619), (4, 730), (4, 525), (4, 219), (4, 167), (4, 715), (5, 311), (5, 601), (5, 195), (5, 548), (5, 741), (5, 112), (6, 7), (6, 344), (6, 228), (6, 593), (6, 381), (6, 496), (6, 171), (6, 681), (7, 194), (7, 173), (7, 233), (7, 757), (7, 109), (7, 662), (7, 1), (8, 9), (8, 440), (8, 412), (8, 165), (8, 702), (8, 286), (8, 100), (8, 612), (9, 383), (9, 36), (9, 544), (9, 352), (9, 229), (9, 759), (9, 678), (10, 11), (10, 200), (10, 251), (10, 595), (10, 342), (10, 96), (10, 278), (10, 253), (11, 358), (11, 605), (11, 737), (11, 105), (11, 286), (11, 294), (11, 680), (12, 13), (12, 305), (12, 17), (12, 604), (12, 119), (12, 489), (12, 710), (12, 263), (13, 389), (13, 129), (13, 460), (13, 141), (13, 756), (13, 242), (13, 345), (14, 15), (14, 53), (14, 510), (14, 355), (14, 738), (14, 385), (14, 449), (14, 483), (15, 534), (15, 235), (15, 649), (15, 262), (15, 653), (15, 

In [92]:
ds[0][1].size()

torch.Size([619, 800])

In [93]:
save_object(ds, './testData/nx_generated_graph_n800_500_d6_8_t300.pkl')

## Generating random graph with nodes 2000-4000 and edges 8-12


In [18]:
nx_generated_graph = {}
terminals = {}

for i in range (600):
    nodes = random.randint(2000,4000)
    degree = random.randint(8,12)
    if (nodes * degree) % 2 != 0:
        i-=1
        continue
    nx_graph = generate_graph(n=nodes, d=degree, p=None, graph_type='reg', random_seed=i)

    for u, v, d in nx_graph.edges(data=True):
        d['weight'] = 1
        d['capacity'] = 1

    unique_random_numbers = generate_unique_random_numbers(nodes)

    nx_generated_graph[i] = nx_graph
    terminals[i] = unique_random_numbers
# #


Generating d-regular graph with n=2488, d=10, seed=0
Generating d-regular graph with n=3779, d=8, seed=1
Generating d-regular graph with n=2192, d=10, seed=2
Generating d-regular graph with n=3524, d=9, seed=3
Generating d-regular graph with n=2883, d=10, seed=4
Generating d-regular graph with n=2692, d=9, seed=6
Generating d-regular graph with n=3846, d=10, seed=8
Generating d-regular graph with n=3168, d=8, seed=9
Generating d-regular graph with n=2021, d=10, seed=11
Generating d-regular graph with n=2502, d=11, seed=12
Generating d-regular graph with n=3468, d=11, seed=13
Generating d-regular graph with n=2832, d=9, seed=14
Generating d-regular graph with n=2839, d=10, seed=15
Generating d-regular graph with n=2619, d=10, seed=16
Generating d-regular graph with n=2569, d=12, seed=17
Generating d-regular graph with n=2639, d=8, seed=18
Generating d-regular graph with n=3453, d=8, seed=19
Generating d-regular graph with n=3089, d=8, seed=20
Generating d-regular graph with n=2261, d=12

KeyboardInterrupt: 

In [15]:
ds = createGraphFromFolder_full(nx_generated_graph, terminals, 4000)

0


In [10]:
save_object(ds, './testData/nx_generated_graph_n2000_4000_d8_12_t300.pkl')

In [14]:
len(nx_generated_graph)

238

5

## Generating random Testing graph with nodes 2000-4000 and edges 8-12


In [8]:
nx_generated_graph = {}
terminals = {}

for i in range (400, 405, 1):
    nodes = random.randint(2000,4000)
    degree = random.randint(8,12)
    if (nodes * degree) % 2 != 0:
        i-=1
        continue
    nx_graph = generate_graph(n=nodes, d=degree, p=None, graph_type='reg', random_seed=i)

    for u, v, d in nx_graph.edges(data=True):
        d['weight'] = 1
        d['capacity'] = 1

    unique_random_numbers = generate_unique_random_numbers(nodes)

    nx_generated_graph[i] = nx_graph
    terminals[i] = unique_random_numbers
# #


Generating d-regular graph with n=2999, d=12, seed=400
Generating d-regular graph with n=3767, d=12, seed=401
Generating d-regular graph with n=3457, d=8, seed=402
Generating d-regular graph with n=3478, d=12, seed=403
Generating d-regular graph with n=3600, d=12, seed=404


In [9]:
ds = createGraphFromFolder_full(nx_generated_graph, terminals, 4000)

0


In [11]:
save_object(ds, './testData/nx_test_generated_graph_n2000_4000_d8_12_t300.pkl')

In [10]:
ds

{0: [Graph(num_nodes=2999, num_edges=35988,
        ndata_schemes={}
        edata_schemes={}),
  tensor([[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.]]),
  <networkx.classes.graph.Graph at 0x105afe950>,
  [0, 1, 2]],
 1: [Graph(num_nodes=3767, num_edges=45204,
        ndata_schemes={}
        edata_schemes={}),
  tensor([[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.]]),
  <networkx.classes.graph.Graph at 0x105afea40>,
  [0, 1, 2]],
 2: [Graph(num_nodes=3457, num_edges=27656,
        ndata_schemes={}
        edata_schemes={}),
  tensor([[0., 0., 0.,  ..., 0., 0., 0.],
      

## Generating random Testing graph with nodes 2000-5000 and edges 8-12

In [19]:
nx_generated_graph = {}
terminals = {}

for i in range (0, 500, 1):
    nodes = random.randint(2000,5000)
    degree = random.randint(8,12)
    if (nodes * degree) % 2 != 0:
        i-=1
        continue
    nx_graph = generate_graph(n=nodes, d=degree, p=None, graph_type='reg', random_seed=i)

    for u, v, d in nx_graph.edges(data=True):
        d['weight'] = 1
        d['capacity'] = 1

    unique_random_numbers = generate_unique_random_numbers(nodes)

    nx_generated_graph[i] = nx_graph
    terminals[i] = unique_random_numbers
# #


Generating d-regular graph with n=4428, d=12, seed=0
Generating d-regular graph with n=2860, d=9, seed=1
Generating d-regular graph with n=4252, d=11, seed=2
Generating d-regular graph with n=3365, d=10, seed=3
Generating d-regular graph with n=4871, d=8, seed=4
Generating d-regular graph with n=4567, d=10, seed=5
Generating d-regular graph with n=4710, d=11, seed=6
Generating d-regular graph with n=2379, d=8, seed=7
Generating d-regular graph with n=3981, d=10, seed=8
Generating d-regular graph with n=4142, d=12, seed=9
Generating d-regular graph with n=4037, d=12, seed=10
Generating d-regular graph with n=3456, d=8, seed=11
Generating d-regular graph with n=2904, d=8, seed=12
Generating d-regular graph with n=4217, d=12, seed=14
Generating d-regular graph with n=4311, d=10, seed=15
Generating d-regular graph with n=3109, d=8, seed=16
Generating d-regular graph with n=3072, d=8, seed=17
Generating d-regular graph with n=2446, d=12, seed=18
Generating d-regular graph with n=3378, d=10,

In [22]:
len(nx_generated_graph)

575

In [21]:
# nx_generated_graph = {}
# terminals = {}

for i in range (501, 700, 1):
    nodes = random.randint(300,800)
    degree = random.randint(8,12)
    if (nodes * degree) % 2 != 0:
        i-=1
        continue
    nx_graph = generate_graph(n=nodes, d=degree, p=None, graph_type='reg', random_seed=i)

    for u, v, d in nx_graph.edges(data=True):
        d['weight'] = 1
        d['capacity'] = 1

    unique_random_numbers = generate_unique_random_numbers(nodes)

    nx_generated_graph[i] = nx_graph
    terminals[i] = unique_random_numbers
# #


Generating d-regular graph with n=394, d=10, seed=502
Generating d-regular graph with n=758, d=8, seed=503
Generating d-regular graph with n=509, d=10, seed=504
Generating d-regular graph with n=355, d=12, seed=505
Generating d-regular graph with n=308, d=8, seed=506
Generating d-regular graph with n=418, d=10, seed=507
Generating d-regular graph with n=772, d=8, seed=508
Generating d-regular graph with n=333, d=12, seed=509
Generating d-regular graph with n=541, d=12, seed=510
Generating d-regular graph with n=366, d=10, seed=511
Generating d-regular graph with n=502, d=10, seed=513
Generating d-regular graph with n=594, d=8, seed=514
Generating d-regular graph with n=764, d=8, seed=515
Generating d-regular graph with n=418, d=10, seed=517
Generating d-regular graph with n=682, d=12, seed=518
Generating d-regular graph with n=776, d=8, seed=519
Generating d-regular graph with n=512, d=11, seed=521
Generating d-regular graph with n=730, d=9, seed=522
Generating d-regular graph with n=5

In [38]:
ds = createGraphFromFolder_full(nx_generated_graph, terminals, 10000)

Terminal swapped  0
graph finished:  1
Terminal swapped  1
graph finished:  2
Terminal swapped  2
graph finished:  3
Terminal swapped  3
graph finished:  4
Terminal swapped  4
graph finished:  5
Terminal swapped  5
graph finished:  6
Terminal swapped  6
graph finished:  7
Terminal swapped  7
graph finished:  8
Terminal swapped  8
graph finished:  9
Terminal swapped  9
graph finished:  10
Terminal swapped  10
graph finished:  11
Terminal swapped  11
graph finished:  12
Terminal swapped  12
graph finished:  13
Exception Occured  13 14 [2232, 270, 1913]
skipped Items: 0


In [33]:
len(nx_generated_graph)

575

In [39]:
len(ds)

13

In [30]:
terminals

{0: [3240, 2826, 943],
 1: [1382, 1531, 2600],
 2: [232, 3712, 1061],
 3: [3209, 1204, 937],
 4: [2928, 899, 4780],
 5: [1101, 2692, 2621],
 6: [3955, 826, 1262],
 7: [96, 815, 2188],
 8: [1810, 2088, 1495],
 9: [895, 485, 3230],
 10: [1078, 3609, 2647],
 11: [1644, 3050, 1548],
 12: [2232, 270, 1913],
 14: [3236, 1861, 1405],
 15: [2214, 1527, 1402],
 16: [983, 593, 341],
 17: [1836, 1425, 1208],
 18: [2353, 249, 921],
 19: [1390, 222, 1110],
 20: [1225, 1381, 648],
 22: [1662, 1749, 2095],
 23: [1726, 2043, 2016],
 24: [100, 4476, 3181],
 25: [1320, 2730, 2048],
 26: [1039, 311, 562],
 27: [4255, 2352, 439],
 30: [2525, 2757, 2684],
 31: [3275, 3909, 4171],
 32: [3651, 1802, 2725],
 33: [4932, 3051, 3723],
 34: [1909, 1148, 1701],
 35: [1995, 133, 1413],
 36: [718, 862, 2406],
 37: [3501, 2764, 3163],
 39: [4050, 26, 3426],
 40: [2967, 78, 268],
 41: [2579, 3218, 2903],
 43: [4276, 1356, 2923],
 44: [1498, 1476, 1953],
 45: [2035, 1141, 605],
 46: [1052, 3050, 3249],
 47: [1336, 1058

In [23]:
save_object(ds, './testData/nx_test_generated_graph_n800_4000_d8_12_t500.pkl')

# Training Code

In [8]:
def train1(modelName, filename = './testData/nx_generated_graph_n80_d3_t200.pkl', n = 80):
    n, d, p, graph_type, number_epochs, learning_rate, PROB_THRESHOLD, tol, patience, dim_embedding, hidden_dim = hyperParameters(learning_rate=0.001, n=n,patience=20)

    # Establish pytorch GNN + optimizer
    opt_params = {'lr': learning_rate}
    gnn_hypers = {
        'dim_embedding': dim_embedding,
        'hidden_dim': hidden_dim,
        'dropout': 0.0,
        'number_classes': 3,
        'prob_threshold': PROB_THRESHOLD,
        'number_epochs': number_epochs,
        'tolerance': tol,
        'patience': patience,
        'nodes':n
    }
    datasetItem = open_file(filename)
    # print(datasetItem)
    # datasetItem_all = {}
    # for key, (dgl_graph, adjacency_matrix,graph) in datasetItem.items():
    #     A, C = FIndAC(graph)
    #     datasetItem_all[key] = [dgl_graph, adjacency_matrix, graph, A, C]

    # print(len(datasetItem), datasetItem[0][3])
    # datasetItem_2 = {}
    # datasetItem_2[0]=datasetItem[1]
    # print(datasetItem_2)

    net, embed, optimizer = get_gnn(n, gnn_hypers, opt_params, TORCH_DEVICE, TORCH_DTYPE)


    # print(datasetItem[1][2].nodes)
    # # Visualize graph
    # pos = nx.kamada_kawai_layout(datasetItem[1][2])
    # nx.draw(datasetItem[1][2], pos, with_labels=True, node_color=[[.7, .7, .7]])
    # cut_value, (part_1, part_2) = nx.minimum_cut(datasetItem_2[0][2], datasetItem_2[0][3][1], datasetItem_2[0][3][0], flow_func=shortest_augmenting_path)

    # print(cut_value, len(part_1), len(part_2))

    # resultList = []
    # all_indexes = sorted(part_1.union(part_2))
    # # Check membership for each index and append the appropriate pair to the result list
    # for index in all_indexes:
    #     if index in part_1:
    #         resultList.append([1, 0])
    #     elif index in part_2:
    #         resultList.append([0, 1])

    #
    trained_net, bestLost, epoch, inp, lossList= run_gnn_training2(
        datasetItem, net, optimizer, int(1000),
        gnn_hypers['tolerance'], gnn_hypers['patience'], loss_terminal,gnn_hypers['dim_embedding'], gnn_hypers['number_classes'], modelName,  TORCH_DTYPE,  TORCH_DEVICE)

    return trained_net, bestLost, epoch, inp, lossList

def train_2wayNeural(modelName, filename='./testData/prepareDS.pkl'):
    n, d, p, graph_type, number_epochs, learning_rate, PROB_THRESHOLD, tol, patience, dim_embedding, hidden_dim = hyperParameters(learning_rate=0.001, n=4096,patience=20)

    # Establish pytorch GNN + optimizer
    opt_params = {'lr': learning_rate}
    gnn_hypers = {
        'dim_embedding': dim_embedding,
        'hidden_dim': hidden_dim,
        'dropout': 0.0,
        'number_classes': 2,
        'prob_threshold': PROB_THRESHOLD,
        'number_epochs': number_epochs,
        'tolerance': tol,
        'patience': patience,
        'nodes':n
    }
    datasetItem = open_file(filename)
    # print(datasetItem)
    # datasetItem_all = {}
    # for key, (dgl_graph, adjacency_matrix,graph) in datasetItem.items():
    #     A, C = FIndAC(graph)
    #     datasetItem_all[key] = [dgl_graph, adjacency_matrix, graph, A, C]

    # print(len(datasetItem), datasetItem[0][3])
    # datasetItem_2 = {}
    # datasetItem_2[0]=datasetItem[1]
    # print(datasetItem_2)

    net, embed, optimizer = get_gnn(n, gnn_hypers, opt_params, TORCH_DEVICE, TORCH_DTYPE)


    # print(datasetItem[1][2].nodes)
    # # Visualize graph
    # pos = nx.kamada_kawai_layout(datasetItem[1][2])
    # nx.draw(datasetItem[1][2], pos, with_labels=True, node_color=[[.7, .7, .7]])
    # cut_value, (part_1, part_2) = nx.minimum_cut(datasetItem_2[0][2], datasetItem_2[0][3][1], datasetItem_2[0][3][0], flow_func=shortest_augmenting_path)

    # print(cut_value, len(part_1), len(part_2))

    # resultList = []
    # all_indexes = sorted(part_1.union(part_2))
    # # Check membership for each index and append the appropriate pair to the result list
    # for index in all_indexes:
    #     if index in part_1:
    #         resultList.append([1, 0])
    #     elif index in part_2:
    #         resultList.append([0, 1])

    #
    trained_net, bestLost, epoch, inp, lossList= run_gnn_training2(
        datasetItem, net, optimizer, int(500),
        gnn_hypers['tolerance'], gnn_hypers['patience'], loss_terminal,gnn_hypers['dim_embedding'], gnn_hypers['number_classes'], modelName,  TORCH_DTYPE,  TORCH_DEVICE)

    return trained_net, bestLost, epoch, inp, lossList


## Exp 1 - loss

- expriment 5 of modifying the loss function (purely binary input) and find exact loss value (vectorized)
- removing terminal loss

In [24]:
import torch
import torch.nn.functional as F

def max_to_one_hot(tensor):
    # Find the index of the maximum value
    max_index = torch.argmax(tensor)

    # Create a one-hot encoded tensor
    one_hot_tensor = torch.zeros_like(tensor)
    one_hot_tensor[max_index] = 1.0

    one_hot_tensor = one_hot_tensor + tensor - tensor.detach()

    return one_hot_tensor

def apply_max_to_one_hot(output):
    return torch.stack([max_to_one_hot(output[i]) for i in range(output.size(0))])


def run_gnn_training2(dataset, net, optimizer, number_epochs, tol, patience, loss_func, dim_embedding, total_classes=3, save_directory=None, torch_dtype = TORCH_DTYPE, torch_device = TORCH_DEVICE, labels=None):
    """
    Train a GCN model with early stopping.
    """
    # loss for a whole epoch
    prev_loss = float('inf')  # Set initial loss to infinity for comparison
    prev_cummulative_loss = float('inf')
    cummulativeCount = 0
    count = 0  # Patience counter
    best_loss = float('inf')  # Initialize best loss to infinity
    best_model_state = None  # Placeholder for the best model state
    loss_list = []
    epochList = []
    cumulative_loss = 0

    t_gnn_start = time()

    # contains information regarding all terminal nodes for the dataset
    terminal_configs = {}
    epochCount = 0
    criterion = nn.BCELoss()
    A = nn.Parameter(torch.tensor([65.0]))
    C = nn.Parameter(torch.tensor([32.5]))

    embed = nn.Embedding(80, dim_embedding)
    embed = embed.type(torch_dtype).to(torch_device)
    inputs = embed.weight

    for epoch in range(number_epochs):

        cumulative_loss = 0.0  # Reset cumulative loss for each epoch

        for key, (dgl_graph, adjacency_matrix,graph, terminals) in dataset.items():
            epochCount +=1


            # Ensure model is in training mode
            net.train()

            # Pass the graph and the input features to the model
            logits = net(dgl_graph, adjacency_matrix)
            logits = override_fixed_nodes(logits)
            # Apply max to one-hot encoding
            one_hot_output = apply_max_to_one_hot(logits)
            # Compute the loss
            # loss = loss_func(criterion, logits, labels, terminals[0], terminals[1])

            loss = loss_func( one_hot_output, adjacency_matrix)


            # Backpropagation
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            # Update cumulative loss
            cumulative_loss += loss.item()



            # # Check for early stopping
            if epoch > 0 and (cumulative_loss > prev_loss or abs(prev_loss - cumulative_loss) <= tol):
                count += 1
                if count >= patience: # play around with patience value, try lower one
                    print(f'Stopping early at epoch {epoch}')
                    break
            else:
                count = 0  # Reset patience counter if loss decreases

            # Update best model
            if cumulative_loss < best_loss:
                best_loss = cumulative_loss
                best_model_state = net.state_dict()  # Save the best model state

        loss_list.append(loss)

        # # Early stopping break from the outer loop
        # if count >= patience:
        #     count=0

        prev_loss = cumulative_loss  # Update previous loss

        if epoch % 100 == 0:  # Adjust printing frequency as needed
            print(f'Epoch: {epoch}, Cumulative Loss: {cumulative_loss}')

            if save_directory != None:
                checkpoint = {
                    'epoch': epoch,
                    'model': net.state_dict(),
                    'optimizer': optimizer.state_dict(),
                    'lossList':loss_list,
                    'inputs':inputs}
                torch.save(checkpoint, './epoch'+str(epoch)+'loss'+str(cumulative_loss)+ save_directory)

            if (prev_cummulative_loss == cummulativeCount):
                cummulativeCount+=1

                if cummulativeCount > 4:
                    break
            else:
                prev_cummulative_loss = cumulative_loss


    t_gnn = time() - t_gnn_start

    # Load the best model state
    if best_model_state is not None:
        net.load_state_dict(best_model_state)

    print(f'GNN training took {round(t_gnn, 3)} seconds.')
    print(f'Best cumulative loss: {best_loss}')
    loss = loss_func(logits, adjacency_matrix)
    if save_directory != None:
        checkpoint = {
            'epoch': epoch,
            'model': net.state_dict(),
            'optimizer': optimizer.state_dict(),
            'lossList':loss_list,
            'inputs':inputs}
        torch.save(checkpoint, './final_'+save_directory)

    return net, best_loss, epoch, inputs, loss_list

class GCNSoftmax(nn.Module):
    def __init__(self, in_feats, hidden_size, num_classes, dropout, device):
        super(GCNSoftmax, self).__init__()
        self.dropout_frac = dropout
        self.conv1 = GraphConv(in_feats, hidden_size).to(device)
        self.conv2 = GraphConv(hidden_size, num_classes).to(device)

    def forward(self, g, inputs):
        # Basic forward pass
        h = self.conv1(g, inputs)
        h = F.relu(h)
        h = F.dropout(h, p=self.dropout_frac, training=self.training)
        h = self.conv2(g, h)
        h = F.softmax(h, dim=1)  # Apply softmax over the classes dimension
        # h = F.sigmoid(h)
        # h = override_fixed_nodes(h)

        return h

def override_fixed_nodes(h):
    output = h.clone()
    # Set the output for node 0 to [1, 0, 0]
    output[10] = torch.tensor([1.0, 0.0, 0.0],requires_grad=True) + h[10] - h[10].detach()
    # Set the output for node 1 to [0, 1, 0]
    output[40] = torch.tensor([0.0, 1.0, 0.0],requires_grad=True)+ h[40] - h[40].detach()
    # Set the output for node 2 to [0, 0, 1]
    output[70] = torch.tensor([0.0, 0.0, 1.0],requires_grad=True)+ h[70] - h[70].detach()
    return output

def calculate_HC_vectorized(s, adjacency_matrix):
    """
    Compute the minimum cut loss, which is the total weight of edges cut between partitions using vectorized operations.

    Parameters:
    s (torch.Tensor): Binary partition matrix of shape (num_nodes, num_partitions)
    adjacency_matrix (torch.Tensor): Adjacency matrix of the graph of shape (num_nodes, num_nodes)

    Returns:
    torch.Tensor: Scalar loss value representing the total weight of edges cut
    """
    num_nodes, num_partitions = s.shape

    # Compute the partition probability matrix for all partitions
    partition_prob_matrix = s @ s.T

    # Compute the cut value by summing weights of edges that connect nodes in different partitions
    cut_value = adjacency_matrix * (1 - partition_prob_matrix)

    # Sum up the contributions for all edges
    loss = torch.sum(cut_value) / 2  # Divide by 2 to correct for double-counting

    return loss

def Loss(s, adjacency_matrix,  A=1, C=1):
    HC = -1*calculate_HC_vectorized(s, adjacency_matrix)
    return C * HC


def loss_terminal(s, adjacency_matrix,  A=0, C=1, penalty=1000):
    loss = Loss(s, adjacency_matrix, A, C)
    # loss += penalty* terminal_independence_penalty(s, [0,1,2])
    return loss


In [25]:
trained_net, bestLost, epoch, inp, lossList = train1('_80MaxwayCut_LossExp1_loss.pth')


Epoch: 0, Cumulative Loss: -17853.0
Stopping early at epoch 1
Stopping early at epoch 2
Stopping early at epoch 3
Stopping early at epoch 4
Stopping early at epoch 5
Stopping early at epoch 7
Stopping early at epoch 8
Stopping early at epoch 9
Stopping early at epoch 10
Stopping early at epoch 11
Stopping early at epoch 12
Stopping early at epoch 13
Stopping early at epoch 14
Stopping early at epoch 15
Stopping early at epoch 17
Stopping early at epoch 18
Stopping early at epoch 19
Stopping early at epoch 20
Stopping early at epoch 22
Stopping early at epoch 23
Stopping early at epoch 24
Stopping early at epoch 25
Stopping early at epoch 26
Stopping early at epoch 28
Stopping early at epoch 29
Stopping early at epoch 30
Stopping early at epoch 31
Stopping early at epoch 32
Stopping early at epoch 34
Stopping early at epoch 35
Stopping early at epoch 36
Stopping early at epoch 37
Stopping early at epoch 39
Stopping early at epoch 40
Stopping early at epoch 41
Stopping early at epoch 42


## Exp 2 - loss

- expriment 2 of modifying the loss function (purely binary input) and find exact loss value (vectorized)
- removing terminal loss
- graph n=500, d=3

In [29]:
import torch
import torch.nn.functional as F

def max_to_one_hot(tensor):
    # Find the index of the maximum value
    max_index = torch.argmax(tensor)

    # Create a one-hot encoded tensor
    one_hot_tensor = torch.zeros_like(tensor)
    one_hot_tensor[max_index] = 1.0

    one_hot_tensor = one_hot_tensor + tensor - tensor.detach()

    return one_hot_tensor

def apply_max_to_one_hot(output):
    return torch.stack([max_to_one_hot(output[i]) for i in range(output.size(0))])


def run_gnn_training2(dataset, net, optimizer, number_epochs, tol, patience, loss_func, dim_embedding, total_classes=3, save_directory=None, torch_dtype = TORCH_DTYPE, torch_device = TORCH_DEVICE, labels=None):
    """
    Train a GCN model with early stopping.
    """
    # loss for a whole epoch
    prev_loss = float('inf')  # Set initial loss to infinity for comparison
    prev_cummulative_loss = float('inf')
    cummulativeCount = 0
    count = 0  # Patience counter
    best_loss = float('inf')  # Initialize best loss to infinity
    best_model_state = None  # Placeholder for the best model state
    loss_list = []
    epochList = []
    cumulative_loss = 0

    t_gnn_start = time()

    # contains information regarding all terminal nodes for the dataset
    terminal_configs = {}
    epochCount = 0
    criterion = nn.BCELoss()
    A = nn.Parameter(torch.tensor([65.0]))
    C = nn.Parameter(torch.tensor([32.5]))

    embed = nn.Embedding(500, dim_embedding)
    embed = embed.type(torch_dtype).to(torch_device)
    inputs = embed.weight

    for epoch in range(number_epochs):

        cumulative_loss = 0.0  # Reset cumulative loss for each epoch

        for key, (dgl_graph, adjacency_matrix,graph, terminals) in dataset.items():
            epochCount +=1


            # Ensure model is in training mode
            net.train()

            # Pass the graph and the input features to the model
            logits = net(dgl_graph, adjacency_matrix)
            logits = override_fixed_nodes(logits)
            # Apply max to one-hot encoding
            one_hot_output = apply_max_to_one_hot(logits)
            # Compute the loss
            # loss = loss_func(criterion, logits, labels, terminals[0], terminals[1])

            loss = loss_func( one_hot_output, adjacency_matrix)


            # Backpropagation
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            # Update cumulative loss
            cumulative_loss += loss.item()



            # # Check for early stopping
            if epoch > 0 and (cumulative_loss > prev_loss or abs(prev_loss - cumulative_loss) <= tol):
                count += 1
                if count >= patience: # play around with patience value, try lower one
                    print(f'Stopping early at epoch {epoch}')
                    break
            else:
                count = 0  # Reset patience counter if loss decreases

            # Update best model
            if cumulative_loss < best_loss:
                best_loss = cumulative_loss
                best_model_state = net.state_dict()  # Save the best model state

        loss_list.append(loss)

        # # Early stopping break from the outer loop
        # if count >= patience:
        #     count=0

        prev_loss = cumulative_loss  # Update previous loss

        if epoch % 100 == 0:  # Adjust printing frequency as needed
            print(f'Epoch: {epoch}, Cumulative Loss: {cumulative_loss}')

            if save_directory != None:
                checkpoint = {
                    'epoch': epoch,
                    'model': net.state_dict(),
                    'optimizer': optimizer.state_dict(),
                    'lossList':loss_list,
                    'inputs':inputs}
                torch.save(checkpoint, './epoch'+str(epoch)+'loss'+str(cumulative_loss)+ save_directory)

            if (prev_cummulative_loss == cummulativeCount):
                cummulativeCount+=1

                if cummulativeCount > 4:
                    break
            else:
                prev_cummulative_loss = cumulative_loss


    t_gnn = time() - t_gnn_start

    # Load the best model state
    if best_model_state is not None:
        net.load_state_dict(best_model_state)

    print(f'GNN training took {round(t_gnn, 3)} seconds.')
    print(f'Best cumulative loss: {best_loss}')
    loss = loss_func(logits, adjacency_matrix)
    if save_directory != None:
        checkpoint = {
            'epoch': epoch,
            'model': net.state_dict(),
            'optimizer': optimizer.state_dict(),
            'lossList':loss_list,
            'inputs':inputs}
        torch.save(checkpoint, './final_'+save_directory)

    return net, best_loss, epoch, inputs, loss_list

class GCNSoftmax(nn.Module):
    def __init__(self, in_feats, hidden_size, num_classes, dropout, device):
        super(GCNSoftmax, self).__init__()
        self.dropout_frac = dropout
        self.conv1 = GraphConv(in_feats, hidden_size).to(device)
        self.conv2 = GraphConv(hidden_size, num_classes).to(device)

    def forward(self, g, inputs):
        # Basic forward pass
        h = self.conv1(g, inputs)
        h = F.relu(h)
        h = F.dropout(h, p=self.dropout_frac, training=self.training)
        h = self.conv2(g, h)
        h = F.softmax(h, dim=1)  # Apply softmax over the classes dimension
        # h = F.sigmoid(h)
        # h = override_fixed_nodes(h)

        return h

def override_fixed_nodes(h):
    output = h.clone()
    # Set the output for node 0 to [1, 0, 0]
    output[100] = torch.tensor([1.0, 0.0, 0.0],requires_grad=True) + h[10] - h[10].detach()
    # Set the output for node 1 to [0, 1, 0]
    output[300] = torch.tensor([0.0, 1.0, 0.0],requires_grad=True)+ h[40] - h[40].detach()
    # Set the output for node 2 to [0, 0, 1]
    output[450] = torch.tensor([0.0, 0.0, 1.0],requires_grad=True)+ h[70] - h[70].detach()
    return output

def calculate_HC_vectorized(s, adjacency_matrix):
    """
    Compute the minimum cut loss, which is the total weight of edges cut between partitions using vectorized operations.

    Parameters:
    s (torch.Tensor): Binary partition matrix of shape (num_nodes, num_partitions)
    adjacency_matrix (torch.Tensor): Adjacency matrix of the graph of shape (num_nodes, num_nodes)

    Returns:
    torch.Tensor: Scalar loss value representing the total weight of edges cut
    """
    num_nodes, num_partitions = s.shape

    # Compute the partition probability matrix for all partitions
    partition_prob_matrix = s @ s.T

    # Compute the cut value by summing weights of edges that connect nodes in different partitions
    cut_value = adjacency_matrix * (1 - partition_prob_matrix)

    # Sum up the contributions for all edges
    loss = torch.sum(cut_value) / 2  # Divide by 2 to correct for double-counting

    return loss

def Loss(s, adjacency_matrix,  A=1, C=1):
    HC = -1*calculate_HC_vectorized(s, adjacency_matrix)
    return C * HC


def loss_terminal(s, adjacency_matrix,  A=0, C=1, penalty=1000):
    loss = Loss(s, adjacency_matrix, A, C)
    # loss += penalty* terminal_independence_penalty(s, [0,1,2])
    return loss


In [30]:
trained_net, bestLost, epoch, inp, lossList = train1('_80MaxwayCut_LossExp2_loss.pth', './testData/nx_generated_graph_n500_d3_t200.pkl', 500)


Epoch: 0, Cumulative Loss: -119754.0
Stopping early at epoch 1
Stopping early at epoch 2
Stopping early at epoch 3
Stopping early at epoch 5
Stopping early at epoch 6
Stopping early at epoch 7
Stopping early at epoch 9
Stopping early at epoch 10
Stopping early at epoch 12
Stopping early at epoch 13
Stopping early at epoch 14
Stopping early at epoch 16
Stopping early at epoch 17
Stopping early at epoch 18
Stopping early at epoch 20
Stopping early at epoch 21
Stopping early at epoch 23
Stopping early at epoch 24
Stopping early at epoch 25
Stopping early at epoch 27
Stopping early at epoch 28
Stopping early at epoch 29
Stopping early at epoch 30
Stopping early at epoch 32
Stopping early at epoch 33
Stopping early at epoch 34
Stopping early at epoch 36
Stopping early at epoch 37
Stopping early at epoch 38
Stopping early at epoch 40
Stopping early at epoch 41
Stopping early at epoch 42
Stopping early at epoch 43
Stopping early at epoch 44
Stopping early at epoch 45
Stopping early at epoch 4

## Exp 3 - loss

- expriment 3 of modifying the loss function (purely binary input) and find exact loss value (vectorized)
- removing terminal loss
- graph n=1000, d=3

In [9]:
import torch
import torch.nn.functional as F

def max_to_one_hot(tensor):
    # Find the index of the maximum value
    max_index = torch.argmax(tensor)

    # Create a one-hot encoded tensor
    one_hot_tensor = torch.zeros_like(tensor)
    one_hot_tensor[max_index] = 1.0

    one_hot_tensor = one_hot_tensor + tensor - tensor.detach()

    return one_hot_tensor

def apply_max_to_one_hot(output):
    return torch.stack([max_to_one_hot(output[i]) for i in range(output.size(0))])


def run_gnn_training2(dataset, net, optimizer, number_epochs, tol, patience, loss_func, dim_embedding, total_classes=3, save_directory=None, torch_dtype = TORCH_DTYPE, torch_device = TORCH_DEVICE, labels=None):
    """
    Train a GCN model with early stopping.
    """
    # loss for a whole epoch
    prev_loss = float('inf')  # Set initial loss to infinity for comparison
    prev_cummulative_loss = float('inf')
    cummulativeCount = 0
    count = 0  # Patience counter
    best_loss = float('inf')  # Initialize best loss to infinity
    best_model_state = None  # Placeholder for the best model state
    loss_list = []
    epochList = []
    cumulative_loss = 0

    t_gnn_start = time()

    # contains information regarding all terminal nodes for the dataset
    terminal_configs = {}
    epochCount = 0
    criterion = nn.BCELoss()
    A = nn.Parameter(torch.tensor([65.0]))
    C = nn.Parameter(torch.tensor([32.5]))

    embed = nn.Embedding(1000, dim_embedding)
    embed = embed.type(torch_dtype).to(torch_device)
    inputs = embed.weight

    for epoch in range(number_epochs):

        cumulative_loss = 0.0  # Reset cumulative loss for each epoch

        for key, (dgl_graph, adjacency_matrix,graph, terminals) in dataset.items():
            epochCount +=1


            # Ensure model is in training mode
            net.train()

            # Pass the graph and the input features to the model
            logits = net(dgl_graph, adjacency_matrix)
            logits = override_fixed_nodes(logits)
            # Apply max to one-hot encoding
            one_hot_output = apply_max_to_one_hot(logits)
            # Compute the loss
            # loss = loss_func(criterion, logits, labels, terminals[0], terminals[1])

            loss = loss_func( one_hot_output, adjacency_matrix)


            # Backpropagation
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            # Update cumulative loss
            cumulative_loss += loss.item()



            # # Check for early stopping
            if epoch > 0 and (cumulative_loss > prev_loss or abs(prev_loss - cumulative_loss) <= tol):
                count += 1
                if count >= patience: # play around with patience value, try lower one
                    print(f'Stopping early at epoch {epoch}')
                    break
            else:
                count = 0  # Reset patience counter if loss decreases

            # Update best model
            if cumulative_loss < best_loss:
                best_loss = cumulative_loss
                best_model_state = net.state_dict()  # Save the best model state

        loss_list.append(loss)

        # # Early stopping break from the outer loop
        # if count >= patience:
        #     count=0

        prev_loss = cumulative_loss  # Update previous loss

        if epoch % 100 == 0:  # Adjust printing frequency as needed
            print(f'Epoch: {epoch}, Cumulative Loss: {cumulative_loss}')

            if save_directory != None:
                checkpoint = {
                    'epoch': epoch,
                    'model': net.state_dict(),
                    'optimizer': optimizer.state_dict(),
                    'lossList':loss_list,
                    'inputs':inputs}
                torch.save(checkpoint, './epoch'+str(epoch)+'loss'+str(cumulative_loss)+ save_directory)

            if (prev_cummulative_loss == cummulativeCount):
                cummulativeCount+=1

                if cummulativeCount > 4:
                    break
            else:
                prev_cummulative_loss = cumulative_loss


    t_gnn = time() - t_gnn_start

    # Load the best model state
    if best_model_state is not None:
        net.load_state_dict(best_model_state)

    print(f'GNN training took {round(t_gnn, 3)} seconds.')
    print(f'Best cumulative loss: {best_loss}')
    loss = loss_func(logits, adjacency_matrix)
    if save_directory != None:
        checkpoint = {
            'epoch': epoch,
            'model': net.state_dict(),
            'optimizer': optimizer.state_dict(),
            'lossList':loss_list,
            'inputs':inputs}
        torch.save(checkpoint, './final_'+save_directory)

    return net, best_loss, epoch, inputs, loss_list

class GCNSoftmax(nn.Module):
    def __init__(self, in_feats, hidden_size, num_classes, dropout, device):
        super(GCNSoftmax, self).__init__()
        self.dropout_frac = dropout
        self.conv1 = GraphConv(in_feats, hidden_size).to(device)
        self.conv2 = GraphConv(hidden_size, num_classes).to(device)

    def forward(self, g, inputs):
        # Basic forward pass
        h = self.conv1(g, inputs)
        h = F.relu(h)
        h = F.dropout(h, p=self.dropout_frac, training=self.training)
        h = self.conv2(g, h)
        h = F.softmax(h, dim=1)  # Apply softmax over the classes dimension
        # h = F.sigmoid(h)
        # h = override_fixed_nodes(h)

        return h

def override_fixed_nodes(h):
    output = h.clone()
    # Set the output for node 0 to [1, 0, 0]
    output[200] = torch.tensor([1.0, 0.0, 0.0],requires_grad=True) + h[200] - h[200].detach()
    # Set the output for node 1 to [0, 1, 0]
    output[400] = torch.tensor([0.0, 1.0, 0.0],requires_grad=True)+ h[400] - h[400].detach()
    # Set the output for node 2 to [0, 0, 1]
    output[700] = torch.tensor([0.0, 0.0, 1.0],requires_grad=True)+ h[700] - h[700].detach()
    return output

def calculate_HC_vectorized(s, adjacency_matrix):
    """
    Compute the minimum cut loss, which is the total weight of edges cut between partitions using vectorized operations.

    Parameters:
    s (torch.Tensor): Binary partition matrix of shape (num_nodes, num_partitions)
    adjacency_matrix (torch.Tensor): Adjacency matrix of the graph of shape (num_nodes, num_nodes)

    Returns:
    torch.Tensor: Scalar loss value representing the total weight of edges cut
    """
    num_nodes, num_partitions = s.shape

    # Compute the partition probability matrix for all partitions
    partition_prob_matrix = s @ s.T

    # Compute the cut value by summing weights of edges that connect nodes in different partitions
    cut_value = adjacency_matrix * (1 - partition_prob_matrix)

    # Sum up the contributions for all edges
    loss = torch.sum(cut_value) / 2  # Divide by 2 to correct for double-counting

    return loss

def Loss(s, adjacency_matrix,  A=1, C=1):
    HC = -1*calculate_HC_vectorized(s, adjacency_matrix)
    return C * HC


def loss_terminal(s, adjacency_matrix,  A=0, C=1, penalty=1000):
    loss = Loss(s, adjacency_matrix, A, C)
    # loss += penalty* terminal_independence_penalty(s, [0,1,2])
    return loss


In [10]:
trained_net, bestLost, epoch, inp, lossList = train1('_10000MaxwayCut_LossExp3_loss.pth', './testData/nx_generated_graph_n1000_d3_t200.pkl', 1000)


Epoch: 0, Cumulative Loss: -240338.0
Stopping early at epoch 1
Stopping early at epoch 2
Stopping early at epoch 4
Stopping early at epoch 5
Stopping early at epoch 7
Stopping early at epoch 8
Stopping early at epoch 9
Stopping early at epoch 11
Stopping early at epoch 12
Stopping early at epoch 14
Stopping early at epoch 15
Stopping early at epoch 16
Stopping early at epoch 18
Stopping early at epoch 19
Stopping early at epoch 21
Stopping early at epoch 22
Stopping early at epoch 23
Stopping early at epoch 24
Stopping early at epoch 26
Stopping early at epoch 27
Stopping early at epoch 29
Stopping early at epoch 30
Stopping early at epoch 32
Stopping early at epoch 33
Stopping early at epoch 34
Stopping early at epoch 36
Stopping early at epoch 37
Stopping early at epoch 38
Stopping early at epoch 39
Stopping early at epoch 41
Stopping early at epoch 42
Stopping early at epoch 44
Stopping early at epoch 45
Stopping early at epoch 47
Stopping early at epoch 48
Stopping early at epoch 5

## Exp 4 - loss

- expriment 4 of modifying the loss function (purely binary input) and find exact loss value (vectorized)
- removing terminal loss
- graph n=200-500, d=6-8

In [52]:
import torch
import torch.nn.functional as F

def max_to_one_hot(tensor):
    # Find the index of the maximum value
    max_index = torch.argmax(tensor)

    # Create a one-hot encoded tensor
    one_hot_tensor = torch.zeros_like(tensor)
    one_hot_tensor[max_index] = 1.0

    one_hot_tensor = one_hot_tensor + tensor - tensor.detach()

    return one_hot_tensor

def apply_max_to_one_hot(output):
    return torch.stack([max_to_one_hot(output[i]) for i in range(output.size(0))])


def run_gnn_training2(dataset, net, optimizer, number_epochs, tol, patience, loss_func, dim_embedding, total_classes=3, save_directory=None, torch_dtype = TORCH_DTYPE, torch_device = TORCH_DEVICE, labels=None):
    """
    Train a GCN model with early stopping.
    """
    # loss for a whole epoch
    prev_loss = float('inf')  # Set initial loss to infinity for comparison
    prev_cummulative_loss = float('inf')
    cummulativeCount = 0
    count = 0  # Patience counter
    best_loss = float('inf')  # Initialize best loss to infinity
    best_model_state = None  # Placeholder for the best model state
    loss_list = []
    epochList = []
    cumulative_loss = 0

    t_gnn_start = time()

    # contains information regarding all terminal nodes for the dataset
    terminal_configs = {}
    epochCount = 0
    criterion = nn.BCELoss()
    A = nn.Parameter(torch.tensor([65.0]))
    C = nn.Parameter(torch.tensor([32.5]))

    embed = nn.Embedding(1000, dim_embedding)
    embed = embed.type(torch_dtype).to(torch_device)
    inputs = embed.weight

    for epoch in range(number_epochs):

        cumulative_loss = 0.0  # Reset cumulative loss for each epoch

        for key, (dgl_graph, adjacency_matrix,graph, terminals) in dataset.items():
            epochCount +=1


            # Ensure model is in training mode
            net.train()

            # Pass the graph and the input features to the model
            logits = net(dgl_graph, adjacency_matrix)
            logits = override_fixed_nodes(logits)
            # Apply max to one-hot encoding
            one_hot_output = apply_max_to_one_hot(logits)
            # Compute the loss
            # loss = loss_func(criterion, logits, labels, terminals[0], terminals[1])

            loss = loss_func( one_hot_output, adjacency_matrix)


            # Backpropagation
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            # Update cumulative loss
            cumulative_loss += loss.item()



            # # Check for early stopping
            if epoch > 0 and (cumulative_loss > prev_loss or abs(prev_loss - cumulative_loss) <= tol):
                count += 1
                if count >= patience: # play around with patience value, try lower one
                    print(f'Stopping early at epoch {epoch}')
                    break
            else:
                count = 0  # Reset patience counter if loss decreases

            # Update best model
            if cumulative_loss < best_loss:
                best_loss = cumulative_loss
                best_model_state = net.state_dict()  # Save the best model state

        loss_list.append(loss)

        # # Early stopping break from the outer loop
        # if count >= patience:
        #     count=0

        prev_loss = cumulative_loss  # Update previous loss

        if epoch % 100 == 0:  # Adjust printing frequency as needed
            print(f'Epoch: {epoch}, Cumulative Loss: {cumulative_loss}')

            if save_directory != None:
                checkpoint = {
                    'epoch': epoch,
                    'model': net.state_dict(),
                    'optimizer': optimizer.state_dict(),
                    'lossList':loss_list,
                    'inputs':inputs}
                torch.save(checkpoint, './epoch'+str(epoch)+'loss'+str(cumulative_loss)+ save_directory)

            if (prev_cummulative_loss == cummulativeCount):
                cummulativeCount+=1

                if cummulativeCount > 4:
                    break
            else:
                prev_cummulative_loss = cumulative_loss


    t_gnn = time() - t_gnn_start

    # Load the best model state
    if best_model_state is not None:
        net.load_state_dict(best_model_state)

    print(f'GNN training took {round(t_gnn, 3)} seconds.')
    print(f'Best cumulative loss: {best_loss}')
    loss = loss_func(logits, adjacency_matrix)
    if save_directory != None:
        checkpoint = {
            'epoch': epoch,
            'model': net.state_dict(),
            'optimizer': optimizer.state_dict(),
            'lossList':loss_list,
            'inputs':inputs}
        torch.save(checkpoint, './final_'+save_directory)

    return net, best_loss, epoch, inputs, loss_list

class GCNSoftmax(nn.Module):
    def __init__(self, in_feats, hidden_size, num_classes, dropout, device):
        super(GCNSoftmax, self).__init__()
        self.dropout_frac = dropout
        self.conv1 = GraphConv(in_feats, hidden_size).to(device)
        self.conv2 = GraphConv(hidden_size, num_classes).to(device)

    def forward(self, g, inputs):
        # Basic forward pass
        h = self.conv1(g, inputs)
        h = F.relu(h)
        h = F.dropout(h, p=self.dropout_frac, training=self.training)
        h = self.conv2(g, h)
        h = F.softmax(h, dim=1)  # Apply softmax over the classes dimension
        # h = F.sigmoid(h)
        # h = override_fixed_nodes(h)

        return h

def override_fixed_nodes(h):
    output = h.clone()
    # Set the output for node 0 to [1, 0, 0]
    output[0] = torch.tensor([1.0, 0.0, 0.0],requires_grad=True) + h[0] - h[0].detach()
    # Set the output for node 1 to [0, 1, 0]
    output[1] = torch.tensor([0.0, 1.0, 0.0],requires_grad=True)+ h[1] - h[1].detach()
    # Set the output for node 2 to [0, 0, 1]
    output[2] = torch.tensor([0.0, 0.0, 1.0],requires_grad=True)+ h[2] - h[2].detach()
    return output

def calculate_HC_vectorized(s, adjacency_matrix):
    """
    Compute the minimum cut loss, which is the total weight of edges cut between partitions using vectorized operations.

    Parameters:
    s (torch.Tensor): Binary partition matrix of shape (num_nodes, num_partitions)
    adjacency_matrix (torch.Tensor): Adjacency matrix of the graph of shape (num_nodes, num_nodes)

    Returns:
    torch.Tensor: Scalar loss value representing the total weight of edges cut
    """
    num_nodes, num_partitions = s.shape

    # Compute the partition probability matrix for all partitions
    partition_prob_matrix = s @ s.T

    # Compute the cut value by summing weights of edges that connect nodes in different partitions
    cut_value = adjacency_matrix * (1 - partition_prob_matrix)

    # Sum up the contributions for all edges
    loss = torch.sum(cut_value) / 2  # Divide by 2 to correct for double-counting

    return loss

def Loss(s, adjacency_matrix,  A=1, C=1):
    HC = -1*calculate_HC_vectorized(s, adjacency_matrix)
    return C * HC


def loss_terminal(s, adjacency_matrix,  A=0, C=1, penalty=1000):
    loss = Loss(s, adjacency_matrix, A, C)
    # loss += penalty* terminal_independence_penalty(s, [0,1,2])
    return loss


In [53]:
trained_net, bestLost, epoch, inp, lossList = train1('_500MaxwayCut_LossExp4_loss.pth', './testData/nx_generated_graph_n500_d6_8_t300.pkl', 500)


Epoch: 0, Cumulative Loss: -16523.00001525879
Stopping early at epoch 76
Stopping early at epoch 77
Stopping early at epoch 89
Stopping early at epoch 90
Stopping early at epoch 98
Stopping early at epoch 99
Epoch: 100, Cumulative Loss: -28108.0
Stopping early at epoch 104
Stopping early at epoch 105
Stopping early at epoch 106
Stopping early at epoch 107
Stopping early at epoch 108
Stopping early at epoch 109
Stopping early at epoch 113
Stopping early at epoch 114
Stopping early at epoch 115
Stopping early at epoch 116
Stopping early at epoch 120
Stopping early at epoch 121
Stopping early at epoch 122
Stopping early at epoch 123
Stopping early at epoch 124
Stopping early at epoch 125
Stopping early at epoch 126
Stopping early at epoch 127
Stopping early at epoch 128
Stopping early at epoch 129
Stopping early at epoch 130
Stopping early at epoch 137
Stopping early at epoch 138
Stopping early at epoch 139
Stopping early at epoch 140
Stopping early at epoch 141
Stopping early at epoch 14

## Exp 5 - loss

- expriment 5 of modifying the loss function (purely binary input) and find exact loss value (vectorized)
- removing terminal loss
- graph n=200-500, d=6-8

In [81]:
import torch
import torch.nn.functional as F

def max_to_one_hot(tensor):
    # Find the index of the maximum value
    max_index = torch.argmax(tensor)

    # Create a one-hot encoded tensor
    one_hot_tensor = torch.zeros_like(tensor)
    one_hot_tensor[max_index] = 1.0

    one_hot_tensor = one_hot_tensor + tensor - tensor.detach()

    return one_hot_tensor

def apply_max_to_one_hot(output):
    return torch.stack([max_to_one_hot(output[i]) for i in range(output.size(0))])


def run_gnn_training2(dataset, net, optimizer, number_epochs, tol, patience, loss_func, dim_embedding, total_classes=3, save_directory=None, torch_dtype = TORCH_DTYPE, torch_device = TORCH_DEVICE, labels=None):
    """
    Train a GCN model with early stopping.
    """
    # loss for a whole epoch
    prev_loss = float('inf')  # Set initial loss to infinity for comparison
    prev_cummulative_loss = float('inf')
    cummulativeCount = 0
    count = 0  # Patience counter
    best_loss = float('inf')  # Initialize best loss to infinity
    best_model_state = None  # Placeholder for the best model state
    loss_list = []
    epochList = []
    cumulative_loss = 0

    t_gnn_start = time()

    # contains information regarding all terminal nodes for the dataset
    terminal_configs = {}
    epochCount = 0
    criterion = nn.BCELoss()
    A = nn.Parameter(torch.tensor([65.0]))
    C = nn.Parameter(torch.tensor([32.5]))

    embed = nn.Embedding(1000, dim_embedding)
    embed = embed.type(torch_dtype).to(torch_device)
    inputs = embed.weight

    for epoch in range(number_epochs):

        cumulative_loss = 0.0  # Reset cumulative loss for each epoch

        for key, (dgl_graph, adjacency_matrix,graph, terminals) in dataset.items():
            epochCount +=1


            # Ensure model is in training mode
            net.train()

            # Pass the graph and the input features to the model
            logits = net(dgl_graph, adjacency_matrix)
            logits = override_fixed_nodes(logits)
            # Apply max to one-hot encoding
            one_hot_output = apply_max_to_one_hot(logits)
            # Compute the loss
            # loss = loss_func(criterion, logits, labels, terminals[0], terminals[1])

            loss = loss_func( one_hot_output, adjacency_matrix)


            # Backpropagation
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            # Update cumulative loss
            cumulative_loss += loss.item()



            # # Check for early stopping
            if epoch > 0 and (cumulative_loss > prev_loss or abs(prev_loss - cumulative_loss) <= tol):
                count += 1
                if count >= patience: # play around with patience value, try lower one
                    print(f'Stopping early at epoch {epoch}')
                    break
            else:
                count = 0  # Reset patience counter if loss decreases

            # Update best model
            if cumulative_loss < best_loss:
                best_loss = cumulative_loss
                best_model_state = net.state_dict()  # Save the best model state

        loss_list.append(loss)

        # # Early stopping break from the outer loop
        # if count >= patience:
        #     count=0

        prev_loss = cumulative_loss  # Update previous loss

        if epoch % 100 == 0:  # Adjust printing frequency as needed
            print(f'Epoch: {epoch}, Cumulative Loss: {cumulative_loss}')

            if save_directory != None:
                checkpoint = {
                    'epoch': epoch,
                    'model': net.state_dict(),
                    'optimizer': optimizer.state_dict(),
                    'lossList':loss_list,
                    'inputs':inputs}
                torch.save(checkpoint, './epoch'+str(epoch)+'loss'+str(cumulative_loss)+ save_directory)

            if (prev_cummulative_loss == cummulativeCount):
                cummulativeCount+=1

                if cummulativeCount > 4:
                    break
            else:
                prev_cummulative_loss = cumulative_loss


    t_gnn = time() - t_gnn_start

    # Load the best model state
    if best_model_state is not None:
        net.load_state_dict(best_model_state)

    print(f'GNN training took {round(t_gnn, 3)} seconds.')
    print(f'Best cumulative loss: {best_loss}')
    loss = loss_func(logits, adjacency_matrix)
    if save_directory != None:
        checkpoint = {
            'epoch': epoch,
            'model': net.state_dict(),
            'optimizer': optimizer.state_dict(),
            'lossList':loss_list,
            'inputs':inputs}
        torch.save(checkpoint, './final_'+save_directory)

    return net, best_loss, epoch, inputs, loss_list

class GCNSoftmax(nn.Module):
    def __init__(self, in_feats, hidden_size, num_classes, dropout, device):
        super(GCNSoftmax, self).__init__()
        self.dropout_frac = dropout
        self.conv1 = GraphConv(in_feats, hidden_size).to(device)
        self.conv2 = GraphConv(hidden_size, num_classes).to(device)

    def forward(self, g, inputs):
        # Basic forward pass
        h = self.conv1(g, inputs)
        h = F.relu(h)
        h = F.dropout(h, p=self.dropout_frac, training=self.training)
        h = self.conv2(g, h)
        h = F.softmax(h, dim=1)  # Apply softmax over the classes dimension
        # h = F.sigmoid(h)
        # h = override_fixed_nodes(h)

        return h

def override_fixed_nodes(h):
    output = h.clone()
    # Set the output for node 0 to [1, 0, 0]
    output[0] = torch.tensor([1.0, 0.0, 0.0],requires_grad=True) + h[0] - h[0].detach()
    # Set the output for node 1 to [0, 1, 0]
    output[1] = torch.tensor([0.0, 1.0, 0.0],requires_grad=True)+ h[1] - h[1].detach()
    # Set the output for node 2 to [0, 0, 1]
    output[2] = torch.tensor([0.0, 0.0, 1.0],requires_grad=True)+ h[2] - h[2].detach()
    return output

def calculate_HC_vectorized(s, adjacency_matrix):
    """
    Compute the minimum cut loss, which is the total weight of edges cut between partitions using vectorized operations.

    Parameters:
    s (torch.Tensor): Binary partition matrix of shape (num_nodes, num_partitions)
    adjacency_matrix (torch.Tensor): Adjacency matrix of the graph of shape (num_nodes, num_nodes)

    Returns:
    torch.Tensor: Scalar loss value representing the total weight of edges cut
    """
    num_nodes, num_partitions = s.shape

    # Compute the partition probability matrix for all partitions
    partition_prob_matrix = s @ s.T

    # Compute the cut value by summing weights of edges that connect nodes in different partitions
    cut_value = adjacency_matrix * (1 - extend_matrix_torch(partition_prob_matrix, 500))

    # Sum up the contributions for all edges
    loss = torch.sum(cut_value) / 2  # Divide by 2 to correct for double-counting

    return loss

def Loss(s, adjacency_matrix,  A=1, C=1):
    HC = -1*calculate_HC_vectorized(s, adjacency_matrix)
    return C * HC


def loss_terminal(s, adjacency_matrix,  A=0, C=1, penalty=1000):
    loss = Loss(s, adjacency_matrix, A, C)
    # loss += penalty* terminal_independence_penalty(s, [0,1,2])
    return loss


In [82]:
trained_net, bestLost, epoch, inp, lossList = train1('_500MaxwayCut_LossExp5_loss.pth', './testData/nx_generated_graph_n500_200_d6_8_t300.pkl', 500)


Epoch: 0, Cumulative Loss: -158774.0
Stopping early at epoch 1
Stopping early at epoch 2
Stopping early at epoch 3
Stopping early at epoch 4
Stopping early at epoch 5
Stopping early at epoch 7
Stopping early at epoch 8
Stopping early at epoch 10
Stopping early at epoch 11
Stopping early at epoch 12
Stopping early at epoch 14
Stopping early at epoch 15
Stopping early at epoch 16
Stopping early at epoch 17
Stopping early at epoch 18
Stopping early at epoch 19
Stopping early at epoch 21
Stopping early at epoch 22
Stopping early at epoch 23
Stopping early at epoch 25
Stopping early at epoch 26
Stopping early at epoch 27
Stopping early at epoch 28
Stopping early at epoch 30
Stopping early at epoch 31
Stopping early at epoch 33
Stopping early at epoch 34
Stopping early at epoch 35
Stopping early at epoch 37
Stopping early at epoch 38
Stopping early at epoch 40
Stopping early at epoch 41
Stopping early at epoch 42
Stopping early at epoch 43
Stopping early at epoch 44
Stopping early at epoch 4

## Exp 6 - loss

- expriment 5 of modifying the loss function (purely binary input) and find exact loss value (vectorized)
- removing terminal loss
- adding HC as a parameter to cause chaotic-ness
- graph n=200-500, d=6-8

In [None]:
import torch
import torch.nn.functional as F

def max_to_one_hot(tensor):
    # Find the index of the maximum value
    max_index = torch.argmax(tensor)

    # Create a one-hot encoded tensor
    one_hot_tensor = torch.zeros_like(tensor)
    one_hot_tensor[max_index] = 1.0

    one_hot_tensor = one_hot_tensor + tensor - tensor.detach()

    return one_hot_tensor

def apply_max_to_one_hot(output):
    return torch.stack([max_to_one_hot(output[i]) for i in range(output.size(0))])


def run_gnn_training2(dataset, net, optimizer, number_epochs, tol, patience, loss_func, dim_embedding, total_classes=3, save_directory=None, torch_dtype = TORCH_DTYPE, torch_device = TORCH_DEVICE, labels=None):
    """
    Train a GCN model with early stopping.
    """
    # loss for a whole epoch
    prev_loss = float('inf')  # Set initial loss to infinity for comparison
    prev_cummulative_loss = float('inf')
    cummulativeCount = 0
    count = 0  # Patience counter
    best_loss = float('inf')  # Initialize best loss to infinity
    best_model_state = None  # Placeholder for the best model state
    loss_list = []
    epochList = []
    cumulative_loss = 0

    t_gnn_start = time()

    # contains information regarding all terminal nodes for the dataset
    terminal_configs = {}
    epochCount = 0
    criterion = nn.BCELoss()
    A = nn.Parameter(torch.tensor([65.0]))
    C = nn.Parameter(torch.tensor([32.5]))

    embed = nn.Embedding(1000, dim_embedding)
    embed = embed.type(torch_dtype).to(torch_device)
    inputs = embed.weight

    for epoch in range(number_epochs):

        cumulative_loss = 0.0  # Reset cumulative loss for each epoch

        for key, (dgl_graph, adjacency_matrix,graph, terminals) in dataset.items():
            epochCount +=1


            # Ensure model is in training mode
            net.train()

            # Pass the graph and the input features to the model
            logits = net(dgl_graph, adjacency_matrix)
            logits = override_fixed_nodes(logits)
            # Apply max to one-hot encoding
            one_hot_output = apply_max_to_one_hot(logits)
            # Compute the loss
            # loss = loss_func(criterion, logits, labels, terminals[0], terminals[1])

            loss = loss_func( one_hot_output, adjacency_matrix)


            # Backpropagation
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            # Update cumulative loss
            cumulative_loss += loss.item()



            # # Check for early stopping
            if epoch > 0 and (cumulative_loss > prev_loss or abs(prev_loss - cumulative_loss) <= tol):
                count += 1
                if count >= patience: # play around with patience value, try lower one
                    print(f'Stopping early at epoch {epoch}')
                    break
            else:
                count = 0  # Reset patience counter if loss decreases

            # Update best model
            if cumulative_loss < best_loss:
                best_loss = cumulative_loss
                best_model_state = net.state_dict()  # Save the best model state

        loss_list.append(loss)

        # # Early stopping break from the outer loop
        # if count >= patience:
        #     count=0

        prev_loss = cumulative_loss  # Update previous loss

        if epoch % 100 == 0:  # Adjust printing frequency as needed
            print(f'Epoch: {epoch}, Cumulative Loss: {cumulative_loss}')

            if save_directory != None:
                checkpoint = {
                    'epoch': epoch,
                    'model': net.state_dict(),
                    'optimizer': optimizer.state_dict(),
                    'lossList':loss_list,
                    'inputs':inputs}
                torch.save(checkpoint, './epoch'+str(epoch)+'loss'+str(cumulative_loss)+ save_directory)

            if (prev_cummulative_loss == cummulativeCount):
                cummulativeCount+=1

                if cummulativeCount > 4:
                    break
            else:
                prev_cummulative_loss = cumulative_loss


    t_gnn = time() - t_gnn_start

    # Load the best model state
    if best_model_state is not None:
        net.load_state_dict(best_model_state)

    print(f'GNN training took {round(t_gnn, 3)} seconds.')
    print(f'Best cumulative loss: {best_loss}')
    loss = loss_func(logits, adjacency_matrix)
    if save_directory != None:
        checkpoint = {
            'epoch': epoch,
            'model': net.state_dict(),
            'optimizer': optimizer.state_dict(),
            'lossList':loss_list,
            'inputs':inputs}
        torch.save(checkpoint, './final_'+save_directory)

    return net, best_loss, epoch, inputs, loss_list

class GCNSoftmax(nn.Module):
    def __init__(self, in_feats, hidden_size, num_classes, dropout, device):
        super(GCNSoftmax, self).__init__()
        self.dropout_frac = dropout
        self.conv1 = GraphConv(in_feats, hidden_size).to(device)
        self.conv2 = GraphConv(hidden_size, num_classes).to(device)

    def forward(self, g, inputs):
        # Basic forward pass
        h = self.conv1(g, inputs)
        h = F.relu(h)
        h = F.dropout(h, p=self.dropout_frac, training=self.training)
        h = self.conv2(g, h)
        h = F.softmax(h, dim=1)  # Apply softmax over the classes dimension
        # h = F.sigmoid(h)
        # h = override_fixed_nodes(h)

        return h

def override_fixed_nodes(h):
    output = h.clone()
    # Set the output for node 0 to [1, 0, 0]
    output[0] = torch.tensor([1.0, 0.0, 0.0],requires_grad=True) + h[0] - h[0].detach()
    # Set the output for node 1 to [0, 1, 0]
    output[1] = torch.tensor([0.0, 1.0, 0.0],requires_grad=True)+ h[1] - h[1].detach()
    # Set the output for node 2 to [0, 0, 1]
    output[2] = torch.tensor([0.0, 0.0, 1.0],requires_grad=True)+ h[2] - h[2].detach()
    return output

def calculate_HC_vectorized(s, adjacency_matrix):
    """
    Compute the minimum cut loss, which is the total weight of edges cut between partitions using vectorized operations.

    Parameters:
    s (torch.Tensor): Binary partition matrix of shape (num_nodes, num_partitions)
    adjacency_matrix (torch.Tensor): Adjacency matrix of the graph of shape (num_nodes, num_nodes)

    Returns:
    torch.Tensor: Scalar loss value representing the total weight of edges cut
    """
    num_nodes, num_partitions = s.shape

    # Compute the partition probability matrix for all partitions
    partition_prob_matrix = s @ s.T

    # Compute the cut value by summing weights of edges that connect nodes in different partitions
    cut_value = adjacency_matrix * (1 - extend_matrix_torch(partition_prob_matrix, 500))

    # Sum up the contributions for all edges
    loss = torch.sum(cut_value) / 2  # Divide by 2 to correct for double-counting

    return loss

def Loss(s, adjacency_matrix,  A=1, C=1):
    HC = -1*calculate_HC_vectorized(s, adjacency_matrix)
    return C * HC


def loss_terminal(s, adjacency_matrix,  A=0, C=1, penalty=1000):
    loss = Loss(s, adjacency_matrix, A, C)
    # loss += penalty* terminal_independence_penalty(s, [0,1,2])
    return loss


In [None]:
trained_net, bestLost, epoch, inp, lossList = train1('_500MaxwayCut_LossExp6_loss.pth', './testData/nx_generated_graph_n500_200_d6_8_t300.pkl', 500)


## Exp 7 - loss

- expriment 7 of modifying the loss function (purely binary input) and find exact loss value (vectorized)
- removing terminal loss
- adding HC as a parameter to cause chaotic-ness
- graph n=500-800, d=6-8

In [96]:
import torch
import torch.nn.functional as F

def max_to_one_hot(tensor):
    # Find the index of the maximum value
    max_index = torch.argmax(tensor)

    # Create a one-hot encoded tensor
    one_hot_tensor = torch.zeros_like(tensor)
    one_hot_tensor[max_index] = 1.0

    one_hot_tensor = one_hot_tensor + tensor - tensor.detach()

    return one_hot_tensor

def apply_max_to_one_hot(output):
    return torch.stack([max_to_one_hot(output[i]) for i in range(output.size(0))])


def run_gnn_training2(dataset, net, optimizer, number_epochs, tol, patience, loss_func, dim_embedding, total_classes=3, save_directory=None, torch_dtype = TORCH_DTYPE, torch_device = TORCH_DEVICE, labels=None):
    """
    Train a GCN model with early stopping.
    """
    # loss for a whole epoch
    prev_loss = float('inf')  # Set initial loss to infinity for comparison
    prev_cummulative_loss = float('inf')
    cummulativeCount = 0
    count = 0  # Patience counter
    best_loss = float('inf')  # Initialize best loss to infinity
    best_model_state = None  # Placeholder for the best model state
    loss_list = []
    epochList = []
    cumulative_loss = 0

    t_gnn_start = time()

    # contains information regarding all terminal nodes for the dataset
    terminal_configs = {}
    epochCount = 0
    criterion = nn.BCELoss()
    A = nn.Parameter(torch.tensor([65.0]))
    C = nn.Parameter(torch.tensor([32.5]))

    embed = nn.Embedding(1000, dim_embedding)
    embed = embed.type(torch_dtype).to(torch_device)
    inputs = embed.weight

    for epoch in range(number_epochs):

        cumulative_loss = 0.0  # Reset cumulative loss for each epoch

        for key, (dgl_graph, adjacency_matrix,graph, terminals) in dataset.items():
            epochCount +=1


            # Ensure model is in training mode
            net.train()

            # Pass the graph and the input features to the model
            logits = net(dgl_graph, adjacency_matrix)
            logits = override_fixed_nodes(logits)
            # Apply max to one-hot encoding
            one_hot_output = apply_max_to_one_hot(logits)
            # Compute the loss
            # loss = loss_func(criterion, logits, labels, terminals[0], terminals[1])

            loss = loss_func( one_hot_output, adjacency_matrix)


            # Backpropagation
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            # Update cumulative loss
            cumulative_loss += loss.item()



            # # Check for early stopping
            if epoch > 0 and (cumulative_loss > prev_loss or abs(prev_loss - cumulative_loss) <= tol):
                count += 1
                if count >= patience: # play around with patience value, try lower one
                    print(f'Stopping early at epoch {epoch}')
                    break
            else:
                count = 0  # Reset patience counter if loss decreases

            # Update best model
            if cumulative_loss < best_loss:
                best_loss = cumulative_loss
                best_model_state = net.state_dict()  # Save the best model state

        loss_list.append(loss)

        # # Early stopping break from the outer loop
        # if count >= patience:
        #     count=0

        prev_loss = cumulative_loss  # Update previous loss

        if epoch % 100 == 0:  # Adjust printing frequency as needed
            print(f'Epoch: {epoch}, Cumulative Loss: {cumulative_loss}')

            if save_directory != None:
                checkpoint = {
                    'epoch': epoch,
                    'model': net.state_dict(),
                    'optimizer': optimizer.state_dict(),
                    'lossList':loss_list,
                    'inputs':inputs}
                torch.save(checkpoint, './epoch'+str(epoch)+'loss'+str(cumulative_loss)+ save_directory)

            if (prev_cummulative_loss == cummulativeCount):
                cummulativeCount+=1

                if cummulativeCount > 4:
                    break
            else:
                prev_cummulative_loss = cumulative_loss


    t_gnn = time() - t_gnn_start

    # Load the best model state
    if best_model_state is not None:
        net.load_state_dict(best_model_state)

    print(f'GNN training took {round(t_gnn, 3)} seconds.')
    print(f'Best cumulative loss: {best_loss}')
    loss = loss_func(logits, adjacency_matrix)
    if save_directory != None:
        checkpoint = {
            'epoch': epoch,
            'model': net.state_dict(),
            'optimizer': optimizer.state_dict(),
            'lossList':loss_list,
            'inputs':inputs}
        torch.save(checkpoint, './final_'+save_directory)

    return net, best_loss, epoch, inputs, loss_list

class GCNSoftmax(nn.Module):
    def __init__(self, in_feats, hidden_size, num_classes, dropout, device):
        super(GCNSoftmax, self).__init__()
        self.dropout_frac = dropout
        self.conv1 = GraphConv(in_feats, hidden_size).to(device)
        self.conv2 = GraphConv(hidden_size, num_classes).to(device)

    def forward(self, g, inputs):
        # Basic forward pass
        h = self.conv1(g, inputs)
        h = F.relu(h)
        h = F.dropout(h, p=self.dropout_frac, training=self.training)
        h = self.conv2(g, h)
        h = F.softmax(h, dim=1)  # Apply softmax over the classes dimension
        # h = F.sigmoid(h)
        # h = override_fixed_nodes(h)

        return h

def override_fixed_nodes(h):
    output = h.clone()
    # Set the output for node 0 to [1, 0, 0]
    output[0] = torch.tensor([1.0, 0.0, 0.0],requires_grad=True) + h[0] - h[0].detach()
    # Set the output for node 1 to [0, 1, 0]
    output[1] = torch.tensor([0.0, 1.0, 0.0],requires_grad=True)+ h[1] - h[1].detach()
    # Set the output for node 2 to [0, 0, 1]
    output[2] = torch.tensor([0.0, 0.0, 1.0],requires_grad=True)+ h[2] - h[2].detach()
    return output

def calculate_HC_vectorized(s, adjacency_matrix):
    """
    Compute the minimum cut loss, which is the total weight of edges cut between partitions using vectorized operations.

    Parameters:
    s (torch.Tensor): Binary partition matrix of shape (num_nodes, num_partitions)
    adjacency_matrix (torch.Tensor): Adjacency matrix of the graph of shape (num_nodes, num_nodes)

    Returns:
    torch.Tensor: Scalar loss value representing the total weight of edges cut
    """
    num_nodes, num_partitions = s.shape

    # Compute the partition probability matrix for all partitions
    partition_prob_matrix = s @ s.T

    # Compute the cut value by summing weights of edges that connect nodes in different partitions
    cut_value = adjacency_matrix * (1 - extend_matrix_torch(partition_prob_matrix, 800))

    # Sum up the contributions for all edges
    loss = torch.sum(cut_value) / 2  # Divide by 2 to correct for double-counting

    return loss

def Loss(s, adjacency_matrix,  A=1, C=1):
    HC = -1*calculate_HC_vectorized(s, adjacency_matrix)
    return C * HC


def loss_terminal(s, adjacency_matrix,  A=0, C=1, penalty=1000):
    loss = Loss(s, adjacency_matrix, A, C)
    # loss += penalty* terminal_independence_penalty(s, [0,1,2])
    return loss


In [97]:
trained_net, bestLost, epoch, inp, lossList = train1('_800MaxwayCut_LossExp7_loss.pth', './testData/nx_generated_graph_n800_500_d6_8_t300.pkl', 800)


Epoch: 0, Cumulative Loss: -16414.0
Stopping early at epoch 11
Stopping early at epoch 12
Stopping early at epoch 25
Stopping early at epoch 26
Stopping early at epoch 29
Stopping early at epoch 30
Stopping early at epoch 31
Stopping early at epoch 36
Stopping early at epoch 37
Stopping early at epoch 65
Stopping early at epoch 66
Stopping early at epoch 67
Epoch: 100, Cumulative Loss: -34858.0
Stopping early at epoch 104
Stopping early at epoch 105
Stopping early at epoch 125
Stopping early at epoch 126
Stopping early at epoch 127
Stopping early at epoch 128
Stopping early at epoch 129
Stopping early at epoch 130
Stopping early at epoch 131
Stopping early at epoch 132
Stopping early at epoch 133
Stopping early at epoch 134
Stopping early at epoch 135
Stopping early at epoch 136
Stopping early at epoch 137
Stopping early at epoch 138
Stopping early at epoch 139
Stopping early at epoch 140
Stopping early at epoch 141
Stopping early at epoch 151
Stopping early at epoch 152
Stopping early

## Exp 8 - loss

- expriment 8 of modifying the loss function (purely binary input) and find exact loss value (vectorized)
- removing terminal loss
- adding HC as a parameter to cause chaotic-ness
- graph n=2000-4000, d=10-12

In [26]:
import torch
import torch.nn.functional as F

def max_to_one_hot(tensor):
    # Find the index of the maximum value
    max_index = torch.argmax(tensor)

    # Create a one-hot encoded tensor
    one_hot_tensor = torch.zeros_like(tensor)
    one_hot_tensor[max_index] = 1.0

    one_hot_tensor = one_hot_tensor + tensor - tensor.detach()

    return one_hot_tensor

def apply_max_to_one_hot(output):
    return torch.stack([max_to_one_hot(output[i]) for i in range(output.size(0))])


def run_gnn_training2(dataset, net, optimizer, number_epochs, tol, patience, loss_func, dim_embedding, total_classes=3, save_directory=None, torch_dtype = TORCH_DTYPE, torch_device = TORCH_DEVICE, labels=None):
    """
    Train a GCN model with early stopping.
    """
    # loss for a whole epoch
    prev_loss = float('inf')  # Set initial loss to infinity for comparison
    prev_cummulative_loss = float('inf')
    cummulativeCount = 0
    count = 0  # Patience counter
    best_loss = float('inf')  # Initialize best loss to infinity
    best_model_state = None  # Placeholder for the best model state
    loss_list = []
    epochList = []
    cumulative_loss = 0

    t_gnn_start = time()

    # contains information regarding all terminal nodes for the dataset
    terminal_configs = {}
    epochCount = 0
    criterion = nn.BCELoss()
    A = nn.Parameter(torch.tensor([65.0]))
    C = nn.Parameter(torch.tensor([32.5]))

    embed = nn.Embedding(1000, dim_embedding)
    embed = embed.type(torch_dtype).to(torch_device)
    inputs = embed.weight

    for epoch in range(number_epochs):

        cumulative_loss = 0.0  # Reset cumulative loss for each epoch

        for key, (dgl_graph, adjacency_matrix,graph, terminals) in dataset.items():
            epochCount +=1


            # Ensure model is in training mode
            net.train()

            # Pass the graph and the input features to the model
            logits = net(dgl_graph, adjacency_matrix)
            logits = override_fixed_nodes(logits)
            # Apply max to one-hot encoding
            one_hot_output = apply_max_to_one_hot(logits)
            # Compute the loss
            # loss = loss_func(criterion, logits, labels, terminals[0], terminals[1])

            loss = loss_func( one_hot_output, adjacency_matrix)


            # Backpropagation
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            # Update cumulative loss
            cumulative_loss += loss.item()



            # # Check for early stopping
            if epoch > 0 and (cumulative_loss > prev_loss or abs(prev_loss - cumulative_loss) <= tol):
                count += 1
                if count >= patience: # play around with patience value, try lower one
                    print(f'Stopping early at epoch {epoch}')
                    break
            else:
                count = 0  # Reset patience counter if loss decreases

            # Update best model
            if cumulative_loss < best_loss:
                best_loss = cumulative_loss
                best_model_state = net.state_dict()  # Save the best model state

        loss_list.append(loss)

        # # Early stopping break from the outer loop
        # if count >= patience:
        #     count=0

        prev_loss = cumulative_loss  # Update previous loss

        if epoch % 100 == 0:  # Adjust printing frequency as needed
            print(f'Epoch: {epoch}, Cumulative Loss: {cumulative_loss}')

            if save_directory != None:
                checkpoint = {
                    'epoch': epoch,
                    'model': net.state_dict(),
                    'optimizer': optimizer.state_dict(),
                    'lossList':loss_list,
                    'inputs':inputs}
                torch.save(checkpoint, './epoch'+str(epoch)+'loss'+str(cumulative_loss)+ save_directory)

            if (prev_cummulative_loss == cummulativeCount):
                cummulativeCount+=1

                if cummulativeCount > 4:
                    break
            else:
                prev_cummulative_loss = cumulative_loss


    t_gnn = time() - t_gnn_start

    # Load the best model state
    if best_model_state is not None:
        net.load_state_dict(best_model_state)

    print(f'GNN training took {round(t_gnn, 3)} seconds.')
    print(f'Best cumulative loss: {best_loss}')
    loss = loss_func(logits, adjacency_matrix)
    if save_directory != None:
        checkpoint = {
            'epoch': epoch,
            'model': net.state_dict(),
            'optimizer': optimizer.state_dict(),
            'lossList':loss_list,
            'inputs':inputs}
        torch.save(checkpoint, './final_'+save_directory)

    return net, best_loss, epoch, inputs, loss_list

class GCNSoftmax(nn.Module):
    def __init__(self, in_feats, hidden_size, num_classes, dropout, device):
        super(GCNSoftmax, self).__init__()
        self.dropout_frac = dropout
        self.conv1 = GraphConv(in_feats, hidden_size).to(device)
        self.conv2 = GraphConv(hidden_size, num_classes).to(device)

    def forward(self, g, inputs):
        # Basic forward pass
        h = self.conv1(g, inputs)
        h = F.relu(h)
        h = F.dropout(h, p=self.dropout_frac, training=self.training)
        h = self.conv2(g, h)
        h = F.softmax(h, dim=1)  # Apply softmax over the classes dimension
        # h = F.sigmoid(h)
        # h = override_fixed_nodes(h)

        return h

def override_fixed_nodes(h):
    output = h.clone()
    # Set the output for node 0 to [1, 0, 0]
    output[0] = torch.tensor([1.0, 0.0, 0.0],requires_grad=True) + h[0] - h[0].detach()
    # Set the output for node 1 to [0, 1, 0]
    output[1] = torch.tensor([0.0, 1.0, 0.0],requires_grad=True)+ h[1] - h[1].detach()
    # Set the output for node 2 to [0, 0, 1]
    output[2] = torch.tensor([0.0, 0.0, 1.0],requires_grad=True)+ h[2] - h[2].detach()
    return output

def calculate_HC_vectorized(s, adjacency_matrix):
    """
    Compute the minimum cut loss, which is the total weight of edges cut between partitions using vectorized operations.

    Parameters:
    s (torch.Tensor): Binary partition matrix of shape (num_nodes, num_partitions)
    adjacency_matrix (torch.Tensor): Adjacency matrix of the graph of shape (num_nodes, num_nodes)

    Returns:
    torch.Tensor: Scalar loss value representing the total weight of edges cut
    """
    num_nodes, num_partitions = s.shape

    # Compute the partition probability matrix for all partitions
    partition_prob_matrix = s @ s.T

    # Compute the cut value by summing weights of edges that connect nodes in different partitions
    cut_value = adjacency_matrix * (1 - extend_matrix_torch(partition_prob_matrix, 4000))

    # Sum up the contributions for all edges
    loss = torch.sum(cut_value) / 2  # Divide by 2 to correct for double-counting

    return loss

def Loss(s, adjacency_matrix,  A=1, C=1):
    HC = -1*calculate_HC_vectorized(s, adjacency_matrix)
    return C * HC


def loss_terminal(s, adjacency_matrix,  A=0, C=1, penalty=1000):
    loss = Loss(s, adjacency_matrix, A, C)
    # loss += penalty* terminal_independence_penalty(s, [0,1,2])
    return loss


In [27]:
trained_net, bestLost, epoch, inp, lossList = train1('_4000MaxwayCut_LossExp8_loss.pth', './testData/nx_generated_graph_n2000_4000_d8_12_t300.pkl', 4000)


Epoch: 0, Cumulative Loss: -208927.01397705078
Stopping early at epoch 1
Stopping early at epoch 2
Stopping early at epoch 3
Stopping early at epoch 5
Stopping early at epoch 6
Stopping early at epoch 8
Stopping early at epoch 9
Stopping early at epoch 10
Stopping early at epoch 12
Stopping early at epoch 13
Stopping early at epoch 15
Stopping early at epoch 16
Stopping early at epoch 18
Stopping early at epoch 19
Stopping early at epoch 20
Stopping early at epoch 22
Stopping early at epoch 23
Stopping early at epoch 25
Stopping early at epoch 26
Stopping early at epoch 28
Stopping early at epoch 29
Stopping early at epoch 31
Stopping early at epoch 32
Stopping early at epoch 34
Stopping early at epoch 35
Stopping early at epoch 36
Stopping early at epoch 38
Stopping early at epoch 39
Stopping early at epoch 40
Stopping early at epoch 42
Stopping early at epoch 43
Stopping early at epoch 44
Stopping early at epoch 46
Stopping early at epoch 47
Stopping early at epoch 49
Stopping early 

Exp 9 - loss
- expriment 9 of modifying the loss function (purely binary input) and find exact loss value (vectorized)
- removing terminal loss
- Max Graph Node 10000
- graph n=2000-4000, d=10-12


In [9]:
import torch
import torch.nn.functional as F

def max_to_one_hot(tensor):
    # Find the index of the maximum value
    max_index = torch.argmax(tensor)

    # Create a one-hot encoded tensor
    one_hot_tensor = torch.zeros_like(tensor)
    one_hot_tensor[max_index] = 1.0

    one_hot_tensor = one_hot_tensor + tensor - tensor.detach()

    return one_hot_tensor

def apply_max_to_one_hot(output):
    return torch.stack([max_to_one_hot(output[i]) for i in range(output.size(0))])


def run_gnn_training2(dataset, net, optimizer, number_epochs, tol, patience, loss_func, dim_embedding, total_classes=3, save_directory=None, torch_dtype = TORCH_DTYPE, torch_device = TORCH_DEVICE, labels=None):
    """
    Train a GCN model with early stopping.
    """
    # loss for a whole epoch
    prev_loss = float('inf')  # Set initial loss to infinity for comparison
    prev_cummulative_loss = float('inf')
    cummulativeCount = 0
    count = 0  # Patience counter
    best_loss = float('inf')  # Initialize best loss to infinity
    best_model_state = None  # Placeholder for the best model state
    loss_list = []
    epochList = []
    cumulative_loss = 0

    t_gnn_start = time()

    # contains information regarding all terminal nodes for the dataset
    terminal_configs = {}
    epochCount = 0
    criterion = nn.BCELoss()
    A = nn.Parameter(torch.tensor([65.0]))
    C = nn.Parameter(torch.tensor([32.5]))

    embed = nn.Embedding(1000, dim_embedding)
    embed = embed.type(torch_dtype).to(torch_device)
    inputs = embed.weight

    for epoch in range(number_epochs):

        cumulative_loss = 0.0  # Reset cumulative loss for each epoch

        for key, (dgl_graph, adjacency_matrix,graph, terminals) in dataset.items():
            epochCount +=1


            # Ensure model is in training mode
            net.train()

            # Pass the graph and the input features to the model
            logits = net(dgl_graph, adjacency_matrix)
            logits = override_fixed_nodes(logits)
            # Apply max to one-hot encoding
            one_hot_output = apply_max_to_one_hot(logits)
            # Compute the loss
            # loss = loss_func(criterion, logits, labels, terminals[0], terminals[1])

            loss = loss_func( one_hot_output, adjacency_matrix)


            # Backpropagation
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            # Update cumulative loss
            cumulative_loss += loss.item()



            # # Check for early stopping
            if epoch > 0 and (cumulative_loss > prev_loss or abs(prev_loss - cumulative_loss) <= tol):
                count += 1
                if count >= patience: # play around with patience value, try lower one
                    print(f'Stopping early at epoch {epoch}')
                    break
            else:
                count = 0  # Reset patience counter if loss decreases

            # Update best model
            if cumulative_loss < best_loss:
                best_loss = cumulative_loss
                best_model_state = net.state_dict()  # Save the best model state

        loss_list.append(loss)

        # # Early stopping break from the outer loop
        # if count >= patience:
        #     count=0

        prev_loss = cumulative_loss  # Update previous loss

        if epoch % 100 == 0:  # Adjust printing frequency as needed
            print(f'Epoch: {epoch}, Cumulative Loss: {cumulative_loss}')

            if save_directory != None:
                checkpoint = {
                    'epoch': epoch,
                    'model': net.state_dict(),
                    'optimizer': optimizer.state_dict(),
                    'lossList':loss_list,
                    'inputs':inputs}
                torch.save(checkpoint, './epoch'+str(epoch)+'loss'+str(cumulative_loss)+ save_directory)

            if (prev_cummulative_loss == cummulativeCount):
                cummulativeCount+=1

                if cummulativeCount > 4:
                    break
            else:
                prev_cummulative_loss = cumulative_loss


    t_gnn = time() - t_gnn_start

    # Load the best model state
    if best_model_state is not None:
        net.load_state_dict(best_model_state)

    print(f'GNN training took {round(t_gnn, 3)} seconds.')
    print(f'Best cumulative loss: {best_loss}')
    loss = loss_func(logits, adjacency_matrix)
    if save_directory != None:
        checkpoint = {
            'epoch': epoch,
            'model': net.state_dict(),
            'optimizer': optimizer.state_dict(),
            'lossList':loss_list,
            'inputs':inputs}
        torch.save(checkpoint, './final_'+save_directory)

    return net, best_loss, epoch, inputs, loss_list

class GCNSoftmax(nn.Module):
    def __init__(self, in_feats, hidden_size, num_classes, dropout, device):
        super(GCNSoftmax, self).__init__()
        self.dropout_frac = dropout
        self.conv1 = GraphConv(in_feats, hidden_size).to(device)
        self.conv2 = GraphConv(hidden_size, num_classes).to(device)

    def forward(self, g, inputs):
        # Basic forward pass
        h = self.conv1(g, inputs)
        h = F.relu(h)
        h = F.dropout(h, p=self.dropout_frac, training=self.training)
        h = self.conv2(g, h)
        h = F.softmax(h, dim=1)  # Apply softmax over the classes dimension
        # h = F.sigmoid(h)
        # h = override_fixed_nodes(h)

        return h

def override_fixed_nodes(h):
    output = h.clone()
    # Set the output for node 0 to [1, 0, 0]
    output[0] = torch.tensor([1.0, 0.0, 0.0],requires_grad=True) + h[0] - h[0].detach()
    # Set the output for node 1 to [0, 1, 0]
    output[1] = torch.tensor([0.0, 1.0, 0.0],requires_grad=True)+ h[1] - h[1].detach()
    # Set the output for node 2 to [0, 0, 1]
    output[2] = torch.tensor([0.0, 0.0, 1.0],requires_grad=True)+ h[2] - h[2].detach()
    return output

def calculate_HC_vectorized(s, adjacency_matrix):
    """
    Compute the minimum cut loss, which is the total weight of edges cut between partitions using vectorized operations.

    Parameters:
    s (torch.Tensor): Binary partition matrix of shape (num_nodes, num_partitions)
    adjacency_matrix (torch.Tensor): Adjacency matrix of the graph of shape (num_nodes, num_nodes)

    Returns:
    torch.Tensor: Scalar loss value representing the total weight of edges cut
    """
    num_nodes, num_partitions = s.shape

    # Compute the partition probability matrix for all partitions
    partition_prob_matrix = s @ s.T

    # Compute the cut value by summing weights of edges that connect nodes in different partitions
    cut_value = adjacency_matrix * (1 - extend_matrix_torch(partition_prob_matrix, 10000))

    # Sum up the contributions for all edges
    loss = torch.sum(cut_value) / 2  # Divide by 2 to correct for double-counting

    return loss

def Loss(s, adjacency_matrix,  A=1, C=1):
    HC = -1*calculate_HC_vectorized(s, adjacency_matrix)
    return C * HC


def loss_terminal(s, adjacency_matrix,  A=0, C=1, penalty=1000):
    loss = Loss(s, adjacency_matrix, A, C)
    # loss += penalty* terminal_independence_penalty(s, [0,1,2])
    return loss


In [None]:
trained_net, bestLost, epoch, inp, lossList = train1('_10000MaxwayCut_LossExp8_loss.pth', './testData/nx_test_generated_graph_n800_4000_d8_12_t500.pkl', 10000)
