Skip to content

Commit

Permalink
Update tests to account for absence of Rbm
Browse files Browse the repository at this point in the history
  • Loading branch information
JohnVinyard committed Aug 25, 2017
1 parent 5c70f34 commit 3e2927a
Showing 1 changed file with 15 additions and 16 deletions.
31 changes: 15 additions & 16 deletions zounds/learn/test_learn.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from random_samples import ReservoirSampler
from preprocess import \
UnitNorm, MeanStdNormalization, PreprocessingPipeline, Pipeline
from learn import LinearRbm, Learned
from learn import Learned, KMeans
import numpy as np


Expand Down Expand Up @@ -45,9 +45,8 @@ class Rbm(featureflow.BaseModel, Settings):
store=False)

rbm = featureflow.PickleFeature(
LinearRbm,
hdim=64,
epochs=5,
KMeans,
centroids=64,
needs=meanstd,
store=False)

Expand All @@ -66,27 +65,27 @@ def data():

class RbmTests(unittest2.TestCase):
def test_can_retrieve_rbm_pipeline(self):
Rbm = build_classes()
Rbm.process(iterator=data())
self.assertIsInstance(Rbm().pipeline, Pipeline)
KMeans = build_classes()
KMeans.process(iterator=data())
self.assertIsInstance(KMeans().pipeline, Pipeline)


class LearnedTests(unittest2.TestCase):
def test_can_use_learned_feature(self):
Rbm = build_classes()
Rbm.process(iterator=data())
l = Learned(learned=Rbm())
KMeans = build_classes()
KMeans.process(iterator=data())
l = Learned(learned=KMeans())
results = list(l._process(np.random.random_sample((33, 3))))[0]
self.assertEqual((33, 64), results.shape)

def test_pipeline_changes_version_when_recomputed(self):
Rbm = build_classes()
Rbm.process(iterator=data())
v1 = Learned(learned=Rbm()).version
v2 = Learned(learned=Rbm()).version
KMeans = build_classes()
KMeans.process(iterator=data())
v1 = Learned(learned=KMeans()).version
v2 = Learned(learned=KMeans()).version
self.assertEqual(v1, v2)
Rbm.process(iterator=data())
v3 = Learned(learned=Rbm()).version
KMeans.process(iterator=data())
v3 = Learned(learned=KMeans()).version
self.assertNotEqual(v1, v3)

def test_pipeline_does_not_store_computed_data_from_training(self):
Expand Down

0 comments on commit 3e2927a

Please sign in to comment.