In [1]:
import numpy as np
import pandas as pd
from sklearn.linear_model import LinearRegression
from sklearn.metrics import mean_squared_error
from sklearn.model_selection import train_test_split
from sklearn.model_selection import cross_val_score

In [2]:
df = pd.read_csv('gapminder.csv')

In [3]:
df.head()

Unnamed: 0,population,fertility,HIV,CO2,BMI_male,GDP,BMI_female,life,child_mortality,Region
0,34811059.0,2.73,0.1,3.328945,24.5962,12314.0,129.9049,75.3,29.5,Middle East & North Africa
1,19842251.0,6.43,2.0,1.474353,22.25083,7103.0,130.1247,58.3,192.0,Sub-Saharan Africa
2,40381860.0,2.24,0.5,4.78517,27.5017,14646.0,118.8915,75.5,15.4,America
3,2975029.0,1.4,0.1,1.804106,25.35542,7383.0,132.8108,72.5,20.0,Europe & Central Asia
4,21370348.0,1.96,0.1,18.016313,27.56373,41312.0,117.3755,81.5,5.2,East Asia & Pacific


In [4]:
y = df['life']
X = df.drop(columns=['life', 'Region'])

### 1. Create training and test sets

In [5]:
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size = 0.3, random_state=42)

### 2. linear regression regressor

In [6]:
reg_all = LinearRegression()

### 3. fit regressor to the training set

In [7]:
reg_all.fit(X_train, y_train)

LinearRegression(copy_X=True, fit_intercept=True, n_jobs=None, normalize=False)

### 4. Predict on the test data

In [8]:
y_pred = reg_all.predict(X_test)

### 5. Compute and print the R^2 and RMSE

In [9]:
print("R^2: {}".format(round(reg_all.score(X_test, y_test), 2)))
rmse = np.sqrt(mean_squared_error(y_test, y_pred))
print("Root Mean Squared Error: {}".format(round(rmse, 2)))

R^2: 0.84
Root Mean Squared Error: 3.25


### 6. Perform 5-fold cross-validation on X and y

In [10]:
cv_scores = cross_val_score(reg_all, X, y, cv=5)

In [11]:
print("Average 5-Fold CV Score: {}".format(round(np.mean(cv_scores), 2)))

Average 5-Fold CV Score: 0.86


### 7. K-Fold CV comparison

In [13]:
cvscores_3 = cross_val_score(reg_all, X, y, cv=3)
print(round(np.mean(cvscores_3),2))

cvscores_10 = cross_val_score(reg_all, X, y, cv=10)
print(round(np.mean(cvscores_10),2))

0.87
0.84
