Skip to content

Commit

Permalink
Merge branch 'diagnostic_fix' into dev
Browse files Browse the repository at this point in the history
  • Loading branch information
bryanhe committed Oct 21, 2016
2 parents 083ec1e + df7f6b7 commit 2cc1e9b
Showing 1 changed file with 17 additions and 15 deletions.
32 changes: 17 additions & 15 deletions numbskull/factorgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,9 +127,11 @@ def diagnosticsLearning(self, weight_copy=0):
# INFERENCE AND LEARNING #
################################

def burnIn(self, epochs, sample_evidence, var_copy=0, weight_copy=0):
def burnIn(self, epochs, sample_evidence, diagnostics=False,
var_copy=0, weight_copy=0):
"""TODO."""
print("FACTOR " + str(self.fid) + ": STARTED BURN-IN...")
if diagnostics:
print("FACTOR " + str(self.fid) + ": STARTED BURN-IN...")
# NUMBA-based method. Implemented in inference.py
for ep in range(epochs):
args = (self.threads, var_copy, weight_copy,
Expand All @@ -138,17 +140,20 @@ def burnIn(self, epochs, sample_evidence, var_copy=0, weight_copy=0):
self.factor_index, self.Z, self.cstart, self.count,
self.var_value, self.weight_value, sample_evidence, True)
run_pool(self.threadpool, self.threads, gibbsthread, args)
print("FACTOR " + str(self.fid) + ": DONE WITH BURN-IN")
if diagnostics:
print("FACTOR " + str(self.fid) + ": DONE WITH BURN-IN")

def inference(self, burnin_epochs, epochs, sample_evidence=False,
diagnostics=False, var_copy=0, weight_copy=0):
"""TODO."""
# Burn-in
if burnin_epochs > 0:
self.burnIn(burnin_epochs, sample_evidence)
self.burnIn(burnin_epochs, sample_evidence,
diagnostics=diagnostics)

# Run inference
print("FACTOR " + str(self.fid) + ": STARTED INFERENCE")
if diagnostics:
print("FACTOR " + str(self.fid) + ": STARTED INFERENCE")
for ep in range(epochs):
with Timer() as timer:
args = (self.threads, var_copy, weight_copy, self.weight,
Expand All @@ -162,7 +167,8 @@ def inference(self, burnin_epochs, epochs, sample_evidence=False,
if diagnostics:
print('Inference epoch #%d took %.03f sec.' %
(ep, self.inference_epoch_time))
print("FACTOR " + str(self.fid) + ": DONE WITH INFERENCE")
if diagnostics:
print("FACTOR " + str(self.fid) + ": DONE WITH INFERENCE")
# compute marginals
if epochs != 0:
self.marginals = self.count / float(epochs)
Expand All @@ -175,10 +181,11 @@ def learn(self, burnin_epochs, epochs, stepsize, decay,
"""TODO."""
# Burn-in
if burnin_epochs > 0:
self.burnIn(burnin_epochs, True)
self.burnIn(burnin_epochs, True, diagnostics=diagnostics)

# Run learning
print("FACTOR " + str(self.fid) + ": STARTED LEARNING")
if diagnostics:
print("FACTOR " + str(self.fid) + ": STARTED LEARNING")
for ep in range(epochs):
if diagnostics:
print("FACTOR " + str(self.fid) + ": EPOCH #" + str(ep))
Expand All @@ -196,15 +203,10 @@ def learn(self, burnin_epochs, epochs, stepsize, decay,
run_pool(self.threadpool, self.threads, learnthread, args)
self.learning_epoch_time = timer.interval
self.learning_total_time += timer.interval
if diagnostics:
print("FACTOR " + str(self.fid) + ": EPOCH #" + str(ep))
print("Current stepsize = " + str(stepsize))
if verbose:
self.diagnosticsLearning(weight_copy)
sys.stdout.flush() # otherwise output refuses to show in DD
# Decay stepsize
stepsize *= decay
print("FACTOR " + str(self.fid) + ": DONE WITH LEARNING")
if diagnostics:
print("FACTOR " + str(self.fid) + ": DONE WITH LEARNING")

def dump_weights(self, fout, weight_copy=0):
"""Dump <wid, weight> text file in DW format."""
Expand Down

0 comments on commit 2cc1e9b

Please sign in to comment.