<a href="https://colab.research.google.com/github/HarlanAlternative/I4/blob/main/I4_BDAS_Notebook-checkpoint.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Iteration 4: BDAS - PySpark + Colab + AWS
## Life Expectancy Prediction with Spark MLlib

This notebook implements the complete data science pipeline using PySpark for big data analytics based on the I3.py analysis.


## 1. Environment Setup and Spark Installation




In [1]:
# Install required packages for Spark environment
!apt-get update
!apt-get install openjdk-11-jdk-headless -qq > /dev/null

# Download and setup Spark 3.5.0
!wget -q https://archive.apache.org/dist/spark/spark-3.5.0/spark-3.5.0-bin-hadoop3.tgz
!tar xf spark-3.5.0-bin-hadoop3.tgz

# Install Python packages
!pip install findspark pyspark==3.5.0

import os
import findspark
os.environ["JAVA_HOME"] = "/usr/lib/jvm/java-11-openjdk-amd64"
os.environ["SPARK_HOME"] = "/content/spark-3.5.0-bin-hadoop3"
findspark.init()

from pyspark.sql import SparkSession
from pyspark.sql.functions import *
from pyspark.sql.types import *
from pyspark.ml import Pipeline
from pyspark.ml.feature import *
from pyspark.ml.regression import *
from pyspark.ml.evaluation import RegressionEvaluator
from pyspark.ml.tuning import CrossValidator, ParamGridBuilder
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np

# Initialize SparkSession
spark = SparkSession.builder \
    .appName("LifeExpectancyPrediction") \
    .config("spark.sql.adaptive.enabled", "true") \
    .config("spark.sql.adaptive.coalescePartitions.enabled", "true") \
    .getOrCreate()

print(f"Spark version: {spark.version}")
spark


0% [Working]            Get:1 http://security.ubuntu.com/ubuntu jammy-security InRelease [129 kB]
Get:2 https://cloud.r-project.org/bin/linux/ubuntu jammy-cran40/ InRelease [3,632 B]
Hit:3 https://cli.github.com/packages stable InRelease
Hit:4 https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64  InRelease
Hit:5 http://archive.ubuntu.com/ubuntu jammy InRelease
Get:6 https://r2u.stat.illinois.edu/ubuntu jammy InRelease [6,555 B]
Get:7 http://archive.ubuntu.com/ubuntu jammy-updates InRelease [128 kB]
Hit:8 https://ppa.launchpadcontent.net/deadsnakes/ppa/ubuntu jammy InRelease
Hit:9 https://ppa.launchpadcontent.net/graphics-drivers/ppa/ubuntu jammy InRelease
Get:10 http://security.ubuntu.com/ubuntu jammy-security/universe amd64 Packages [1,287 kB]
Hit:11 https://ppa.launchpadcontent.net/ubuntugis/ppa/ubuntu jammy InRelease
Get:12 http://archive.ubuntu.com/ubuntu jammy-backports InRelease [127 kB]
Get:13 https://r2u.stat.illinois.edu/ubuntu jammy/main amd64 Package

## 2. 导入并检查数据


In [2]:
# 从 CSV 加载 WHO Life Expectancy 数据集
df = spark.read.options(header=True, inferSchema=True).csv("Life Expectancy Data.csv")

# 自动推断字段类型、打印 schema 与前 5 行样例
print("Data Schema:")
df.printSchema()

print("\nFirst 5 rows:")
df.show(5)

# 统计总行数与各列缺失数量
total_rows = df.count()
print(f"\nTotal rows: {total_rows}")

print("\nMissing values per column:")
missing_stats = df.select([count(when(col(c).isNull(), c)).alias(c) for c in df.columns])
missing_stats.show(1, vertical=True)


AnalysisException: [PATH_NOT_FOUND] Path does not exist: file:/content/Life Expectancy Data.csv.

## 3. 清洗数据


