In [4]:

from commons import *

Separate code base from heurestics
# Utils code

In [23]:
TORCH_DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
TORCH_DTYPE = torch.float32


def partition_weight(adj, s):
    """
    Calculates the sum of weights of edges that are in different partitions.

    :param adj: Adjacency matrix of the graph.
    :param s: List indicating the partition of each edge (0 or 1).
    :return: Sum of weights of edges in different partitions.
    """
    s = np.array(s)
    partition_matrix = np.not_equal.outer(s, s).astype(int)
    weight = (adj * partition_matrix).sum() / 2
    return weight

import torch

def partition_weight2(adj, s):
    """
    Calculates the sum of weights of edges that are in different partitions.

    :param adj: Adjacency matrix of the graph as a PyTorch tensor.
    :param s: Tensor indicating the partition of each node (0 or 1).
    :return: Sum of weights of edges in different partitions.
    """
    # Ensure s is a tensor
    # s = torch.tensor(s, dtype=torch.float32)

    # Compute outer difference to create partition matrix
    s = s.unsqueeze(0)  # Convert s to a row vector
    t = s.t()           # Transpose s to a column vector
    partition_matrix = (s != t).float()  # Compute outer product and convert boolean to float

    # Calculate the weight of edges between different partitions
    weight = (adj * partition_matrix).sum() / 2

    return weight

def calculateAllCut(q_torch, s):
    '''

    :param q_torch: The adjacent matrix of the graph
    :param s: The binary output from the neural network. s will be in form of [[prob1, prob2, ..., prob n], ...]
    :return: The calculated cut loss value
    '''
    if len(s) > 0:
        totalCuts = len(s[0])
        CutValue = 0
        for i in range(totalCuts):
            CutValue += partition_weight2(q_torch, s[:,i])
        return CutValue/2
    return 0

def hyperParameters(n = 100, d = 3, p = None, graph_type = 'reg', number_epochs = int(1e5),
                    learning_rate = 1e-4, PROB_THRESHOLD = 0.5, tol = 1e-4, patience = 100):
    dim_embedding = 80 #int(np.sqrt(4096))    # e.g. 10, used to be the one before
    hidden_dim = int(dim_embedding/2)

    return n, d, p, graph_type, number_epochs, learning_rate, PROB_THRESHOLD, tol, patience, dim_embedding, hidden_dim
def FIndAC(graph):
    max_degree = max(dict(graph.degree()).values())
    A_initial = max_degree + 1  # A is set to be one more than the maximum degree
    C_initial = max_degree / 2  # C is set to half the maximum degree

    return A_initial, C_initial



# Neural Network Model

# Training Neural network

In [20]:


