In [36]:
import numpy as np
import pandas as pd
from sklearn.linear_model import Ridge, LinearRegression, Lasso
from sklearn.model_selection import train_test_split
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis, QuadraticDiscriminantAnalysis
from sklearn.metrics import confusion_matrix, accuracy_score
from sklearn.naive_bayes import GaussianNB
import matplotlib.pyplot as plt
import seaborn as sns
sns.set(style = 'whitegrid')

# Question 2

In [127]:
def get_lambda(lam_vals, data, x_test, y_test, model_type = 'ridge'):
    error_vals = []
    for lam in lam_vals:
        error = []
        if model_type == 'ridge':
            model = Ridge(alpha = lam, max_iter = 1e6)
        elif model_type == 'lasso':
            model = Lasso(alpha = lam, max_iter = 1e6)
        for i in range(5):
            cur_data = data[data['group'] == i]
            x_train = cur_data.drop('MEDV', axis = 1)
            y_train = cur_data['MEDV']
            cur_error = get_error(model, x_train, y_train, x_test, y_test)
            error.append(cur_error)
        avg_error = np.mean(np.array(error))
        error_vals.append(np.mean(np.array(error)))
    min_ind = np.where(error_vals == np.amin(error_vals))
    return lam_vals[min_ind]

In [154]:
def get_error(model, x_train, y_train, x_test, y_test, beta = False):
    model.fit(x_train, y_train)
    preds = model.predict(x_test)
    error = np.mean(np.square(preds - y_test))
    if not beta:
        return error
    else:
        coef = model.coef_
        print('error = {}'.format(error))
        print(coef)
        result = np.concatenate((np.array([error]), coef))
        print(result)
        return result

In [108]:
def add_groups(data):
    groups = np.concatenate(([0]*80, [1]*80, [2]*80, [3]*80, [4]*80, [5]*106))
    np.random.shuffle(groups)
    data['group'] = groups

In [107]:
#read the data into a dataframe
housing_data = pd.read_csv('housing.data.txt', sep = '\t')
data = np.zeros(housing_data.shape)
for i in housing_data.index:
    data[i] = housing_data['CRIM'][i].split()
cols = []
for col in housing_data.columns:
    cols.append(col.strip())
housing_data = pd.DataFrame(data = data, columns = cols)

In [174]:
lam_vals = np.arange(0, 1.1, 0.1)
trials = np.arange(0, 10, 1)

data_dict = {'error': pd.DataFrame(index = trials, columns = ['ridge', 'lasso', 'ols'])}
for col in housing_data.columns:
    if not col == 'group' and not col == 'MEDV':
        data_dict[col] = pd.DataFrame(index = trials, columns = ['ridge', 'lasso', 'ols'])
    
for trial in trials:
    add_groups(housing_data)
    x_test = housing_data[housing_data['group'] == 5].drop('MEDV', axis = 1)
    y_test = housing_data[housing_data['group'] == 5]['MEDV']
    
    #find the optimal value of lambda
    ridge_lam = get_lambda(lam_vals, housing_data, x_test, y_test, model_type = 'ridge')
    lasso_lam = get_lambda(lam_vals, housing_data, x_test, y_test, model_type = 'lasso')
    
    x_train = housing_data[housing_data['group'] != 5].drop('MEDV', axis = 1)
    y_train = housing_data[housing_data['group'] != 5]['MEDV']
    
    ridge = Ridge(alpha = ridge_lam, max_iter = 1e6)
    lasso = Lasso(alpha = lasso_lam, max_iter = 1e6)
    regr = LinearRegression()
    
    ridge_vals = get_error(ridge, x_train, y_train, x_test, y_test, beta = True)
    lasso_vals = get_error(lasso, x_train, y_train, x_test, y_test, beta = True)
    ols_vals = get_error(regr, x_train, y_train, x_test, y_test, beta = True)
    for ridge_val, lasso_val, ols_val, key in zip(ridge_vals, lasso_vals, ols_vals, data_dict.keys()):
        data_dict[key]['ridge'][trial] = ridge_val
        data_dict[key]['lasso'][trial] = lasso_val
        data_dict[key]['ols'][trial] = ols_val
    

  
  positive)
  positive)
  
  positive)
  positive)
  
  positive)
  positive)
  
  positive)
  positive)
  
  positive)
  positive)


