In [92]:
import pandas as pd
import pickle
import unittest


In [93]:

with open('dwarf_classifier_model.pickle', 'rb') as handle:
  clf = pickle.load(handle)


In [94]:
# expected value: 18480
clf.predict([[0, 50, 2019, 3, 0, 5, 1.5, 6.7, 4, 1, 1, 7]])


array([30061.42])

In [95]:
class NormalizeData:
  def __init__(self, data):
    self.data = data

  def remove_dashes(self):
    self.data = self.data.replace('-', 0)
    return self

  def parse_float(self, field):
    self.data[field] = self.data[field].astype(float)
    return self

  def remove_string(self, field, string_to_remove):
    self.data[field] = self.data[field].str.replace(string_to_remove, '')
    return self

  def remove_outliers(self, condition):
    self.data = self.data.drop(self.data[condition].index)
    return self

  def remove_fields(self, fields):
    self.data = self.data.drop(fields, axis=1)
    return self


In [96]:


class TestNormalizeData(unittest.TestCase):

  def remove_dashes(self):
    test_data = pd.DataFrame({'price': [0, '-', 2, 5]})
    normalize = NormalizeData(test_data)
    prices = normalize.remove_dashes().data['price'].tolist()
    self.assertListEqual([0, 0, 2, 5], prices)

  def parse_float(self):
    test_data = pd.DataFrame({'price': [0, 0, 2, '5', '11.5']})
    normalize = NormalizeData(test_data)
    prices = normalize.parse_float('price').data['price'].tolist()
    self.assertListEqual([0.0, 0.0, 2.0, 5.0, 11.5], prices)

  def remove_string(self):
    test_data = pd.DataFrame({'price': ['2 km', '5', '11.5 km']})
    normalize = NormalizeData(test_data)
    prices = normalize.remove_string('price', ' km').data['price'].tolist()
    self.assertListEqual(['2', '5', '11.5'], prices)

  def remove_outliers(self):
    test_data = pd.DataFrame({'price': [10, 0, 2, 5, 11.5]})
    normalize = NormalizeData(test_data)

    prices = normalize.remove_outliers(
        test_data['price'] > 10).data['price'].tolist()
    self.assertListEqual([10, 0, 2, 5], prices)

  def remove_fields(self):
    test_data = pd.DataFrame({'price': [10, 0, 2, 5, 11.5]})
    normalize = NormalizeData(test_data)

    prices = normalize.remove_fields(['price']).data.columns.tolist()
    self.assertListEqual([], prices)


test = TestNormalizeData()
test.remove_dashes()
test.parse_float()
test.remove_string()
test.remove_outliers()
test.remove_fields()
