In [2]:
import pandas as pd
import numpy as np
from sklearn.linear_model import LinearRegression, Lasso
from sklearn.model_selection import train_test_split, cross_val_score
from sklearn.metrics import mean_absolute_error

#### Train - Test split

In [99]:
data = pd.read_csv("datasets/model_data.csv")
data_frame_trimmed = data.apply(lambda x: x.astype(str).str.lower() if x.dtype == "object" else x)
data_frame_trimmed = data.apply(lambda x: x.str.strip() if x.dtype == "object" else x)

In [100]:
data_frame_trimmed.rename(columns={'5G': '5g', 'Battery Capacity': 'battery', 
    'Card slot': 'card_slot', 
    'Display type': 'display_type',
    'GPS': 'gps',
    'Internal Memory': 'internal_memory',
    'LTE (4G) Network': '4g',
    'NFC': 'nfc',
    'Number of cameras': 'num_of_cameras',
    'Primary Camera': 'primary_cam',
    'RAM': 'ram',
    'Screen size': 'screen_size',
    'Year': 'year', 
    'Selfie Camera': 'selfie_cam'
}, inplace=True)

In [101]:
X = pd.get_dummies(data_frame_trimmed.drop(['name', 'price'], axis=1))
y = data_frame_trimmed.price

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size = .2, random_state = 1)

#### Linear Regression

In [102]:
lr = LinearRegression()
lr.fit(X_train, y_train)

print(mean_absolute_error(y_train, lr.predict(X_train)))
print(mean_absolute_error(y_test, lr.predict(X_test)))

147.7080112891959
176.2236846099498


In [103]:
print(lr.predict(X_test)[3])
print(y_test.iloc[3])

208.68034924166204
200.0


#### Lasso

In [105]:
ls = Lasso(alpha=.13)
ls.fit(X_train, y_train)

print(mean_absolute_error(y_train, ls.predict(X_train)))
print(mean_absolute_error(y_test, ls.predict(X_test)))

148.77647626048628
174.51144599295313


In [106]:
print(ls.predict(X_test)[3])
print(y_test.iloc[3])

224.12509716400564
200.0


#### Save model

In [107]:
import pickle

pckl = {"model": lr}
pickle.dump(pckl, open('model.p', "wb"))

In [111]:
with open('flask_app/ml_models/model.p', 'rb') as f:
    m = pickle.load(f)['model']

print(m.predict(X_test)[3])
print(y_test.iloc[3])

208.68034924166204
200.0