error = 22.767270429086025
[-1.57826238e-01  4.42528077e-02 -5.37848405e-02  3.34152529e+00
 -8.05273662e+00  4.00758484e+00 -7.30928881e-03 -1.31948875e+00
  3.24699360e-01 -1.12041465e-02 -8.00262455e-01  1.19906348e-02
 -5.59089743e-01 -5.30147233e-02]
[ 2.27672704e+01 -1.57826238e-01  4.42528077e-02 -5.37848405e-02
  3.34152529e+00 -8.05273662e+00  4.00758484e+00 -7.30928881e-03
 -1.31948875e+00  3.24699360e-01 -1.12041465e-02 -8.00262455e-01
  1.19906348e-02 -5.59089743e-01 -5.30147233e-02]
error = 22.345576871391806
[-1.27682519e-01  4.74980807e-02 -3.98907675e-02  0.00000000e+00
 -0.00000000e+00  2.66437349e+00  2.63185148e-03 -9.69142000e-01
  3.19633856e-01 -1.39831079e-02 -7.16197453e-01  1.13251201e-02
 -6.81128488e-01 -0.00000000e+00]
[ 2.23455769e+01 -1.27682519e-01  4.74980807e-02 -3.98907675e-02
  0.00000000e+00 -0.00000000e+00  2.66437349e+00  2.63185148e-03
 -9.69142000e-01  3.19633856e-01 -1.39831079e-02 -7.16197453e-01
  1.13251201e-02 -6.81128488e-01 -0.00000000e+00

  
  positive)
  positive)
  
  positive)
  positive)
  
  positive)
  positive)
  
  positive)
  positive)
  
  positive)
  positive)


error = 17.123943818622298
[-1.12894254e-01  5.60903390e-02  3.94054748e-02  2.64685584e+00
 -1.12836682e+01  3.52360926e+00  3.04143059e-03 -1.42327285e+00
  3.45522169e-01 -1.46295324e-02 -8.80164038e-01  1.04590636e-02
 -6.09619490e-01 -1.65194534e-01]
[ 1.71239438e+01 -1.12894254e-01  5.60903390e-02  3.94054748e-02
  2.64685584e+00 -1.12836682e+01  3.52360926e+00  3.04143059e-03
 -1.42327285e+00  3.45522169e-01 -1.46295324e-02 -8.80164038e-01
  1.04590636e-02 -6.09619490e-01 -1.65194534e-01]
error = 18.394198254215592
[-0.09787434  0.05577376  0.          0.         -0.          2.27506104
  0.00902487 -1.06231172  0.31236265 -0.01568287 -0.75560998  0.01084177
 -0.70863254 -0.        ]
[ 1.83941983e+01 -9.78743370e-02  5.57737564e-02  0.00000000e+00
  0.00000000e+00 -0.00000000e+00  2.27506104e+00  9.02486849e-03
 -1.06231172e+00  3.12362650e-01 -1.56828686e-02 -7.55609979e-01
  1.08417737e-02 -7.08632538e-01 -0.00000000e+00]
error = 17.547609962449457
[-1.16698208e-01  5.44562296

  
  positive)
  positive)
  
  positive)
  positive)
  
  positive)
  positive)
  
  positive)
  positive)
  
  positive)
  positive)


error = 22.647322246162386
[-9.71072604e-02  4.07493132e-02 -3.62233203e-02  2.53998703e+00
 -9.27742144e+00  3.85088896e+00 -2.70676277e-03 -1.36321346e+00
  2.90234243e-01 -1.08775081e-02 -8.49543274e-01  1.31839580e-02
 -5.76521401e-01  2.81781308e-01]
[ 2.26473222e+01 -9.71072604e-02  4.07493132e-02 -3.62233203e-02
  2.53998703e+00 -9.27742144e+00  3.85088896e+00 -2.70676277e-03
 -1.36321346e+00  2.90234243e-01 -1.08775081e-02 -8.49543274e-01
  1.31839580e-02 -5.76521401e-01  2.81781308e-01]
error = 22.07695881476452
[-0.07452876  0.04316375 -0.02378704  0.         -0.          2.83022783
  0.00432833 -0.99611201  0.2826685  -0.01396489 -0.7377274   0.01304181
 -0.67222945  0.01687718]
