In [1]:
import os
import sys
import numpy as np
import pandas as pd
from functools import partial
from multiprocessing import Pool
from DataProcessor import createData
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import roc_auc_score
import matplotlib.pyplot as plt

In [2]:
path = "./stocks data/Data/"
stock_data_filenames = os.listdir(path)

In [3]:
def cal_acurracy(stockname, method = "Logistic"):
    area = accuracy = 0
    path = "./stocks data/Data/" + stockname
    
    try:
        df = pd.read_csv(path)
        data = createData(df, label_return_period=20, annualized_return=0.1).values
        X = data[:, :-1]
        Y = data[:, -1]
        x_train, x_test, y_train, y_test = train_test_split(X, Y, test_size=0.3, random_state=18)

        if method == "Logistic":
            lr = LogisticRegression()
            lr.fit(x_train, y_train)
            accuracy = lr.score(x_test, y_test)
            probs = lr.predict_proba(x_test)
            area = roc_auc_score(y_test, [tup[1] for tup in probs])
            
        if method == "RandomForest":
            rf = RandomForestClassifier()
            rf.fit(x_train, y_train)
            accuracy = rf.score(x_test, y_test)
            probs = rf.predict_proba(x_test)
            area = roc_auc_score(y_test, [tup[1] for tup in probs])
            
    except:
        print("Error occurs.")
        
    return accuracy, area

def run(file_names, method="Logistic"):
    
    pf = partial(cal_acurracy, method=method)
    p = Pool(processes=10)
    accs = p.map(pf, file_names)
    p.close()
    
    print("Average accuracy: " + str(np.mean([tup[0] for tup in accs])))
    print("Average auc: " + str(np.mean([tup[1] for tup in accs])))

## Logistic regression

In [4]:
run(stock_data_filenames, "Logistic")

Error occurs.
Average accuracy: 0.659014854969
Average auc: 0.677799423232


## RandomForest

In [5]:
run(stock_data_filenames, "RandomForest")

Error occurs.
Average accuracy: 0.867435525304
Average auc: 0.937926759741


## XGBoost