In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np

# Focus columns
focus_columns = ['hp', 'attack', 'defense', 'speed', 'generation', 'status', 'name']

# Filter and clean data
df_focus = df[focus_columns].dropna()
df_focus = df_focus[df_focus['status'].isin(['Legendary', 'Mythical', 'Sub Legendary', 'Normal'])]

# Prepare heatmaps per status
status_groups = ['Legendary', 'Mythical', 'Sub Legendary', 'Normal']
stat_columns = ['hp', 'attack', 'defense', 'speed']

# Build a 2x2 grid of heatmaps
fig, axes = plt.subplots(2, 2, figsize=(18, 14))
fig.suptitle("Stat Distribution by Pokémon Status and Generation", fontsize=18)

for ax, status, position in zip(axes.flat, status_groups, range(4)):
    group = df_focus[df_focus['status'] == status]
    heatmap_data = group.groupby('generation')[stat_columns].mean().T

    sns.heatmap(
        heatmap_data,
        annot=True,
        fmt=".1f",
        cmap='YlOrRd',
        ax=ax,
        cbar=position == 0,  # Show color bar only for first
        linewidths=0.5,
        linecolor='gray'
    )
    ax.set_title(f"{status} Pokémon", fontsize=14)
    ax.set_xlabel("Generation")
    ax.set_ylabel("Stat")

plt.tight_layout(rect=[0, 0, 1, 0.95])
plt.show()
