In [1]:
import pandas as pd
import networkx as nx
import matplotlib.pyplot as plt

In [2]:
# Load unit path data
unit_path = pd.read_csv('../unit_path.csv')

In [3]:
def generate_graph():
    G = nx.MultiDiGraph()

    for index, row in unit_path.iterrows():
        source = row['domain']  
        target = row['range']   
        edge = row['P']         
        weight = row['W']       

        G.add_node(source)        
        G.add_node(target)
            
        G.add_edge(source, target, property=edge, weight=weight, label=edge)  # 엣지 추가
    
    return G

In [4]:
G = generate_graph()

In [5]:
source_node = 'schema.MusicAlbum'
target_node = 'schema.Country'

In [6]:
nx.shortest_path(G, source_node, target_node, weight='weight')

['schema.MusicAlbum', 'mo.Track', 'mo.MusicGroup', 'schema.Country']

In [7]:
from heapq import heappop, heappush
from itertools import count

TRIPLE_NUM = 4 # shortest path 길이 (triple 개수) 제한

def find_shortest_path(G, source, target, p_name=None):
    push = heappush
    pop = heappop
    
    c = count()
    fringe = []
    path = [source]
    result = []

    if source == target:
        return result
    
    push(fringe, (0, next(c), source, path))

    while fringe:
        (d, loop, v, path) = pop(fringe)

        # if len(path)//2 >= TRIPLE_NUM:
        #     continue

        if path[-1] == target:
            if p_name is None or p_name in path:
                result.append((d,path))
            continue

        for u, edges in G._adj[v].items():
            if u in path:
                continue
            for _, e in edges.items():
                cost = e['weight']
                if cost is None:
                    continue
                vu_dist = d + cost

                new_path = list()
                new_path += path
                new_path += [e['property'], u]

                push(fringe, (vu_dist, next(c), u, new_path))
    
    if len(result) == 0:
        raise nx.NetworkXNoPath(f"Node {target} not reachable from {source} with {p_name}")
    
    return sorted(result)[0][1][::2]

In [8]:
result = find_shortest_path(G, source_node, target_node, p_name=None)
result

['schema.MusicAlbum', 'mo.Track', 'mo.MusicGroup', 'schema.Country']

In [9]:
result = find_shortest_path(G, source_node, target_node, p_name='schema.nationality')
result

['schema.MusicAlbum',
 'mo.Track',
 'mo.MusicGroup',
 'foaf.Person',
 'schema.Country']