In [34]:
import numpy as np
from sklearn import preprocessing

from sklearn.ensemble import ExtraTreesRegressor
from sklearn.neighbors import KNeighborsRegressor
from sklearn.linear_model import LinearRegression
from sklearn.linear_model import RidgeCV
from sklearn import cross_validation

__author__ = 'amanda'


def preprocess_data():
    rating = np.load('./data_ready_to_use/clean_rating_data.npz')
    rating = rating['full_rating']
    average_rating = rating.mean(axis=0)

    feature_arr = np.load('./data_ready_to_use/clean_features.npz')
    feature_arr = feature_arr['feature_arr']

    gender_list = np.load('./data_ready_to_use/gender_list.npz')
    gender_list = gender_list['gender_list']

    gender_list = gender_list[:, np.newaxis]
    feature_arr = np.hstack((gender_list, feature_arr))
    feature_arr = preprocessing.scale(feature_arr)
    return average_rating, feature_arr

# Load the preprocessed data.
rating, feature_arr = preprocess_data()


cv = cross_validation.ShuffleSplit(200, n_iter=10, test_size=0.15, random_state=0)


# Fit estimators
ESTIMATORS = {"Extra trees": ExtraTreesRegressor(n_estimators=10, max_features=32, random_state=0),
              "K-nn": KNeighborsRegressor(),
              "Linear regression": LinearRegression(),
              "Ridge": RidgeCV()}

y_test_predict = dict()
y_overall_predict = dict()
for name, estimator in ESTIMATORS.items():
    print name
    a = cross_validation.cross_val_score(estimator, feature_arr, rating, cv=cv)
    print a

Ridge
[ 0.14839561  0.19584775  0.433925    0.23706595  0.1377556   0.13401471
  0.30756033  0.18602091  0.46530518  0.19753111]
K-nn
[ 0.3170519   0.09909138  0.28267922  0.1125815  -0.48924928 -0.13560508
  0.20111675  0.00468427  0.24530657  0.02829938]
Extra trees
[ 0.19740106  0.19146068  0.36235779  0.16944499 -0.13489705 -0.04301468
  0.17240085 -0.18456174  0.26688219 -0.09677614]
Linear regression
[  8.21838525e-02  -6.47944001e-01   3.52706117e-02   2.62173255e-01
  -4.43820648e-01   1.43728089e-03   1.32926403e-01   1.24245327e-01
  -1.07710752e-01  -7.70261922e+00]
