<a href="https://colab.research.google.com/github/SeanBarnier/HAFS_Air-Sea/blob/main/tempMaps.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

#Set up environment

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
!pip install cfgrib
!pip install cartopy

In [None]:
import xarray as xr
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import cartopy.crs as ccrs
import cartopy.feature as cft
from datetime import datetime as dt
import cfgrib

#User parameters

In [None]:
name = "Milton"
tcNum = "14"
trackType = ""

initTime = dt(year=2024, month=10, day=7, hour=12) #Time when Milton began its most rapid intensification

fHourStep = 12      #Normally 3 for HAFS-A
forecastLength = 36 #Normally 126 for HAFS-A.
#runStep = 6         #Normally 6 for HAFS-A

figureSuffix = "_RI"
subfolder = "RI/"
dataPath = "/content/drive/MyDrive/savedData/"
figurePath = "/content/drive/MyDrive/figures/"

potentialTemp = True #Use atmospheric potential temperature instead of in-situ temperature

radType = "dist" # "wind", "dist", or None
windRad = 34 # wind radius in kt to plot; only matters if radType = "wind"
dist = 1.0 # in degrees; only matters if radType = "dist"

#Retrieve HAFS-A Data

Find times needed

In [None]:
dateFormat = "%Y-%m-%d %H:%M:%S"
runFormat = "%Y%m%d%H"

#Times to plot
fcastTimes = [] #Key: initiation, item: valid time list
fhour = 0
validTime = initTime
while fhour <= forecastLength:
    fcastTimes.append(validTime)
    validTime += pd.Timedelta(hours=fHourStep)
    fhour += fHourStep

#All forecast times
allTimes = [] #Key: initiation, item: valid time list
fhour = 0
validTime = initTime
while fhour <= forecastLength:
    allTimes.append(validTime)
    validTime += pd.Timedelta(hours=3)
    fhour += 3

Find storm location in HAFS-A from ATCF files. Used to find along-storm profile.

In [None]:
cols = ["BASIN", "CY", "YYYYMMDDHH", "TECHNUM/MIN", "TECH", "TAU", "LatN/S", "LonE/W",
    "VMAX", "MSLP", "TY", "RAD", "WINDCODE", "RAD1", "RAD2", "RAD3", "RAD4",
    "POUTER", "ROUTER", "RMW", "GUSTS", "EYE", "SUBREGION", "MAXSEAS", "INITIALS",
    "DIR", "SPEED", "STORMNAME", "DEPTH", "SEAS", "SEASCODE", "SEAS1", "SEAS2",
    "SEAS3", "SEAS4", "USERDEFINED1", "Thermo1", "Thermo2", "Thermo3", "Thermo4",
    "Thermo5", "Thermo6", "Thermo7", "USERDEFINED2", "DT", "SHR82", "SHR81_1",
    "SHR82_2",  "USERDEFINED3", "SST", "USERDEFINED4", "ARMW1", "ARMW2"]

initStr, initHour = initTime.strftime("%Y%m%d_%H").split("_")

atcfURL = f"https://noaa-nws-hafs-pds.s3.amazonaws.com/hfsa/{initStr}/{initHour}/{tcNum}l.{initStr}{initHour}.hfsa.trak.atcfunix"
atcfFile = "atcf_" + initStr + "_" + initHour + ".csv"

!wget -O {atcfFile} {atcfURL}
atcf = pd.read_csv(atcfFile, names=cols)

In [None]:
tcLocs = {}
windRads = {}

for valid in allTimes:
    fHour = int((valid-initTime).total_seconds() / 3600)

    pointLat = int(atcf[atcf.TAU==fHour]["LatN/S"].iloc[0].replace("N", ""))/10
    pointLon = int(atcf[atcf.TAU==fHour]["LonE/W"].iloc[0].replace("W", ""))/-10 #Assume western hemisphere
    tcLocs[valid] = (pointLat, pointLon)

    #if windRad == "RMW": windRads[runTime][valid] = atcf[atcf.TAU==fHour]["RMW"].iloc[0]
    #else:
    windRadData = atcf[atcf.RAD==windRad][atcf.TAU==fHour]
    if len(windRadData) == 0: windRads[valid] = 0 #If there is no radius for a given wind value, the TC does not posess that wind value
    # Get smallest radius; outside this, the wind will be less than 34kt in one quadrant
    else: windRads[valid] = np.min([windRadData["RAD1"], windRadData["RAD2"], windRadData["RAD3"], windRadData["RAD4"]]) / 60 # Convert from n mi to degrees