[ 2.20769588e+01 -7.45287604e-02  4.31637492e-02 -2.37870378e-02
  0.00000000e+00 -0.00000000e+00  2.83022783e+00  4.32833436e-03
 -9.96112011e-01  2.82668501e-01 -1.39648859e-02 -7.37727400e-01
  1.30418103e-02 -6.72229449e-01  1.68771792e-02]
error = 22.27722363367208
[-1.03898504e-01  4.07149881e-

  
  positive)
  positive)
  
  positive)
  positive)
  
  positive)
  positive)
  
  positive)
  positive)
  
  positive)
  positive)


error = 22.66077853021746
[-8.92746060e-02  5.90461338e-02  1.45769090e-02  2.49190076e+00
 -1.64695279e+01  3.13461572e+00  1.13512425e-02 -1.39358425e+00
  3.49936576e-01 -1.38474792e-02 -1.01213169e+00  9.85295405e-03
 -5.70510862e-01 -1.77621427e-01]
[ 2.26607785e+01 -8.92746060e-02  5.90461338e-02  1.45769090e-02
  2.49190076e+00 -1.64695279e+01  3.13461572e+00  1.13512425e-02
 -1.39358425e+00  3.49936576e-01 -1.38474792e-02 -1.01213169e+00
  9.85295405e-03 -5.70510862e-01 -1.77621427e-01]


  
  positive)
  positive)


error = 22.660778530222498
[-8.92746060e-02  5.90461338e-02  1.45769090e-02  2.49190076e+00
 -1.64695279e+01  3.13461572e+00  1.13512425e-02 -1.39358425e+00
  3.49936576e-01 -1.38474792e-02 -1.01213169e+00  9.85295405e-03
 -5.70510862e-01 -1.77621427e-01]
[ 2.26607785e+01 -8.92746060e-02  5.90461338e-02  1.45769090e-02
  2.49190076e+00 -1.64695279e+01  3.13461572e+00  1.13512425e-02
 -1.39358425e+00  3.49936576e-01 -1.38474792e-02 -1.01213169e+00
  9.85295405e-03 -5.70510862e-01 -1.77621427e-01]
error = 22.66077853021753
[-8.92746060e-02  5.90461338e-02  1.45769090e-02  2.49190076e+00
 -1.64695279e+01  3.13461572e+00  1.13512425e-02 -1.39358425e+00
  3.49936576e-01 -1.38474792e-02 -1.01213169e+00  9.85295405e-03
 -5.70510862e-01 -1.77621427e-01]
[ 2.26607785e+01 -8.92746060e-02  5.90461338e-02  1.45769090e-02
  2.49190076e+00 -1.64695279e+01  3.13461572e+00  1.13512425e-02
 -1.39358425e+00  3.49936576e-01 -1.38474792e-02 -1.01213169e+00
  9.85295405e-03 -5.70510862e-01 -1.77621427e-01]

  
  positive)
  positive)
  
  positive)
  positive)
  
  positive)
  positive)
  
  positive)
  positive)
  
  positive)
  positive)


error = 22.359888681985016
[-1.43131269e-01  3.26084646e-02  1.71629605e-02  2.62863030e+00
 -1.65804874e+01  3.45770313e+00  1.01340839e-02 -1.28481802e+00
  2.70970536e-01 -9.13469187e-03 -9.82199269e-01  1.06046538e-02
 -5.88715805e-01  1.15086975e-02]
[ 2.23598887e+01 -1.43131269e-01  3.26084646e-02  1.71629605e-02
  2.62863030e+00 -1.65804874e+01  3.45770313e+00  1.01340839e-02
 -1.28481802e+00  2.70970536e-01 -9.13469187e-03 -9.82199269e-01
  1.06046538e-02 -5.88715805e-01  1.15086975e-02]


  
  positive)
  positive)


error = 22.359888681988068
[-1.43131269e-01  3.26084646e-02  1.71629605e-02  2.62863030e+00
 -1.65804874e+01  3.45770313e+00  1.01340839e-02 -1.28481802e+00
  2.70970536e-01 -9.13469187e-03 -9.82199269e-01  1.06046538e-02
 -5.88715805e-01  1.15086975e-02]
[ 2.23598887e+01 -1.43131269e-01  3.26084646e-02  1.71629605e-02
  2.62863030e+00 -1.65804874e+01  3.45770313e+00  1.01340839e-02
 -1.28481802e+00  2.70970536e-01 -9.13469187e-03 -9.82199269e-01
  1.06046538e-02 -5.88715805e-01  1.15086975e-02]
