In [107]:
# for the sentence John saw Mary
# scores taken from the lab handout

import numpy as np

adjacency_matrix = np.array([
    [-np.Inf, 9, 10, 9],
    [-np.Inf, -np.Inf, 20, 3],
    [-np.Inf, 30, -np.Inf, 30],
    [-np.Inf, 11, 0, -np.Inf]
]) # type: ignore

In [108]:
# col 2 
adjacency_matrix[:, 2]

array([ 10.,  20., -inf,   0.])

In [109]:
from typing import List

# construct graph
class Vertex:
    def __init__(self, node_id: int) -> None:
        self.node_id = node_id
        
        # index -> edge weight mapping
        self.incoming = dict()
        self.outgoing = dict()
        
    def __str__(self) -> str:
        return str(self.__dict__)
        
class Graph:
    def __init__(self) -> None:
        self.nodes: List[Vertex] = list()
        
    def __str__(self) -> str:
        return str(self.nodes)

In [114]:
from typing import Dict


def find_incoming(node_id: int, matrix: np.ndarray) -> Dict:
    inc = dict()
    incoming = matrix[:, node_id]
    for i in range(incoming.shape[0]):
        if incoming[i] != -np.Inf:
            inc[i] = incoming[i]
    
    return inc


def find_outgoing(node_id: int, matrix: np.ndarray) -> Dict:
    out = dict()
    outgoing = matrix[node_id, :]
    for i in range(matrix.shape[0]):
        if outgoing[i] != -np.Inf:
            out[i] = outgoing[i]
            
    return out    

def construct_graph(matrix: np.ndarray) -> Graph:
    graph = Graph()
    
    
    # square matrix, so shape index doesn't matter
    for node_id in range(matrix.shape[0]):        
        v = Vertex(node_id)
        v.incoming = find_incoming(node_id, matrix)
        v.outgoing = find_outgoing(node_id, matrix)
        
        graph.nodes.append(v)
        
    
    return graph

# for node in graph.nodes:
#     print(node)

In [111]:
def has_cycle(node1: Vertex, node2: Vertex) -> bool:
    if node1.node_id in node2.incoming.keys() and node2.node_id in node1.incoming.keys():
        return True
    else:
        return False
    
    
assert has_cycle(graph.nodes[1], graph.nodes[2])

In [116]:
def contract_cycle(cycle: Graph, graph: Graph):
    cycle_node_ids = [node.node_id for node in cycle.nodes]
    # print("nodes in cyle ", cycle_node_ids)
    
    
    # for outgoing nodes from the cycle
    max_outgoing_weight = -np.Inf
    max_outgoing_id = -np.Inf
    
                    
                    
    # for incoming nodes to the cycle
    max_incoming_id = -np.Inf
    max_incoming_weight = -np.Inf
    
    for i, c_node in enumerate(cycle.nodes):
        for j, g_node in enumerate(graph.nodes):
            if j in cycle_node_ids:
                continue
            
            # check for outgoing
            if c_node.node_id in g_node.incoming.keys():
                # we have a match
                if g_node.outgoing[c_node.node_id] > max_outgoing_weight:
                    max_outgoing_weight = g_node.outgoing[c_node.node_id]
                    max_outgoing_id = g_node.node_id
            
            # incoming
            if c_node.node_id in g_node.outgoing.keys():
                t = g_node.outgoing[c_node.node_id] + cycle.nodes[(i + 1) % 2].incoming[c_node.node_id]
                if t > max_incoming_weight:
                    max_incoming_weight = t
                    max_incoming_id = g_node.node_id
    
    
            
                    
    # print("max out  from cycle", max_outgoing_id)
    print("max in to cycle ", max_incoming_id)

In [115]:
def cle(graph: Graph):
    # ignore node 0, since ROOT
    for node in graph.nodes[1:]:
        # find max incoming node id
        incoming = node.incoming
        max_node_id = max(incoming, key=incoming.get) # type: ignore
        max_node_weight = max(incoming.values())
        max_node = graph.nodes[max_node_id]
        
        # update
        graph.nodes[node.node_id].incoming = { max_node_id: max_node_weight } 
        
        # check for cycle
        if (has_cycle(node, max_node)):
            # contract
            cycle = Graph()
            cycle.nodes = [node, max_node]
            
            contract_cycle(cycle, graph)
            
            # cle
            pass
        else:
            # return 
            pass


graph = construct_graph(adjacency_matrix)
cle(graph)

max in to cycle  0
max in to cycle  0
