- EMPIRICAL

In [None]:
#3d yearly 
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
from mpl_toolkits.mplot3d import Axes3D
from matplotlib.colors import Normalize
from matplotlib import cm

# Define the station order dictionary
station_order = {
    779: "TP",
    775: "ERICE",
    780: "TRAPANI FULGATORE",
    778: "SALEMI",
    774: "CASTELVETRANO",
    773: "CASTELLAMMARE DEL GOLFO",
    744: "Contessa Entellina",
    751: "PARTINICO",
    742: "CAMPOREALE",
    749: "MONREALE VIGNA API",
    745: "CORLEONE",
    693: "RIBERA",
    750: "PA",
    686: "BIVONA",
    748: "MISILMERI",
    747: "MEZZOJUSO",
    683: "AG",
    755: "TERMINI IMERESE",
    685: "ARAGONA",
    684: "AGRIGENTO MANDRASCAVA",
    687: "CAMMARATA",
    740: "ALIA",
    689: "CANICATTì",
    700: "MUSSOMELI",
    703: "SCLAFANI BAGNI",
    690: "LICATA",
    746: "LASCARI",
    696: "DELIA",
    754: "POLIZZI GENEROSA",
    753: "PETRALIA SOTTANA",
    695: "CL",
    701: "RIESI",
    743: "CASTELBUONO",
    718: "EN",
    752: "GANGI",
    699: "MAZZARINO",
    736: "PETTINEO",
    731: "MISTRETTA",
    722: "PIAZZA ARMERINA",
    762: "ACATE",
    721: "NICOSIA",
    717: "AIDONE",
    723: "CARONIA BUZZA",
    760: "SANTA CROCE CAMERINA",
    710: "MAZZARRONE",
    716: "CALTAGIRONE",
    757: "COMISO",
    737: "SAN FRATELLO",
    730: "MILITELLO ROSMARINO",
    761: "SCICLI",
    756: "RG",
    725: "CESARò VIGNAZZA",
    711: "MINEO",
    733: "NASO",
    705: "BRONTE",
    712: "PATERNò",
    770: "PALAZZOLO ACREIDE",
    709: "MALETTO",
    765: "FRANCOFONTE",
    766: "LENTINI",
    715: "RANDAZZO",
    758: "ISPICA",
    735: "PATTI",
    713: "PEDARA",
    767: "NOTO",
    706: "CT",
    769: "PACHINO",
    708: "LINGUAGLOSSA",
    734: "NOVARA DI SICILIA",
    764: "SR",
    707: "RIPOSTO",
    739: "TORREGROTTA",
    738: "SAN PIER NICETO",
    727: "FIUMEDINISI",
    729: "ME"
}

# Specify the stations you want to label
stations_to_label = ["CT", "PA", "ME", "AG", "EN", "CL", "TP", "SR", "RG"]

# Read the CSV file into a DataFrame
df = pd.read_csv('75gauges.csv')

# Convert the 'DATETIME' column to a datetime type
df['DATETIME'] = pd.to_datetime(df['DATETIME'])

df['Year'] = df['DATETIME'].dt.year

# Define a colormap
cmap = cm.colors.ListedColormap(['#9ABDDC', '#B4CF68', '#FFD872', '#FF96C5', '#FF00FF', 'purple'])

# Create a list of station IDs in the desired order
station_ids_in_order = list(station_order.keys())

fig = plt.figure(figsize=(7, 7))
ax = fig.add_subplot(111, projection='3d')  # Create a single 3D subplot

filtered_df = df[(df['VALUE'] > 1) & (df['VALUE'] < 2000)]
pivot_table = filtered_df.pivot_table(index='STATION', columns='Year', values='VALUE', aggfunc='max').fillna(0)

# Reorder the pivot_table and Y_labels based on station_order
pivot_table = pivot_table.loc[station_ids_in_order]

# Create a list of y-axis labels
Y_labels = []
for station_id in station_ids_in_order:
    station_name = station_order[station_id]
    if station_name.upper() in stations_to_label:
        Y_labels.append(station_name)  # Label specific cities
    else:
        Y_labels.append("")  # Empty string for other stations

# Calculate the normalization for the data
norm = Normalize(vmin=pivot_table.values.min(), vmax=pivot_table.values.max())

# Apply the normalization to the count values and map to the 'coolwarm' colormap
colors = cmap(norm(pivot_table.values.ravel()))

Z = pivot_table.values
X_labels = pivot_table.columns[::1]

X, Y = np.meshgrid(np.arange(len(X_labels)), np.arange(len(Y_labels)))

dx = np.ones(Z.shape) * 0.75
dy = np.ones(Z.shape) * 0.75
dz = Z

for x, y, z, color in zip(X.flatten(), Y.flatten(), dz.flatten(), colors):
    ax.bar3d(x, y, 0, dx[0, 0], dy[0, 0], z, shade=True, color=color)

ax.set_title('Yearly Sum Volumes')
ax.set_zlabel('mm')
#ax.set_zlim(0, 500)
ax.set_yticks(np.arange(len(Y_labels)) + 5.5)
ax.set_yticklabels(Y_labels, fontsize= 7)
ax.set_xticks(np.arange(len(X_labels)))
ax.set_xticklabels(X_labels, rotation=45)

plt.tight_layout(rect=[0, 0.05, 1, 0.95])
#fig.savefig("75gauge_3d_yearly_volumes.jpg", dpi=300)
plt.show()

In [None]:
#2d yearly 
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
from mpl_toolkits.mplot3d import Axes3D
from matplotlib.colors import Normalize
from matplotlib import cm