#Figures

Set Figure Parameters

In [None]:
lonMin, lonMax, latMin, latMax = -95, -82, 20, 27

atmLayers = [0, 1, 2, 3, 4, 5, 6]
oceLayers = [0, 1, 2, 3, 4, 5, 6]

buffer = 0.1 / len(atmLayers)
xWidth = 0.8 / len(fcastTimes)
yWidth = 0.4 / len(atmLayers)
xcorners = np.arange(0.1,0.9,xWidth)
ycorners = np.arange(0.7,0.3,-yWidth)

diffType = "top" #top or incremental

Atmospheric Temperature

In [None]:
atmFig = plt.figure(figsize=(len(fcastTimes)*5,len(atmLayers)*6))
atmAxes = [[atmFig.add_axes([xcorner, ycorner, xWidth-buffer, yWidth-buffer], projection=ccrs.PlateCarree()) for ycorner in ycorners] for xcorner in xcorners]

contourLevs = {layer:["Empty"] for layer in atmLayers}
for valid, atmAxColumn in zip(fcastTimes, atmAxes):

  fhour = str(int((valid-initTime).total_seconds() / 3600))
  while len(fhour) < 3: fhour = "0" + fhour

  atmFile = "hafsa_" + initStr + initHour + "_f" + fhour + ".nc"
  atmPath = dataPath + "hafsaOutput/" + subfolder + atmFile
  atmData = xr.open_dataset(atmPath)

  for atmLayer, atmAx in zip(atmLayers, atmAxColumn):

    atmSlice = atmData.isel(isobaricInhPa=atmLayer).sel(longitude=slice(lonMin+360, lonMax+360), latitude=slice(latMin, latMax))

    if potentialTemp: temp = atmSlice.t.data * (1000/atmSlice.isobaricInhPa.data) ** 0.286
    else: temp = atmSlice.t.data

    if "Empty" in contourLevs[atmLayer]: contourLevs[atmLayer] = [round(l, 1) for l in np.linspace(np.min(temp[np.isnan(temp)==False])-3, np.max(temp[np.isnan(temp)==False])+3, 15)]

    tempContour = atmAx.contourf(atmSlice.longitude.data, atmSlice.latitude.data, temp, cmap="coolwarm", transform=ccrs.PlateCarree(),
                                extent = [lonMin, lonMax, latMin, latMax], levels=contourLevs[atmLayer])

    if valid == fcastTimes[-1]: atmFig.colorbar(tempContour, shrink=0.8)

    #atmAx.scatter(tcLocs[valid][1], tcLocs[valid][0], marker="*", color="black", s=75, transform=ccrs.PlateCarree())

    atmAx.add_feature(cft.COASTLINE)
    atmAx.add_feature(cft.BORDERS)
    atmAx.gridlines(draw_labels=["left", "bottom"], alpha=0.5)
    atmAx.set_title(f'{atmData.isobaricInhPa.data[atmLayer]} hPa\n{valid.strftime("%Y-%m-%d %HUTC")}')

    atmAx.set_extent([lonMin, lonMax, latMin, latMax])

#atmFig.suptitle(f"Atmosphere Initialized {initTime.strftime('%Y-%m-%d %HUTC')}")

Oceanic Temperature

In [None]:
oceFig = plt.figure(figsize=(len(fcastTimes)*5,len(oceLayers)*6))
oceAxes = [[oceFig.add_axes([xcorner, ycorner, xWidth-buffer, yWidth-buffer], projection=ccrs.PlateCarree()) for ycorner in ycorners] for xcorner in xcorners]

contourLevs = {layer:["Empty"] for layer in atmLayers}

