In [7]:
import pyspark
from pyspark.sql import SparkSession
import matplotlib.pyplot as plt
import numpy as np
from pyspark.ml.regression import LinearRegression
from pyspark.ml.feature import VectorAssembler
from pyspark.ml.linalg import Vectors

In [2]:
spark = SparkSession.builder.appName("ChinaEnergyAnalysis").getOrCreate()

Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
25/04/07 19:37:14 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable


In [3]:
df = spark.read.format("csv").option("header", "true").option("inferschema","true").load("data/owid-energy-data.csv.gz")

In [4]:
df_cn = df.select('year','population','electricity_demand').where("country == 'China' AND electricity_demand IS NOT NULL")


In [5]:
df2 = df.select("country","year","population","electricity_demand").where("country like \'Po%\' AND year >= 2000")

In [8]:
df_cn = df_cn.orderBy('year')
y = df_cn.select('year').rdd.flatMap(lambda x: x).collect()
pop = df_cn.select('population').rdd.flatMap(lambda x: x).collect()
dem = df_cn.select('electricity_demand').rdd.flatMap(lambda x: x).collect()

+----+----------+------------------+---------------+
|year|population|electricity_demand|       features|
+----+----------+------------------+---------------+
|2000|1264099072|           1346.85|[1.264099072E9]|
|2001|1272739584|           1472.19|[1.272739584E9]|
|2002|1280926080|           1645.61| [1.28092608E9]|
|2003|1288873344|           1903.22|[1.288873344E9]|
|2004|1296816768|           2197.23|[1.296816768E9]|
+----+----------+------------------+---------------+
only showing top 5 rows



25/04/07 19:42:08 WARN InstanceBuilder: Failed to load implementation from:dev.ludovic.netlib.blas.JNIBLAS
25/04/07 19:42:08 WARN InstanceBuilder: Failed to load implementation from:dev.ludovic.netlib.blas.VectorBLAS


RMSE: 385.3277526451458
r2: 0.9718833109232434
iterations: 4
demand = [4.2259922733572943e-05]*population  -52582.60663196023

Year-based regression:
RMSE: 197.60475088211462
r2: 0.992605677930012
iterations: 4
demand = [345.13683014301574]*year  -689387.0897654308


In [None]:
# Plot electricity demand vs. year
plt.figure(figsize=(10, 6))
plt.plot(y, dem, label='demand')
plt.xlabel('Year')
plt.ylabel('Electricity Demand (TWh)')
plt.title('China Electricity Demand vs. Year')
plt.grid(True)
plt.savefig('china_demand_vs_year.png')
plt.close()

In [None]:

# Plot electricity demand vs. population (scatter plot)
plt.figure(figsize=(10, 6))
plt.scatter(pop, dem)
plt.xlabel('Population')
plt.ylabel('Electricity Demand (TWh)')
plt.title('China Electricity Demand vs. Population')
plt.grid(True)
plt.savefig('china_demand_vs_population.png')
plt.close()

In [None]:

# Regression analysis: Population vs. Electricity Demand
va_pop = VectorAssembler().setInputCols(["population"]).setOutputCol("features")
df_cn_pop = va_pop.transform(df_cn)
df_cn_pop.show(5)

# Linear Regression model
lr_pop = LinearRegression()\
  .setMaxIter(10)\
  .setRegParam(0.1)\
  .setElasticNetParam(0.5)\
  .setFeaturesCol("features")\
  .setLabelCol("electricity_demand")

model_pop = lr_pop.fit(df_cn_pop)

# Display metrics and regression equation
print(f'RMSE: {model_pop.summary.rootMeanSquaredError}')
print(f'r2: {model_pop.summary.r2}')
print(f'iterations: {model_pop.summary.totalIterations}')
print(f'demand = {model_pop.coefficients}*population {"+" if model_pop.intercept > 0 else ""} {model_pop.intercept}')

# Plot fitted function
pop_min = np.min(pop)
pop_max = np.max(pop)
pop_range = np.linspace(pop_min, pop_max, 100)

# Calculate predictions
pop_preds = [model_pop.predict(Vectors.dense([p])) for p in pop_range]

# Plot the scatter points and the fitted line
plt.figure(figsize=(10, 6))
plt.scatter(pop, dem, label='Actual Demand')
plt.plot(pop_range, pop_preds, 'r-', label='Fitted Function', linewidth=2)
plt.xlabel('Population')
plt.ylabel('Electricity Demand (TWh)')
plt.title('China: Electricity Demand vs. Population with Linear Regression')
plt.legend()
plt.grid(True)
plt.savefig('china_demand_vs_population_regression.png')
plt.close()

# Also perform analysis with year as the feature for comparison
va_year = VectorAssembler().setInputCols(["year"]).setOutputCol("features")
df_cn_year = va_year.transform(df_cn)

# Linear Regression model for year
lr_year = LinearRegression()\
  .setMaxIter(10)\
  .setRegParam(0.1)\
  .setElasticNetParam(0.5)\
  .setFeaturesCol("features")\
  .setLabelCol("electricity_demand")

model_year = lr_year.fit(df_cn_year)

# Display metrics and regression equation for year
print("\nYear-based regression:")
print(f'RMSE: {model_year.summary.rootMeanSquaredError}')
print(f'r2: {model_year.summary.r2}')
print(f'iterations: {model_year.summary.totalIterations}')
print(f'demand = {model_year.coefficients}*year {"+" if model_year.intercept > 0 else ""} {model_year.intercept}')

# Plot year-based regression
year_min = np.min(y)
year_max = np.max(y)
year_range = np.linspace(year_min-1, year_max+1, 100)

year_preds = [model_year.predict(Vectors.dense([yr])) for yr in year_range]

plt.figure(figsize=(10, 6))
plt.scatter(y, dem, label='Actual Demand')
plt.plot(year_range, year_preds, 'g-', label='Fitted Function', linewidth=2)
plt.xlabel('Year')
plt.ylabel('Electricity Demand (TWh)')
plt.title('China: Electricity Demand vs. Year with Linear Regression')
plt.legend()
plt.grid(True)
plt.savefig('china_demand_vs_year_regression.png')
plt.close()

# Stop the Spark session
spark.stop()