# Define the station order dictionary
station_order = {
    729: "MESSINA",
    727: "FIUMEDINISI",
    738: "SAN PIER NICETO",
    739: "TORREGROTTA",
    707: "RIPOSTO",
    764: "SIRACUSA",
    734: "NOVARA DI SICILIA",
    708: "LINGUAGLOSSA",
    769: "PACHINO",
    706: "CATANIA",
    767: "NOTO",
    713: "PEDARA",
    735: "PATTI",
    758: "ISPICA",
    715: "RANDAZZO",
    766: "LENTINI",
    765: "FRANCOFONTE",
    709: "MALETTO",
    770: "PALAZZOLO ACREIDE",
    712: "PATERNò",
    705: "BRONTE",
    733: "NASO",
    711: "MINEO",
    725: "CESARò VIGNAZZA",
    756: "RAGUSA",
    761: "SCICLI",
    730: "MILITELLO ROSMARINO",
    737: "SAN FRATELLO",
    757: "COMISO",
    716: "CALTAGIRONE",
    710: "MAZZARRONE",
    760: "SANTA CROCE CAMERINA",
    723: "CARONIA BUZZA",
    717: "AIDONE",
    721: "NICOSIA",
    762: "ACATE",
    722: "PIAZZA ARMERINA",
    731: "MISTRETTA",
    736: "PETTINEO",
    699: "MAZZARINO",
    752: "GANGI",
    718: "ENNA",
    743: "CASTELBUONO",
    701: "RIESI",
    695: "CALTANISSETTA",
    753: "PETRALIA SOTTANA",
    754: "POLIZZI GENEROSA",
    696: "DELIA",
    746: "LASCARI",
    690: "LICATA",
    703: "SCLAFANI BAGNI",
    700: "MUSSOMELI",
    689: "CANICATTì",
    740: "ALIA",
    687: "CAMMARATA",
    684: "AGRIGENTO MANDRASCAVA",
    685: "ARAGONA",
    755: "TERMINI IMERESE",
    683: "AGRIGENTO SCIBICA",
    747: "MEZZOJUSO",
    748: "MISILMERI",
    686: "BIVONA",
    750: "PALERMO",
    693: "RIBERA",
    745: "CORLEONE",
    749: "MONREALE VIGNA API",
    742: "CAMPOREALE",
    751: "PARTINICO",
    744: "Contessa Entellina",
    773: "CASTELLAMMARE DEL GOLFO",
    774: "CASTELVETRANO",
    778: "SALEMI",
    780: "TRAPANI FULGATORE",
    775: "ERICE",
    779: "TRAPANI FONTANASALSA"
}

# Read the CSV file into a DataFrame
df = pd.read_csv('75gauges.csv')

# Convert the 'DATETIME' column to a datetime type
df['DATETIME'] = pd.to_datetime(df['DATETIME'])

df['Year'] = df['DATETIME'].dt.year

# Define a custom colormap with specific colors from 'tab10'
cmap = cm.colors.ListedColormap(['#9ABDDC','#B4CF68','#FFD872', '#FF96C5','#FF00FF','purple'])

# Create a list of station IDs in the desired order
station_ids_in_order = list(station_order.keys())

# Create a single plot for yearly data per station
fig, ax = plt.subplots(figsize=(10, 6))

filtered_df = df[(df['VALUE'] > 1) & (df['VALUE'] < 2000)]

pivot_table = filtered_df.pivot_table(index='STATION', columns='Year', values='VALUE', aggfunc='max').fillna(0)

# Reorder the pivot_table based on station_order
pivot_table = pivot_table.loc[station_ids_in_order]

# Create a list of y-axis labels
Y_labels = [station_order[station_id] for station_id in station_ids_in_order]

# Calculate the normalization for the data
norm = Normalize(vmin=pivot_table.values.min(), vmax=pivot_table.values.max())

# Apply the normalization to the count values and map to the custom colormap
colors = cmap(norm(pivot_table.values))

# Create X and Y data for the heatmap
X_labels = pivot_table.columns
X, Y = np.meshgrid(np.arange(len(X_labels)), np.arange(len(Y_labels)))

# Plot the heatmap
heatmap = ax.imshow(pivot_table.values, cmap=cmap, norm=norm)
ax.set_title('Yearly Data per Station')
ax.set_xlabel('Year')
ax.set_ylabel('Station')
ax.set_xticks(np.arange(len(X_labels)))
ax.set_xticklabels(X_labels, rotation=90, fontsize= 5)
ax.set_yticks(np.arange(len(Y_labels)))
ax.set_yticklabels(Y_labels, rotation=0 , fontsize= 5)

# Add colorbar
cbar = plt.colorbar(heatmap, ax=ax)
cbar.set_label('mm')


plt.tight_layout(rect=[0, 0.05, 1, 0.95])
#fig.savefig("75gauge_3d_yearlyvolumessheatmap.jpg", dpi=300)

plt.show()

In [None]:
#3d seasonal

import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
from mpl_toolkits.mplot3d import Axes3D
from matplotlib.colors import Normalize
from matplotlib import cm

# Define the station order dictionary
station_order = {
    779: "TP",
    775: "ERICE",
    780: "TRAPANI FULGATORE",
    778: "SALEMI",
    774: "CASTELVETRANO",
    773: "CASTELLAMMARE DEL GOLFO",
    744: "Contessa Entellina",
    751: "PARTINICO",
    742: "CAMPOREALE",
    749: "MONREALE VIGNA API",
    745: "CORLEONE",
    693: "RIBERA",
    750: "PA",
    686: "BIVONA",
    748: "MISILMERI",
    747: "MEZZOJUSO",
    683: "AG",
    755: "TERMINI IMERESE",
    685: "ARAGONA",
    684: "AGRIGENTO MANDRASCAVA",
    687: "CAMMARATA",
    740: "ALIA",
    689: "CANICATTì",
    700: "MUSSOMELI",
    703: "SCLAFANI BAGNI",
    690: "LICATA",
    746: "LASCARI",
    696: "DELIA",
    754: "POLIZZI GENEROSA",
    753: "PETRALIA SOTTANA",
    695: "CL",
    701: "RIESI",
    743: "CASTELBUONO",
    718: "EN",
    752: "GANGI",
    699: "MAZZARINO",
    736: "PETTINEO",
    731: "MISTRETTA",
    722: "PIAZZA ARMERINA",
    762: "ACATE",
    721: "NICOSIA",
    717: "AIDONE",
    723: "CARONIA BUZZA",
    760: "SANTA CROCE CAMERINA",
    710: "MAZZARRONE",
    716: "CALTAGIRONE",
    757: "COMISO",
    737: "SAN FRATELLO",
    730: "MILITELLO ROSMARINO",
    761: "SCICLI",
    756: "RG",
    725: "CESARò VIGNAZZA",
    711: "MINEO",
    733: "NASO",
    705: "BRONTE",
    712: "PATERNò",
    770: "PALAZZOLO ACREIDE",
    709: "MALETTO",
    765: "FRANCOFONTE",
    766: "LENTINI",
    715: "RANDAZZO",
    758: "ISPICA",
    735: "PATTI",
    713: "PEDARA",
    767: "NOTO",
    706: "CT",
    769: "PACHINO",
    708: "LINGUAGLOSSA",
    734: "NOVARA DI SICILIA",
    764: "SR",
    707: "RIPOSTO",
    739: "TORREGROTTA",
    738: "SAN PIER NICETO",
    727: "FIUMEDINISI",
    729: "ME"
}