In [None]:
# 转换数据类型和处理缺失值
numeric_columns = ['Life expectancy ', 'Adult Mortality', 'infant deaths', 'Alcohol',
                   'percentage expenditure', 'Hepatitis B', 'Measles ', ' BMI ',
                   'under-five deaths ', 'Polio', 'Total expenditure', 'Diphtheria ',
                   ' HIV/AIDS', 'GDP', 'Population', ' thinness  1-19 years',
                   ' thinness 5-9 years', 'Income composition of resources', 'Schooling']

# 转换为适当的数据类型
df_clean = df
for col_name in numeric_columns:
    if col_name in df.columns:
        df_clean = df_clean.withColumn(col_name, col(col_name).cast("double"))

# 确保 Year 是整数
df_clean = df_clean.withColumn("Year", col("Year").cast("int"))

# 用均值填补数值缺失
numeric_means = {}
for col_name in numeric_columns:
    if col_name in df_clean.columns:
        mean_val = df_clean.select(mean(col(col_name))).collect()[0][0]
        numeric_means[col_name] = mean_val if mean_val is not None else 0.0

for col_name, mean_val in numeric_means.items():
    df_clean = df_clean.withColumn(col_name, when(col(col_name).isNull(), mean_val).otherwise(col(col_name)))

# 用 "Developing" 填补 Status
df_clean = df_clean.withColumn("Status", when(col("Status").isNull(), "Developing").otherwise(col("Status")))

# 校正不合理值（如疫苗覆盖率 > 100 或 < 0）
vaccine_cols = ['Hepatitis B', 'Polio', 'Diphtheria ']
for vaccine_col in vaccine_cols:
    if vaccine_col in df_clean.columns:
        df_clean = df_clean.withColumn(vaccine_col,
            when(col(vaccine_col) > 100, 100.0)
            .when(col(vaccine_col) < 0, 0.0)
            .otherwise(col(vaccine_col)))

print("Data cleaning completed.")
print("\nSummary statistics:")
df_clean.describe().show()


## 4. 构造新特征


In [None]:
from pyspark.sql.functions import log, log1p

# 基于 I3.py 特征工程部分创建新特征
df_features = df_clean

# 1. log_GDP：GDP 取对数
df_features = df_features.withColumn("log_GDP", log1p(col("GDP") + 1e-6))

# 2. immunization_avg：三种疫苗平均
vaccine_cols = ['Hepatitis B', 'Polio', 'Diphtheria ']
available_vaccine_cols = [c for c in vaccine_cols if c in df_features.columns]
if available_vaccine_cols:
    df_features = df_features.withColumn("immunization_avg",
        sum(*[col(vc) for vc in available_vaccine_cols]) / len(available_vaccine_cols))

# 3. mortality_ratio：婴儿死亡率比
if 'infant deaths' in df_features.columns and 'under-five deaths ' in df_features.columns and 'Adult Mortality' in df_features.columns:
    df_features = df_features.withColumn("child_mortality_rate",
        col('infant deaths') + col('under-five deaths '))
    df_features = df_features.withColumn("mortality_ratio",
        col('child_mortality_rate') / (col('Adult Mortality') + 1e-6))

# 4. health_spend_pc：人均健康支出
if 'percentage expenditure' in df_features.columns and 'Population' in df_features.columns:
    df_features = df_features.withColumn("health_spend_pc",
        col('percentage expenditure') / (col('Population') + 1e-6))

# 基于 I3.py 的额外特征
if 'percentage expenditure' in df_features.columns and 'GDP' in df_features.columns:
    df_features = df_features.withColumn('health_expenditure_ratio',
        col('percentage expenditure') / (col('GDP') + 1e-6))

disease_cols = ['Measles ', 'Polio', 'Diphtheria ']
if 'GDP' in df_features.columns and 'Income composition of resources' in df_features.columns:
    df_features = df_features.withColumn('economic_dev_index',
        col('GDP') * col('Income composition of resources'))

print("New features created.")
new_feature_cols = [c for c in df_features.columns if c not in df_clean.columns]
print(f"New features: {new_feature_cols}")


