In [41]:
import warnings
warnings.filterwarnings("ignore")

In [42]:
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error, mean_absolute_error
from sklearn.svm import SVR
import matplotlib.pyplot as plt

In [43]:
df = pd.read_csv('../Data/svm_apple_parameters_with_views.csv')  

In [44]:
columns = ['apple_label', 'semi_major_axis', 'semi_minor_axis', 'area', 'perimeter', 'eccentricity', 'volume']

column_dict = {
    0: 'semi_major_axis',
    1: 'semi_minor_axis',
    2: 'area',
    3: 'perimeter',
    4: 'eccentricity'
}


In [45]:
def mean_absolute_percentage_error(y_true, y_pred):
    return np.mean(np.abs((y_true - y_pred) / y_true)) * 100

In [46]:
X = df.drop(columns = ['apple_label', 'area', 'volume']).values.tolist() 
y = df['volume'].values.tolist()  

svr_model = SVR(kernel='linear')
svr_model.fit(X, y)
y_pred = svr_model.predict(X)
print(y_pred)
rmse = np.sqrt(mean_squared_error(y, y_pred))
mse = mean_squared_error(y, y_pred)
mpe = mean_absolute_percentage_error(y, y_pred)

df['predicted_volume'] = y_pred

df[['apple_label', 'volume', 'predicted_volume']].to_csv('../Data/svm_predictions_with_each_view.csv', index=False)


[ 61.45598316  93.79650934 114.28101698 120.86483252 106.1663046
 112.05464045 115.58628147 111.57222095 121.55131048 127.03065595
 127.02770012 182.0869889  182.94501848 185.08860596 175.98024084
 170.23570335 169.67166518 174.70676201 172.49571063 180.90753379
 179.14380825 175.75260861 171.5808859  182.4353419  166.27400316
 178.14077679 223.80840152 222.55947111 238.71184947 244.6144668
 233.9496537  225.6619593  239.36073966 260.43530456 220.86357034
 208.71942519 223.95967557 233.75289672 223.57765135 213.90863084
 223.15539241 233.17646852 217.38960823 127.81951212 135.22354062
 135.83777373 128.08429467 135.6285295  135.11672622 136.0279235
 139.41315337 137.35464419 130.36767358 193.89748649 171.57280032
 190.34670199 185.48193111 187.54699705 188.77639373 171.89786638
 174.89979901 184.56140956 184.01131914 187.7461344  192.03291549
 195.92581966 173.89685914 183.33330641 199.23320644 179.25765104
 239.31394295 238.63386011 242.64909806 229.95024845 227.83570179
 238.75167629

In [47]:
print("********* Testing error **********")
print("Root Mean Squared Error (RMSE):", round(rmse, 3))
print("Mean Squared Error (MSE):", round(mse, 3))
print("Mean Percentage Error (MPE):", round(mpe, 3))

********* Testing error **********
Root Mean Squared Error (RMSE): 21.927
Mean Squared Error (MSE): 480.813
Mean Percentage Error (MPE): 10.313