error = 22.359888681984625
[-1.43131269e-01  3.26084646e-02  1.71629605e-02  2.62863030e+00
 -1.65804874e+01  3.45770313e+00  1.01340839e-02 -1.28481802e+00
  2.70970536e-01 -9.13469187e-03 -9.82199269e-01  1.06046538e-02
 -5.88715805e-01  1.15086975e-02]
[ 2.23598887e+01 -1.43131269e-01  3.26084646e-02  1.71629605e-02
  2.62863030e+00 -1.65804874e+01  3.45770313e+00  1.01340839e-02
 -1.28481802e+00  2.70970536e-01 -9.13469187e-03 -9.82199269e-01
  1.06046538e-02 -5.88715805e-01  1.15086975e-02

  
  positive)
  positive)
  
  positive)
  positive)
  
  positive)
  positive)
  
  positive)
  positive)
  
  positive)
  positive)


error = 20.62083176186124
[-8.44710766e-02  5.23202405e-02 -1.11274254e-02  1.26151623e+00
 -1.05124173e+01  4.04613256e+00 -1.10071395e-02 -1.53477467e+00
  2.86438449e-01 -1.39494123e-02 -8.70933751e-01  9.58347879e-03
 -5.37382782e-01 -5.80887926e-02]
[ 2.06208318e+01 -8.44710766e-02  5.23202405e-02 -1.11274254e-02
  1.26151623e+00 -1.05124173e+01  4.04613256e+00 -1.10071395e-02
 -1.53477467e+00  2.86438449e-01 -1.39494123e-02 -8.70933751e-01
  9.58347879e-03 -5.37382782e-01 -5.80887926e-02]
error = 22.034649096998237
[-0.04998085  0.05494856 -0.00520314  0.         -0.          1.91083434
  0.00424522 -0.96699351  0.25057644 -0.01583377 -0.68523682  0.00859935
 -0.71117316 -0.        ]
[ 2.20346491e+01 -4.99808506e-02  5.49485605e-02 -5.20314468e-03
  0.00000000e+00 -0.00000000e+00  1.91083434e+00  4.24521844e-03
 -9.66993511e-01  2.50576444e-01 -1.58337690e-02 -6.85236817e-01
  8.59935440e-03 -7.11173158e-01 -0.00000000e+00]
error = 20.633474496786818
[-8.98137621e-02  5.04952346e

  
  positive)
  positive)
  
  positive)
  positive)
  
  positive)
  positive)
  
  positive)
  positive)
  
  positive)
  positive)


error = 18.68015388237477
[-1.05475472e-01  4.64419701e-02 -2.60550374e-02  1.76425877e+00
 -1.59825525e+01  3.48472217e+00  6.88694053e-03 -1.49283535e+00
  3.48329636e-01 -1.43383096e-02 -9.12665900e-01  9.60898048e-03
 -5.46351395e-01  3.87470250e-02]
[ 1.86801539e+01 -1.05475472e-01  4.64419701e-02 -2.60550374e-02
  1.76425877e+00 -1.59825525e+01  3.48472217e+00  6.88694053e-03
 -1.49283535e+00  3.48329636e-01 -1.43383096e-02 -9.12665900e-01
  9.60898048e-03 -5.46351395e-01  3.87470250e-02]


  
  positive)
  positive)


error = 18.642729742839673
[-1.05939952e-01  4.61318332e-02 -1.96261519e-02  1.77707683e+00
 -1.72954471e+01  3.47681419e+00  8.03535862e-03 -1.51211456e+00
  3.51122546e-01 -1.42531932e-02 -9.28479908e-01  9.55997215e-03
 -5.45306591e-01  4.12367695e-02]
[ 1.86427297e+01 -1.05939952e-01  4.61318332e-02 -1.96261519e-02
  1.77707683e+00 -1.72954471e+01  3.47681419e+00  8.03535862e-03
 -1.51211456e+00  3.51122546e-01 -1.42531932e-02 -9.28479908e-01
  9.55997215e-03 -5.45306591e-01  4.12367695e-02]
error = 18.64272974284122
[-1.05939952e-01  4.61318332e-02 -1.96261519e-02  1.77707683e+00
 -1.72954471e+01  3.47681419e+00  8.03535862e-03 -1.51211456e+00
  3.51122546e-01 -1.42531932e-02 -9.28479908e-01  9.55997215e-03
 -5.45306591e-01  4.12367695e-02]