## 5. 整合与丰富数据



In [None]:
# 为 Country 增加 Region 字段
region_data = {
    'Afghanistan': 'South Asia', 'Albania': 'Europe', 'Algeria': 'North Africa',
    'Argentina': 'South America', 'Armenia': 'Europe', 'Australia': 'Oceania',
    'Austria': 'Europe', 'Azerbaijan': 'Europe', 'Bahamas': 'Caribbean',
    'Bahrain': 'Middle East', 'Bangladesh': 'South Asia', 'Barbados': 'Caribbean',
    'Belarus': 'Europe', 'Belgium': 'Europe', 'Belize': 'Central America',
    'Benin': 'West Africa', 'Bhutan': 'South Asia', 'Bolivia': 'South America',
    'Bosnia and Herzegovina': 'Europe', 'Botswana': 'Southern Africa', 'Brazil': 'South America',
    'Brunei': 'Southeast Asia', 'Bulgaria': 'Europe', 'Burkina Faso': 'West Africa',
    'Burundi': 'East Africa', 'Cabo Verde': 'West Africa', 'Cambodia': 'Southeast Asia',
    'Cameroon': 'Central Africa', 'Canada': 'North America', 'Central African Republic': 'Central Africa',
    'Chad': 'Central Africa', 'Chile': 'South America', 'China': 'East Asia',
    'Colombia': 'South America', 'Comoros': 'East Africa', 'Congo': 'Central Africa',
    'Costa Rica': 'Central America', "Cote d'Ivoire": 'West Africa', 'Croatia': 'Europe',
    'Cuba': 'Caribbean', 'Cyprus': 'Europe', 'Czech Republic': 'Europe',
    'Denmark': 'Europe', 'Djibouti': 'East Africa', 'Dominican Republic': 'Caribbean',
    'Ecuador': 'South America', 'Egypt': 'North Africa', 'El Salvador': 'Central America',
    'Eritrea': 'East Africa', 'Estonia': 'Europe', 'Ethiopia': 'East Africa',
    'Fiji': 'Oceania', 'Finland': 'Europe', 'France': 'Europe',
    'Gabon': 'Central Africa', 'Gambia': 'West Africa', 'Georgia': 'Europe',
    'Germany': 'Europe', 'Ghana': 'West Africa', 'Greece': 'Europe',
    'Grenada': 'Caribbean', 'Guatemala': 'Central America', 'Guinea': 'West Africa',
    'Guinea-Bissau': 'West Africa', 'Guyana': 'South America', 'Haiti': 'Caribbean',
    'Honduras': 'Central America', 'Hungary': 'Europe', 'Iceland': 'Europe',
    'India': 'South Asia', 'Indonesia': 'Southeast Asia', 'Iran': 'Middle East',
    'Iraq': 'Middle East', 'Ireland': 'Europe', 'Israel': 'Middle East',
    'Italy': 'Europe', 'Jamaica': 'Caribbean', 'Japan': 'East Asia',
    'Jordan': 'Middle East', 'Kazakhstan': 'Central Asia', 'Kenya': 'East Africa',
    'Kiribati': 'Oceania', 'Kuwait': 'Middle East', 'Kyrgyzstan': 'Central Asia',
    'Laos': 'Southeast Asia', 'Latvia': 'Europe', 'Lebanon': 'Middle East',
    'Lesotho': 'Southern Africa', 'Liberia': 'West Africa', 'Libya': 'North Africa',
    'Lithuania': 'Europe', 'Luxembourg': 'Europe', 'Madagascar': 'East Africa',
    'Malawi': 'East Africa', 'Malaysia': 'Southeast Asia', 'Maldives': 'South Asia',
    'Mali': 'West Africa', 'Malta': 'Europe', 'Mauritania': 'North Africa',
    'Mauritius': 'East Africa', 'Mexico': 'North America', 'Mongolia': 'East Asia',
    'Montenegro': 'Europe', 'Morocco': 'North Africa', 'Mozambique': 'East Africa',
    'Myanmar': 'Southeast Asia', 'Namibia': 'Southern Africa', 'Nepal': 'South Asia',
    'Netherlands': 'Europe', 'New Zealand': 'Oceania', 'Nicaragua': 'Central America',
    'Niger': 'West Africa', 'Nigeria': 'West Africa', 'Norway': 'Europe',
    'Oman': 'Middle East', 'Pakistan': 'South Asia', 'Panama': 'Central America',
    'Papua New Guinea': 'Oceania', 'Paraguay': 'South America', 'Peru': 'South America',
    'Philippines': 'Southeast Asia', 'Poland': 'Europe', 'Portugal': 'Europe',
    'Qatar': 'Middle East', 'Romania': 'Europe', 'Russia': 'Europe',
    'Rwanda': 'East Africa', 'Samoa': 'Oceania', 'Saudi Arabia': 'Middle East',
    'Senegal': 'West Africa', 'Serbia': 'Europe', 'Seychelles': 'East Africa',
    'Sierra Leone': 'West Africa', 'Singapore': 'Southeast Asia', 'Slovakia': 'Europe',
    'Slovenia': 'Europe', 'Solomon Islands': 'Oceania', 'Somalia': 'East Africa',
    'South Africa': 'Southern Africa', 'South Korea': 'East Asia', 'South Sudan': 'East Africa',
    'Spain': 'Europe', 'Sri Lanka': 'South Asia', 'Sudan': 'North Africa',
    'Suriname': 'South America', 'Swaziland': 'Southern Africa', 'Sweden': 'Europe',
    'Switzerland': 'Europe', 'Syria': 'Middle East', 'Tajikistan': 'Central Asia',
    'Tanzania': 'East Africa', 'Thailand': 'Southeast Asia', 'Timor-Leste': 'Southeast Asia',
    'Togo': 'West Africa', 'Tonga': 'Oceania', 'Trinidad and Tobago': 'Caribbean',
    'Tunisia': 'North Africa', 'Turkey': 'Europe', 'Turkmenistan': 'Central Asia',
    'Uganda': 'East Africa', 'Ukraine': 'Europe', 'United Arab Emirates': 'Middle East',
    'United Kingdom': 'Europe', 'United States of America': 'North America', 'Uruguay': 'South America',
    'Uzbekistan': 'Central Asia', 'Vanuatu': 'Oceania', 'Venezuela': 'South America',
    'Vietnam': 'Southeast Asia', 'Yemen': 'Middle East', 'Zambia': 'East Africa', 'Zimbabwe': 'East Africa'
}