# Specify the stations you want to label
stations_to_label = ["CT", "PA", "ME", "AG", "EN", "CL", "TP", "SR", "RG"]

# Read the CSV file into a DataFrame
df = pd.read_csv('75gauges.csv')

# Convert the 'DATETIME' column to a datetime type
df['DATETIME'] = pd.to_datetime(df['DATETIME'])

def map_month_to_season(month):
    if month in [3, 4, 5]:
        return 'Spring'
    elif month in [6, 7, 8]:
        return 'Summer'
    elif month in [9, 10, 11]:
        return 'Autumn'
    else:
        return 'Winter'

df['Year'] = df['DATETIME'].dt.year
df['Month'] = df['DATETIME'].dt.month
df['Season'] = df['Month'].apply(map_month_to_season)

# Define a colormap
#cmap = cm.get_cmap('Paired',6)
cmap = cm.colors.ListedColormap(['#9ABDDC','#B4CF68','#FFD872', '#FF96C5','#FF00FF','purple'])

# Create a list of station IDs in the desired order
station_ids_in_order = list(station_order.keys())

unique_seasons = df['Season'].unique()

fig = plt.figure(figsize=(22, 22))

for i, season in enumerate(unique_seasons, 1):
    ax = fig.add_subplot(2, 2, i, projection='3d')
    
    filtered_df = df[(df['Season'] == season) & (df['VALUE'] > 1) & (df['VALUE'] < 2000)]
    pivot_table = filtered_df.pivot_table(index='STATION', columns='Year', values='VALUE', aggfunc='sum').fillna(0)
    
    # Reorder the pivot_table and Y_labels based on station_order
    pivot_table = pivot_table.loc[station_ids_in_order]
    
    # Create a list of y-axis labels
    Y_labels = []
    for station_id in station_ids_in_order:
        station_name = station_order[station_id]
        if station_name.upper() in stations_to_label:
            Y_labels.append(station_name)  # Label specific cities
        else:
            Y_labels.append("")  # Empty string for other stations
    
    # Calculate the normalization for this season's data
    norm = Normalize(vmin=pivot_table.values.min(), vmax=pivot_table.values.max())
    
    # Apply the normalization to the count values and map to the 'coolwarm' colormap
    colors = cmap(norm(pivot_table.values.ravel()))

    Z = pivot_table.values
    X_labels = pivot_table.columns[::1]
    
    X, Y = np.meshgrid(np.arange(len(X_labels)), np.arange(len(Y_labels)))
    
    dx = np.ones(Z.shape) * 0.75
    dy = np.ones(Z.shape) * 0.75
    dz = Z
    
    for x, y, z, color in zip(X.flatten(), Y.flatten(), dz.flatten(), colors):
        #ax.bar3d(x, y, 0, dx[0, 0], dy[0, 0], z, shade=True, color=color, edgecolor='k')
        ax.bar3d(x, y, 0, dx[0, 0], dy[0, 0], z, shade=True, color=color)

    ax.set_title(f'Volumes {season}')
    ax.set_zlabel('mm')
    ax.set_zlim(0, 1000)
    ax.set_yticks(np.arange(len(Y_labels))+7.5)
    ax.set_yticklabels(Y_labels)
    ax.set_xticks(np.arange(len(X_labels)))
    ax.set_xticklabels(X_labels, rotation=45)

plt.tight_layout(rect=[0, 0.05, 1, 0.95]) 
#fig.savefig("75gauge_3d_volumes.jpg", dpi=300)
plt.show()


In [None]:
# Define a colormap
cmap = cm.colors.ListedColormap(['#9ABDDC', '#B4CF68', '#FFD872', '#FF96C5', '#FF00FF', 'purple'])

# Create a list of station IDs in the desired order
station_ids_in_order = list(station_order.keys())

unique_seasons = df['Season'].unique()

# Loop through each season and save the plot separately
for i, season in enumerate(unique_seasons):
    fig = plt.figure(figsize=(10, 10))
    ax = fig.add_subplot(111, projection='3d')
    
    filtered_df = df[(df['Season'] == season) & (df['VALUE'] > 1) & (df['VALUE'] < 2000)]
    pivot_table = filtered_df.pivot_table(index='STATION', columns='Year', values='VALUE', aggfunc='sum').fillna(0)
    
    # Reorder the pivot_table and Y_labels based on station_order
    pivot_table = pivot_table.loc[station_ids_in_order]
    
    # Create a list of y-axis labels
    Y_labels = []
    for station_id in station_ids_in_order:
        station_name = station_order[station_id]
        if station_name.upper() in stations_to_label:
            Y_labels.append(station_name)  # Label specific cities
        else:
            Y_labels.append("")  # Empty string for other stations
    
    # Calculate the normalization for this season's data
    norm = Normalize(vmin=pivot_table.values.min(), vmax=pivot_table.values.max())
    
    # Apply the normalization to the count values and map to the colormap
    colors = cmap(norm(pivot_table.values.ravel()))

    Z = pivot_table.values
    X_labels = pivot_table.columns
    X, Y = np.meshgrid(np.arange(len(X_labels)), np.arange(len(Y_labels)))
    
    dx = np.ones(Z.shape) * 0.75
    dy = np.ones(Z.shape) * 0.75
    dz = Z
    
    for x, y, z, color in zip(X.flatten(), Y.flatten(), dz.flatten(), colors):
        ax.bar3d(x, y, 0, dx[0, 0], dy[0, 0], z, shade=True, color=color)
    
    ax.set_title(f'Volumes {season}', fontsize=20)
    ax.set_zlabel('mm')
    ax.set_zlim(0, 1000)
    ax.set_yticks(np.arange(len(Y_labels)))
    ax.set_yticklabels(Y_labels)
    ax.set_xticks(np.arange(len(X_labels)))
    ax.set_xticklabels(X_labels, rotation=45)
    
    # Save each plot as a separate file
    fig.savefig(f"75gauge_3d_volumes_{season}.jpg", dpi=500)
    plt.close(fig)  # Close the figure to avoid overlapping plots


In [None]:
#2d seasonal

import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
from mpl_toolkits.mplot3d import Axes3D
from matplotlib.colors import Normalize
from matplotlib import cm