[ 1.86427297e+01 -1.05939952e-01  4.61318332e-02 -1.96261519e-02
  1.77707683e+00 -1.72954471e+01  3.47681419e+00  8.03535862e-03
 -1.51211456e+00  3.51122546e-01 -1.42531932e-02 -9.28479908e-01
  9.55997215e-03 -5.45306591e-01  4.12367695e-02]

  
  positive)
  positive)
  
  positive)
  positive)
  
  positive)
  positive)
  
  positive)
  positive)
  
  positive)
  positive)


error = 20.799899248574988
[-1.01526515e-01  4.70775399e-02 -3.04334086e-02  2.76413774e+00
 -1.51975944e+01  3.66593494e+00 -3.24295104e-03 -1.44720384e+00
  3.11172941e-01 -1.14656965e-02 -8.90438243e-01  9.66491618e-03
 -5.44293801e-01 -1.50586637e-01]
[ 2.07998992e+01 -1.01526515e-01  4.70775399e-02 -3.04334086e-02
  2.76413774e+00 -1.51975944e+01  3.66593494e+00 -3.24295104e-03
 -1.44720384e+00  3.11172941e-01 -1.14656965e-02 -8.90438243e-01
  9.66491618e-03 -5.44293801e-01 -1.50586637e-01]


  
  positive)
  positive)


error = 20.799899248581273
[-1.01526515e-01  4.70775399e-02 -3.04334086e-02  2.76413774e+00
 -1.51975944e+01  3.66593494e+00 -3.24295104e-03 -1.44720384e+00
  3.11172941e-01 -1.14656965e-02 -8.90438243e-01  9.66491618e-03
 -5.44293801e-01 -1.50586637e-01]
[ 2.07998992e+01 -1.01526515e-01  4.70775399e-02 -3.04334086e-02
  2.76413774e+00 -1.51975944e+01  3.66593494e+00 -3.24295104e-03
 -1.44720384e+00  3.11172941e-01 -1.14656965e-02 -8.90438243e-01
  9.66491618e-03 -5.44293801e-01 -1.50586637e-01]
error = 20.799899248575027
[-1.01526515e-01  4.70775399e-02 -3.04334086e-02  2.76413774e+00
 -1.51975944e+01  3.66593494e+00 -3.24295104e-03 -1.44720384e+00
  3.11172941e-01 -1.14656965e-02 -8.90438243e-01  9.66491618e-03
 -5.44293801e-01 -1.50586637e-01]
[ 2.07998992e+01 -1.01526515e-01  4.70775399e-02 -3.04334086e-02
  2.76413774e+00 -1.51975944e+01  3.66593494e+00 -3.24295104e-03
 -1.44720384e+00  3.11172941e-01 -1.14656965e-02 -8.90438243e-01
  9.66491618e-03 -5.44293801e-01 -1.50586637e-01

  
  positive)
  positive)
  
  positive)
  positive)
  
  positive)
  positive)
  
  positive)
  positive)
  
  positive)
  positive)


error = 20.867041230169043
[-1.10349321e-01  4.06704459e-02  5.21950263e-02  1.80415145e+00
 -1.66294650e+01  3.96013281e+00  7.42749159e-04 -1.39918669e+00
  3.09018149e-01 -1.38142896e-02 -9.86102722e-01  7.65535014e-03
 -5.31919542e-01 -1.66289134e-01]
[ 2.08670412e+01 -1.10349321e-01  4.06704459e-02  5.21950263e-02
  1.80415145e+00 -1.66294650e+01  3.96013281e+00  7.42749159e-04
 -1.39918669e+00  3.09018149e-01 -1.38142896e-02 -9.86102722e-01
  7.65535014e-03 -5.31919542e-01 -1.66289134e-01]


  
  positive)
  positive)


error = 20.927823778605244
[-1.12108192e-01  3.99227938e-02  6.37717820e-02  1.84204312e+00
 -1.93741606e+01  3.93491226e+00  3.20517422e-03 -1.43869679e+00
  3.14819695e-01 -1.35083395e-02 -1.01735552e+00  7.49744888e-03
 -5.29492050e-01 -1.64771953e-01]