region_df = spark.createDataFrame(
    [(country, region) for country, region in region_data.items()],
    ["Country", "Region"]
)

df_enriched = df_features.join(region_df, "Country", "left")
df_enriched = df_enriched.withColumn("Region",
    when(col("Region").isNull(), "Other").otherwise(col("Region")))

print("Region distribution:")
df_enriched.groupBy("Region").count().orderBy(desc("count")).show()

print("\nStatus distribution:")
df_enriched.groupBy("Status").count().show()


## 6. 特征标准化与编码


In [None]:
# 准备 ML 管道的特征列
exclude_cols = ['Country', 'Year', 'Status', 'Region', 'Life expectancy ']
feature_cols = [c for c in df_enriched.columns if c not in exclude_cols]

print(f"Selected feature columns ({len(feature_cols)})")

# 处理特征列中剩余的空值
for col_name in feature_cols:
    df_enriched = df_enriched.withColumn(col_name,
        when(col(col_name).isNull(), 0.0).otherwise(col(col_name)))

# 创建索引器和编码器
status_indexer = StringIndexer(inputCol="Status", outputCol="StatusIndex")
status_encoder = OneHotEncoder(inputCol="StatusIndex", outputCol="StatusVec")
region_indexer = StringIndexer(inputCol="Region", outputCol="RegionIndex")
region_encoder = OneHotEncoder(inputCol="RegionIndex", outputCol="RegionVec")

