In [1]:
# for the sentence John saw Mary
# scores taken from the lab handout

import numpy as np

adjacency_matrix = np.array([
    [-np.Inf, 9, 10, 9],
    [-np.Inf, -np.Inf, 20, 3],
    [-np.Inf, 30, -np.Inf, 30],
    [-np.Inf, 11, 0, -np.Inf]
]) # type: ignore

In [16]:
# col 2 
adjacency_matrix[:, 2]

array([ 10.,  20., -inf,   0.])

In [26]:
from typing import List

# construct graph
class Vertex:
    def __init__(self, node_id: int) -> None:
        self.node_id = node_id
        
        # index -> edge weight mapping
        self.incoming = dict()
        self.outgoing = dict()
        
    def __str__(self) -> str:
        return str(self.__dict__)
        
class Graph:
    def __init__(self) -> None:
        self.nodes: List[Vertex] = list()
        
    def __str__(self) -> str:
        return str(self.nodes)

In [33]:
from typing import Dict


def find_incoming(node_id: int, matrix: np.ndarray) -> Dict:
    inc = dict()
    incoming = matrix[:, node_id]
    for i in range(incoming.shape[0]):
        if incoming[i] != -np.Inf:
            inc[i] = incoming[i]
    
    return inc


def find_outgoing(node_id: int, matrix: np.ndarray) -> Dict:
    out = dict()
    outgoing = matrix[node_id, :]
    for i in range(matrix.shape[0]):
        if outgoing[i] != -np.Inf:
            out[i] = outgoing[i]
            
    return out    

def construct_graph(matrix: np.ndarray) -> Graph:
    graph = Graph()
    
    
    # square matrix, so shape index doesn't matter
    for node_id in range(matrix.shape[0]):        
        v = Vertex(node_id)
        v.incoming = find_incoming(node_id, matrix)
        v.outgoing = find_outgoing(node_id, matrix)
        
        graph.nodes.append(v)
        
    
    return graph

graph = construct_graph(adjacency_matrix)
for node in graph.nodes:
    print(node)

{'node_id': 0, 'incoming': {}, 'outgoing': {1: 9.0, 2: 10.0, 3: 9.0}}
{'node_id': 1, 'incoming': {0: 9.0, 2: 30.0, 3: 11.0}, 'outgoing': {2: 20.0, 3: 3.0}}
{'node_id': 2, 'incoming': {0: 10.0, 1: 20.0, 3: 0.0}, 'outgoing': {1: 30.0, 3: 30.0}}
{'node_id': 3, 'incoming': {0: 9.0, 1: 3.0, 2: 30.0}, 'outgoing': {1: 11.0, 2: 0.0}}


In [41]:
def has_cycle(node1: Vertex, node2: Vertex) -> bool:
    if node1.node_id in node2.incoming.keys() and node2.node_id in node1.incoming.keys():
        return True
    else:
        return False
    
    
assert has_cycle(graph.nodes[1], graph.nodes[2])

In [61]:
def contract_cycle(cycle: Graph, graph: Graph):
    cycle_node_ids = [node.node_id for node in cycle.nodes]
    print("nodes in cyle ", cycle_node_ids)
    
    
    # first the max outgoing from the cycle
    max_ougoing_weight = -np.Inf
    max_outgoing_id = None
    
    for cycle_node in cycle.nodes:
        for graph_node in graph.nodes:
            # skip nodes in cycles
            if graph_node.node_id in cycle_node_ids:
                continue
            
            # check for outgoing
            if cycle_node.node_id in graph_node.incoming.keys():
                # we have a match
                if graph_node.outgoing[cycle_node.node_id] > max_ougoing_weight:
                    max_ougoing_weight = graph_node.outgoing[cycle_node.node_id]
                    max_outgoing_id = graph_node.node_id
                    
    print("max out  from cycle", max_outgoing_id)

In [62]:
def cle(graph: Graph):
    # ignore node 0, since ROOT
    for node in graph.nodes[1:]:
        # find max incoming node id
        incoming = node.incoming
        max_node_id = max(incoming, key=incoming.get) # type: ignore
        max_node = graph.nodes[max_node_id]
        
        # update
        graph.nodes[node.node_id].incoming = { max_node_id: max_node } 
        
        # check for cycle
        if (has_cycle(node, max_node)):
            # contract
            cycle = Graph()
            cycle.nodes = [node, max_node]
            
            contract_cycle(cycle, graph)
            
            # cle
            pass
        else:
            # return 
            pass



cle(graph)

nodes in cyle  [1, 2]
max out  from cycle 3
nodes in cyle  [2, 1]
max out  from cycle 3