[ 2.09278238e+01 -1.12108192e-01  3.99227938e-02  6.37717820e-02
  1.84204312e+00 -1.93741606e+01  3.93491226e+00  3.20517422e-03
 -1.43869679e+00  3.14819695e-01 -1.35083395e-02 -1.01735552e+00
  7.49744888e-03 -5.29492050e-01 -1.64771953e-01]
error = 20.927823778616386
[-1.12108192e-01  3.99227938e-02  6.37717820e-02  1.84204312e+00
 -1.93741606e+01  3.93491226e+00  3.20517422e-03 -1.43869679e+00
  3.14819695e-01 -1.35083395e-02 -1.01735552e+00  7.49744888e-03
 -5.29492050e-01 -1.64771953e-01]
[ 2.09278238e+01 -1.12108192e-01  3.99227938e-02  6.37717820e-02
  1.84204312e+00 -1.93741606e+01  3.93491226e+00  3.20517422e-03
 -1.43869679e+00  3.14819695e-01 -1.35083395e-02 -1.01735552e+00
  7.49744888e-03 -5.29492050e-01 -1.64771953e-01

  
  positive)
  positive)
  
  positive)
  positive)
  
  positive)
  positive)
  
  positive)
  positive)
  
  positive)
  positive)


error = 20.277981470555627
[-1.03990271e-01  4.53284518e-02  1.45601229e-02  2.37356995e+00
 -1.08898367e+01  4.24711265e+00 -1.29356383e-02 -1.46771749e+00
  2.94569751e-01 -1.41896683e-02 -8.59261131e-01  1.00673853e-02
 -4.96776266e-01 -1.52963219e-01]
[ 2.02779815e+01 -1.03990271e-01  4.53284518e-02  1.45601229e-02
  2.37356995e+00 -1.08898367e+01  4.24711265e+00 -1.29356383e-02
 -1.46771749e+00  2.94569751e-01 -1.41896683e-02 -8.59261131e-01
  1.00673853e-02 -4.96776266e-01 -1.52963219e-01]
error = 19.930249169292615
[-7.90463016e-02  4.69748209e-02 -0.00000000e+00  0.00000000e+00
 -0.00000000e+00  2.58647616e+00  9.54303133e-04 -9.73676022e-01
  2.78639866e-01 -1.64684048e-02 -7.18086728e-01  9.73982453e-03
 -6.45300514e-01 -0.00000000e+00]