# Define the station order dictionary
station_order = {
    729: "MESSINA",
    727: "FIUMEDINISI",
    738: "SAN PIER NICETO",
    739: "TORREGROTTA",
    707: "RIPOSTO",
    764: "SIRACUSA",
    734: "NOVARA DI SICILIA",
    708: "LINGUAGLOSSA",
    769: "PACHINO",
    706: "CATANIA",
    767: "NOTO",
    713: "PEDARA",
    735: "PATTI",
    758: "ISPICA",
    715: "RANDAZZO",
    766: "LENTINI",
    765: "FRANCOFONTE",
    709: "MALETTO",
    770: "PALAZZOLO ACREIDE",
    712: "PATERNò",
    705: "BRONTE",
    733: "NASO",
    711: "MINEO",
    725: "CESARò VIGNAZZA",
    756: "RAGUSA",
    761: "SCICLI",
    730: "MILITELLO ROSMARINO",
    737: "SAN FRATELLO",
    757: "COMISO",
    716: "CALTAGIRONE",
    710: "MAZZARRONE",
    760: "SANTA CROCE CAMERINA",
    723: "CARONIA BUZZA",
    717: "AIDONE",
    721: "NICOSIA",
    762: "ACATE",
    722: "PIAZZA ARMERINA",
    731: "MISTRETTA",
    736: "PETTINEO",
    699: "MAZZARINO",
    752: "GANGI",
    718: "ENNA",
    743: "CASTELBUONO",
    701: "RIESI",
    695: "CALTANISSETTA",
    753: "PETRALIA SOTTANA",
    754: "POLIZZI GENEROSA",
    696: "DELIA",
    746: "LASCARI",
    690: "LICATA",
    703: "SCLAFANI BAGNI",
    700: "MUSSOMELI",
    689: "CANICATTì",
    740: "ALIA",
    687: "CAMMARATA",
    684: "AGRIGENTO MANDRASCAVA",
    685: "ARAGONA",
    755: "TERMINI IMERESE",
    683: "AGRIGENTO SCIBICA",
    747: "MEZZOJUSO",
    748: "MISILMERI",
    686: "BIVONA",
    750: "PALERMO",
    693: "RIBERA",
    745: "CORLEONE",
    749: "MONREALE VIGNA API",
    742: "CAMPOREALE",
    751: "PARTINICO",
    744: "Contessa Entellina",
    773: "CASTELLAMMARE DEL GOLFO",
    774: "CASTELVETRANO",
    778: "SALEMI",
    780: "TRAPANI FULGATORE",
    775: "ERICE",
    779: "TRAPANI FONTANASALSA"
}

# Read the CSV file into a DataFrame
df = pd.read_csv('75gauges.csv')

# Convert the 'DATETIME' column to a datetime type
df['DATETIME'] = pd.to_datetime(df['DATETIME'])

def map_month_to_season(month):
    if month in [3, 4, 5]:
        return 'Spring'
    elif month in [6, 7, 8]:
        return 'Summer'
    elif month in [9, 10, 11]:
        return 'Autumn'
    else:
        return 'Winter'

df['Year'] = df['DATETIME'].dt.year
df['Month'] = df['DATETIME'].dt.month
df['Season'] = df['Month'].apply(map_month_to_season)

# Define a colormap
# Define a custom colormap with specific colors from 'tab10'
cmap = cm.colors.ListedColormap(['#9ABDDC','#B4CF68','#FFD872', '#FF96C5','#FF00FF','purple'])
#cmap = cm.get_cmap('Paired',6)
#'#FFAAB0'

# Create a list of station IDs in the desired order
station_ids_in_order = list(station_order.keys())

unique_seasons = df['Season'].unique()

# Calculate the number of rows and columns for subplots based on the number of seasons
num_seasons = len(unique_seasons)
num_rows = int(np.ceil(num_seasons / 2))
num_cols = 2

fig, axes = plt.subplots(num_rows, num_cols, figsize=(22, 22))

for i, season in enumerate(unique_seasons, 1):
    row_index = (i - 1) // num_cols
    col_index = (i - 1) % num_cols
    ax = axes[row_index, col_index]
    
    filtered_df = df[(df['Season'] == season) & (df['VALUE'] > 1) & (df['VALUE'] < 2000)]
    pivot_table = filtered_df.pivot_table(index='STATION', columns='Year', values='VALUE', aggfunc='sum').fillna(0)
    
    # Reorder the pivot_table based on station_order
    pivot_table = pivot_table.loc[station_ids_in_order]
    
    # Create a list of y-axis labels
    Y_labels = [station_order[station_id] for station_id in station_ids_in_order]
    
    # Calculate the normalization for this season's data
    norm = Normalize(vmin=pivot_table.values.min(), vmax=pivot_table.values.max())
    
    # Apply the normalization to the count values and map to the 'coolwarm' colormap
    colors = cmap(norm(pivot_table.values))
    
    # Create X and Y data for the heatmap
    X_labels = pivot_table.columns[::1]
    X, Y = np.meshgrid(np.arange(len(X_labels)), np.arange(len(Y_labels)))
    
    # Plot the heatmap
    heatmap = ax.imshow(pivot_table.values, cmap=cmap, norm=norm)
    ax.set_title(f'Volumes {season}')
    ax.set_xlabel('Year')
    ax.set_ylabel('Station')
    ax.set_xticks(np.arange(len(X_labels)))
    ax.set_xticklabels(X_labels, rotation=90)
    ax.set_yticks(np.arange(len(Y_labels)))
    ax.set_yticklabels(Y_labels, rotation=0)
    
    # Add colorbar
    cbar = plt.colorbar(heatmap, ax=ax)
    cbar.set_label('mm')

plt.tight_layout(rect=[0, 0.05, 1, 0.95])
fig.savefig("75gauge_3d_volumesheatmap.jpg", dpi=300)

plt.show()


NETWORK

In [None]:
#all networks grid visualization
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.patches import Circle
from statsmodels.tsa.stattools import grangercausalitytests

# Load data
gauges_data = pd.read_csv('75gauges12H.csv')
coordinates_data = pd.read_csv('merged_COO.csv')

# Merge data on STATION
merged_data = pd.merge(gauges_data, coordinates_data, on='STATION')

# Convert 'DATETIME' to datetime
merged_data['DATETIME'] = pd.to_datetime(merged_data['DATETIME'])

# Define the station IDs you want to include in the network
stations_of_interest = [779, 750, 684, 695, 718, 756, 706, 764, 729]

# Filter merged_data to include only stations of interest
merged_data_filtered = merged_data[merged_data['STATION'].isin(stations_of_interest)]

# Granger causality test function
def granger_test(dataframe, column1, column2, max_lag=1):
    try:
        gc_test = grangercausalitytests(dataframe[[column1, column2]], maxlag=max_lag)
        f_statistic = gc_test[1][0]['ssr_ftest'][0]  # F-statistic
        p_value = gc_test[1][0]['ssr_ftest'][1]  # p-value
        return f_statistic if p_value < 0.01 else 0
    except Exception as e:
        print(f"Error in granger_test: {e}")
        return 0