def run_gnn_training2(dataset, net, optimizer, number_epochs, tol, patience, loss_func, dim_embedding, total_classes=3, save_directory=None, torch_dtype = TORCH_DTYPE, torch_device = TORCH_DEVICE, labels=None):
    """
    Train a GCN model with early stopping.
    """
    # loss for a whole epoch
    prev_loss = float('inf')  # Set initial loss to infinity for comparison
    prev_cummulative_loss = float('inf')
    cummulativeCount = 0
    count = 0  # Patience counter
    best_loss = float('inf')  # Initialize best loss to infinity
    best_model_state = None  # Placeholder for the best model state
    loss_list = []
    epochList = []
    cumulative_loss = 0

    t_gnn_start = time()

    # contains information regarding all terminal nodes for the dataset
    terminal_configs = {}
    epochCount = 0
    criterion = nn.BCELoss()
    A = nn.Parameter(torch.tensor([65.0]))
    C = nn.Parameter(torch.tensor([32.5]))

    embed = nn.Embedding(80, dim_embedding)
    embed = embed.type(torch_dtype).to(torch_device)
    inputs = embed.weight

    for epoch in range(number_epochs):

        cumulative_loss = 0.0  # Reset cumulative loss for each epoch

        for key, (dgl_graph, adjacency_matrix,graph, terminals) in dataset.items():
            epochCount +=1


            # Ensure model is in training mode
            net.train()

            # Pass the graph and the input features to the model
            logits = net(dgl_graph, adjacency_matrix)

            # Compute the loss
            # loss = loss_func(criterion, logits, labels, terminals[0], terminals[1])

            loss = loss_func( logits, adjacency_matrix)


            # Backpropagation
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            # Update cumulative loss
            cumulative_loss += loss.item()



            # # Check for early stopping
            if epoch > 0 and (cumulative_loss > prev_loss or abs(prev_loss - cumulative_loss) <= tol):
                count += 1
                if count >= patience: # play around with patience value, try lower one
                    print(f'Stopping early at epoch {epoch}')
                    break
            else:
                count = 0  # Reset patience counter if loss decreases

            # Update best model
            if cumulative_loss < best_loss:
                best_loss = cumulative_loss
                best_model_state = net.state_dict()  # Save the best model state

        loss_list.append(loss)

        # # Early stopping break from the outer loop
        # if count >= patience:
        #     count=0

        prev_loss = cumulative_loss  # Update previous loss

        if epoch % 100 == 0:  # Adjust printing frequency as needed
            print(f'Epoch: {epoch}, Cumulative Loss: {cumulative_loss}')

            if save_directory != None:
                checkpoint = {
                    'epoch': epoch,
                    'model': net.state_dict(),
                    'optimizer': optimizer.state_dict(),
                    'lossList':loss_list,
                    'inputs':inputs}
                torch.save(checkpoint, './epoch'+str(epoch)+'loss'+str(cumulative_loss)+ save_directory)

            if (prev_cummulative_loss == cummulativeCount):
                cummulativeCount+=1

                if cummulativeCount > 4:
                    break
            else:
                prev_cummulative_loss = cumulative_loss


    t_gnn = time() - t_gnn_start

    # Load the best model state
    if best_model_state is not None:
        net.load_state_dict(best_model_state)

    print(f'GNN training took {round(t_gnn, 3)} seconds.')
    print(f'Best cumulative loss: {best_loss}')
    loss = loss_func(logits, adjacency_matrix)
    if save_directory != None:
        checkpoint = {
            'epoch': epoch,
            'model': net.state_dict(),
            'optimizer': optimizer.state_dict(),
            'lossList':loss_list,
            'inputs':inputs}
        torch.save(checkpoint, './final_'+save_directory)

    return net, best_loss, epoch, inputs, loss_list

## HyperParameters initialization and related functions

In [7]:



def printCombo(orig):
    # Original dictionary
    input_dict = orig

    # Generate all permutations of the dictionary values
    value_permutations = list(permutations(input_dict.values()))

    # Create a list of dictionaries from the permutations
    permuted_dicts = [{key: value for key, value in zip(input_dict.keys(), perm)} for perm in value_permutations]

    return permuted_dicts

def GetOptimalNetValue(net, dgl_graph, inp, q_torch, terminal_dict):
    net.eval()
    best_loss = float('inf')

    if (dgl_graph.number_of_nodes() < 30):
        inp = torch.ones((dgl_graph.number_of_nodes(), 30))

    # find all potential combination of terminal nodes with respective indices

    perm_items = printCombo(terminal_dict)
    for i in perm_items:
        probs = net(dgl_graph, inp, i)
        binary_partitions = (probs >= 0.5).float()
        cut_value_item = calculateAllCut(q_torch, binary_partitions)
        if cut_value_item < best_loss:
            best_loss = cut_value_item
    return best_loss



# Hamiltonian loss function

In [8]:
def terminal_independence_penalty(s, terminal_nodes):
    """
    Calculate a penalty that enforces each terminal node to be in a distinct partition.
    :param s: A probability matrix of size |V| x |K| where s[i][j] is the probability of vertex i being in partition j.
    :param terminal_nodes: A list of indices for terminal nodes.
    :return: The penalty term.
    """
    penalty = 0
    num_terminals = len(terminal_nodes)
    # Compare each pair of terminal nodes
    for i in range(num_terminals):
        for j in range(i + 1, num_terminals):
            # Calculate the dot product of the probability vectors for the two terminals
            dot_product = torch.dot(s[terminal_nodes[i]], s[terminal_nodes[j]])
            # Penalize the similarity in their partition assignments (dot product should be close to 0)
            penalty += dot_product
    return penalty

