In [None]:
filepath = '/content/drive/MyDrive/air.2m.mon.mean.nc'
choose_months = np.arange(0, 12) #this chooses all however can be reduced to restrict the temporal domain
lag = 0 #lag at which to calculate cross correlation between series
threshold = 0.2

# Installation of Packages

In [None]:
!pip install netCDF4

In [None]:
!pip install h5py

In [None]:
!pip install --upgrade netCDF4 h5py

In [None]:
!pip install cdlib

In [None]:
!pip install geopandas
!pip install geoplot

# Importing

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import netCDF4 as nc
import h5py
import networkx as nx
import cartopy.crs as ccrs
from tqdm import tqdm
from statsmodels.tsa.stattools import ccf
from cdlib import algorithms

# Functions required

In [None]:
#Loading and Reading Data
def load_and_read(filepath):
  f = nc.Dataset(filepath, 'r')
  lat = f.variables['lat']
  lon = f.variables['lon']
  band = f.variables['time_bnds']
  time = f.variables['time']
  temp = f.variables['air']
  lat_arr = lat[:]
  lon_arr = lon[:]
  band_arr = band[:]
  time_arr = time[:]
  temp_arr = np.zeros((540, 94, 192)) ; temp_arr = temp[:]
  monthly_temp = temp_arr.reshape((12, 45, 94, 192), order='F')
  monthly_means = np.mean(monthly_temp, axis=1)

  anomaly = np.zeros_like(monthly_temp)
  for i in range(45):
    anomaly[:, i] = monthly_temp[:, i]-monthly_means

  return monthly_means, anomaly, lat_arr, lon_arr

#Finding Correlations between spatial locations
def find_correlations(anomaly, choose_months = choose_months, lag=lag):
  correlations = np.zeros((94*192, 94*192))

reshaped = anomaly.reshape(12, 45, 94*192)

for i in tqdm(range(94*192)):
  for j in range(i+1, 94*192):
    series1 = reshaped[choose_months, :, i].flatten()
    series2 = reshaped[choose_montha, :, j].flatten()

    correlations[i, j] = ccf(series1, series2)[lag]

  return correlations


#Building network based on calculated correlations and a threshold
def build_network(correlations, lat_arr, lon_arr, resolution=1, threshold=0.5):
  grid = np.zeros((94, 192), dtype=tuple)
  for i in tqdm(range(94)):
    for j in range(192):
      grid[i, j] = (lat_arr[i], lon_arr[j]-180)

  list_of_points = grid.flatten()
  flattened = correlations + correlations.T

  threshold = 0.5
  A = np.where(np.abs(flattened) > threshold, np.abs(flattened), 0)

  G = nx.Graph()

  for i in range(0, len(A), int(1/resolution)):
    G.add_node(i, pos=list_of_points[i])

  for i in range(0, len(A), int(1/resolution)):
    for j in range(0, len(A), int(1/resolution)):
      if A[i, j] == 1:
        G.add_edge(i, j)

  return G

#Plotting distribution of correlations
def plot_correlations(correlations, flattened, threshold=0.5):
  upper_entries = np.abs(flattened[np.triu_indices(flattened.shape[0], k=1)])
  fig, ax = plt.subplots(figsize=(10, 5), constrained_layout=True)
  plt.hist(upper_entries, bins=100, color="tab:blue", alpha=0.5, label="counts",)
  ax.axvline(threshold, color="tab:red", label="threshold")
  ax.set_xlabel("Correlation")
  ax.set_ylabel("Pairs of Locations")
  ax.legend()

  return fig

#Plotting network
def plot_network(G, plot_edges=True, plot_nodes=True):
  fig, ax = plt.subplots(figsize = (15, 10), subplot_kw={'projection': ccrs.PlateCarree()})
  pos = {node: (lon, lat) for node, (lat, lon) in nx.get_node_attributes(G, 'pos').items()}

  if plot_edges:
    nx.draw_networkx_edges(G, pos, ax=ax, edge_color='k', alpha=0.2)
  if plot_nodes:
    nx.draw_networkx_nodes(G, pos, ax=ax, node_color='r', node_size=15, alpha=0.7)

  ax.coastlines()
  ax.gridlines()

  return fig


def community_detection(G):
  communities = algorithms.louvain(G)
  print('No. of Communities:', len(communities.communities))
  print('Modularity:', communities.newman_girvan_modularity())
  return communities

def plot_communities(G, communities, plot_edges=False, node_size=50):
  num_communities = len(communities.communities)
  community_colors = plt.cm.rainbow(np.linspace(0, 1, num_communities))

  node_colors = {node: community_colors[community_id] for community_id, community in enumerate(communities.communities) for node in community}
  community_mapping = {node: community_id for community_id, community in enumerate(communities.communities) for node in community}

  fig, ax = plt.subplots(figsize = (15, 10), subplot_kw={'projection': ccrs.PlateCarree()})
  pos = {node: (lon, lat) for node, (lat, lon) in nx.get_node_attributes(G, 'pos').items()}

  if plot_edges:
    nx.draw_networkx_edges(G, pos, ax=ax, edge_color='k', alpha=0.3)

  for node, (lon, lat) in pos.items():
    color = community_colors[community_mapping.get(node, 0)]  # Default to the first color if not in any community
    ax.scatter(lon, lat, c=[color], s=node_size, transform=ccrs.PlateCarree())

  ax.coastlines()
  ax.gridlines()

  legend_labels = [f"Community {i}" for i in range(num_communities)]
  legend_handles = [plt.Line2D([0], [0], marker='o', color='w', markerfacecolor=community_colors[i], markersize=10) for i in range(num_communities)]
  ax.legend(legend_handles, legend_labels, loc='upper right')

  return fig




# Pipeline

In [None]:
monthly_means, anomaly, lat_arr, lon_arr = load_and_read(filepath)
correlations = find_correlations(anomaly, P_level, choose_months, lag)

G1 = build_network(correlations, lat_arr, lon_arr, resolution=1/10, threshold=threshold)
G2 = build_network(correlations, lat_arr, lon_arr, resolution=1, threshold=threshold)

print('Resolution 1/10th')
communities1 = community_detection(G1)
print('\n')
print('Full Resolution')
communities2 = community_detection(G2)

print('Modularity:', communities2.newman_girvan_modularity())

# Plotting

In [None]:
#1/10th resolution
fig= plot_network(G1)

In [None]:
#Communities in 1/10th resolution
fig = plot_communities(G1, communities1)

In [None]:
#Communities in full resolution
fig = plot_communities(G2, communities2, node_size=10)

In [None]:
# Extract community sizes
community_sizes = [len(community) for community in communities2.communities]
for i in range(len(community_sizes)):
  plt.vlines(i, ymin=0, ymax=community_sizes[i], color=community_colors[i])
plt.xlabel("Community No.")
plt.ylabel("Number of Nodes")
plt.show()