# Initialize variables to find global maxima
global_max_out_links = 0
global_max_in_links = 0
global_max_f_statistic_out = 0
global_max_f_statistic_in = 0
station_year_data_out = {}
station_year_data_in = {}

# Loop over each year for analysis
for year in range(merged_data_filtered['DATETIME'].dt.year.min(), merged_data_filtered['DATETIME'].dt.year.max() + 1):
    data_year = merged_data_filtered[(merged_data_filtered['DATETIME'].dt.year == year) & 
                                     (merged_data_filtered['DATETIME'].dt.month.isin([12, 1, 2]))]

    # Replace zero values with a small quantity
    small_quantity = 0.000001
    data_year.loc[:, 'VALUE'] = data_year['VALUE'].replace(0, small_quantity)

    # Calculate log returns
    def calculate_log_returns(group):
        log_returns = np.log(group) - np.log(group.shift(1))
        return log_returns

    data_year.loc[:, 'LOG_RETURN'] = data_year.groupby('STATION')['VALUE'].transform(calculate_log_returns)
    data_year = data_year.dropna()

    # Reset index to align with the original DataFrame
    data_year.reset_index(drop=True, inplace=True)

    # Pivot the filtered data
    pivot_data = data_year.pivot(index='DATETIME', columns='STATION', values='LOG_RETURN').fillna(0)

    temp_out_links_count = {station: 0 for station in pivot_data.columns}
    temp_in_links_count = {station: 0 for station in pivot_data.columns}
    temp_f_statistics_sum_out = {station: 0 for station in pivot_data.columns}
    temp_f_statistics_sum_in = {station: 0 for station in pivot_data.columns}

    # Analysis for each station pair
    for station1 in pivot_data.columns:
        for station2 in pivot_data.columns:
            if station1 != station2:
                f_statistic_out = granger_test(pivot_data, station1, station2)
                f_statistic_in = granger_test(pivot_data, station2, station1)
                
                if f_statistic_out > 0:
                    temp_out_links_count[station1] += 1
                    temp_f_statistics_sum_out[station1] += f_statistic_out
                    station_year_data_out[(station1, year)] = (temp_f_statistics_sum_out[station1], temp_out_links_count[station1])
                
                if f_statistic_in > 0:
                    temp_in_links_count[station1] += 1
                    temp_f_statistics_sum_in[station1] += f_statistic_in
                    station_year_data_in[(station1, year)] = (temp_f_statistics_sum_in[station1], temp_in_links_count[station1])

    # Update global maxima
    year_max_out_links = max(temp_out_links_count.values())
    year_max_in_links = max(temp_in_links_count.values())
    year_max_f_statistic_out = max(temp_f_statistics_sum_out.values())
    year_max_f_statistic_in = max(temp_f_statistics_sum_in.values())
    
    global_max_out_links = max(global_max_out_links, year_max_out_links)
    global_max_in_links = max(global_max_in_links, year_max_in_links)
    global_max_f_statistic_out = max(global_max_f_statistic_out, year_max_f_statistic_out)
    global_max_f_statistic_in = max(global_max_f_statistic_in, year_max_f_statistic_in)

# Define the station order dictionary
station_order = {
    779: "TP",
    750: "PA",
    684: "AG",
    695: "CL",
    718: "EN",
    756: "RG",
    706: "CT",
    764: "SR",
    729: "ME"
}  # ESTOVEST

def create_visualization(station_year_data, global_max_links, global_max_f_statistic, title, filename, fontsize=12):
    fig, ax = plt.subplots(figsize=(15, 10))
    ax.set_facecolor('white')  # Set background to white
    
    # Define scales for size and color
    max_node_size = 0.2  # Maximum node size
    color_norm = plt.Normalize(0, global_max_links)  # Normalize link count
    color_map = plt.cm.coolwarm  # Color map

    # Draw each station-year as a circle on the grid
    for (station, year), (f_stat, links) in station_year_data.items():
        x = year
        y = list(station_order.keys()).index(station)  # Get station position based on order
        size = (f_stat / global_max_f_statistic) * max_node_size  # Scale size based on f_stat
        color = color_map(color_norm(links))  # Get color based on number of links

        # Create a circle and add it to the plot
        circle = Circle((x, y), np.sqrt(size), color=color, alpha=0.6)  # Use square root of size for radius
        ax.add_patch(circle)

    # Add color bar
    sm = plt.cm.ScalarMappable(cmap=color_map, norm=color_norm)
    sm.set_array([])
    cbar = plt.colorbar(sm, ax=ax, orientation='vertical', pad=0.02)
    cbar.set_label('Number of Links', fontsize=12)
    cbar.set_ticks(range(0, global_max_links + 1))  # Set integer ticks on the color bar

    # Set axis labels, ticks, and limits
    ax.set_xlabel('Year')
    ax.set_ylabel('Station ID')
    ax.set_xticks(np.arange(merged_data_filtered['DATETIME'].dt.year.min(), merged_data_filtered['DATETIME'].dt.year.max() + 1))
    ax.set_yticks(np.arange(len(station_order)))
    ax.set_yticklabels(station_order.values())
    ax.set_xlim(merged_data_filtered['DATETIME'].dt.year.min() - 1, merged_data_filtered['DATETIME'].dt.year.max() + 1)
    ax.set_ylim(-1, len(station_order))

    plt.grid(False)
    plt.title(title, fontsize=fontsize)
    plt.savefig(filename, dpi=300)
    plt.show()

# Create visualizations for outlinks and inlinks
# Create visualizations for outlinks and inlinks
create_visualization(station_year_data_out, global_max_out_links, global_max_f_statistic_out, 'Winter Outlinks - 12h',
                     "1.12Hgrid_visualization9gauges.jpg", fontsize=14)
create_visualization(station_year_data_in, global_max_in_links, global_max_f_statistic_in, 'Winter Inlinks - 12h',
                     "1.12Hingrid_visualization9gauges.jpg", fontsize=14)


In [None]:
# netwrok string visualization

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from statsmodels.tsa.stattools import grangercausalitytests

# Load data
#gauges_data = pd.read_csv('combined_data_9gauges.csv')
#gauges_data = pd.read_csv('75gauges.csv')
#gauges_data = pd.read_csv('75gauges6H.csv')
#gauges_data = pd.read_csv('75gauges12H.csv')
gauges_data = pd.read_csv('75gauges24H.csv')
coordinates_data = pd.read_csv('merged_COO.csv')