for valid, oceAxColumn in zip(fcastTimes, oceAxes):

  fhour = str(int((valid-initTime).total_seconds() / 3600))
  while len(fhour) < 3: fhour = "0" + fhour

  oceFile = "mom6_" + initStr + initHour + "_f" + fhour + ".nc"
  ocePath = dataPath + "mom6Output/" + subfolder + oceFile
  oceData = xr.open_dataset(ocePath, decode_times=False)

  for oceLayer, oceAx in zip(oceLayers, oceAxColumn):

    oceSlice = oceData.isel(z_l=oceLayer).sel(xh=slice(lonMin, lonMax), yh=slice(latMin, latMax), time=oceData.time.data[0])

    temp = oceSlice.temp.data + 273.15
    if "Empty" in contourLevs[oceLayer]: contourLevs[oceLayer] = [round(l, 1) for l in np.linspace(min(temp[np.isnan(temp)==False])-1, max(temp[np.isnan(temp)==False])+1, 15)]

    tempContour = oceAx.contourf(oceSlice.xh.data, oceSlice.yh.data, temp, cmap="coolwarm", transform=ccrs.PlateCarree(),
                                extent=[lonMin, lonMax, latMin, latMax], levels=contourLevs[oceLayer])
    if valid == fcastTimes[-1]: oceFig.colorbar(tempContour, shrink=0.8, label="Temperature (K)")

    oceAx.scatter(tcLocs[valid][1], tcLocs[valid][0], marker="*", color="black", s=100, transform=ccrs.PlateCarree())

    gridLabels = []
    if oceLayer == oceLayers[-1]: gridLabels.append("bottom")
    if valid == fcastTimes[0]: gridLabels.append("left")
    oceAx.gridlines(draw_labels=gridLabels, alpha=0.5)

    oceAx.add_feature(cft.COASTLINE)
    oceAx.add_feature(cft.BORDERS)
    oceAx.set_title(f"{round(oceData.z_l.data[oceLayer],1)} m\n{valid.strftime('%Y-%m-%d %HUTC')}")

    oceAx.set_extent([lonMin, lonMax, latMin, latMax])

#oceFig.suptitle(f"Initialized {initTime.strftime('%Y-%m-%d %HUTC')}")

Atmosphere Temperature Difference by Level

In [None]:
atmFig = plt.figure(figsize=(len(fcastTimes)*5,len(atmLayers)*6))
atmAxes = [[atmFig.add_axes([xcorner, ycorner, xWidth-buffer, yWidth-buffer], projection=ccrs.PlateCarree()) for ycorner in ycorners[:-1]] for xcorner in xcorners]

contourLevs = {layer:["Empty"] for layer in atmLayers}
for valid, atmAxColumn in zip(fcastTimes, atmAxes):

  fhour = str(int((valid-initTime).total_seconds() / 3600))
  while len(fhour) < 3: fhour = "0" + fhour

  atmFile = "hafsa_" + initStr + initHour + "_f" + fhour + ".nc"
  atmPath = dataPath + "hafsaOutput/" + subfolder + atmFile
  atmData = xr.open_dataset(atmPath)

  lowerLayers = {"top":[0]*len(atmLayers[:-1]), "incremental":atmLayers[1:]}[diffType]
  for upperLayer, lowerLayer, atmAx in zip(atmLayers[1:], lowerLayers, atmAxColumn):

    upperSlice = atmData.isel(isobaricInhPa=upperLayer).sel(longitude=slice(lonMin+360, lonMax+360), latitude=slice(latMin, latMax))
    lowerSlice = atmData.isel(isobaricInhPa=lowerLayer).sel(longitude=slice(lonMin+360, lonMax+360), latitude=slice(latMin, latMax))

    if potentialTemp: tdiff = (upperSlice.t.data * (1000/atmSlice.isobaricInhPa.data) ** 0.286) - (lowerSlice.t.data * (1000/atmSlice.isobaricInhPa.data) ** 0.286)
    else: tdiff = upperSlice.t.data - lowerSlice.t.data

    if "Empty" in contourLevs[upperLayer]: contourLevs[upperLayer] = np.round(np.linspace(-1*(np.max(abs(tdiff[np.isnan(tdiff)==False]))+1), np.max(abs(tdiff[np.isnan(tdiff)==False]))+1, 15), 2)

    tempContour = atmAx.contourf(upperSlice.longitude.data, upperSlice.latitude.data, tdiff, cmap="bwr", transform=ccrs.PlateCarree(),
                                extent = [lonMin, lonMax, latMin, latMax], levels=contourLevs[upperLayer])
    if valid == fcastTimes[-1]: atmFig.colorbar(tempContour, shrink=0.8, label=f"T(z={int(atmData.isobaricInhPa.data[upperLayer])}) - T(z={int(atmData.isobaricInhPa.data[lowerLayer])})")

    #atmAx.scatter(tcLocs[valid][1], tcLocs[valid][0], marker="*", color="black", s=75, transform=ccrs.PlateCarree())
    gridLabels = []
    if lowerLayer == oceLayers[-1]: gridLabels.append("bottom")
    if valid == fcastTimes[0]: gridLabels.append("left")
    atmAx.gridlines(draw_labels=gridLabels, alpha=0.5)

    atmAx.add_feature(cft.COASTLINE)
    atmAx.add_feature(cft.BORDERS)
    atmAx.set_title(valid.strftime("%Y-%m-%d %HUTC"))

    atmAx.set_extent([lonMin, lonMax, latMin, latMax])

