In [59]:
%%capture
import numpy as np
import scipy as sp
import matplotlib as mpl
import matplotlib.pyplot as plt
import networkx as nx
%matplotlib inline

In [60]:
%%capture
import sys
sys.path.append('/home/wrwt/Programming/pygraphmodels')
import graphmodels as gm

In [61]:
import warnings
warnings.filterwarnings('ignore', module='matplotlib')
warnings.filterwarnings('ignore', module='IPython')

In [62]:
%%capture
import theano
import theano.tensor as T
from theano.tensor import nnet

In [63]:
class Neurofunction:
    def __init__(self, n_in, n_hid, lambda_=1e-4):
        self.n_in = n_in
        self.n_hid = n_hid
        self.n_params = (self.n_in + 2) * self.n_hid + 1
        
        self.W0 = T.dmatrix('W0')
        self.b0 = T.dvector('b0')
        self.W1 = T.dvector('W1')
        self.b1 = T.dscalar('b1')

        self.X = T.dmatrix('X')
        self.x = T.dvector('x')
        self.y = T.dvector('y')
        
        self.hidden = nnet.sigmoid(self.X.dot(self.W0) + self.b0)
        self.xhidden = nnet.sigmoid(self.x.dot(self.W0) + self.b0)
        self.out = self.hidden.dot(self.W1) + self.b1
        self.xout = self.xhidden.dot(self.W1) + self.b1
        
        self.loss = T.mean((self.out - self.y) ** 2) + lambda_ * (T.sum(self.W0 ** 2) + T.sum(self.W1 ** 2))
        
        self.theta_grad = T.concatenate([T.grad(self.loss, wrt=self.W0).flatten(),
                              T.grad(self.loss, wrt=self.W1).flatten(),
                              T.grad(self.loss, wrt=self.b0).flatten(),
                              [T.grad(self.loss, wrt=self.b1)]])
        self.loss_theta_grad_f = theano.function(inputs=[self.X, self.y, self.W0, self.W1, self.b0, self.b1],
                                           outputs=[self.loss, self.theta_grad])
        self.predict_f = theano.function(inputs=[self.X, self.W0, self.W1, self.b0, self.b1], 
                                         outputs=self.out, name='predict')
        
        self.x_grad = T.grad(-self.xout, wrt=self.x)
        self.loss_x_grad_f = theano.function(inputs=[self.x, self.W0, self.W1, self.b0, self.b1], 
                                            outputs=[-self.xout, self.x_grad])
    
        
    def _parse_params(self, theta):
        n_in = self.n_in
        n_hid = self.n_hid
        n_out = 1
        W0 = theta[:n_in*n_hid].reshape((n_in, n_hid))
        W1 = theta[n_in*n_hid:n_in*n_hid + n_hid*n_out]
        b0 = theta[n_in*n_hid + n_hid*n_out:n_in*n_hid + n_hid*n_out + n_hid]
        b1 = theta[n_in*n_hid + n_hid*n_out + n_hid]
        return W0, W1, b0, b1
        
    def fit(self, X, y):
        theta = np.random.rand(self.n_params)
        
        def loss_grad(theta):
            W0, W1, b0, b1 = self._parse_params(theta)
            return self.loss_theta_grad_f(X, y, W0, W1, b0, b1)
    
        res = sp.optimize.minimize(loss_grad, theta, method='L-BFGS-B', jac=True)
        self.theta = res['x']
        
    def __call__(self, X):
        params = self._parse_params(self.theta)
        return self.predict_f(np.atleast_2d(X), *params)
    
    def maximize(self, assignment=None):
        if assignment is None:
            assignment = [None] * self.n_in
        
        n_assigned = sum([1 if a is not None else 0 for a in assignment])
        x = np.random.rand(self.n_in - n_assigned)
        idx = [i for i, a in enumerate(assignment) if a is None]
        
        
        dispatch = {}
        for i, a in enumerate(assignment):
            if a is None:
                dispatch[len(dispatch)] = i
        
        def loss_grad(x_compl):
            x = np.array(assignment)
            for i, xi in enumerate(x_compl):
                x[dispatch[i]] = xi
            W0, W1, b0, b1 = self._parse_params(self.theta)
            loss, grad = self.loss_x_grad_f(x.astype('float'), W0, W1, b0, b1)
            grad = grad[idx]
            return loss, grad
        
        res = sp.optimize.minimize(loss_grad, x, method='L-BFGS-B', jac=True, 
                                   bounds=[(0, 1)] * (self.n_in - n_assigned))
        return res['x']