# Merge data on STATION
merged_data = pd.merge(gauges_data, coordinates_data, on='STATION')

# Convert 'DATETIME' to datetime
merged_data['DATETIME'] = pd.to_datetime(merged_data['DATETIME'])

# Define the station IDs you want to include in the network
stations_of_interest = [779, 750, 684, 695, 718, 756, 706, 764, 729]

# Filter merged_data to include only stations of interest
merged_data_filtered = merged_data[merged_data['STATION'].isin(stations_of_interest)]

# Granger causality test function
def granger_test(dataframe, column1, column2, max_lag=1):
    try:
        gc_test = grangercausalitytests(dataframe[[column1, column2]], maxlag=max_lag)
        f_statistic = gc_test[1][0]['ssr_ftest'][0]  # F-statistic
        p_value = gc_test[1][0]['ssr_ftest'][1]  # p-value
        return f_statistic if p_value < 0.01 else 0
    except Exception as e:
        print(f"Error in granger_test: {e}")
        return 0

# Function to calculate log returns
def calculate_log_returns(group):
    log_returns = np.log(group) - np.log(group.shift(1))
    return log_returns

# Initialize dictionaries to hold total number of links and sum of F-statistics
eastward_links = {}
westward_links = {}

# Loop over each year for analysis
for year in range(merged_data_filtered['DATETIME'].dt.year.min(), merged_data_filtered['DATETIME'].dt.year.max() + 1):
    data_year = merged_data_filtered[(merged_data_filtered['DATETIME'].dt.year == year) & 
                                     (merged_data_filtered['DATETIME'].dt.month.isin([6, 7, 8]))]

    # Replace zero values with a small quantity
    small_quantity = 0.000001
    data_year.loc[:, 'VALUE'] = data_year['VALUE'].replace(0, small_quantity)

    # Calculate log returns
    data_year['LOG_RETURN'] = data_year.groupby('STATION')['VALUE'].transform(calculate_log_returns)
    data_year.dropna(inplace=True)

    # Reset index to align with the original DataFrame
    data_year.reset_index(drop=True, inplace=True)

    # Pivot the filtered data
    pivot_data = data_year.pivot(index='DATETIME', columns='STATION', values='LOG_RETURN').fillna(0)

    links = {}

    # Analysis for each station pair
    for station1 in pivot_data.columns:
        for station2 in pivot_data.columns:
            if station1 != station2:
                f_statistic_out = granger_test(pivot_data, station1, station2)
                if f_statistic_out > 0:
                    links[(station1, station2)] = f_statistic_out

                    # Determine if the link is eastward or westward
                    coord1 = coordinates_data[coordinates_data['STATION'] == station1][['EST', 'NORD']].values[0]
                    coord2 = coordinates_data[coordinates_data['STATION'] == station2][['EST', 'NORD']].values[0]
                    if coord1[0] < coord2[0]:
                        if year not in eastward_links:
                            eastward_links[year] = 0
                        eastward_links[year] += 1
                    else:
                        if year not in westward_links:
                            westward_links[year] = 0
                        westward_links[year] += 1

# Prepare data for plotting
years = sorted(list(set(westward_links.keys()).union(set(eastward_links.keys()))))
number_of_westward_links = [westward_links.get(year, 0) for year in years]
number_of_eastward_links = [eastward_links.get(year, 0) for year in years]

# Plot the number of westward and eastward links over the years
plt.figure(figsize=(10, 6))

# Plot westward links with a solid line
plt.plot(years, number_of_westward_links, marker='o', linestyle='-', color='r', label='Westward Links')

# Plot eastward links with a dashed line
plt.plot(years, number_of_eastward_links, marker='o', linestyle='--', color='b', label='Eastward Links')

# Add labels and title
plt.title('Autumn no of Granger Causality Links over Years (East vs West) - 24h', fontsize=16)
plt.xlabel('Year', fontsize=14)
plt.ylabel('Number of Links', fontsize=14)

# Add grid and legend
plt.grid(True)
plt.legend(loc='upper right')

# Show the plot
plt.xticks(rotation=45)
plt.tight_layout()
#plt.savefig('4.24Heast_west_links_over_years.jpg', dpi=300)
plt.show()


In [None]:
#paiwise string visualization
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from statsmodels.tsa.stattools import grangercausalitytests

# Load data
gauges_data = pd.read_csv('75gauges.csv')
coordinates_data = pd.read_csv('merged_COO.csv')

# Merge data on STATION
merged_data = pd.merge(gauges_data, coordinates_data, on='STATION')

# Convert 'DATETIME' to datetime
merged_data['DATETIME'] = pd.to_datetime(merged_data['DATETIME'])

# Define the station IDs you want to include in the network
stations_of_interest = [779, 750, 684, 695, 718, 756, 706, 764, 729]

# Filter merged_data to include only stations of interest
merged_data_filtered = merged_data[merged_data['STATION'].isin(stations_of_interest)]

# Granger causality test function
def granger_test(dataframe, column1, column2, max_lag=1):
    try:
        gc_test = grangercausalitytests(dataframe[[column1, column2]], maxlag=max_lag)
        f_statistic = gc_test[1][0]['ssr_ftest'][0]  # F-statistic
        p_value = gc_test[1][0]['ssr_ftest'][1]  # p-value
        return f_statistic if p_value < 0.01 else 0
    except Exception as e:
        print(f"Error in granger_test: {e}")
        return 0

# Function to calculate log returns
def calculate_log_returns(group):
    log_returns = np.log(group) - np.log(group.shift(1))
    return log_returns

# Function to create the plot with nodes and directional arrows
def create_network_plot(year, links, coordinates, title, filename):
    fig, ax = plt.subplots(figsize=(12, 8))

    # Plot the nodes
    for _, row in coordinates.iterrows():
        if row['STATION'] in stations_of_interest:
            ax.plot(row['EST'], row['NORD'], 'bo', markersize=2)  # blue circle for each node
            ax.text(row['EST'], row['NORD'], row['LOCATION'], fontsize=12, ha='right')

    # Plot the directional links
    for (station1, station2), f_stat in links.items():
        coord1 = coordinates[coordinates['STATION'] == station1][['EST', 'NORD']].values[0]
        coord2 = coordinates[coordinates['STATION'] == station2][['EST', 'NORD']].values[0]
        ax.annotate("",
                    xy=(coord2[0], coord2[1]), xycoords='data',
                    xytext=(coord1[0], coord1[1]), textcoords='data',
                    arrowprops=dict(arrowstyle="->", color='blue', lw=1))

    ax.set_title(title)
    ax.set_xlabel('')
    ax.set_ylabel('')
    ax.set_xticks([])
    ax.set_yticks([])
    ax.axis('off')
    plt.savefig(filename)
    plt.show()