In [9]:
def calculate_HA_vectorized(s):
    """
    Vectorized calculation of HA.
    :param s: A binary matrix of size |V| x |K| where s[i][j] is 1 if vertex i is in partition j.
    :return: The HA value.
    """
    # HA = ∑v∈V(∑k∈K(sv,k)−1)^2
    HA = torch.sum((torch.sum(s, axis=1) - 1) ** 2)
    return HA

def calculate_HC_min_cut_intra_inter(s, adjacency_matrix):
    """
    Vectorized calculation of HC to minimize cut size.
    :param s: A probability matrix of size |V| x |K| where s[i][j] is the probability of vertex i being in partition j.
    :param adjacency_matrix: A matrix representing the graph where the value at [i][j] is the weight of the edge between i and j.
    :return: The HC value focusing on minimizing edge weights between partitions.
    """
    HC = 0
    K = s.shape[1]
    for k in range(K):
        for l in range(k + 1, K):
            partition_k = s[:, k].unsqueeze(1) * s[:, k].unsqueeze(0)  # Probability node pair both in partition k
            partition_l = s[:, l].unsqueeze(1) * s[:, l].unsqueeze(0)  # Probability node pair both in partition l
            # Edges between partitions k and l
            inter_partition_edges = adjacency_matrix * (partition_k + partition_l)
            HC += torch.sum(inter_partition_edges)

    return HC

def calculate_HC_min_cut_intra_inter2(s, adjacency_matrix):
    """
    Vectorized calculation of HC to minimize cut size.
    :param s: A probability matrix of size |V| x |K| where s[i][j] is the probability of vertex i being in partition j.
    :param adjacency_matrix: A matrix representing the graph where the value at [i][j] is the weight of the edge between i and j.
    :return: The HC value focusing on minimizing edge weights between partitions.
    """
    HC = 0
    K = s.shape[1]
    for k in range(K):
        for l in range(k + 1, K):
            partition_k = s[:, k].unsqueeze(1) * s[:, k].unsqueeze(0)  # Probability node pair both in partition k
            partition_l = s[:, l].unsqueeze(1) * s[:, l].unsqueeze(0)  # Probability node pair both in partition l
            # Edges between partitions k and l
            inter_partition_edges = adjacency_matrix * (partition_k + partition_l)
            HC += torch.sum(inter_partition_edges)

    return HC

def calculate_HC_min_cut_new(s, adjacency_matrix):
    """
    Differentiable calculation of HC for minimizing edge weights between different partitions.
    :param s: A probability matrix of size |V| x |K| where s[i][j] is the probability of vertex i being in partition j.
    :param adjacency_matrix: A matrix representing the graph where the value at [i][j] is the weight of the edge between i and j.
    :return: The HC value, focusing on minimizing edge weights between partitions.
    """
    K = s.shape[1]
    V = s.shape[0]

    # Create a full partition matrix indicating the likelihood of each node pair being in the same partition
    partition_matrix = torch.matmul(s, s.T)

    # Calculate the complement matrix, which indicates the likelihood of node pairs being in different partitions
    complement_matrix = 1 - partition_matrix

    # Apply adjacency matrix to only consider actual edges and their weights
    inter_partition_edges = adjacency_matrix * complement_matrix

    # Summing up all contributions for edges between different partitions
    HC = torch.sum(inter_partition_edges)

    return HC

def calculate_HC_vectorized_old(s, adjacency_matrix):
    """
    Vectorized calculation of HC.
    :param s: A binary matrix of size |V| x |K|.
    :param adjacency_matrix: A matrix representing the graph where the value at [i][j] is the weight of the edge between i and j.
    :return: The HC value.
    """
    # HC = ∑(u,v)∈E(1−∑k∈K(su,k*sv,k))*adjacency_matrix[u,v]
    K = s.shape[1]
    # Outer product to find pairs of vertices in the same partition and then weight by the adjacency matrix
    prod = adjacency_matrix * (1 - s @ s.T)
    HC = torch.sum(prod)
    return HC
import torch