#atmFig.suptitle(f"Atmosphere Initialized {initTime.strftime('%Y-%m-%d %HUTC')}")

Ocean Temperature Difference by Level

In [None]:
oceFig = plt.figure(figsize=(len(fcastTimes)*6,len(oceLayers)*7))
oceAxes = [[oceFig.add_axes([xcorner, ycorner, xWidth-buffer, yWidth-buffer], projection=ccrs.PlateCarree()) for ycorner in ycorners[:-1]] for xcorner in xcorners]

contourLevs = {layer:["Empty"] for layer in oceLayers}

for valid, oceAxColumn in zip(fcastTimes, oceAxes):

  fhour = str(int((valid-initTime).total_seconds() / 3600))
  while len(fhour) < 3: fhour = "0" + fhour

  oceFile = "mom6_" + initStr + initHour + "_f" + fhour + ".nc"
  ocePath = dataPath + "mom6Output/" + subfolder + oceFile
  oceData = xr.open_dataset(ocePath, decode_times=False)

  upperLayers = {"top":[0]*len(oceLayers[:-1]), "incremental":oceLayers[:-1]}[diffType]
  for upperLayer, lowerLayer, oceAx in zip(upperLayers, oceLayers[1:], oceAxColumn):

    upperSlice = oceData.isel(z_l=upperLayer, time=0).sel(xh=slice(lonMin, lonMax), yh=slice(latMin, latMax))
    lowerSlice = oceData.isel(z_l=lowerLayer, time=0).sel(xh=slice(lonMin, lonMax), yh=slice(latMin, latMax))

    tdiff = upperSlice.temp.data - lowerSlice.temp.data

    maxContour = np.max(abs(tdiff[np.isnan(tdiff)==False]))+0.5
    if "Empty" in contourLevs[lowerLayer]: contourLevs[lowerLayer] = np.round(np.linspace(-maxContour, maxContour, 20), 2)

    #oceSlice = oceData.isel(z_l=oceLayer).sel(xh=slice(lonMin, lonMax), yh=slice(latMin, latMax), time=oceData.time.data[0]

    tempContour = oceAx.contourf(lowerSlice.xh.data, lowerSlice.yh.data, tdiff, cmap="bwr", transform=ccrs.PlateCarree(),
                                extent=[lonMin, lonMax, latMin, latMax], levels=contourLevs[lowerLayer])
    if valid == fcastTimes[-1]:
      cax = oceFig.add_axes([oceAx.get_position().get_points()[1, 0]+0.02, oceAx.get_position().get_points()[0, 1], 0.01, 0.05])
      oceFig.colorbar(tempContour, cax=cax, shrink=0.4, label=f"T(z={round(oceData.z_l.data[upperLayer],1)}) - T(z={round(oceData.z_l.data[lowerLayer],1)})")
    if valid == fcastTimes[0]: oceAx.set_ylabel(f"{round(oceData.z_l.data[upperLayer],1)} m - {round(oceData.z_l.data[lowerLayer],1)} m")

    oceAx.scatter(tcLocs[valid][1], tcLocs[valid][0], marker="*", color="black", s=200, transform=ccrs.PlateCarree())

    gridLabels = []
    if lowerLayer == oceLayers[-1]: gridLabels.append("bottom")
    if valid == fcastTimes[0]: gridLabels.append("left")
    oceAx.gridlines(draw_labels=gridLabels, alpha=0.5)

    oceAx.add_feature(cft.COASTLINE)
    oceAx.add_feature(cft.BORDERS)
    oceAx.set_title(f"T(z={round(oceData.z_l.data[upperLayer],1)}) - T(z={round(oceData.z_l.data[lowerLayer],1)})" + "\n" + valid.strftime('%m-%d %HUTC'), fontsize=20)

    oceAx.set_extent([lonMin, lonMax, latMin, latMax])

#oceFig.suptitle(f"Initialized {initTime.strftime('%Y-%m-%d %HUTC')}")

Ocean Salinity Difference by Level

In [None]:
oceFig = plt.figure(figsize=(len(fcastTimes)*6,len(oceLayers)*7))
oceAxes = [[oceFig.add_axes([xcorner, ycorner, xWidth-buffer, yWidth-buffer], projection=ccrs.PlateCarree()) for ycorner in ycorners[:-1]] for xcorner in xcorners]