# Initialize dictionary to hold F-statistics for the PA and CT stations
f_stats_pa_ct = {}
f_stats_ct_pa = {}

# Loop over each year for analysis
for year in range(merged_data_filtered['DATETIME'].dt.year.min(), merged_data_filtered['DATETIME'].dt.year.max() + 1):
    data_year = merged_data_filtered[(merged_data_filtered['DATETIME'].dt.year == year) & 
                                     (merged_data_filtered['DATETIME'].dt.month.isin([12, 1, 2]))]

    # Replace zero values with a small quantity
    small_quantity = 0.000001
    data_year['VALUE'] = data_year['VALUE'].replace(0, small_quantity)

    # Calculate log returns
    data_year['LOG_RETURN'] = data_year.groupby('STATION')['VALUE'].transform(calculate_log_returns)
    data_year.dropna(inplace=True)

    # Reset index to align with the original DataFrame
    data_year.reset_index(drop=True, inplace=True)

    # Pivot the filtered data
    pivot_data = data_year.pivot(index='DATETIME', columns='STATION', values='LOG_RETURN').fillna(0)

    links = {}

    # Analysis for each station pair
    for station1 in pivot_data.columns:
        for station2 in pivot_data.columns:
            if station1 != station2:
                f_statistic_out = granger_test(pivot_data, station1, station2)
                if f_statistic_out > 0:
                    links[(station1, station2)] = f_statistic_out

                # Collect F-statistics for PA and CT stations
                if station1 == 750 and station2 == 779:
                    if year not in f_stats_pa_ct:
                        f_stats_pa_ct[year] = f_statistic_out
                if station1 == 779 and station2 == 750:
                    if year not in f_stats_ct_pa:
                        f_stats_ct_pa[year] = f_statistic_out

    # Create plot for the year
    title = f'Granger Causality Network for Summer {year}'
    filename = f'0.gcnetwork_{year}.jpg'
    create_network_plot(year, links, coordinates_data, title, filename)

# Prepare data for heatmap
years = sorted(list(f_stats_pa_ct.keys()))
heatmap_data = pd.DataFrame({
    'PA-TP': [f_stats_pa_ct[year] for year in years],
    'TP-PA': [f_stats_ct_pa[year] for year in years]
}, index=years)

# Plot the heatmap with squared cells and explicit normalization
plt.figure(figsize=(12, 2))  # Adjust the figure size to make it look like strips
max_val = max(heatmap_data.values.flatten())
sns.heatmap(heatmap_data.T, annot=False, cmap='coolwarm', cbar_kws={'label': 'F-statistic'}, vmin=0, vmax=max_val, square=True)
plt.title('Winter Granger Causality F-statistics')
plt.xlabel('Year')
plt.ylabel('Direction')
plt.yticks(rotation=0)
plt.xticks(rotation=45) 
plt.savefig('1.PATP.jpg', dpi=300)
plt.show()

In [None]:
#LINK 2D GRID seasonal

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
from math import radians, cos, sin, asin, sqrt
from statsmodels.tsa.stattools import grangercausalitytests

# Haversine formula to calculate the distance between two points on the Earth
def haversine(lon1, lat1, lon2, lat2):
    lon1, lat1, lon2, lat2 = map(radians, [lon1, lat1, lon2, lat2])
    dlon = lon2 - lon1
    dlat = lat2 - lat1
    a = sin(dlat / 2)**2 + cos(lat1) * cos(lat2) * sin(dlon / 2)**2
    c = 2 * asin(sqrt(a))
    r = 6371  # Radius of Earth in kilometers
    return c * r

# Function to perform Granger causality test and return F-statistic
def granger_test(dataframe, column1, column2, max_lag=1):
    try:
        gc_test = grangercausalitytests(dataframe[[column1, column2]], maxlag=max_lag)
        f_statistic = gc_test[1][0]['ssr_ftest'][0]  # F-statistic
        p_value = gc_test[1][0]['ssr_ftest'][1]  # p-value
        return f_statistic if p_value < 0.0001 else 0
    except:
        return 0  # Return 0 in case of an error

# Function to determine the color of a cell based on the link direction and distance
def link_color(lon1, lat1, lon2, lat2, max_distance):
    distance = haversine(lon1, lat1, lon2, lat2)
    normalized_distance = distance / max_distance  # Normalize distance to [0, 1]
    
    if lon2 > lon1:
        # Eastward link, use red colormap
        return mcolors.to_hex(plt.cm.Reds(normalized_distance))
    else:
        # Westward link, use blue colormap
        return mcolors.to_hex(plt.cm.Blues(normalized_distance))

# Load data
gauges_data = pd.read_csv('75gauges.csv')
coordinates_data = pd.read_csv('merged_COO.csv')

# Merge data on STATION
merged_data = pd.merge(gauges_data, coordinates_data, on='STATION')

# Convert 'DATETIME' to datetime
merged_data['DATETIME'] = pd.to_datetime(merged_data['DATETIME'])

# Define the station order dictionary
station_order = {
    779: "TP",
    775: "ERICE",
    780: "TRAPANI FULGATORE",
    778: "SALEMI",
    774: "CASTELVETRANO",
    773: "CASTELLAMMARE DEL GOLFO",
    744: "Contessa Entellina",
    751: "PARTINICO",
    742: "CAMPOREALE",
    749: "MONREALE VIGNA API",
    745: "CORLEONE",
    693: "RIBERA",
    750: "PA",
    686: "BIVONA",
    748: "MISILMERI",
    747: "MEZZOJUSO",
    683: "AG",
    755: "TERMINI IMERESE",
    685: "ARAGONA",
    684: "AGRIGENTO MANDRASCAVA",
    687: "CAMMARATA",
    740: "ALIA",
    689: "CANICATTì",
    700: "MUSSOMELI",
    703: "SCLAFANI BAGNI",
    690: "LICATA",
    746: "LASCARI",
    696: "DELIA",
    754: "POLIZZI GENEROSA",
    753: "PETRALIA SOTTANA",
    695: "CL",
    701: "RIESI",
    743: "CASTELBUONO",
    718: "EN",
    752: "GANGI",
    699: "MAZZARINO",
    736: "PETTINEO",
    731: "MISTRETTA",
    722: "PIAZZA ARMERINA",
    762: "ACATE",
    721: "NICOSIA",
    717: "AIDONE",
    723: "CARONIA BUZZA",
    760: "SANTA CROCE CAMERINA",
    710: "MAZZARRONE",
    716: "CALTAGIRONE",
    757: "COMISO",
    737: "SAN FRATELLO",
    730: "MILITELLO ROSMARINO",
    761: "SCICLI",
    756: "RG",
    725: "CESARò VIGNAZZA",
    711: "MINEO",
    733: "NASO",
    705: "BRONTE",
    712: "PATERNò",
    770: "PALAZZOLO ACREIDE",
    709: "MALETTO",
    765: "FRANCOFONTE",
    766: "LENTINI",
    715: "RANDAZZO",
    758: "ISPICA",
    735: "PATTI",
    713: "PEDARA",
    767: "NOTO",
    706: "CT",
    769: "PACHINO",
    708: "LINGUAGLOSSA",
    734: "NOVARA DI SICILIA",
    764: "SR",
    707: "RIPOSTO",
    739: "TORREGROTTA",
    738: "SAN PIER NICETO",
    727: "FIUMEDINISI",
    729: "ME"
}


