In [None]:
from sklearn import svm
from sklearn import linear_model
from sklearn import neighbors
from sklearn.ensemble import RandomForestClassifier
from sklearn.externals import joblib
from sklearn.metrics import classification_report
import json

class MLModel:
  def __init__(self, target_model, save_path):
    self.save_path = save_path
    self.models = {"ridge":linear_model.Ridge(alpha=.5), 
            "lasso":linear_model.Lasso(alpha=0.1), 
            "knn":neighbors.KNeighborsClassifier(5), 
            "randomforest":RandomForestClassifier(n_estimators=10),
            "sgd_svm":linear_model.SGDClassifier(max_iter=1000, tol=1e-3)} # to do exercise add unsupervised clustering models
    
    if target_model is not None:
      if target_model in self.models.keys():
        self.model = self.models[target_model]
      else:
        raise NotImplementedError
    
    self.clf = None
  
  def fit(self, X, y):
    #put your code down here
    self.clf = self.model.fit(X, y)

  def save(self, mode="pickle"):
    if mode == "pickle":
      if self.clf is not None:
        joblib.dump(self.clf, self.save_path)
      else:
        raise ValueError("train before saving the classifier")

    elif mode == "json":
      if self.clf is not None:
        model_dict = {}
        model_dict["clf"] = self.clf
        json_clf = json.dumps(model_dict, indent=4)
        with open(self.save_path+"/model.json", 'w') as file:
            file.write(json_clf)
        file.close()
      else:
        raise ValueError("Train before saving the classifier")

    else:
      raise NotImplementedError

  def load(self, path, mode):
    if mode == "pickle":
      self.clf = joblib.load(path)
      print(self.clf)
    elif mode == "json":
      with open(path, 'r') as file:
          model_dict = json.load(file)
      self.clf = model_dict["clf"]
    else:
      raise NotImplementedError
    return self.clf

  def predict(self, X):
    predictions = self.clf.predict(X)
    return predictions

  def evaluate(self, y_true, y_pred, target_names):
    # put your code down here add eval
    print(classification_report(y_true, y_pred, target_names=target_names))

## To Do exercises

# Reuse the class above on transformed data in previous sections for:

# 1. classification
# 2. clustering