In [1]:
import numpy as np
import pandas as pd
from sklearn import linear_model

def read_data(csv_file):
    csv_df = pd.read_csv(csv_file)
    return csv_df

def split_data(csv_df):
    input = csv_df.iloc[:, :-1]
    output = csv_df.iloc[:, -1]
    x = input.values
    y = output.values.reshape((-1, 1))
    return x, y

def find_optimize(input, outcome):
    w = np.dot(np.linalg.pinv(np.dot(input.T, input)), np.dot(input.T, outcome))
    return w

def optimize_with_sklearn(input, outcome):
    regr = linear_model.LinearRegression(fit_intercept=False)
    regr.fit(input, outcome)
    return regr.coef_

def get_loss_value(input, outcome, w):
    cost = 0
    y_hat = np.dot(input, w)
    for x, y in zip(outcome, y_hat):
        print('Outcome:', x[0], 'Predict:', y[0])
        cost += pow(x[0] - y[0], 2)
    return cost / 2


def predict_new_data(input, w):
    # convert to input_bar
    one = np.ones((input.shape[0], 1))
    input = np.concatenate((one, input), axis=1)
    return np.dot(input, w)


if __name__ == '__main__':
    df = read_data('data.csv')
    print(df)
    input, outcome = split_data(df)
    one = np.ones((input.shape[0], 1))
    input = np.concatenate((one, input), axis=1)
    w1 = find_optimize(input, outcome)
    w2 = optimize_with_sklearn(input, outcome)
    print(w1.T)
    print(w2)
    print('Loss value:', get_loss_value(input, outcome, w1))
    print('Predict new label:', predict_new_data(np.array([[70, 73, 79]]), w1))

    exam1  exam2  exam3  final_exam
0      73     80     75         152
1      93     88     93         185
2      89     91     90         180
3      96     98    100         196
4      73     66     70         142
5      53     46     55         101
6      69     74     77         149
7      47     56     60         115
8      87     79     90         175
9      79     70     88         164
10     69     70     73         141
11     70     65     74         141
12     93     95     91         184
13     79     80     73         152
14     70     73     78         148
15     93     89     96         192
16     78     75     68         147
17     81     90     93         183
18     88     92     86         177
19     78     83     77         159
20     82     86     90         177
21     86     82     89         175
22     78     83     85         175
23     76     83     71         149
24     96     93     95         192
[[-4.3361024   0.35593822  0.54251876  1.16744422]]
[[-4.3361024