In [64]:
def generate_subset(target, factor, kmin=1, kmax=4, size=1):
    """
    Generate `size` random subsets of parents for node `target` with factor `factor` 
    and calculate discrete mutual information for them.
    Number of parents lies between kmin and kmax
    """
    arguments = list(factor.arguments)
    arguments.remove(target)
    arguments = np.asarray(arguments)
    
    n = np.random.randint(low=kmin, high=kmax+1, size=size)
    result = []
    score = []
    for i, cn in enumerate(n):
        current = np.random.choice(arguments, size=cn, replace=False)
        result.append([1.0 if arg in current else 0.0 for arg in arguments])
        if cn == 0:
            score.append(0)
        else:
            score.append(gm.information.discrete_mutual_information(data[[target]], data[current]))
    return np.vstack(result), np.hstack(score)

In [65]:
from os import listdir
import os.path
NETWORKS_PATH = '/home/wrwt/Programming/pygraphmodels/networks/'
network_filenames = listdir(NETWORKS_PATH)

In [226]:
from itertools import permutations, repeat

class Cyclic(Exception): pass
class Outdated(Exception): pass

class ScoreManager:
    def __init__(self, dgm, score_f):
        self.dgm = dgm
        self.data = data
        self.score_f = score_f
        self.scores = {node: self.score_f(node, self.dgm.predecessors(node)) for node in self.dgm.nodes()}
        self.score = sum(self.scores.values())
        self.ops = ['add_edge', 'remove_edge', 'reverse_edge']
        self.inv_ops = {
            'add_edge': self.remove_edge, 
            'remove_edge': self.add_edge,
            'reverse_edge': lambda dst, src: self.reverse_edge(src, dst)
                       }
    
    def _recalc_score(self, node):
        self.score -= self.scores[node]
        self.scores[node] = self.score_f(node, self.dgm.predecessors(node))
        self.score += self.scores[node]
    
    def operations(self, candidates):
        ops = []
        scores = []
        for src, dst in candidates:
            if self.dgm.has_edge(src, dst):
                # remove_edge
                ops.append(('remove_edge', src, dst))
                self.dgm.remove_edge(src, dst)
                scores.append(self.score_f(dst, self.dgm.predecessors(dst)) - self.scores[dst])
                self.dgm.add_edge(src, dst)
                # reverse_edge
                ops.append(('reverse_edge', src, dst))
                self.dgm.remove_edge(src, dst)
                self.dgm.add_edge(dst, src)
                scores.append(self.score_f(dst, self.dgm.predecessors(dst)) - self.scores[dst] + \
                              self.score_f(src, self.dgm.predecessors(src)) - self.scores[src])
                self.dgm.remove_edge(dst, src)
                self.dgm.add_edge(src, dst)
            elif not self.dgm.has_edge(dst, src):
                # add_edge
                ops.append(('add_edge', src, dst))
                self.dgm.add_edge(src, dst)
                scores.append(self.score_f(dst, self.dgm.predecessors(dst)) - self.scores[dst])
                self.dgm.remove_edge(src, dst)
        return scores, ops
    
    def node_operations(self, node):
        others = list(self.dgm.nodes())
        others.remove(node)
        candidates = list(zip(repeat(node), others)) + list(zip(others, repeat(node)))
        return self.operations(candidates)
    
    def all_operations(self):
        return self.operations(permutations(self.dgm.nodes(), 2))
    
    def op_score(self, op, src, dst):
        return getattr(self, op + '_score')(src, dst)
    
    def apply_op(self, dscore, op, src, dst, eps=1e-5):
        delta_score = -self.score
        if not getattr(self, op)(src, dst):
            raise Outdated()
        
        if not nx.is_directed_acyclic_graph(self.dgm):
            self.inv_ops[op](src, dst)
            raise Cyclic()
        
        self._recalc_score(dst)
        if op == 'reverse_edge':
            self._recalc_score(src)
        
        delta_score += self.score
        if delta_score + eps < dscore:
            getattr(self, self.inv_ops[op])(src, dst)
            self._recalc_score(dst)
            if op == 'reverse_edge':
                self._recalc_score(src)
            raise Outdated()
    
    def add_edge(self, src, dst):
        if self.dgm.has_edge(src, dst):
            return False
        self.dgm.add_edge(src, dst)
        return True
        
    def remove_edge(self, src, dst):
        if not self.dgm.has_edge(src, dst):
            return False
        self.dgm.remove_edge(src, dst)
        return True
        
    def reverse_edge(self, src, dst):
        if not self.dgm.has_edge(src, dst):
            return False
        self.dgm.remove_edge(src, dst)
        self.dgm.add_edge(dst, src)
        return True

