# Plots for use in the final thesis

In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.dates as mdates
import matplotlib.ticker as mticker
import seaborn as sns
import sys
from pathlib import Path
import matplotlib as mpl
from mpl_toolkits.axes_grid1 import make_axes_locatable
from helpers import add_right_strip

mpl.rcParams.update({
    "text.usetex": True,
    "font.family": "serif",
    "font.serif": ["Latin Modern Roman"],
    "text.latex.preamble": r"\usepackage{lmodern}\usepackage[T1]{fontenc}\usepackage{amsmath}\usepackage{amssymb}",
    "axes.labelsize": 10, "axes.titlesize": 10,
    "xtick.labelsize": 9, "ytick.labelsize": 9,
    "legend.fontsize": 9,
})

# Set up project paths
ROOT_DIR = Path("../").resolve()
sys.path.append(str(ROOT_DIR))

# Define data and output directories
DATA_DIR = ROOT_DIR / "data" / "processed"
OUTPUT_DIR = ROOT_DIR / "results" / "plots"
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

In [None]:
# Load datasets
datasets = {
    "ch": pd.read_csv(DATA_DIR / "ch_data_final.csv"),
    "eu": pd.read_csv(DATA_DIR / "eu_data_final.csv"),
    "us": pd.read_csv(DATA_DIR / "us_data_final.csv"),
}

# Preprocess to time series format
def preprocess(df, column):
    dates = pd.to_datetime(df["date"]) + pd.offsets.MonthEnd(0)
    return pd.Series(df[column].values, index=dates).asfreq("ME").dropna()

series = {
    "ch": preprocess(datasets["ch"], "cpi_total_yoy"),
    "eu": preprocess(datasets["eu"], "hcpi_yoy"),
    "us": preprocess(datasets["us"], "cpi_all_yoy"),
}

In [None]:
# Set a style that provides a clean background without a default grid
sns.set_style("ticks")

# Create three vertically stacked subplots that share the same x-axis
fig, axes = plt.subplots(3, 1, figsize=(10, 8), sharex=True)

# Define the plotting order and labels for each subplot
plot_data = [
    {'key': 'us', 'label': 'USA (CPI)'},
    {'key': 'ch', 'label': 'Switzerland (CPI)'},
    {'key': 'eu', 'label': 'Euro Area (HCPI)'}
]

# Configure colors for the plot lines
colors = sns.color_palette("tab10")

# Set the overall x-axis limits from the longest series (US data)
start_date = series['us'].index.min()
end_date = series['us'].index.max()

# Iterate through each subplot and plot the corresponding data
for i, ax in enumerate(axes):
    key = plot_data[i]['key']
    label = plot_data[i]['label']
    
    # Assign a specific color based on the original series for consistency
    if key == 'us': color = colors[2] # Green
    elif key == 'ch': color = colors[3] # Red
    elif key == 'eu': color = colors[0] # Blue
    
    # Plot the time series data with a slightly thicker line
    ax.plot(series[key].index, series[key].values, color=color, linewidth=1.5)

    # Calculate the index position of the 40th percentile observation
    forty_percent_index = int(len(series[key]) * 0.4)
    # Get the date at that specific index.
    date_at_40_percent = series[key].index[forty_percent_index]
    # Draw a distinct vertical line at that date
    ax.axvline(x=date_at_40_percent, color='black', linestyle=':', linewidth=1.5, zorder=3)
    
    # Add gray overlay for missing data
    if key in ['ch', 'eu']:
        series_start = series[key].index.min()
        series_end = series[key].index.max()
        if series_start > start_date:
            ax.axvspan(start_date, series_start, color='lightgray', alpha=0.6, zorder=0, linewidth=0)
        if series_end < end_date:
            ax.axvspan(series_end, end_date, color='lightgray', alpha=0.6, zorder=0, linewidth=0)

    # Add a horizontal line at y=0 for reference
    ax.axhline(0, color='black', linestyle='--', linewidth=0.7)

    # Ensure y-axis ticks are integers.
    ax.yaxis.set_major_locator(mticker.MaxNLocator(integer=True))
    
    # Add the dotted grid lines
    ax.grid(True, which='both', linestyle='--', linewidth=0.5)
    
    # grey strip with rotated label on the right
    add_right_strip(ax, label)


# Set the y-label for the middle plot, which applies to all
axes[1].set_ylabel(r"\textrm{Inflation \%}", fontsize=12)

# Format the shared x-axis to display year labels every 5 years
axes[-1].xaxis.set_major_locator(mdates.YearLocator(5))
axes[-1].xaxis.set_major_formatter(mdates.DateFormatter('%Y'))
plt.xticks(fontsize=11)

# Apply the x-axis limits
axes[-1].set_xlim([start_date, end_date])

# Adjust spacing between subplots
plt.subplots_adjust(hspace=0.15)

# Save the figure in high resolution (optional)
output_path = OUTPUT_DIR / "eda/inflation_plot.pdf"
fig.savefig(output_path, dpi=1200, bbox_inches='tight')
print(f"Plot saved to: {output_path}")

# Display the plot
plt.show()