contourLevs = {layer:["Empty"] for layer in oceLayers}

for valid, oceAxColumn in zip(fcastTimes, oceAxes):

  fhour = str(int((valid-initTime).total_seconds() / 3600))
  while len(fhour) < 3: fhour = "0" + fhour

  oceFile = "mom6_" + initStr + initHour + "_f" + fhour + ".nc"
  ocePath = dataPath + "mom6Output/" + subfolder + oceFile
  oceData = xr.open_dataset(ocePath, decode_times=False)

  upperLayers = {"top":[0]*len(oceLayers[:-1]), "incremental":oceLayers[:-1]}[diffType]
  for upperLayer, lowerLayer, oceAx in zip(upperLayers, oceLayers[1:], oceAxColumn):

    upperSlice = oceData.isel(z_l=upperLayer, time=0).sel(xh=slice(lonMin, lonMax), yh=slice(latMin, latMax))
    lowerSlice = oceData.isel(z_l=lowerLayer, time=0).sel(xh=slice(lonMin, lonMax), yh=slice(latMin, latMax))

    tdiff = upperSlice.so.data - lowerSlice.so.data

    maxContour = np.max(abs(tdiff[np.isnan(tdiff)==False]))+0.5
    if "Empty" in contourLevs[lowerLayer]: contourLevs[lowerLayer] = np.linspace(-2.5, 2.5, 16) #np.round(np.linspace(-maxContour, maxContour, 20), 2)

    #oceSlice = oceData.isel(z_l=oceLayer).sel(xh=slice(lonMin, lonMax), yh=slice(latMin, latMax), time=oceData.time.data[0]

    tempContour = oceAx.contourf(lowerSlice.xh.data, lowerSlice.yh.data, tdiff, cmap="PRGn", transform=ccrs.PlateCarree(),
                                extent=[lonMin, lonMax, latMin, latMax], levels=contourLevs[lowerLayer])
    if valid == fcastTimes[-1]:
      cax = oceFig.add_axes([oceAx.get_position().get_points()[1, 0]+0.02, oceAx.get_position().get_points()[0, 1], 0.01, 0.05])
      oceFig.colorbar(tempContour, cax=cax, shrink=0.4, label="$\Delta$psu")
    if valid == fcastTimes[0]: oceAx.set_ylabel(f"{round(oceData.z_l.data[upperLayer],1)} m - {round(oceData.z_l.data[lowerLayer],1)} m")

    oceAx.scatter(tcLocs[valid][1], tcLocs[valid][0], marker="*", color="black", s=200, transform=ccrs.PlateCarree())

    gridLabels = []
    if lowerLayer == oceLayers[-1]: gridLabels.append("bottom")
    if valid == fcastTimes[0]: gridLabels.append("left")
    oceAx.gridlines(draw_labels=gridLabels, alpha=0.5)

    oceAx.add_feature(cft.COASTLINE)
    oceAx.add_feature(cft.BORDERS)
    oceAx.set_title(f"S(z={round(oceData.z_l.data[upperLayer],1)}) - S(z={round(oceData.z_l.data[lowerLayer],1)})" + "\n" + valid.strftime('%m-%d %HUTC'), fontsize=20)

    oceAx.set_extent([lonMin, lonMax, latMin, latMax])

#oceFig.suptitle(f"Initialized {initTime.strftime('%Y-%m-%d %HUTC')}")

Map of MLD

In [None]:
oceFig = plt.figure(figsize=(len(fcastTimes)*5, 6))
oceAxes = [oceFig.add_axes([xcorner, 0.1, xWidth-buffer, 0.8], projection=ccrs.PlateCarree()) for xcorner in xcorners]

