In [0]:
from pyspark.sql.functions import col

df = spark.read.table("workspace.default.gold_country_year")

df_hist = df.filter(col("year").between(2018, 2024))


In [0]:
from pyspark.sql.functions import min, max, pow

base = (
    df_hist.groupBy("country_code", "country_name", "region")
           .agg(
               min("year").alias("start_year"),
               max("year").alias("end_year"),
               min("total_value_usd").alias("start_value"),
               max("total_value_usd").alias("end_value")
           )
)

base = base.withColumn(
    "num_years",
    col("end_year") - col("start_year")
)

base = base.withColumn(
    "cagr",
    pow(col("end_value") / col("start_value"), 1 / col("num_years")) - 1
)


In [0]:
future_years = spark.createDataFrame(
    [(2025,), (2026,), (2027,)],
    ["year"]
)


In [0]:
from pyspark.sql.functions import explode, array

forecast = (
    base.crossJoin(future_years)
        .withColumn(
            "forecast_value_usd",
            col("end_value") * pow(1 + col("cagr"), col("year") - col("end_year"))
        )
        .select(
            "country_code",
            "country_name",
            "region",
            "year",
            col("forecast_value_usd").alias("total_value_usd")
        )
)


In [0]:
actuals = df.select(
    "country_code",
    "country_name",
    "region",
    "year",
    "total_value_usd"
)

final_forecast = actuals.unionByName(forecast)


In [0]:
final_forecast.write \
    .mode("overwrite") \
    .format("delta") \
    .saveAsTable("workspace.default.gold_country_forecast")