# 组装所有特征
all_feature_cols = feature_cols + ["StatusVec", "RegionVec"]
assembler = VectorAssembler(inputCols=all_feature_cols, outputCol="features_raw")

# 标准化特征
scaler = StandardScaler(inputCol="features_raw", outputCol="features",
                       withStd=True, withMean=True)

print("Feature preprocessing pipeline components defined.")


## 7. 划分训练与测试集


In [None]:
# 准备带有目标变量的数据集
df_ml = df_enriched.filter(col("Life expectancy ").isNotNull())

# 按 0.8 / 0.2 随机划分（固定 seed）
train_df, test_df = df_ml.randomSplit([0.8, 0.2], seed=42)

print(f"Training set size: {train_df.count()}")
print(f"Test set size: {test_df.count()}")

# 缓存结果提高后续效率
train_df.cache()
test_df.cache()

print("\nTrain set sample:")
train_df.select("Country", "Year", "Life expectancy ", "Status", "Region").show(5)


## 8. 建模与比较


In [None]:
# 创建 ML 管道
def create_preprocessing_pipeline():
    return Pipeline(stages=[
        status_indexer,
        region_indexer,
        status_encoder,
        region_encoder,
        assembler,
        scaler
    ])

# 定义要比较的模型
models = {
    'Linear Regression': LinearRegression(featuresCol='features', labelCol='Life expectancy '),
    'Decision Tree': DecisionTreeRegressor(featuresCol='features', labelCol='Life expectancy ', seed=42),
    'Random Forest': RandomForestRegressor(featuresCol='features', labelCol='Life expectancy ', seed=42),
    'GBT': GBTRegressor(featuresCol='features', labelCol='Life expectancy ', seed=42)
}

# 设置评估指标
evaluator_rmse = RegressionEvaluator(labelCol='Life expectancy ', predictionCol='prediction', metricName='rmse')
evaluator_mae = RegressionEvaluator(labelCol='Life expectancy ', predictionCol='prediction', metricName='mae')
evaluator_r2 = RegressionEvaluator(labelCol='Life expectancy ', predictionCol='prediction', metricName='r2')

# 存储结果
results = {}
best_models = {}

# 训练和评估每个模型
for model_name, model in models.items():
    print(f"\nTraining {model_name}...")

    full_pipeline = Pipeline(stages=create_preprocessing_pipeline().getStages() + [model])
    param_grid = ParamGridBuilder().build()

    cv = CrossValidator(
        estimator=full_pipeline,
        estimatorParamMaps=param_grid,
        evaluator=evaluator_r2,
        numFolds=5,
        seed=42
    )

    cv_model = cv.fit(train_df)
    predictions = cv_model.transform(test_df)

    rmse = evaluator_rmse.evaluate(predictions)
    mae = evaluator_mae.evaluate(predictions)
    r2 = evaluator_r2.evaluate(predictions)

    results[model_name] = {'RMSE': rmse, 'MAE': mae, 'R²': r2}
    best_models[model_name] = cv_model

    print(f"{model_name} - RMSE: {rmse:.4f}, MAE: {mae:.4f}, R²: {r2:.4f}")

# 显示结果表
results_df = pd.DataFrame.from_dict(results, orient='index')
print("\nModel Performance Comparison:")
print(results_df.round(4))


In [None]:
# 选择最佳模型
best_model_name = max(results.keys(), key=lambda x: results[x]['R²'])
best_model = best_models[best_model_name]

print(f"\nBest model: {best_model_name}")
print(f"Performance: R² = {results[best_model_name]['R²']:.4f}, RMSE = {results[best_model_name]['RMSE']:.4f}")


## 9. 运行模型并生成预测结果


In [None]:
# 在 test 集运行预测
final_predictions = best_model.transform(test_df)

# 输出预测样例
print("Prediction samples:")
final_predictions.select(
    "Country", "Year",
    col("Life expectancy ").alias("True_Value"),
    col("prediction").alias("Predicted_Value")
).show(10)