for valid, oceAx in zip(fcastTimes, oceAxes):

  fhour = str(int((valid-initTime).total_seconds() / 3600))
  while len(fhour) < 3: fhour = "0" + fhour

  oceFile = "mom6_" + initStr + initHour + "_f" + fhour + ".nc"
  ocePath = dataPath + "mom6Output/" + subfolder + oceFile
  oceData = xr.open_dataset(ocePath, decode_times=False)

  oceSlice = oceData.isel(z_l=oceLayer).sel(xh=slice(lonMin, lonMax), yh=slice(latMin, latMax), time=oceData.time.data[0])

  mld = oceSlice.MLD_003.data
  contourLevs = np.linspace(1, 81, 17) #np.round(np.linspace(min(mld[np.isnan(mld)==False])-1, max(mld[np.isnan(mld)==False])+1, 15))

  tempContour = oceAx.contourf(oceSlice.xh.data, oceSlice.yh.data, mld, cmap="cividis_r", transform=ccrs.PlateCarree(),
                              extent=[lonMin, lonMax, latMin, latMax], levels=contourLevs)
  if valid == fcastTimes[-1]:
    cax = oceFig.add_axes([oceAx.get_position().get_points()[1, 0]+0.02, oceAx.get_position().get_points()[0, 1], 0.01, 0.3])
    oceFig.colorbar(tempContour, cax=cax, shrink=0.5, label="MLD (m)")

  oceAx.scatter(tcLocs[valid][1], tcLocs[valid][0], marker="*", color="black", s=100, transform=ccrs.PlateCarree())

  gridLabels = []
  if oceLayer == oceLayers[-1]: gridLabels.append("bottom")
  if valid == fcastTimes[0]: gridLabels.append("left")
  oceAx.gridlines(draw_labels=gridLabels, alpha=0.5)

  oceAx.add_feature(cft.COASTLINE)
  oceAx.add_feature(cft.BORDERS)
  oceAx.set_title(f"{valid.strftime('%Y-%m-%d %HUTC')}")

  oceAx.set_extent([lonMin, lonMax, latMin, latMax])

Map of ocean depth (Assumes first nan value is bottom of the ocean)

In [None]:
oceSlice = oceData.sel(xh=slice(lonMin, lonMax), yh=slice(latMin, latMax), time=oceData.time.data[0])
nanDepths = np.zeros(shape=oceSlice.SST.data.shape)

for i in range(len(oceSlice.yh.data)):
  for j in range(len(oceSlice.xh.data)):
    nanTemps = np.isnan(oceSlice.isel(xh=j, yh=i).temp.data)
    if np.sum(nanTemps) == 0: nanDepths[i, j] = max(oceSlice.z_l.data)
    else: nanDepths[i, j] = np.min(oceSlice.z_l.data[nanTemps])

In [None]:
fig = plt.figure(figsize=(6, 4))
ax = fig.add_axes([0.1,0.1,0.8,0.8], projection=ccrs.PlateCarree())
contour = ax.contourf(oceSlice.xh.data, oceSlice.yh.data, nanDepths, transform=ccrs.PlateCarree(), cmap="plasma_r",
             extent=[lonMin, lonMax, latMin, latMax], extend="both")
ax.coastlines()
fig.colorbar(contour, shrink=0.5, label="Depth (m)")

Plot of Vertical Model Resolution

In [None]:
resFig = plt.figure(figsize=(5, 8))
atmAx = resFig.add_axes([0.1, 0.55, 0.8, 0.4])
oceAx = resFig.add_axes([0.1, 0.05, 0.8, 0.4])

atmAx.plot(np.arange(1, len(atmData.isobaricInhPa.data)+1), atmData.isobaricInhPa.data, label="Atmosphere")
atmAx.scatter(np.arange(1, len(atmData.isobaricInhPa.data)+1), atmData.isobaricInhPa.data)

oceZ = oceData.z_l.data[oceData.z_l.data<=150]
oceAx.plot(np.arange(1, len(oceZ)+1), oceZ, label="Ocean")
oceAx.scatter(np.arange(1, len(oceZ)+1), oceZ)
oceAx.set_xticks(np.arange(1, len(oceZ), 3))

atmAx.grid(alpha=0.5)
atmAx.set_xlabel("# of Levels")
atmAx.set_ylabel("Pressure (hPa)")
atmAx.invert_yaxis()

oceAx.grid(alpha=0.5)
oceAx.set_xlabel("# of Levels")
oceAx.set_ylabel("Depth (m)")
oceAx.invert_yaxis()

#Air-Sea Fluxes

Paramaters for flux figures

In [None]:
atmVars = ["Total Heat Flux", "Latent Heat Flux", "Wind Speed"]
oceVars = []
vars = atmVars + oceVars