def min_cut_loss(s, adjacency_matrix):
    """
    Compute a differentiable min-cut loss for a graph given node partition probabilities.

    :param s: A probability matrix of size |V| x |K| where s[i][j] is the probability of vertex i being in partition j.
    :param adjacency_matrix: A matrix representing the graph where the value at [i][j] is the weight of the edge between i and j.
    :return: The expected min-cut value, computed as a differentiable loss.
    """
    V = s.size(0)  # Number of nodes
    K = s.size(1)  # Number of partitions

    # Ensure the partition matrix s sums to 1 over partitions
    s = torch.softmax(s, dim=1)

    # Compute the expected weight of edges within each partition
    intra_partition_cut = torch.zeros((K, K), dtype=torch.float32)
    for k in range(K):
        for l in range(k + 1, K):
            # Probability that a node pair (i, j) is split between partitions k and l
            partition_k = s[:, k].unsqueeze(1)  # Shape: V x 1
            partition_l = s[:, l].unsqueeze(0)  # Shape: 1 x V

            # Compute the expected weight of the cut edges between partitions k and l
            cut_weight = adjacency_matrix * (partition_k @ partition_l)
            intra_partition_cut[k, l] = torch.sum(cut_weight)

    # Sum up all contributions to get the total expected min-cut value
    total_cut_weight = torch.sum(intra_partition_cut)

    return total_cut_weight

import torch

# def min_cut_loss(s, adjacency_matrix):
#     """
#     Compute a differentiable min-cut loss for a graph given node partition probabilities.
#
#     :param s: A probability matrix of size |V| x |K| where s[i][j] is the probability of vertex i being in partition j.
#     :param adjacency_matrix: A matrix representing the graph where the value at [i][j] is the weight of the edge between i and j.
#     :return: The expected min-cut value, computed as a differentiable loss.
#     """
#     V = s.size(0)  # Number of nodes
#     K = s.size(1)  # Number of partitions
#
#     # Ensure the partition matrix s sums to 1 over partitions
#     # s = torch.softmax(s, dim=1)
#
#     # Compute the expected weight of cut edges between each pair of partitions
#     total_cut_weight = 0
#     for k in range(K):
#         for l in range(k + 1, K):
#             # Probability that a node pair (i, j) is split between partitions k and l
#             partition_k = s[:, k].unsqueeze(1)  # Shape: V x 1
#             partition_l = s[:, l].unsqueeze(0)  # Shape: 1 x V
#
#             # Compute the expected weight of the cut edges between partitions k and l
#             cut_weight = adjacency_matrix * (partition_k @ partition_l)
#             total_cut_weight += torch.sum(cut_weight)
#
#     return total_cut_weight


def calculate_HC_vectorized(s, adjacency_matrix):
    """
    Vectorized calculation of HC for soft partitioning.
    :param s: A probability matrix of size |V| x |K| where s[i][j] is the probability of vertex i being in partition j.
    :param adjacency_matrix: A matrix representing the graph where the value at [i][j] is the weight of the edge between i and j.
    :return: The HC value.
    """
    # Initialize HC to 0
    HC = 0

    # Iterate over each partition to calculate its contribution to HC
    for k in range(s.shape[1]):
        # Compute the probability matrix for partition k
        partition_prob_matrix = s[:, k].unsqueeze(1) * s[:, k].unsqueeze(0)

        # Compute the contribution to HC for partition k
        HC_k =adjacency_matrix * (1 - partition_prob_matrix)
        # Sum up the contributions for partition k
        HC += torch.sum(HC_k, dim=(0, 1))

    # Since we've summed up the partition contributions twice (due to symmetry), divide by 2
    HC = HC / 2

    return HC




In [10]:
s = torch.Tensor([[0,1,0],[0,1,0],[0,0,1]])
# print(calculate_HA_vectorized(s))
# print(calculate_HA_vectorized(torch.Tensor([[0,0.9,0.9],[0.9,0.9,0],[0,0,0.9]])))
terminal_loss = torch.abs(s[0] - s[1]-s[2])
# print(terminal_loss)
# print(10 * (1 - terminal_loss))
# print(torch.sum(10 * (1 - terminal_loss)))
print(torch.abs(s[0] - s[1]))
print(torch.abs(s[0] - s[2]))
print(torch.abs(s[2] - s[1]))