# 保存预测结果
final_predictions.select(
    "Country", "Year", "Status", "Region",
    col("Life expectancy ").alias("True_Value"),
    col("prediction").alias("Predicted_Value")
).coalesce(1).write.mode("overwrite").option("header", "true").csv("predictions_output")

print("\nPredictions saved to 'predictions_output' directory.")


## 10. 分析模型输出模式


In [None]:
# 计算残差统计
residuals_df = final_predictions.withColumn("residual",
    col("Life expectancy ") - col("prediction"))

residual_stats = residuals_df.select(
    mean("residual").alias("mean_residual"),
    stddev("residual").alias("std_residual")
).collect()[0]

print(f"Residual statistics - Mean: {residual_stats['mean_residual']:.4f}, Std: {residual_stats['std_residual']:.4f}")

# 按 Status 分组
print("\nAverage predictions by Status:")
residuals_df.groupBy("Status").agg(
    mean("prediction").alias("avg_prediction"),
    mean("Life expectancy ").alias("avg_actual")
).show()

# 按 Region 分组
print("\nAverage predictions by Region:")
residuals_df.groupBy("Region").agg(
    mean("prediction").alias("avg_prediction"),
    mean("Life expectancy ").alias("avg_actual")
).orderBy(desc("avg_prediction")).show()

# Top 误差样本
print("\nTop 10 largest absolute errors:")
top_errors = residuals_df.select(
    "Country", "Year", "Status", "Region",
    "Life expectancy ", "prediction", col("residual").alias("error")
).withColumn("abs_error", abs(col("error"))).orderBy(desc("abs_error")).limit(10)
top_errors.show()


In [None]:
# 输出特征重要性
if best_model_name in ['Decision Tree', 'Random Forest', 'GBT']:
    try:
        model_stage = best_model.bestModel.stages[-1]
        if hasattr(model_stage, 'featureImportances'):
            importances = model_stage.featureImportances
            print(f"\nFeature importance for {best_model_name}:")

            feature_names = all_feature_cols
            importance_list = [(feature_names[i], importances[i]) for i in range(len(feature_names))]
            importance_list.sort(key=lambda x: x[1], reverse=True)

            for i, (feature, importance) in enumerate(importance_list[:15]):
                print(f"  {i+1:2d}. {feature}: {importance:.4f}")
    except Exception as e:
        print(f"Feature importance not available: {e}")
else:
    print("Feature importance not available for Linear Regression model.")


## 11. 可视化结果


In [None]:
import matplotlib.pyplot as plt
import numpy as np

# 设置 matplotlib 字体支持（基于 I3.py）
plt.rcParams['font.sans-serif'] = ['Arial', 'DejaVu Sans']
plt.rcParams['axes.unicode_minus'] = False

# 转换预测结果为 Pandas 用于可视化
viz_data = residuals_df.select(
    "Life expectancy ", "prediction", "residual", "Status", "Region"
).toPandas()

# 设置绘图
plt.style.use('default')
fig = plt.figure(figsize=(20, 15))

# 1. 绘制 Pred vs True 散点图
plt.subplot(2, 3, 1)
plt.scatter(viz_data['Life expectancy '], viz_data['prediction'], alpha=0.6, s=20)
plt.plot([viz_data['Life expectancy '].min(), viz_data['Life expectancy '].max()],
         [viz_data['Life expectancy '].min(), viz_data['Life expectancy '].max()], 'r--', lw=2)
plt.xlabel('True Life Expectancy')
plt.ylabel('Predicted Life Expectancy')
plt.title('Predicted vs True Life Expectancy')
plt.grid(True, alpha=0.3)

# 2. 绘制 Residual 直方图
plt.subplot(2, 3, 2)
plt.hist(viz_data['residual'], bins=30, alpha=0.7, color='skyblue', edgecolor='black')
plt.axvline(viz_data['residual'].mean(), color='red', linestyle='--', linewidth=2,
           label=f'Mean: {viz_data["residual"].mean():.3f}')