varUnits = {"Momentum Flux":"N m$^{-2}$", "Latent Heat Flux":"W m$^{-2}$", "Total Heat Flux":"W m$^{-2}$", "SST":"K", "Wind Speed":"m s$^{-1}$", "MLD":"m"}
cmaps = {"Momentum Flux":"cividis", "Latent Heat Flux":"summer_r", "Total Heat Flux":"cool_r", "SST":"coolwarm", "Wind Speed":"viridis", "MLD":"cividis_r"}
contourLevs = {"Wind Speed":np.linspace(0,80,17), "Total Heat Flux": np.linspace(-3300, 0, 12), "Latent Heat Flux":np.linspace(-2100, 100, 12),
               "Momentum Flux": np.linspace(-12, 0, 13), "SST": np.linspace(300.0, 303.5, 15), "MLD":np.linspace(1, 66, 14)}

buffer = 0.1 / len(vars)
xWidth = 0.8 / len(fcastTimes)
yWidth = 0.8 / len(vars)
xcorners = np.arange(0.1,0.9,xWidth)
ycorners = np.arange(0.1,0.8,yWidth)

mapRadius = 0.8 #in degrees

In [None]:
allData = {valid:{} for valid in allTimes}
allDims = {"atm":{"lat":{}, "lon":{}}, "oce":{"lat":{}, "lon":{}}}

for valid in allTimes:

  fhour = str(int((valid-initTime).total_seconds() / 3600))
  while len(fhour) < 3: fhour = "0" + fhour

  atmFile = "hafsa_" + initStr + initHour + "_f" + fhour + ".nc"
  atmPath = dataPath + "hafsaOutput/" + subfolder + atmFile
  atmData = xr.open_dataset(atmPath, decode_timedelta=False)

  lonMin, lonMax, latMin, latMax = tcLocs[valid][1]-mapRadius, tcLocs[valid][1]+mapRadius, tcLocs[valid][0]-mapRadius, tcLocs[valid][0]+mapRadius
  atmSlice = atmData.sel(longitude=slice(lonMin+360, lonMax+360), latitude=slice(latMin, latMax))
  allData[valid]["Total Heat Flux"] = (atmSlice.slhtf.data + atmSlice.ishf.data + atmSlice.sulwrf.data) * -1
  allData[valid]["Latent Heat Flux"] = atmSlice.slhtf.data * -1
  allData[valid]["Momentum Flux"] = (atmSlice.utaua.data**2 + atmSlice.vtaua.data**2)**0.5
  allData[valid]["SST"] = atmSlice.sst.data
  allData[valid]["Wind Speed"] = (atmSlice.u.sel(isobaricInhPa=1000.0).data**2 + atmSlice.v.sel(isobaricInhPa=1000.0).data**2)**0.5

  if len(oceVars) != 0:
    oceFile = "mom6_" + initStr + initHour + "_f" + fhour + ".nc"
    ocePath = dataPath + "mom6Output/" + subfolder + oceFile
    oceData = xr.open_dataset(ocePath, decode_times=False)
    oceSlice = oceData.sel(xh=slice(lonMin, lonMax), yh=slice(latMin, latMax), time=oceData.time.data[0])
    allData[valid]["MLD"] = oceSlice.MLD_003.data #Using rho=0.03 as criteria

  allDims["atm"]["lat"][valid] = atmSlice.latitude.data
  allDims["atm"]["lon"][valid] = atmSlice.longitude.data
  if len(oceVars) != 0:
    allDims["oce"]["lat"][valid] = oceSlice.yh.data
    allDims["oce"]["lon"][valid] = oceSlice.xh.data

In [None]:
atmFig = plt.figure(figsize=(len(fcastTimes)*4,len(vars)*4))
atmAxes = [[atmFig.add_axes([xcorner, ycorner, xWidth-buffer, yWidth-buffer], projection=ccrs.PlateCarree()) for ycorner in ycorners] for xcorner in xcorners]

for valid, atmAxColumn in zip(fcastTimes, atmAxes):
  for var, atmAx in zip(vars, atmAxColumn):

    data = allData[valid][var]
    if var in atmVars: contour = atmAx.contourf(allDims["atm"]["lon"][valid], allDims["atm"]["lat"][valid], data, cmap=cmaps[var], transform=ccrs.PlateCarree(), levels=contourLevs[var])
    if var in oceVars: contour = atmAx.contourf(allDims["oce"]["lon"][valid], allDims["oce"]["lat"][valid], data, cmap=cmaps[var], transform=ccrs.PlateCarree(), levels=contourLevs[var])

    labelSides = []
    if var == vars[-0]: labelSides.append("bottom")
    if valid == fcastTimes[0]: labelSides.append("left")
    atmAx.gridlines(draw_labels=labelSides, alpha=0.5)

    atmAx.set_title(valid.strftime('%m-%d %HUTC'), fontsize=16)
    atmAx.add_feature(cft.COASTLINE)
    atmAx.add_feature(cft.BORDERS)

    cbarTicks = [contourLevs[var][i] for i in range(len(contourLevs[var])) if i%2 == 0]
    if valid == fcastTimes[-1]:
      cax = atmFig.add_axes([atmAx.get_position().get_points()[1, 0]+0.03, atmAx.get_position().get_points()[0, 1]+0.02, 0.01, 0.14])
      cbar = atmFig.colorbar(contour, cax=cax, ticks=cbarTicks)
      cbar.set_label(label=f"{var} ({varUnits[var]})", fontsize=16)

    lonMin, lonMax, latMin, latMax = tcLocs[valid][1]-mapRadius, tcLocs[valid][1]+mapRadius, tcLocs[valid][0]-mapRadius, tcLocs[valid][0]+mapRadius
    atmAx.set_extent([lonMin, lonMax, latMin, latMax])

