# Iris â€” Linear Regression (predict sepal_length)

This notebook shows the code used for the analysis.

In [None]:
# Imports and dataset
from sklearn.datasets import load_iris
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LinearRegression
from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score

iris = load_iris(as_frame=True)
df = iris.frame.copy()
df.rename(columns={'sepal length (cm)': 'sepal_length',
                   'sepal width (cm)': 'sepal_width',
                   'petal length (cm)': 'petal_length',
                   'petal width (cm)': 'petal_width'}, inplace=True)

df.head()

## Features and target
We predict `sepal_length` using `sepal_width`, `petal_length`, and `petal_width`.

In [None]:
X = df[['sepal_width', 'petal_length', 'petal_width']]
y = df['sepal_length']
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
model = LinearRegression()
model.fit(X_train, y_train)
y_pred = model.predict(X_test)

print('MAE:', mean_absolute_error(y_test, y_pred))
print('RMSE:', mean_squared_error(y_test, y_pred, squared=False))
print('R2:', r2_score(y_test, y_pred))

## Visualizations
Plots: true vs predicted, residuals vs predicted, and correlation matrix.

In [None]:
import matplotlib.pyplot as plt
# True vs Predicted
plt.scatter(y_test, y_pred)
plt.plot([y_test.min(), y_test.max()], [y_test.min(), y_test.max()], linestyle='--')
plt.xlabel('True sepal_length')
plt.ylabel('Predicted sepal_length')
plt.title('True vs Predicted')
plt.show()

# Residuals
residuals = y_test - y_pred
plt.scatter(y_pred, residuals)
plt.axhline(0, linestyle='--')
plt.xlabel('Predicted sepal_length')
plt.ylabel('Residuals')
plt.title('Residuals vs Predicted')
plt.show()

# Correlation
corr = df[['sepal_length','sepal_width','petal_length','petal_width']].corr()
plt.imshow(corr, interpolation='nearest')
plt.colorbar()
plt.xticks(range(len(corr.columns)), corr.columns, rotation=45)
plt.yticks(range(len(corr.index)), corr.index)
plt.title('Correlation matrix')
plt.show()