In [247]:
import heapq
class LocalSearch:
    def __init__(self, score_manager):
        self.score_manager = score_manager
        self.opheap = list(zip(*self.score_manager.all_operations()))
        self.cyclic = []
        heapq.heapify(self.opheap)
        
    def iteration(self):
        while True:
            score, op = heapq.heappop(self.opheap)
            if score >= 0:
                return False
            try:
                self.score_manager.apply_op(score, *op)
                if op[0] == 'add_edge':
                    for c in self.cyclic:
                        heapq.heappush(self.opheap, c)
                    self.cyclic = []
                for cop in list(zip(*self.score_manager.node_operations(op[2]))):
                    heapq.heappush(self.opheap, cop)
                if op[0] == 'reverse_edge':
                    for cop in list(zip(*self.score_manager.node_operations(op[1]))):
                        heapq.heappush(self.opheap, cop)
                break
            except Cyclic:
                self.cyclic.append((score, op))
            except Outdated:
                pass
        print(op, score, self.score_manager.score)
        return True
        
    def __call__(self):
        while self.iteration():
            pass

In [248]:
true_dgm = gm.DGM.read(os.path.join(NETWORKS_PATH, 'earthquake.bif'))
true_dgm.draw()

In [249]:
data = true_dgm.rvs(size=10000)

In [250]:
dgm = gm.DGM()
dgm.add_nodes_from(true_dgm.nodes())

In [251]:
from graphmodels.information import discrete_mutual_information, discrete_entropy

def bic_score(x, pa):    
    def n_values(x):
        return len(data[x].value_counts())

    k = n_values(x)*np.prod([n_values(pa_i) for pa_i in pa]) - 1
    n = data.shape[0]
    l = n*(discrete_mutual_information(data[[x]], data[pa]) - \
           discrete_entropy(data[[x]]))
    return  -l + 0.5 * np.log(n) * k
    
sm = ScoreManager(dgm, bic_score)

In [252]:
sm.all_operations()