print(torch.sum(10 * (1-torch.abs(s[0] - s[1]))))
print(torch.sum(10 * (1-torch.abs(s[0] - s[2]))))
print(torch.sum(10 * (1-torch.abs(s[2] - s[1]))))
print(terminal_independence_penalty(s, [0,1,2]))

tensor([0., 0., 0.])
tensor([0., 1., 1.])
tensor([0., 1., 1.])
tensor(30.)
tensor(10.)
tensor(10.)
tensor(1.)


In [11]:
def train1(modelName):
    n, d, p, graph_type, number_epochs, learning_rate, PROB_THRESHOLD, tol, patience, dim_embedding, hidden_dim = hyperParameters(learning_rate=0.001, n=4096,patience=20)

    # Establish pytorch GNN + optimizer
    opt_params = {'lr': learning_rate}
    gnn_hypers = {
        'dim_embedding': dim_embedding,
        'hidden_dim': hidden_dim,
        'dropout': 0.0,
        'number_classes': 3,
        'prob_threshold': PROB_THRESHOLD,
        'number_epochs': number_epochs,
        'tolerance': tol,
        'patience': patience,
        'nodes':n
    }
    datasetItem = open_file('./testData/prepareDS.pkl')
    # print(datasetItem)
    # datasetItem_all = {}
    # for key, (dgl_graph, adjacency_matrix,graph) in datasetItem.items():
    #     A, C = FIndAC(graph)
    #     datasetItem_all[key] = [dgl_graph, adjacency_matrix, graph, A, C]

    # print(len(datasetItem), datasetItem[0][3])
    # datasetItem_2 = {}
    # datasetItem_2[0]=datasetItem[1]
    # print(datasetItem_2)

    net, embed, optimizer = get_gnn(n, gnn_hypers, opt_params, TORCH_DEVICE, TORCH_DTYPE)


    # print(datasetItem[1][2].nodes)
    # # Visualize graph
    # pos = nx.kamada_kawai_layout(datasetItem[1][2])
    # nx.draw(datasetItem[1][2], pos, with_labels=True, node_color=[[.7, .7, .7]])
    # cut_value, (part_1, part_2) = nx.minimum_cut(datasetItem_2[0][2], datasetItem_2[0][3][1], datasetItem_2[0][3][0], flow_func=shortest_augmenting_path)

    # print(cut_value, len(part_1), len(part_2))

    # resultList = []
    # all_indexes = sorted(part_1.union(part_2))
    # # Check membership for each index and append the appropriate pair to the result list
    # for index in all_indexes:
    #     if index in part_1:
    #         resultList.append([1, 0])
    #     elif index in part_2:
    #         resultList.append([0, 1])

    #
    trained_net, bestLost, epoch, inp, lossList= run_gnn_training2(
        datasetItem, net, optimizer, int(500),
        gnn_hypers['tolerance'], gnn_hypers['patience'], loss_terminal,gnn_hypers['dim_embedding'], gnn_hypers['number_classes'], modelName,  TORCH_DTYPE,  TORCH_DEVICE)

    return trained_net, bestLost, epoch, inp, lossList


### Neural Network Training, Setting A to 0

In [22]:
def Loss(s, adjacency_matrix,  A=1, C=1):
    # HA = calculate_HA_vectorized(s)
    HC = calculate_HC_vectorized(s, adjacency_matrix)
    # HC = calculate_HC_min_cut_new(s, adjacency_matrix)
    # HC = calculate_HC_min_cut_intra_inter(s, adjacency_matrix)
    return C * HC


def loss_terminal(s, adjacency_matrix,  A=0, C=1, penalty=10000):
    loss = Loss(s, adjacency_matrix, A, C)
    loss += penalty* terminal_independence_penalty(s, [0,1,2])
    return loss

trained_net, bestLost, epoch, inp, lossList = train1('_80wayCut_LossOrig_2.pth')


Epoch: 0, Cumulative Loss: 4467581.311767578
Epoch: 100, Cumulative Loss: 1668383.3168945312
Stopping early at epoch 109
Epoch: 200, Cumulative Loss: 1490733.4799804688
Stopping early at epoch 206
Stopping early at epoch 234
Epoch: 300, Cumulative Loss: 1417838.9104003906
Epoch: 400, Cumulative Loss: 1397359.0704345703
Stopping early at epoch 465
Stopping early at epoch 466
Stopping early at epoch 498
GNN training took 221.145 seconds.
Best cumulative loss: 4581.91259765625