In [None]:
# Cannot currently calculate correlation for ocean variables due to differing resolutions
# Get the shape of the data arrays (assuming all variables have the same spatial dimensions)
lat_dim = len(allDims["atm"]["lat"][fcastTimes[0]])
lon_dim = len(allDims["atm"]["lon"][fcastTimes[0]])
#lon_dim = allData[fcastTimes[0]]["Total Heat Flux"].shape

# Initialize a dictionary to store correlation coefficients
correlation_data = {}
correlation_pairs = [
    ("Total Heat Flux", "SST"),
    ("Latent Heat Flux", "SST"),
    ("Total Heat Flux", "Wind Speed"),
    ("Latent Heat Flux", "Wind Speed")]

for fluxVar, envVar in correlation_pairs:
    correlation_data[fluxVar.lower()+envVar.lower()] = np.full((lat_dim, lon_dim), np.nan)

# Iterate through each spatial point
for i in range(lat_dim):
    for j in range(lon_dim):
        # Extract time series data for the current spatial point
        time_series_data = {}
        for var in set([pair[0] for pair in correlation_pairs] + [pair[1] for pair in correlation_pairs]):
            time_series_data[var] = np.array([allData[valid][var][i, j] for valid in allTimes])

        # Remove NaN values from the time series
        valid_indices = ~np.isnan(time_series_data[correlation_pairs[0][0]]) # Use one of the variables to check for valid indices
        for var in time_series_data:
            valid_indices = valid_indices & ~np.isnan(time_series_data[var])

        # Calculate correlation only if there are enough valid data points (at least 2 for correlation)
        for fluxVar, envVar in correlation_pairs:
            ts1_valid = time_series_data[fluxVar][valid_indices]
            ts2_valid = time_series_data[envVar][valid_indices]
            if len(ts1_valid) >= 2 and len(ts2_valid) >= 2:
                correlation_data[fluxVar.lower()+envVar.lower()][i, j] = np.corrcoef(ts1_valid, ts2_valid)[0, 1] * -1

In [None]:
envVar = "wind speed"

for valid, axColumn in zip(fcastTimes, atmAxes):
  for var, ax in zip(vars, axColumn):
    if var == "Total Heat Flux":
      c = ax.contour(allDims["atm"]["lon"][valid], allDims["atm"]["lat"][valid], correlation_data[var.lower()+envVar.lower()], transform=ccrs.PlateCarree(), levels=np.arange(0.3, 1, 0.3), colors=["blue", "purple", "red"])
      ax.clabel(c)
    if var == "Latent Heat Flux":
      c = ax.contour(allDims["atm"]["lon"][valid], allDims["atm"]["lat"][valid], correlation_data[var.lower()+envVar.lower()], transform=ccrs.PlateCarree(), levels=np.arange(0.3, 1, 0.3), colors=["blue", "purple", "red"])
      ax.clabel(c)

    if radType == "wind":
      rect = patches.Rectangle((tcLocs[valid][1]-(windRads[valid]*0.5), tcLocs[valid][0]-(windRads[valid]*0.5)), windRads[valid], windRads[valid], linewidth=1, edgecolor='k', facecolor='none')
      ax.add_patch(rect)
    if radType == "dist":
      rect = patches.Rectangle((tcLocs[valid][1]-(dist*0.5), tcLocs[valid][0]-(dist*0.5)), dist, dist, linewidth=1, edgecolor='k', facecolor='none')
      ax.add_patch(rect)

plt.xticks(fontsize=14)
plt.yticks(fontsize=14)

In [None]:
atmFig