([-223.43246723168761,
  -385.74860704055413,
  -186.07504644610708,
  8.7642420581869374,
  -223.43246723168761,
  -427.31493897696964,
  -187.73043971199695,
  -66.742887211499806,
  -385.74860704055413,
  -427.31493897696964,
  -367.41616982602773,
  -115.0294250918455,
  -186.07504644610492,
  -187.73043971199661,
  -367.41616982602989,
  -33.150936335712686,
  8.7642420581868237,
  -66.742887211499806,
  -115.02942509184561,
  -33.150936335708593],
 [('add_edge', 'Burglary', 'MaryCalls'),
  ('add_edge', 'Burglary', 'Alarm'),
  ('add_edge', 'Burglary', 'JohnCalls'),
  ('add_edge', 'Burglary', 'Earthquake'),
  ('add_edge', 'MaryCalls', 'Burglary'),
  ('add_edge', 'MaryCalls', 'Alarm'),
  ('add_edge', 'MaryCalls', 'JohnCalls'),
  ('add_edge', 'MaryCalls', 'Earthquake'),
  ('add_edge', 'Alarm', 'Burglary'),
  ('add_edge', 'Alarm', 'MaryCalls'),
  ('add_edge', 'Alarm', 'JohnCalls'),
  ('add_edge', 'Alarm', 'Earthquake'),
  ('add_edge', 'JohnCalls', 'Burglary'),
  ('add_edge', 'JohnCall

In [253]:
bic_score('Alarm', ['Burglary']), bic_score('Alarm', []), bic_score('Alarm', ['Burglary', 'Earthquake'])

(455.62211712694591, 841.37072416750004, 282.71800134178477)

In [254]:
%load_ext line_profiler

The line_profiler extension is already loaded. To reload it, use:
  %reload_ext line_profiler


In [255]:
ls = LocalSearch(sm)

In [256]:
ls()

(('add_edge', 'Alarm', 'MaryCalls'), -427.31493897696964, 5465.0047032576076)
(('add_edge', 'Alarm', 'Burglary'), -385.74860704055413, 5079.2560962170537)
(('add_edge', 'JohnCalls', 'Alarm'), -367.41616982602989, 4711.8399263910233)
(('add_edge', 'Burglary', 'MaryCalls'), -223.43246723168761, 4730.1427488948984)
(('add_edge', 'JohnCalls', 'MaryCalls'), -187.73043971199661, 4765.743888152776)
(('add_edge', 'JohnCalls', 'Burglary'), -186.07504644610492, 4783.7779296278432)
(('add_edge', 'Earthquake', 'Alarm'), -115.02942509184561, 4718.8627811218767)
(('add_edge', 'Earthquake', 'MaryCalls'), -66.742887211499806, 4788.7051719061519)
(('remove_edge', 'Burglary', 'MaryCalls'), -72.665870808734894, 4716.0393010974167)
(('remove_edge', 'JohnCalls', 'MaryCalls'), -72.237592877493398, 4680.2501310627158)
(('remove_edge', 'Earthquake', 'MaryCalls'), -69.842390784275153, 4664.9588193601239)
(('reverse_edge', 'Alarm', 'MaryCalls'), -63.109251812010939, 4930.8072094700656)
(('reverse_edge', 'MaryCa

KeyboardInterrupt: 

In [191]:
ls.opheap

[(0.0, ('reverse_edge', 'MaryCalls', 'Alarm')),
 (7.9462395774498873e-07, ('reverse_edge', 'Burglary', 'JohnCalls')),
 (5.6843418860808015e-13, ('reverse_edge', 'Alarm', 'JohnCalls')),
 (6.9592987104382473, ('add_edge', 'Earthquake', 'Burglary')),
 (6.9592987104381336, ('add_edge', 'Burglary', 'Earthquake')),
 (4.5730747615380096e-06, ('reverse_edge', 'Alarm', 'Burglary')),
 (18.039901177126012, ('add_edge', 'MaryCalls', 'JohnCalls')),
 (18.055906278657432, ('add_edge', 'JohnCalls', 'Burglary')),
 (18.343603276289173, ('add_edge', 'MaryCalls', 'Burglary')),
 (18.26854228914749, ('reverse_edge', 'Burglary', 'JohnCalls')),
 (36.308359289369037, ('add_edge', 'MaryCalls', 'JohnCalls')),
 (6.9592987104381336, ('add_edge', 'Burglary', 'Earthquake')),
 (17.845851436435396, ('add_edge', 'Earthquake', 'JohnCalls')),
 (18.055906278657432, ('add_edge', 'JohnCalls', 'Burglary')),
 (18.055905484033474, ('add_edge', 'Burglary', 'JohnCalls')),
 (330.37515916246349, ('remove_edge', 'MaryCalls', 'Alarm

In [192]:
dgm.draw()

In [158]:
bic_score('Alarm', ['Earthquake'])

6926.3138675641449

In [159]:
bic_score('Alarm', [])

8233.8573378805067