plt.xlabel('Residuals')
plt.ylabel('Frequency')
plt.title('Residual Distribution')
plt.legend()
plt.grid(True, alpha=0.3)

# 3. Residuals vs Fitted
plt.subplot(2, 3, 3)
plt.scatter(viz_data['prediction'], viz_data['residual'], alpha=0.6, s=20)
plt.axhline(y=0, color='red', linestyle='--', linewidth=2)
plt.xlabel('Fitted Values')
plt.ylabel('Residuals')
plt.title('Residuals vs Fitted Values')
plt.grid(True, alpha=0.3)

# 4. 模型比较柱状图
plt.subplot(2, 3, 4)
model_names = list(results.keys())
r2_scores = [results[name]['R²'] for name in model_names]
rmse_scores = [results[name]['RMSE'] for name in model_names]

x = np.arange(len(model_names))
width = 0.35

plt.bar(x - width/2, r2_scores, width, label='R²', alpha=0.8)
plt.bar(x + width/2, np.array(rmse_scores)/max(rmse_scores), width, label='RMSE (normalized)', alpha=0.8)
plt.xlabel('Models')
plt.ylabel('Score')
plt.title('Model Performance Comparison')
plt.xticks(x, [name.replace(' ', '\n') for name in model_names], rotation=45)
plt.legend()
plt.grid(True, alpha=0.3)

# 5. 按 Status 的预测
plt.subplot(2, 3, 5)
status_means = viz_data.groupby('Status').agg({
    'Life expectancy ': 'mean',
    'prediction': 'mean'
}).reset_index()

x = np.arange(len(status_means))
plt.bar(x - 0.2, status_means['Life expectancy '], 0.4, label='Actual', alpha=0.8)
plt.bar(x + 0.2, status_means['prediction'], 0.4, label='Predicted', alpha=0.8)
plt.xlabel('Status')
plt.ylabel('Life Expectancy')
plt.title('Predictions by Status')
plt.xticks(x, status_means['Status'])
plt.legend()
plt.grid(True, alpha=0.3)

# 6. 按 Region 的预测（前5个地区）
plt.subplot(2, 3, 6)
region_means = viz_data.groupby('Region').agg({
    'Life expectancy ': 'mean',
    'prediction': 'mean'
}).reset_index().sort_values('prediction', ascending=False).head(5)

x = np.arange(len(region_means))
plt.bar(x - 0.2, region_means['Life expectancy '], 0.4, label='Actual', alpha=0.8)
plt.bar(x + 0.2, region_means['prediction'], 0.4, label='Predicted', alpha=0.8)
plt.xlabel('Region')
plt.ylabel('Life Expectancy')
plt.title('Predictions by Region (Top 5)')
plt.xticks(x, region_means['Region'], rotation=45)
plt.legend()
plt.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('spark_ml_results.png', dpi=300, bbox_inches='tight')
plt.show()

print("Visualization saved as 'spark_ml_results.png'")


## 12. 模型再迭代（Iteration 2 within Iteration 4）


In [None]:
# 在主模型上扩大 numTrees / maxDepth 网格重新调参
print(f"\nPerforming enhanced hyperparameter tuning for {best_model_name}...")

# 根据模型类型定义参数网格
if best_model_name == 'Random Forest':
    param_grid = ParamGridBuilder() \
        .addGrid(models['Random Forest'].numTrees, [50, 100, 200]) \
        .addGrid(models['Random Forest'].maxDepth, [5, 10, 15]) \
        .build()
elif best_model_name == 'Decision Tree':
    param_grid = ParamGridBuilder() \
        .addGrid(models['Decision Tree'].maxDepth, [5, 10, 15, 20]) \
        .build()
elif best_model_name == 'GBT':
    param_grid = ParamGridBuilder() \
        .addGrid(models['GBT'].maxIter, [50, 100, 200]) \
        .addGrid(models['GBT'].maxDepth, [3, 5, 7]) \
        .build()