# Create a list of station IDs in order
station_list = list(station_order.keys())

# Process data for each year and calculate all F-statistics
all_f_statistics = []
for year in range(merged_data['DATETIME'].dt.year.min(), merged_data['DATETIME'].dt.year.max() + 1):
    data_year = merged_data[(merged_data['DATETIME'].dt.year == year) & (merged_data['DATETIME'].dt.month.isin([9, 10, 11]))]
    pivot_data = data_year.pivot(index='DATETIME', columns='STATION', values='VALUE').fillna(0)

    for station1 in pivot_data.columns:
        for station2 in pivot_data.columns:
            if station1 != station2:
                f_statistic = granger_test(pivot_data, station1, station2)
                all_f_statistics.append(f_statistic)

# Calculate the 99th percentile of F-statistics
f_statistic_99th_percentile = np.percentile(all_f_statistics, 99)

# Initialize a matrix to store the link information for each year
years = range(merged_data['DATETIME'].dt.year.min(), merged_data['DATETIME'].dt.year.max() + 1)
num_stations = len(station_list)
link_matrix = {year: np.zeros((num_stations, num_stations)) for year in years}

# Populate the link matrix
for year in years:
    data_year = merged_data[(merged_data['DATETIME'].dt.year == year) & (merged_data['DATETIME'].dt.month.isin([9, 10,11]))]
    pivot_data = data_year.pivot(index='DATETIME', columns='STATION', values='VALUE').fillna(0)

    for i, station1 in enumerate(station_list):
        for j, station2 in enumerate(station_list):
            if station1 != station2:
                f_statistic = granger_test(pivot_data, station1, station2)
                if f_statistic > f_statistic_99th_percentile:
                    link_matrix[year][i, j] = f_statistic

# Create a dictionary mapping station IDs to their coordinates
station_coordinates = coordinates_data.set_index('STATION')[['EST', 'NORD']].to_dict('index')

# Get maximum distance
max_distance = 0
for i in range(len(station_list)):
    for j in range(i+1, len(station_list)):
        lon1, lat1 = station_coordinates[station_list[i]]['EST'], station_coordinates[station_list[i]]['NORD']
        lon2, lat2 = station_coordinates[station_list[j]]['EST'], station_coordinates[station_list[j]]['NORD']
        dist = haversine(lon1, lat1, lon2, lat2)
        if dist > max_distance:
            max_distance = dist

# Create and display the grids
for year in years:
    grid = np.ones((num_stations, num_stations, 3))  # Initialize grid with white background (no links)

    for i, station1 in enumerate(station_list):
        lon1, lat1 = station_coordinates[station1]['EST'], station_coordinates[station1]['NORD']
        for j, station2 in enumerate(station_list):
            if link_matrix[year][i, j] > 0:  # Check if there's a significant link
                lon2, lat2 = station_coordinates[station2]['EST'], station_coordinates[station2]['NORD']
                color = link_color(lon1, lat1, lon2, lat2, max_distance)
                grid[i, j] = mcolors.to_rgb(color)

    # Create and display the plot
    plt.figure(figsize=(10, 10))
    plt.imshow(grid, interpolation='nearest')
    plt.title(f'Autumn Adjacency Matrix for {year}')
    plt.xticks(range(num_stations), [station_order[s] for s in station_list], rotation=90)
    plt.yticks(range(num_stations), [station_order[s] for s in station_list])
    plt.grid(False)
    plt.savefig(f"4.Grid_{year}autumn.jpg", bbox_inches='tight', dpi=300)

    plt.show()
    #GRANGER2DGRID4

In [None]:
from mpl_toolkits.mplot3d import Axes3D

# Iterate over years
for year in years:
    # Initialize figure and axis
    fig = plt.figure(figsize=(10, 10))
    ax = fig.add_subplot(111, projection='3d')

    # Set up x, y, z coordinates for bars and their colors
    x_coords = []
    y_coords = []
    z_coords = []
    f_values = []
    colors = []

    # Populate coordinates, F-values, and colors
    for i, station1 in enumerate(station_list):
        lon1, lat1 = station_coordinates[station1]['EST'], station_coordinates[station1]['NORD']
        for j, station2 in enumerate(station_list):
            if link_matrix[year][i, j] > 0:  # Check if there's a significant link
                lon2, lat2 = station_coordinates[station2]['EST'], station_coordinates[station2]['NORD']
                x_coords.append(i)
                y_coords.append(j)
                z_coords.append(0)  # Z-coordinate for the base of the bar (start at 0)
                f_values.append(link_matrix[year][i, j])  # F-statistic value
                # Determine color based on link direction and distance
                color = link_color(lon1, lat1, lon2, lat2, max_distance)
                colors.append(color)

    # Plot bars with colors
    for x, y, z, f, color in zip(x_coords, y_coords, z_coords, f_values, colors):
        ax.bar3d(x, y, z, 1, 1, f, color=color, zsort='average')

    # Set labels and title
    ax.set_xticks(range(num_stations))
    ax.set_xticklabels([station_order[s] for s in station_list], rotation=90)
    ax.set_yticks(range(num_stations))
    ax.set_yticklabels([station_order[s] for s in station_list])
    ax.set_xlabel('Station')
    ax.set_ylabel('Station')
    ax.set_zlabel('F-Statistic')
    ax.set_title(f'Autumn Adjacency Matrix for {year}')

    plt.savefig(f"4.3D_grid_{year}autumn.jpg", bbox_inches='tight', dpi=300)

    plt.show()