In [26]:
import torch
import torch.nn.functional as F

def soft_min_cut_loss(s, adjacency_matrix):
    """
    Calculate a soft min-cut loss that maintains differentiability by penalizing
    the sum of squared differences from binary values (0 or 1).
    """
    s = torch.softmax(s, dim=1)  # Ensure that s is a proper probability distribution
    V, K = s.shape

    min_cut_loss = 0
    for k in range(K):
        for l in range(k + 1, K):
            # Use probabilities directly for nodes being in partitions k and l
            # partition_k = s[:, k].unsqueeze(1)
            # partition_l = s[:, l].unsqueeze(0)

            partition_k = s[:, k].unsqueeze(1) * s[:, k].unsqueeze(0)
            partition_l = s[:, l].unsqueeze(1) * s[:, l].unsqueeze(0)
            # partition_l = s[:, l].unsqueeze(0)
            # Edge weights between partitions
            inter_partition_edges = adjacency_matrix * (partition_k @ partition_l)
            min_cut_loss += torch.sum(inter_partition_edges)

    return min_cut_loss


def loss_terminal(s, adjacency_matrix,  A= 0, C=1, T=100):
    """
    Compute the overall loss including cut loss and terminal independence.

    :param s: Node partition probabilities.
    :param adjacency_matrix: Graph adjacency matrix.
    :param terminals: List of terminal node indices.
    :param C: Weight for the cut loss.
    :param T: Weight for the terminal independence penalty.
    :return: Total loss.
    """
    cut_loss = soft_min_cut_loss(s, adjacency_matrix)
    terminal_loss = terminal_independence_penalty(s, [0,1,2])
    total_loss = C * cut_loss + T * terminal_loss
    return total_loss


trained_net, bestLost, epoch, inp, lossList = train1('_80wayCut_Lossinter_min_cut_loss_9_new.pth')

Epoch: 0, Cumulative Loss: 2671551.08203125
Epoch: 100, Cumulative Loss: 2461037.6333007812
Epoch: 200, Cumulative Loss: 2450037.1362304688
Epoch: 300, Cumulative Loss: 2447161.0883789062
Epoch: 400, Cumulative Loss: 2449458.3696289062
GNN training took 1144.661 seconds.
Best cumulative loss: 10010.3642578125


In [32]:
import torch
import torch.nn.functional as F

def soft_min_cut_loss(s, adjacency_matrix):
    """
    Calculate a soft min-cut loss that maintains differentiability by penalizing
    the sum of squared differences from binary values (0 or 1).
    """
    s = torch.softmax(s, dim=1)  # Ensure that s is a proper probability distribution
    V, K = s.shape

    min_cut_loss = 0
    for k in range(V):
        for l in range(k + 1, V):
            # Use probabilities directly for nodes being in partitions k and l
            partition_k = s[k]
            partition_l = s[l]
            # Edge weights between partitions
            summation = torch.sum(torch.abs_( partition_k-partition_l))

            min_cut_loss += torch.sum(summation * adjacency_matrix[k][l])

    return min_cut_loss


def loss_terminal(s, adjacency_matrix,  A= 0, C=1, T=100):
    """
    Compute the overall loss including cut loss and terminal independence.

    :param s: Node partition probabilities.
    :param adjacency_matrix: Graph adjacency matrix.
    :param terminals: List of terminal node indices.
    :param C: Weight for the cut loss.
    :param T: Weight for the terminal independence penalty.
    :return: Total loss.
    """
    cut_loss = soft_min_cut_loss(s, adjacency_matrix)
    terminal_loss = terminal_independence_penalty(s, [0,1,2])
    total_loss = C * cut_loss + T * terminal_loss
    return total_loss


trained_net, bestLost, epoch, inp, lossList = train1('_80wayCut_Lossinter_min_cut_loss_10_new.pth')

Epoch: 0, Cumulative Loss: 80112.09645080566


KeyboardInterrupt: 

In [36]:
import torch
import torch.nn.functional as F

