From 7b31aeaa423f03e77e30e56e7f32ca977b2eef3f Mon Sep 17 00:00:00 2001 From: Rishabh Agarwal Date: Wed, 1 Mar 2017 17:35:38 +0530 Subject: [PATCH] Added some unittests in test_learning.py --- tests/test_learning.py | 41 +++++++++++++++++++++++++++++++++++++---- 1 file changed, 37 insertions(+), 4 deletions(-) diff --git a/tests/test_learning.py b/tests/test_learning.py index 46ac8dd26..738fa7690 100644 --- a/tests/test_learning.py +++ b/tests/test_learning.py @@ -1,6 +1,11 @@ -from learning import parse_csv, weighted_mode, weighted_replicate, DataSet, \ - PluralityLearner, NaiveBayesLearner, NearestNeighborLearner +import pytest +import math from utils import DataFile +from learning import ( + parse_csv, weighted_mode, weighted_replicate, DataSet, + PluralityLearner, NaiveBayesLearner, NearestNeighborLearner, + rms_error, manhattan_distance, mean_boolean_error, mean_error +) def test_parse_csv(): @@ -33,5 +38,33 @@ def test_naive_bayes(): def test_k_nearest_neighbors(): iris = DataSet(name="iris") - kNN = NearestNeighborLearner(iris, k=3) - assert kNN([5, 3, 1, 0.1]) == "setosa" + kNN = NearestNeighborLearner(iris,k=3) + assert kNN([5,3,1,0.1]) == "setosa" + +def test_rms_error(): + assert rms_error([2,2], [2,2]) == 0 + assert rms_error((0,0), (0,1)) == math.sqrt(0.5) + assert rms_error((1,0), (0,1)) == 1 + assert rms_error((0,0), (0,-1)) == math.sqrt(0.5) + assert rms_error((0,0.5), (0,-0.5)) == math.sqrt(0.5) + +def test_manhattan_distance(): + assert manhattan_distance([2,2], [2,2]) == 0 + assert manhattan_distance([0,0], [0,1]) == 1 + assert manhattan_distance([1,0], [0,1]) == 2 + assert manhattan_distance([0,0], [0,-1]) == 1 + assert manhattan_distance([0,0.5], [0,-0.5]) == 1 + +def test_mean_boolean_error(): + assert mean_boolean_error([1,1], [0,0]) == 1 + assert mean_boolean_error([0,1], [1,0]) == 1 + assert mean_boolean_error([1,1], [0,1]) == 0.5 + assert mean_boolean_error([0,0], [0,0]) == 0 + assert mean_boolean_error([1,1], [1,1]) == 0 + +def test_mean_error(): + assert mean_error([2,2], [2,2]) == 0 + assert mean_error([0,0], [0,1]) == 0.5 + assert mean_error([1,0], [0,1]) == 1 + assert mean_error([0,0], [0,-1]) == 0.5 + assert mean_error([0,0.5], [0,-0.5]) == 0.5 \ No newline at end of file