In [4]:
import numpy as np
import scipy as sp
import matplotlib as mpl
import matplotlib.pyplot as plt
import networkx as nx
import pandas as pd
import sys
sys.path.append('/home/wrwt/Programming/pygraphmodels')
import graphmodels as gm
from itertools import permutations
from graphmodels import MatrixGraph, DGM
%matplotlib inline
%load_ext line_profiler

The line_profiler extension is already loaded. To reload it, use:
  %reload_ext line_profiler


In [5]:
from graphmodels import AddEdge, RemoveEdge, ReverseEdge, InvalidOperation, GreedySearch, ScoreBIC
import heapq
class HeapGreedySearch:
    def __init__(self, data, cls_score):
        graph = nx.DiGraph()
        graph.add_nodes_from(data.columns)
        graph = MatrixGraph.from_networkx_DiGraph(graph, order=data.columns)
        self.graph = graph
        self.fscore = cls_score(graph, data)

        self.ops = []
        self.ops += [AddEdge(graph, self.fscore, u, v) for u, v in permutations(graph.nodes(), 2)]
        self.ops += [RemoveEdge(graph, self.fscore, u, v) for u, v in permutations(graph.nodes(), 2)]
        self.ops += [ReverseEdge(graph, self.fscore, u, v) for u, v in permutations(graph.nodes(), 2)]
        
        self.op_heap = [(-op.score(), op) for op in self.ops]
        heapq.heapify(self.op_heap)
            
    def iteration(self):
        while len(self.op_heap) > 0:
            op = heapq.heappop(self.op_heap)[1]
            if op.score() <= 1e-5:
                return True
            try:
                op.do()
                
                self.op_heap = [(-oper.score(), oper) for oper in self.ops]
                heapq.heapify(self.op_heap)
                
                op.score()
                return False
            except InvalidOperation:
                pass
        return True

    def __call__(self, max_iter=40, verbose=True):
        counter = 0
        while not self.iteration() and counter < max_iter:
            if verbose:
                print(self.fscore.total())
            counter += 1
        return DGM(self.graph.to_networkx_DiGraph())

In [6]:
from os import listdir
import os.path
NETWORKS_PATH = '/home/wrwt/Programming/pygraphmodels/networks/'
network_filenames = listdir(NETWORKS_PATH)
true_dgm = gm.DGM.read(os.path.join(NETWORKS_PATH, 'alarm.bif'))
true_dgm.draw()

In [7]:
data = true_dgm.rvs(size=100000)

In [9]:
%%time
gs = GreedySearch(data, ScoreBIC)
%lprun -f ScoreBIC.__call__ res = gs(max_iter=100)

60461.3743157
113181.798143
165348.619621
216695.329249
263558.412198
309078.133992
354328.497394
397750.879897
438197.057677
474756.984151
510140.222606
541718.71571
573060.53566
604829.905532
636002.767434
664618.701917
691007.838214
716111.984527
741094.380921
765885.824215
782851.784255
803082.167523
819738.102696
834512.315468
849123.841095
863551.042464
877775.535885
891273.596471
902786.812911
913678.106793
924459.791201
933672.942228
942003.845499
949109.839953
955386.516975
960935.707444
965335.024117
969393.976861
973278.506696
977138.130625
980500.741198
983518.789904
986463.083058
989134.151852
991278.298823
993369.185764
995213.261109
996667.897027
997729.614323
998775.801933
998990.538613
999186.410559
999314.712821
999406.816225
999477.505129
999525.895844
CPU times: user 2min 40s, sys: 236 ms, total: 2min 41s
Wall time: 2min 40s


In [33]:
gs = GreedySearch(data, ScoreBIC)

In [34]:
%%time
gs(max_iter=100).draw()

61063.689687
114174.871475
166496.506152
218103.593792
264898.632596
310563.505731
355471.789741
398705.793392
439179.153348
475870.068866
511571.937049
543266.155435
574793.082133
606048.465386
637190.969008
665622.301738
692091.385988
717665.911013
742895.008643
767884.080735
785108.630416
805200.64713
822016.684726
837087.09322
852131.648734
866626.12354
880807.179415
893929.489418
905628.14434
916733.669417
927542.620493
936937.136605
945274.966527
952245.391466
958422.169575
963898.977437
968159.699658
972307.005056
976282.362825
980211.402893
983703.482027
986698.692826
989563.330292
992303.919378
994510.030504
996609.618743
998496.319409
999934.240579
1001104.75809
1002120.57471
1002335.81498
1002532.54129
1002664.88768
1002756.99108
1002818.64799
1002867.74116
CPU times: user 2min 7s, sys: 96 ms, total: 2min 7s
Wall time: 2min 7s