else:
    param_grid = ParamGridBuilder() \
        .addGrid(models['Linear Regression'].regParam, [0.0, 0.01, 0.1]) \
        .build()

# 创建增强交叉验证器
enhanced_cv = CrossValidator(
    estimator=Pipeline(stages=create_preprocessing_pipeline().getStages() + [models[best_model_name]]),
    estimatorParamMaps=param_grid,
    evaluator=evaluator_r2,
    numFolds=5,
    seed=42
)

# 拟合增强模型
enhanced_model = enhanced_cv.fit(train_df)
enhanced_predictions = enhanced_model.transform(test_df)

# 计算增强指标
enhanced_rmse = evaluator_rmse.evaluate(enhanced_predictions)
enhanced_r2 = evaluator_r2.evaluate(enhanced_predictions)

# 比较两轮 RMSE / R² 变化
print(f"\nEnhanced {best_model_name} Results:")
print(f"Original R²: {results[best_model_name]['R²']:.4f}")
print(f"Enhanced R²: {enhanced_r2:.4f}")
print(f"Original RMSE: {results[best_model_name]['RMSE']:.4f}")
print(f"Enhanced RMSE: {enhanced_rmse:.4f}")

# 保存两轮预测结果对比
comparison_results = {
    'Model': [f'{best_model_name} (Original)', f'{best_model_name} (Enhanced)'],
    'R²': [results[best_model_name]['R²'], enhanced_r2],
    'RMSE': [results[best_model_name]['RMSE'], enhanced_rmse]
}

comparison_df = pd.DataFrame(comparison_results)
print("\nModel Enhancement Comparison:")
print(comparison_df.round(4))


In [None]:
# 保存增强预测结果
enhanced_predictions.select(
    "Country", "Year", "Status", "Region",
    col("Life expectancy ").alias("True_Value"),
    col("prediction").alias("Enhanced_Prediction")
).coalesce(1).write.mode("overwrite").option("header", "true").csv("enhanced_predictions")

print("Enhanced predictions saved to 'enhanced_predictions' directory.")


## 13. 收尾准备报告


In [None]:
# 保存最佳增强模型
enhanced_model.write().overwrite().save("best_life_expectancy_model")

# 最终总结
print("="*60)
print("ITERATION 4 - BDAS PYSPARK ML PIPELINE SUMMARY")
print("="*60)

print(f"\nDataset Information:")
print(f"  - Total records: {df_ml.count():,}")
print(f"  - Training samples: {train_df.count():,}")
print(f"  - Test samples: {test_df.count():,}")
print(f"  - Features after engineering: {len(all_feature_cols)}")

print(f"\nModel Performance (Best: {best_model_name}):")
for model_name, metrics in results.items():
    marker = " << BEST" if model_name == best_model_name else ""
    print(f"  - {model_name}: R² = {metrics['R²']:.4f}, RMSE = {metrics['RMSE']:.4f}{marker}")

print(f"\nEnhanced Model Results:")
print(f"  - Enhanced R²: {enhanced_r2:.4f}")
print(f"  - Enhanced RMSE: {enhanced_rmse:.4f}")
print(f"  - Improvement in R²: {(enhanced_r2 - results[best_model_name]['R²']):.4f}")

print(f"\nOutput Files Generated:")
print(f"  - predictions_output/ : Original model predictions")
print(f"  - enhanced_predictions/ : Enhanced model predictions")
print(f"  - best_life_expectancy_model/ : Saved ML model")
print(f"  - spark_ml_results.png : Visualization charts")

print("\n" + "="*60)
print("PIPELINE COMPLETED SUCCESSFULLY!")
print("="*60)


In [None]:
# 清理和结束环境
train_df.unpersist()
test_df.unpersist()

# spark.stop()

print("Notebook execution completed. Spark session is ready for further use.")
print("To stop the Spark session, uncomment the spark.stop() line above.")