[ 1.99302492e+01 -7.90463016e-02  4.69748209e-02 -0.00000000e+00
  0.00000000e+00 -0.00000000e+00  2.58647616e+00  9.54303133e-04
 -9.73676022e-01  2.78639866e-01 -1.64684048e-02 -7.18086728e-01
  9.73982453e-03 -6.45300514e-01 -0.00000000e+00

In [175]:
data_df = pd.DataFrame(index = data_dict.keys(), columns = empty_df.columns)
for i in data_df.index:
    for col in data_df.columns:
        data_df[col][i] = np.mean(data_dict[i][col])
data_df

Unnamed: 0,ridge,lasso,ols
error,20.8805,21.0173,20.8605
CRIM,-0.110605,-0.0981093,-0.113281
ZN,0.0464586,0.0473146,0.0457791
INDUS,-0.00197235,-0.00234289,0.0181255
CHAS,2.36165,1.15038,2.42858
NOX,-13.0876,-8.49172,-17.6308
RM,3.73784,2.9937,3.71237
AGE,-0.000504533,0.00506675,0.0032674
DIS,-1.41261,-1.20447,-1.4801
RAD,0.313089,0.30419,0.323958


- Lasso has the highest error because it pulls the slope coefficients closer to 0

# Question 4

In [41]:
def get_data(model, x_train, y_train, x_test, y_test):
    model.fit(x_train, y_train)
    preds = model.predict(x_test)
    vals = np.arange(1, 12, 1)
    df = pd.DataFrame(index = vals, columns = vals, data = confusion_matrix(y_test, preds))
    acc = accuracy_score(y_test, preds)
    return df, 1 - acc

In [7]:
vowel_test = pd.read_csv('vowel.test.txt', index_col = 'row.names')
vowel_train = pd.read_csv('vowel.train.txt', index_col = 'row.names')

In [10]:
x_train = vowel_train.drop('y', axis = 1)
x_test = vowel_test.drop('y', axis = 1)
y_train = vowel_train['y']
y_test = vowel_test['y']

In [37]:
linear = LinearDiscriminantAnalysis()
quadratic = QuadraticDiscriminantAnalysis()
diagonal = GaussianNB()

In the Confusion Matrices Below
- the row is the actual digit
- the column is the predicted digit

In [43]:
get_confusion_mat(linear, x_train, y_train, x_test, y_test)

Unnamed: 0,1,2,3,4,5,6,7,8,9,10,11
1,28,10,1,0,0,0,0,0,3,0,0
2,23,16,2,0,0,1,0,0,0,0,0
3,0,11,16,11,0,4,0,0,0,0,0
4,0,0,2,33,0,6,0,0,0,0,1
5,0,0,0,1,7,22,9,0,0,0,3
6,0,0,5,3,8,19,1,0,0,0,6
7,0,0,1,0,9,12,11,4,4,0,1
8,0,0,0,0,1,0,2,23,8,8,0
9,0,2,0,0,0,0,0,6,15,14,5
10,8,1,5,0,0,0,0,0,9,13,6


In [44]:
linear_mat, linear_acc = get_data(linear, x_train, y_train, x_test, y_test)
print("The accuracy of the linear model is {}".format(linear_acc))
linear_mat

The accuracy of the linear model is 0.5562770562770563


Unnamed: 0,1,2,3,4,5,6,7,8,9,10,11
1,28,10,1,0,0,0,0,0,3,0,0
2,23,16,2,0,0,1,0,0,0,0,0
3,0,11,16,11,0,4,0,0,0,0,0
4,0,0,2,33,0,6,0,0,0,0,1
5,0,0,0,1,7,22,9,0,0,0,3
6,0,0,5,3,8,19,1,0,0,0,6
7,0,0,1,0,9,12,11,4,4,0,1
8,0,0,0,0,1,0,2,23,8,8,0
9,0,2,0,0,0,0,0,6,15,14,5
10,8,1,5,0,0,0,0,0,9,13,6


The linear model has a hard time distinguishing beween the following vowels
- 5, 6, and 7 are frequently confused
- 8 and 9 are frequently confused
- 1 and 2 are frequently confused
- 2, 3, and 4 and frequently confused

In [45]:
quadratic_mat, quadratic_acc = get_data(quadratic, x_train, y_train, x_test, y_test)
print("The accuracy of the linear model is {}".format(quadratic_acc))
quadratic_mat

The accuracy of the linear model is 0.5281385281385281


Unnamed: 0,1,2,3,4,5,6,7,8,9,10,11
1,37,4,0,0,0,0,0,0,1,0,0
2,18,22,1,0,0,0,0,0,1,0,0
3,9,13,12,5,0,2,0,0,1,0,0
4,0,2,3,12,5,17,2,0,0,0,1
5,0,0,0,0,16,7,19,0,0,0,0
6,0,0,0,1,0,22,14,0,0,0,5
7,0,0,0,0,11,1,22,0,3,0,5
8,0,0,0,0,0,0,15,6,21,0,0
9,0,0,0,0,0,0,3,1,38,0,0
10,2,4,0,0,0,0,4,0,21,11,0


The quadratic model has a hard time distinguishing beween the following vowels
- 4 and 6 are frequently confused
- 5 and 7 are frequently confused
- 9, 10, and 11 are frequently confused

In [46]:
diag_mat, diag_acc = get_data(quadratic, x_train, y_train, x_test, y_test)
print("The accuracy of the linear model is {}".format(diag_acc))
diag_mat

The accuracy of the linear model is 0.5281385281385281


Unnamed: 0,1,2,3,4,5,6,7,8,9,10,11
1,37,4,0,0,0,0,0,0,1,0,0
2,18,22,1,0,0,0,0,0,1,0,0
3,9,13,12,5,0,2,0,0,1,0,0
4,0,2,3,12,5,17,2,0,0,0,1
5,0,0,0,0,16,7,19,0,0,0,0
6,0,0,0,1,0,22,14,0,0,0,5
7,0,0,0,0,11,1,22,0,3,0,5
8,0,0,0,0,0,0,15,6,21,0,0
9,0,0,0,0,0,0,3,1,38,0,0
10,2,4,0,0,0,0,4,0,21,11,0


The quadratic model has a hard time distinguishing beween the following vowels
- 3 and 11 are frequently confused
- 4 and 6 are frequently confused
- 7 and 9 are frequently confused
- 6 and 11 are frequently confused