def soft_min_cut_loss(s, adjacency_matrix):
    """
    Calculate a soft min-cut loss that maintains differentiability by penalizing
    the sum of squared differences from binary values (0 or 1).
    """
    s = torch.softmax(s, dim=1)  # Ensure that s is a proper probability distribution

    # Compute differences
    diff = s.unsqueeze(1) - s.unsqueeze(0)
    abs_diff = torch.abs(diff)
    sum_diff = torch.sum(abs_diff, dim=2)
    min_cut_loss = torch.sum(sum_diff * adjacency_matrix)

    return min_cut_loss


def loss_terminal(s, adjacency_matrix,  A= 0, C=1, T=1000):
    """
    Compute the overall loss including cut loss and terminal independence.

    :param s: Node partition probabilities.
    :param adjacency_matrix: Graph adjacency matrix.
    :param terminals: List of terminal node indices.
    :param C: Weight for the cut loss.
    :param T: Weight for the terminal independence penalty.
    :return: Total loss.
    """
    cut_loss = soft_min_cut_loss(s, adjacency_matrix)
    terminal_loss = terminal_independence_penalty(s, [0,1,2])
    total_loss = C * cut_loss + T * terminal_loss
    return total_loss


trained_net, bestLost, epoch, inp, lossList = train1('_80wayCut_Lossinter_min_cut_loss_10_new.pth')

Epoch: 0, Cumulative Loss: 681368.5506591797
Epoch: 100, Cumulative Loss: 216319.7434539795
Epoch: 200, Cumulative Loss: 195802.2372970581
Stopping early at epoch 248
Stopping early at epoch 268
Stopping early at epoch 286
Epoch: 300, Cumulative Loss: 187731.35697174072
Stopping early at epoch 325
Stopping early at epoch 349
Epoch: 400, Cumulative Loss: 176453.22048187256
Stopping early at epoch 484
GNN training took 192.318 seconds.
Best cumulative loss: 705.4923706054688


In [33]:
import torch

# Adjusted example for 3 partitions
s = torch.tensor([[0.5, 0.3, 0.2], [0.2, 0.5, 0.3], [0.4, 0.4, 0.2]])
adjacency_matrix = torch.tensor([[0.0, 1.0, 0.5], [1.0, 0.0, 0.2], [0.5, 0.2, 0.0]])

# Compute differences
diff = s.unsqueeze(1) - s.unsqueeze(0)
abs_diff = torch.abs(diff)
sum_diff = torch.sum(abs_diff, dim=2)
min_cut_loss = torch.sum(sum_diff * adjacency_matrix)

# Print results
print("Differences:\n", diff)
print("Absolute differences:\n", abs_diff)
print("Summed differences:\n", sum_diff)
print("Weighted summed differences:\n", sum_diff * adjacency_matrix)
print("Min Cut Loss:\n", min_cut_loss)


Differences:
 tensor([[[ 0.0000,  0.0000,  0.0000],
         [ 0.3000, -0.2000, -0.1000],
         [ 0.1000, -0.1000,  0.0000]],

        [[-0.3000,  0.2000,  0.1000],
         [ 0.0000,  0.0000,  0.0000],
         [-0.2000,  0.1000,  0.1000]],

        [[-0.1000,  0.1000,  0.0000],
         [ 0.2000, -0.1000, -0.1000],
         [ 0.0000,  0.0000,  0.0000]]])
Absolute differences:
 tensor([[[0.0000, 0.0000, 0.0000],
         [0.3000, 0.2000, 0.1000],
         [0.1000, 0.1000, 0.0000]],

        [[0.3000, 0.2000, 0.1000],
         [0.0000, 0.0000, 0.0000],
         [0.2000, 0.1000, 0.1000]],

        [[0.1000, 0.1000, 0.0000],
         [0.2000, 0.1000, 0.1000],
         [0.0000, 0.0000, 0.0000]]])
Summed differences:
 tensor([[0.0000, 0.6000, 0.2000],
        [0.6000, 0.0000, 0.4000],
        [0.2000, 0.4000, 0.0000]])
Weighted summed differences:
 tensor([[0.0000, 0.6000, 0.1000],
        [0.6000, 0.0000, 0.0800],
        [0.1000, 0.0800, 0.0000]])
Min Cut Loss:
 tensor(1.5600)
