# Final Group Project: Predict Life expectancy

**Project Info:**
- The data-set related to life expectancy from a period of 2000 to 2015, health factors has been collected from WHO data repository website and its corresponding economic data was collected from United Nation website.
- Reason: Because of mutual interest in World Health
- The dataset is getting from __[Kaggle](https://www.kaggle.com/datasets/kumarajarshi/life-expectancy-who/data)__
- Contributors
    - Priyanka
    - Rohit
    - Grant

## Problem Statement

What are the key factors that influence life expectancy across different countries, and how do socioeconomic variables, healthcare expenditures, 
and mortality rates correlate with life expectancy? Additionally, we can identify actionable insights that could help underperforming regions 
improve their life expectancy?

## Data Pre-Processing

### Load Data
Load from the data stored in **Github repository** so that each team member can directly run the code.<br>
__[Tutorial: How to read a CSV file from GitHub on Jupyter Notebook](https://www.youtube.com/watch?v=4xXBDXDSFts)__

In [None]:
# imort library
import pandas as pd

In [None]:
# Load data: load file from github repository
data = pd.read_csv('https://raw.githubusercontent.com/GrantCa24/DA_Group6-Final_Project/main/data_raw/Life%20Expectancy%20Data.csv')
data.head()

### Data Assessment

**Highlights:**
- There are 2938 rows, 22 columns
- Remove leading and trailing characters
    - __[`Series.str.strip()`](https://pandas.pydata.org/docs/reference/api/pandas.Series.str.strip.html)__
- Rename column name
    - Base on the discussion as the link here: __[1-19 years: typo in the column header](https://www.kaggle.com/datasets/kumarajarshi/life-expectancy-who/discussion/276334)__ we decided to rename the header name
- No duplicate
- Dirty Data: Several columns has max value which does not make sense

|Field|Description|
|---:|:---|
|Country|Country|
|Year|Year|
|Status|Developed or Developing status|
|Life expectancy|Life Expectancy in age|
|Adult Mortality|Adult Mortality Rates of both sexes (probability of dying between 15 and 60 years per 1000 population)|
|infant deaths|Number of Infant Deaths per 1000 population|
|Alcohol|Alcohol, recorded per capita (15+) consumption (in litres of pure alcohol)|
|percentage expenditure|Expenditure on health as a percene of Gross Domestic Product per capita(%)|
|Hepatitis B|Hepatitis B (HepB) immunization coverage among 1-year-olds (%)|
|Measles|Measles - number of reported cases per 1000 population|
|BMI|Average Body Mass Index of entire population|
|under-five deaths|Number of under-five deaths per 1000 population|
|Polio|Polio (Pol3) immunization coverage among 1-year-olds (%)|
|Total expenditure|General government expenditure on health as a percene of total government expenditure (%)|
|Diphtheria|Diphtheria tetanus toxoid and pertussis (DTP3) immunization coverage among 1-year-olds (%)|
|HIV/AIDS|Deaths per 1000 live births HIV/AIDS (0-4 years)|
|GDP|Gross Domestic Product per capita (in USD)|
|Population|Population of the country|
|thinness 10-19 years|Prevalence of thinness among children and adolescents for Age 10 to 19 (%)|
|thinness 5-9 years|Prevalence of thinness among children for Age 5 to 9(%)|
|Income composition of resources|Income composition of resources|
|Schooling|Number of years of Schooling(years)|

In [None]:
data.columns

In [None]:
# Remove spaces at the beginning and at the end of the headers(string)
data.columns = data.columns.str.strip()
print(data.columns)

In [None]:
# Rename column 1-19 years to 10-19 years
data.rename(columns={'thinness  1-19 years': 'thinness 10-19 years'}, inplace=True) # modify the DataFrame

In [None]:
# Final check after renaming column
data.columns

In [None]:
# Check the total of rows and columns)
rows, columns = data.shape
print(f"Rows: {rows}, Columns: {columns}")

In [None]:
data.info()

### Check Duplicates

There is **no duplicate** need to handle.

In [None]:
# Check for duplicate rows
duplicate_rows = data.duplicated()

# Count of duplicate rows
print(f"Number of duplicate rows: {duplicate_rows.sum()}")

### Check and Remove Null values in all the columns and rows

1. **Dropna**: We decided to drop null value of the columns that has null value lower than 10%
    - Columns: `Life expectancy`, `Adult Mortality`, `Alcohol`, `BMI`, `Polio`, `Total expenditure`, `Diphtheria`, `thinness 10-19 years`, `thinness 5-9 years`, `Income composition of resources`, `Schooling`
    - Reason: These null value is only a small portion compared with the whole dataset, thus it won't affect much with the analysis after we drop them.

2. **Imputation**: We decided to imputate those columns that has null value with 10% \~ 20% with mean value by `Status` 
    - Columns: `Hepatitis B` and `GDP`
    - Reason: The amount of missing value is large, but is not that huge to affect overall after imputation. And we believe `Status` is a great categorical indicator to imputate, considering the time and effort.

    - Notes: We attemp to inpute these two columns by each country and take the moving average. But after examine the data in detail, it will be too complicated and time-consuming.

3. **Delete column**: We decided to delete those columns that has null value over 20%
    - Column: `Population` 22.19% null value

    - Reason: The amount of missing value is too large, which after imputation will affect a lot of the dataset.

**Strategy:**
1. Drop all the rows that contain null value
2. Imputate the remaining null value with mean value by `Status`
3. Delete column

**Notes:**
- Year range from 2000~2015 in the dataset

#### Null value count & percentage

In [None]:
# Checking for missing values in each column
missing_values = data.isnull().sum()
print(missing_values)

In [None]:
missing_percentage = missing_values * 100 / len(data)
print(missing_percentage)

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns

In [None]:
print(type(data.isnull))
#visual representation of missing values in the dataset
plt.figure(figsize=(15,10))
sns.heatmap(data.isnull(), cmap = 'crest')
plt.show()

#### Strategy Step1:
**Drop all the rows that contain null value.**

In [None]:
# Strategy Step1: Drop all the rows that contain null value
data.dropna(
    subset=['Life expectancy', 'Adult Mortality', 'Alcohol', 'BMI', 'Polio', 'Total expenditure', 'Diphtheria', 'thinness 10-19 years', 'thinness 5-9 years', 'Income composition of resources', 'Schooling'],
    inplace=True)
# Show the remaing columns that have null values
data.isnull().sum()

In [None]:
data['Year'].unique()

#### Strategy Step2. Imputate the remaining null value with mean value by `Status`

##### Hepatitis B

In [None]:
null_hep_b = data[data['Hepatitis B'].isnull()]
null_hep_b_country = null_hep_b['Country'].unique()

In [None]:
for country in null_hep_b_country:
    null_hep_b_country = null_hep_b[null_hep_b['Country'] == country]
    print(country, ":")
    print(null_hep_b_country['Year'].unique())

##### GDP

In [None]:
null_gdp = data[data['GDP'].isnull()]
null_gdp_country = null_gdp['Country'].unique()

In [None]:
for country in null_gdp_country:
    null_gdp_country = null_gdp[null_gdp['Country'] == country]
    print(country, ":")
    print(null_gdp_country['Year'].unique())

##### Mean value of `Hepatitis B` & `GDP` by `Status`
- __[`mean()` will exclude null value as default](https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.DataFrame.mean.html)__

In [None]:
# Impute with Developed / Developing country's median value

#Create a groupby object
data_group = data.groupby('Status')

#Select only required columns
data_columns = data_group[['Hepatitis B', 'GDP']]

#Apply aggregate function
hep_B_gdp_by_status = data_columns.mean()

hep_B_gdp_by_status

##### Fill null value (Imputate with Mean)

In [None]:
# Fill missing values for 'Hepatitis B' based on 'Status'
data.loc[data['Status'] == 'Developed', 'Hepatitis B'] = data.loc[data['Status'] == 'Developed', 'Hepatitis B'].fillna(
    hep_B_gdp_by_status.loc['Developed','Hepatitis B'])
data.loc[data['Status'] == 'Developing', 'Hepatitis B'] = data.loc[data['Status'] == 'Developing', 'Hepatitis B'].fillna(
    hep_B_gdp_by_status.loc['Developing','Hepatitis B'])

In [None]:
# Fill missing values for 'GDP' based on 'Status'
data.loc[data['Status'] == 'Developed', 'GDP'] = data.loc[data['Status'] == 'Developed', 'GDP'].fillna(
    hep_B_gdp_by_status.loc['Developed','GDP'])
data.loc[data['Status'] == 'Developing', 'GDP'] = data.loc[data['Status'] == 'Developing', 'GDP'].fillna(
    hep_B_gdp_by_status.loc['Developing','GDP'])

#### Strategy Step3. Delete column: Population

In [None]:
null_population = data[data['Population'].isnull()]
null_population_country = null_population['Country'].unique()

In [None]:
# Drop Population
data.drop(columns=['Population'], inplace=True)

# Show the null value across columns
data.isnull().sum()

### Remove Dirty Data

__[Warning from the discussion](https://www.kaggle.com/datasets/kumarajarshi/life-expectancy-who/discussion/161872)__
- Filter out the observations of the value > 1000 of three columns which is measured by `per 1000 population`
    - `infant death`
    - `Measles`
    - `under-five deaths`

In [None]:
data.describe()

#### Filter out dirty data which measured by `per 1000 population`

In [None]:
# value of infant deaths, Measles, and under-five deaths should be <= 1000
cols = ["infant deaths", "Measles", "under-five deaths"]

# Filter out rows where any of the specified columns have values > 1,000
data = data[(data[cols] <= 1000).all(axis=1)] # Filter and remain those <= 1000

In [None]:
data.describe()

In [None]:
# Show the remaing columns that have null values
data.isnull().sum()

In [None]:
data.shape

## Exploratory Data Analysis (EDA)

**Highlights:**
- `Life expectancy`: Life expectancy reaches a peak in the **70-80 year range**, as seen in the histogram. The boxplot reveals that **more than 50% of the population lives beyond 70 years**, with the maximum life expectancy nearing 90 years. The median life expectancy is just above 70 years, indicating a generally high life expectancy. However, 25% of the population has a life expectancy of less than 65 years, with **several outliers living less than 50 years**. This reflects that while overall life expectancy is high, there are notable disparities, with a portion of the population suffering from significantly lower life expectancy. 
- **Right skew ( > 1)**: `under-five deaths`, `infant deaths` , `HIV/AIDS`, `percentage expenditure`, `GDP`, `Measles`, `Adult Mortality`, `thinness 5-9 years`, `thinness 10-19 years`
- **Left skew ( < -1)**: `Income composition of resources`, `Hepatitis B`, `Polio`, `Diphtheria`

<br>

**Methods:**
- `pandas.DataFrame.hist` : Only **numerical columns** will be plotted. __[Here for more info](https://pandas.pydata.org/docs/reference/api/pandas.DataFrame.hist.html)__
- `subplot(nrows, ncols, index)` __[Here for more info](https://matplotlib.org/stable/api/_as_gen/matplotlib.pyplot.subplot.html)__
- `kdeplot` : Only **numerical columns** will be plotted. __[Here for more info](https://seaborn.pydata.org/generated/seaborn.kdeplot.html)__

In [None]:
# Check the histograms
data.hist(bins=35, figsize=(18, 12))
plt.show()

In [None]:
num_cols = data.select_dtypes("number").columns # select all numeric types
print(f"There are {len(cols)} numeric columns: \n {cols}")

non_num_cols = data.select_dtypes(exclude=['int64', 'float64']).columns # select all numeric types
print(f"There are {len(non_num_cols)} non numeric columns: \n {non_num_cols}")

In [None]:
fig = plt.figure(figsize=(25,18))

graph_index = 1 # Set the position of the subplot to 1
for col in num_cols:
    plt.subplot(5, 4, graph_index) # subplot(nrows, ncols, index)
    graph = sns.kdeplot(data = data, x = col, fill = True)
    graph_index += 1 # Set the position to the next one

In [None]:
fig = plt.figure(figsize=(25,18))

boxplot_index = 1 # Set the position of the subplot to 1
for col in num_cols:
    plt.subplot(5, 4, boxplot_index) # subplot(nrows, ncols, index)
    # The higher the better (Life expectancy & immunization coverage)
    if col in ['Life expectancy', 'Hepatitis B', 'Polio', 'Diphtheria']:
        boxplot = sns.boxplot(data=data, x=col, boxprops=dict(alpha=1))  # Set alpha for transparency
    else:
        boxplot = sns.boxplot(data=data, x=col, boxprops=dict(alpha=0.4))  # Set alpha for transparency
    boxplot_index += 1 # Set the position to the next one

In [None]:
data.skew(axis = 0, skipna=True, numeric_only=True).sort_values(ascending=False) # Skewness in each numeric column with ascending order

- **Positive** value: The distribution is skewed to the **right**.
- **Negative** value: The distribution is skewed to the **left**.
- **0**: **Perfect normal distribution**.

### 🔓Synthetic data
**Adding synthetic data for unbiased data** <br>
Resources:
- __[imbalanced-learn documentation](https://imbalanced-learn.org/stable/)__
- __[YouTube tutorial](https://www.youtube.com/watch?v=4SivdTLIwHc)__

As we further analyze, we found out that in a total of 168 countries:
- Developed: **29** (with 406 observations)
- Developing: **139** (with 1710 observations)

This finding can explain why the box plot has so many outliers in some degree

In [None]:
country_status = data.groupby('Country')['Status'].value_counts()
country_status # Type: Series

In [None]:
# Count the number of unique countries
num_countries = country_status.index.get_level_values('Country').nunique()
print(f'Number of unique countries: {num_countries}')

In [None]:
# Count the total number of countries in each status group
countries_per_status = country_status.groupby('Status').size()

# Group by 'Status' and then count the number of unique countries in each group
#countries_per_status = country_status.groupby('Status').apply(lambda x: x.index.get_level_values('Country').nunique())

print(countries_per_status)

In [None]:
# Count the total number of observations in each status group
observations_per_status = country_status.groupby('Status').sum()
observations_per_status

In [None]:
# Combine the results into a DataFrame
status_summary = pd.DataFrame({'# of Countries': countries_per_status, '# of Observations': observations_per_status})

print(status_summary)

Use over sampling method and apply to `data_resampled`
- Make the observation of developed become the same as developing: 1701
- Drop `Country` column since in the linear regression model, we would only use numerical data
- As a result, `data_resampled` has 3420 rows and 20 columns

#### Resampling: Apply one-hot encoding to `Country`

In [None]:
import pandas as pd
from imblearn.over_sampling import SMOTE
from sklearn.preprocessing import StandardScaler

# One-hot encode the 'Country' column
data_encoded = pd.get_dummies(data, columns=['Country'])

# Separate features and target variable
X = data_encoded.drop(['Status'], axis=1)  # Drop the target column 'Status'
y = data_encoded['Status']  # Target variable

# Apply SMOTE to balance the dataset
smote = SMOTE(random_state=42)
X_resampled, y_resampled = smote.fit_resample(X, y)

# Combine resampled data into a DataFrame
data_resampled = pd.DataFrame(X_resampled, columns=X.columns)
data_resampled['Status'] = y_resampled

# Display the count of each unique value in the 'Status' column after resampling
print(data_resampled['Status'].value_counts())

In [None]:
# Show pie plot
pie_plot = y_resampled.value_counts().plot.pie(autopct='%.2f')
title = pie_plot.set_title("Over-sampling")

In [None]:
print(data_resampled.head())

In [None]:
# Check the (row, column) of data_resampled
data_resampled.shape

In [None]:
# KDE plot resample data: continuous features
fig = plt.figure(figsize=(25,18))

graph_index = 1 # Set the position of the subplot to 1
for col in num_cols:
    plt.subplot(5, 4, graph_index) # subplot(nrows, ncols, index)
    graph = sns.kdeplot(data = data_resampled, x = col, fill = True)
    graph_index += 1 # Set the position to the next one

In [None]:
# Box plot resample data
fig = plt.figure(figsize=(25,18))

boxplot_index = 1 # Set the position of the subplot to 1
for col in num_cols:
    plt.subplot(5, 4, boxplot_index) # subplot(nrows, ncols, index)
    # The higher the better (Life expectancy & immunization coverage)
    if col in ['Life expectancy', 'Hepatitis B', 'Polio', 'Diphtheria']:
        boxplot = sns.boxplot(data=data_resampled, x=col, boxprops=dict(alpha=1))  # Set alpha for transparency
    else:
        boxplot = sns.boxplot(data=data_resampled, x=col, boxprops=dict(alpha=0.4))  # Set alpha for transparency
    boxplot_index += 1 # Set the position to the next one

In [None]:
data_resampled.skew(axis = 0, skipna=True, numeric_only=True).sort_values(ascending=False) # Skewness in each numeric column with ascending order

🔓**For vizualization to carry on**

In [None]:
# Add this section to visualization part if needed
# Reverse one-hot encoding for 'Country'
country_columns = [col for col in data_resampled.columns if col.startswith('Country_')]
data_resampled['Country'] = data_resampled[country_columns].idxmax(axis=1).apply(lambda x: x.split('_')[-1])
data_resampled.drop(country_columns, axis=1, inplace=True)

In [None]:
data_resampled

In [None]:
# Check after reverse the one-jpt encoding, how many Developed/Developing countries are there
country_status = data_resampled.groupby('Country')['Status'].value_counts()

# Count the total number of countries in each status group
countries_per_status = country_status.groupby('Status').size()

# Group by 'Status' and then count the number of unique countries in each group
#countries_per_status = country_status.groupby('Status').apply(lambda x: x.index.get_level_values('Country').nunique())

print(countries_per_status)

### Correlation

In [None]:
import seaborn as sns
import matplotlib.pyplot as plt
import pandas as pd


# Select numerical features for the correlation heatmap
numerical_features = [
    'Life expectancy', 'Adult Mortality', 'infant deaths', 'Alcohol', 'percentage expenditure', 
    'Hepatitis B', 'Measles', 'BMI', 'under-five deaths', 'Polio', 
    'Total expenditure', 'Diphtheria', 'HIV/AIDS', 'GDP', 
    'thinness 10-19 years', 'thinness 5-9 years', 'Income composition of resources', 
    'Schooling'
]

# Extract the relevant data
data_num = data_resampled[numerical_features]

# Compute the correlation matrix
correlation_matrix = data_num.corr()
print(correlation_matrix)

In [None]:
import numpy as np

# data_resampled
# Select numerical features for the correlation heatmap
numerical_features = [
    'Life expectancy', 'Adult Mortality', 'infant deaths', 'Alcohol', 'percentage expenditure', 
    'Hepatitis B', 'Measles', 'BMI', 'under-five deaths', 'Polio', 
    'Total expenditure', 'Diphtheria', 'HIV/AIDS', 'GDP', 
    'thinness 10-19 years', 'thinness 5-9 years', 'Income composition of resources', 
    'Schooling'
]

# Extract the relevant data
data_resampled_corr = data_resampled[numerical_features]

# Compute the correlation matrix
correlation_matrix = data_resampled_corr.corr()

# Plot the heatmap using matplotlib
fig, ax = plt.subplots(figsize=(10, 8))
cax = ax.matshow(correlation_matrix, cmap='coolwarm')

# Add color bar
fig.colorbar(cax)

# Set axis labels
ax.set_xticks(range(len(correlation_matrix.columns)))
ax.set_yticks(range(len(correlation_matrix.index)))
ax.set_xticklabels(correlation_matrix.columns, rotation=90)
ax.set_yticklabels(correlation_matrix.index)

# Add the correlation values as text
for (i, j), val in np.ndenumerate(correlation_matrix):
    ax.text(j, i, f'{val:.1f}', ha='center', va='center', color='black')

plt.title("Correlation Heatmap of Life Expectancy Variables", pad=20)
plt.show()

In [None]:
data_resampled.columns

In [None]:
data_resampled.to_csv('output.csv', index=False)

## Data Visualization

## 1. Which decade has the highest life expectancy?

In [None]:
# Create a new column for Life Expectancy Range in decades
bins = [0, 30, 40, 50, 60, 70, 80, 90, 100]
labels = ['0-30', '31-40', '41-50', '51-60', '61-70', '71-80', '81-90', '91-100']
data_resampled['Life Expectancy Range'] = pd.cut(data_resampled['Life expectancy'], bins=bins, labels=labels, right=False)

# Create a pivot table to count the number of occurrences in each Life Expectancy Range
pivot_table = data_resampled['Life Expectancy Range'].value_counts().sort_index().reset_index()
pivot_table.columns = ['Life Expectancy Range', 'Count']

# Choose a list of colors for the bars
colors = ['#FF6347', '#FFD700', '#90EE90', '#87CEEB', '#9370DB', '#FF69B4', '#FF4500', '#4682B4']

# Plot the pivot table as a bar chart with custom colors
plt.figure(figsize=(14, 8))
plt.bar(pivot_table['Life Expectancy Range'], pivot_table['Count'], color=colors[:len(pivot_table)])
plt.title('Distribution of Life Expectancy in Decade Windows')
plt.xlabel('Life Expectancy Range')
plt.ylabel('Count')
plt.xticks(rotation=45)
plt.show()

## 2. What is the proportion of different causes of deaths?

In [None]:
# Summing up the values for different death causes
death_causes = {
    'Infant Deaths': data_resampled['infant deaths'].sum(),
    'Under-Five Deaths': data_resampled['under-five deaths'].sum(),
    'Adult Mortality': data_resampled['Adult Mortality'].sum()
}

# Labels and sizes for the pie chart
labels = death_causes.keys()
sizes = death_causes.values()

# Plotting the pie chart
plt.figure(figsize=(8, 8))
plt.pie(sizes, labels=labels, autopct='%1.1f%%', colors=['lightcoral', 'lightskyblue', 'lightgreen'], startangle=140)
plt.title('Proportion of Different Causes of Death')
plt.axis('equal')  # Equal aspect ratio ensures that pie is drawn as a circle.

# Show the plot
plt.show()

## 3. Which 10 countries have the highest life expectancy?

In [None]:
# Calculate the mean life expectancy for each country
country_life_expectancy = data_resampled.groupby('Country')['Life expectancy'].mean().reset_index()

# Sort the countries by life expectancy and select the top 10
top_10_countries = country_life_expectancy.nlargest(10, 'Life expectancy')

# Plot the top 10 countries with the highest life expectancy
plt.figure(figsize=(14, 8))
ax = sns.barplot(x='Life expectancy', y='Country', data=top_10_countries, palette='viridis')
plt.title('Top 10 Countries with Highest Life Expectancy')
plt.xlabel('Average Life Expectancy')
plt.ylabel('Country')

# Set the x-axis limit
plt.xlim(50, top_10_countries['Life expectancy'].max() + 10)  # Add a little padding on the upper limit

plt.show()

## 4. What has been the average Adult Mortality Rate over the years?

In [None]:
# Group the data by 'Year' and calculate the mean of 'Adult Mortality' for each year
grouped_data = data_resampled.groupby('Year')['Adult Mortality'].mean().reset_index()

# Set the size of the figure
plt.figure(figsize=(14, 8))

# Create an area plot by filling the area under the line plot
# 'fill_between' fills the area between the x-axis and the line representing 'Adult Mortality'
# 'alpha=0.4' sets the transparency level of the filled area to 40%
plt.fill_between(grouped_data['Year'], grouped_data['Adult Mortality'], color='skyblue', alpha=0.4)

# Plot a line chart for 'Adult Mortality' over the years
# 'marker="o"' adds a circle marker at each data point
# 'color="blue"' sets the color of the line and markers
# 'linestyle="-" ' sets the line style to a solid line
plt.plot(grouped_data['Year'], grouped_data['Adult Mortality'], marker='o', color='blue', linestyle='-')

# Set the title of the plot to describe what the chart represents
plt.title('Average Adult Mortality Rate Over the Years')

# Label the x-axis as 'Year' to indicate the time period on this axis
plt.xlabel('Year')

# Label the y-axis as 'Average Adult Mortality Rate' to indicate the average mortality rates measured
plt.ylabel('Average Adult Mortality Rate')

# Add a grid to the plot to improve readability and to make it easier to compare values
plt.grid(True)

# Display the plot
plt.show()

## 5. Correlation of various diseases with Life expectancy

In [None]:
# Select relevant columns for analysis
columns_of_interest = ['Life expectancy', 'Alcohol', 'Polio', 'Measles', 'HIV/AIDS']

# Calculate correlations with life expectancy
correlations = data_resampled[columns_of_interest].corr()['Life expectancy'].drop('Life expectancy')

# Plotting the correlations
plt.figure(figsize=(10, 6))
correlations.sort_values().plot(kind='barh', color='skyblue')

# Adding title and labels
plt.title('Correlation of Various Factors with Life Expectancy')
plt.xlabel('Correlation Coefficient')
plt.ylabel('Factors')

# Show the plot
plt.show()

## 6. Do Schooling and GDP have an impact over the life expectancy?

In [None]:
from mpl_toolkits.mplot3d import Axes3D

fig = plt.figure(figsize=(10, 8))
ax = fig.add_subplot(111, projection='3d')

# 3D scatter plot
scatter = ax.scatter(data_resampled['GDP'], data_resampled['Schooling'], data_resampled['Life expectancy'],
                     c=data_resampled['Life expectancy'], cmap='viridis', s=50)

ax.set_xlabel('GDP')
ax.set_ylabel('Schooling')
ax.set_zlabel('Life Expectancy')
plt.title('3D Plot of Life Expectancy, GDP, and SCHOOLING')
fig.colorbar(scatter, ax=ax, shrink=0.5, aspect=5)
plt.show()

## 7. Do different categories have varied life expectancy?


In [None]:
plt.figure(figsize=(12, 8))

# Violin plot of life expectancy by country status
sns.violinplot(x='Status', y='Life expectancy', data=data_resampled, palette='coolwarm')
plt.title('Violin Plot of Life Expectancy Across Different Categories')
plt.xlabel('Country Status')
plt.ylabel('Life Expectancy')
plt.show()

## 8. What is the expenditure of the 5 Countries with least life expectancy?

In [None]:

import matplotlib.pyplot as plt
import numpy as np

# Group by 'Country' and calculate the average 'Life expectancy' and 'Total expenditure'
avg_life_exp_expenditure = data_resampled.groupby('Country').agg({
    'Life expectancy': 'mean',
    'Total expenditure': 'mean'
}).sort_values(by='Life expectancy', ascending=True).head(5)

# Create a list of colors for the bubbles
colors = plt.cm.viridis(np.linspace(0, 1, len(avg_life_exp_expenditure)))

# Plotting
plt.figure(figsize=(12, 8))
plt.scatter(avg_life_exp_expenditure['Life expectancy'], 
            avg_life_exp_expenditure['Total expenditure'], 
            s=avg_life_exp_expenditure['Total expenditure'] * 50,  # Increase the size of the bubbles
            alpha=0.7, color=colors, edgecolor='black')

plt.title('5 Countries with Lowest Average Life Expectancy and their Average Total Expenditure')
plt.xlabel('Average Life Expectancy')
plt.ylabel('Average Total Expenditure')
plt.grid(True)

# Adding text labels for each point (country names and expenditure amounts)
for i in range(len(avg_life_exp_expenditure)):
    plt.text(avg_life_exp_expenditure['Life expectancy'].iloc[i] + 0.1,  # Adjust for better positioning
             avg_life_exp_expenditure['Total expenditure'].iloc[i] + 0.1, 
             f"{avg_life_exp_expenditure.index[i]}\n%{avg_life_exp_expenditure['Total expenditure'].iloc[i]:,.2f}", 
             fontsize=9, verticalalignment='center')

plt.show()

## Predictive Model

### Linear Regression

In [None]:
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LinearRegression
from sklearn.metrics import mean_squared_error

# List of numerical features and target
numerical_features = [
    'Adult Mortality', 'infant deaths', 'Alcohol', 'percentage expenditure', 
    'Hepatitis B', 'Measles', 'BMI', 'under-five deaths', 'Polio', 
    'Total expenditure', 'Diphtheria', 'HIV/AIDS', 'GDP', 
    'thinness 10-19 years', 'thinness 5-9 years', 'Income composition of resources', 
    'Schooling'
]
target = 'Life expectancy'

# Prepare the features and target, handling missing values
X = data_resampled[numerical_features]
y = data_resampled[target]

# Split the data into training and testing sets
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# Fit the linear regression model
model = LinearRegression()
model.fit(X_train, y_train)

# Predict and evaluate the model
y_pred = model.predict(X_test)
r_squared = model.score(X_test, y_test)
mse = mean_squared_error(y_test, y_pred)

print(f'R-squared: {r_squared:.2f}, Mean Squared Error: {mse:.2f}')


In [None]:
# Scatter plot for Actual vs Predicted values
plt.figure(figsize=(10, 6))
plt.scatter(y_test, y_pred, alpha=0.5, color='blue')
plt.plot([min(y_test), max(y_test)], [min(y_test), max(y_test)], color='red')  # Diagonal line
plt.xlabel('Actual Life Expectancy')
plt.ylabel('Predicted Life Expectancy')
plt.title('Actual vs Predicted Life Expectancy')
plt.show()

## Performance of other Model's

In [None]:
from sklearn.linear_model import Ridge
from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import r2_score

# Create a Ridge regression model
ridge_model = make_pipeline(StandardScaler(), Ridge(alpha=1.0))

# Train the model
ridge_model.fit(X_train, y_train)

# Predict on the test set
y_pred_ridge = ridge_model.predict(X_test)

# Calculate mean squared error and R-squared
mse_ridge = mean_squared_error(y_test, y_pred_ridge)
r2_ridge = r2_score(y_test, y_pred_ridge)

print(f'Mean Squared Error (Ridge): {mse_ridge}')
print(f'R-squared (Ridge): {r2_ridge}')


In [None]:
from sklearn.linear_model import Lasso

# Create a Lasso regression model
lasso_model = make_pipeline(StandardScaler(), Lasso(alpha=0.1))

# Train the model
lasso_model.fit(X_train, y_train)

# Predict on the test set
y_pred_lasso = lasso_model.predict(X_test)

# Calculate mean squared error and R-squared
mse_lasso = mean_squared_error(y_test, y_pred_lasso)
r2_lasso = r2_score(y_test, y_pred_lasso)

print(f'Mean Squared Error (Lasso): {mse_lasso}')
print(f'R-squared (Lasso): {r2_lasso}')


In [None]:
from sklearn.ensemble import RandomForestRegressor

# Create a Random Forest regression model
rf_model = RandomForestRegressor(n_estimators=100, random_state=42)

# Train the model
rf_model.fit(X_train, y_train)

# Predict on the test set
y_pred_rf = rf_model.predict(X_test)

# Calculate mean squared error and R-squared
mse_rf = mean_squared_error(y_test, y_pred_rf)
r2_rf = r2_score(y_test, y_pred_rf)

print(f'Mean Squared Error (Random Forest): {mse_rf}')
print(f'R-squared (Random Forest): {r2_rf}')



In [None]:
import matplotlib.pyplot as plt

def plot_results(y_test, y_pred_ridge, y_pred_lasso, y_pred_rf):
    plt.figure(figsize=(15, 5))

    # Ridge Regression
    plt.subplot(1, 3, 1)
    plt.scatter(y_test, y_pred_ridge, alpha=0.5, color='blue')
    plt.plot([min(y_test), max(y_test)], [min(y_test), max(y_test)], color='red')  # Diagonal line
    plt.xlabel('Actual Life Expectancy')
    plt.ylabel('Predicted Life Expectancy')
    plt.title('Ridge Regression')

    # Lasso Regression
    plt.subplot(1, 3, 2)
    plt.scatter(y_test, y_pred_lasso, alpha=0.5, color='green')
    plt.plot([min(y_test), max(y_test)], [min(y_test), max(y_test)], color='red')  # Diagonal line
    plt.xlabel('Actual Life Expectancy')
    plt.ylabel('Predicted Life Expectancy')
    plt.title('Lasso Regression')

    # Random Forest Regression
    plt.subplot(1, 3, 3)
    plt.scatter(y_test, y_pred_rf, alpha=0.5, color='purple')
    plt.plot([min(y_test), max(y_test)], [min(y_test), max(y_test)], color='red')  # Diagonal line
    plt.xlabel('Actual Life Expectancy')
    plt.ylabel('Predicted Life Expectancy')
    plt.title('Random Forest Regression')

    plt.tight_layout()
    plt.show()

# Plot results for each model side by side
plot_results(y_test, y_pred_ridge, y_pred_lasso, y_pred_rf)
