<a href="https://colab.research.google.com/github/SeanBarnier/HAFS_Air-Sea/blob/main/fluxes.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

#Set up environment

In [None]:
!pip install cfgrib
!pip install seawater

In [None]:
import xarray as xr
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from datetime import datetime as dt
import cfgrib
import seawater

In [None]:
from google.colab import drive
drive.mount('/content/drive')

#User parameters

In [None]:
name = "Milton"
tcNum = "14"
trackType = ""

initTime = dt(year=2024, month=10, day=7, hour=12) #Time when Milton began its most rapid intensification

fHourStep = 3       #Normally 3 for HAFS-A
forecastLength = 48 #Normally 126 for HAFS-A.

figureSuffix = "_RI"
subfolder = "RI/"
dataPath = "/content/drive/MyDrive/savedData/"
figurePath = "/content/drive/MyDrive/figures/"

stormCentered = True
potentialTemp = True #Use atmospheric potential temperature instead of in-situ temperature

atmTop = 850
oceFloor = 100

avgType = 'NE'
areaRange = 0.5 #in degrees

#Retrieve HAFS-A Data

Find times needed

In [None]:
dateFormat = "%Y-%m-%d %H:%M:%S"
runFormat = "%Y%m%d%H"

fcastTimes = [] #Key: initiation, item: valid time list
fhour = 0
validTime = initTime

while fhour <= forecastLength:
    fcastTimes.append(validTime)
    validTime += pd.Timedelta(hours=fHourStep)
    fhour += fHourStep

Find storm location in HAFS-A from ATCF files.

In [None]:
cols = ["BASIN", "CY", "YYYYMMDDHH", "TECHNUM/MIN", "TECH", "TAU", "LatN/S", "LonE/W",
    "VMAX", "MSLP", "TY", "RAD", "WINDCODE", "RAD1", "RAD2", "RAD3", "RAD4",
    "POUTER", "ROUTER", "RMW", "GUSTS", "EYE", "SUBREGION", "MAXSEAS", "INITIALS",
    "DIR", "SPEED", "STORMNAME", "DEPTH", "SEAS", "SEASCODE", "SEAS1", "SEAS2",
    "SEAS3", "SEAS4", "USERDEFINED1", "Thermo1", "Thermo2", "Thermo3", "Thermo4",
    "Thermo5", "Thermo6", "Thermo7", "USERDEFINED2", "DT", "SHR82", "SHR81_1",
    "SHR82_2",  "USERDEFINED3", "SST", "USERDEFINED4", "ARMW1", "ARMW2"]

initStr, initHour = initTime.strftime("%Y%m%d_%H").split("_")

atcfURL = f"https://noaa-nws-hafs-pds.s3.amazonaws.com/hfsa/{initStr}/{initHour}/{tcNum}l.{initStr}{initHour}.hfsa.trak.atcfunix"
atcfFile = "atcf_" + initStr + "_" + initHour + ".csv"

!wget -O {atcfFile} {atcfURL}
atcf = pd.read_csv(atcfFile, names=cols)

In [None]:
tcLocs = {}
vmax = {}
mslp = {}

for valid in fcastTimes:
    fHour = int((valid-initTime).total_seconds() / 3600)

    pointLat = int(atcf[atcf.TAU==fHour]["LatN/S"].iloc[0].replace("N", ""))/10
    pointLon = int(atcf[atcf.TAU==fHour]["LonE/W"].iloc[0].replace("W", ""))/-10 #Assume western hemisphere
    tcLocs[valid] = (pointLat, pointLon)
    vmax[valid] = atcf[atcf.TAU==fHour]["VMAX"].iloc[0]
    mslp[valid] = atcf[atcf.TAU==fHour]["MSLP"].iloc[0]

Get data from HAFS-A output.

In [None]:
fcastTimes

In [None]:
atm = {}

for valid in fcastTimes:

  initStr = initTime.strftime("%Y%m%d%H")

  atm[valid] = {}

  fhour = str(int((valid-initTime).total_seconds() / 3600))
  while len(fhour) < 3: fhour = "0" + fhour

  atmFile = "hafsa_" + initStr + "_f" + fhour + ".nc"
  atmPath = dataPath + "hafsaOutput/" + subfolder + atmFile
  atmData = xr.open_dataset(atmPath)

  point = tcLocs[valid]
  levels = atmData.isobaricInhPa.data[atmData.isobaricInhPa.data>atmTop]

  #if avgType == "wind": areaRange = (2**-0.5) * () #Convert from n mi to degrees latitude. Does this work for degrees longitude?
  area = {"NE":[point[0], point[1], point[0]+areaRange, point[1]+areaRange], "SE":[point[0]-areaRange, point[1], point[0], point[1]+areaRange],
          "SW":[point[0]-areaRange, point[1]-areaRange, point[0], point[1],], "NW":[point[0], point[1]-areaRange, point[0]+areaRange, point[1]],
          "centered":[point[0]-(0.5*areaRange), point[1]-(0.5*areaRange), point[0]+(0.5*areaRange), point[1]+(0.5*areaRange)], "point":None}[avgType]

  for level in levels:
    #Longitude in atm files are in degrees east, but are -180 - 180 in oce files. point has them from -180 - 180
    atm[valid][level] = {}

    if avgType == "point": validPoint = atmData.sel(latitude=point[0], longitude=point[1] + 360, isobaricInhPa=level, method="nearest")
    else: validPoint = atmData.sel(latitude=slice(area[0], area[2]), longitude=slice(area[1]+360, area[3]+360), isobaricInhPa=level)

    atm[valid][level]["T"] = np.mean(validPoint.t.data)
    atm[valid][level]["q"] = np.mean(validPoint.q.data) * 1000 #Convert from kg/kg to g/kg
    atm[valid][level]["u"] = np.mean(validPoint.u.data)
    atm[valid][level]["v"] = np.mean(validPoint.v.data)
    atm[valid][level]["gh"] = np.mean(validPoint.gh.data)
    atm[valid][level]["sst"] = np.mean(validPoint.sst.data)
    atm[valid][level]["shf"] = np.mean(validPoint.ishf.data)
    atm[valid][level]["lhf"] = np.mean(validPoint.slhtf.data)
    atm[valid][level]["tau"] = (np.mean(validPoint.utaua.data)**2 + np.mean(validPoint.vtaua.data)**2)**0.5

In [None]:
oce = {}

for valid in fcastTimes:

  fhour = str(int((valid-initTime).total_seconds() / 3600))
  while len(fhour) < 3: fhour = "0" + fhour
  oceFile = "mom6_" + initStr + "_f" + fhour + ".nc"
  ocePath = dataPath + "mom6Output/" + subfolder + oceFile

  if oceFile == 'mom6_2024100800_f000.nc': #This file is missing
    for layer in oceData.z_l.data:
      oce[initTime][valid][layer] = {"T":np.nan, "s":np.nan, "u":np.nan, "v":np.nan}
    continue

  oceData = xr.open_dataset(ocePath, decode_times=False)

  point = tcLocs[valid]
  levels = oceData.z_l.data[oceData.z_l.data<=oceFloor] #Retrieves most shallow layer; should be 1 m

  oce[valid] = {}

  for level in levels:

    if avgType == "point":
      validPoint = oceData.sel(yh=point[0], yq=point[0], xq=point[1], xh=point[1], z_l=level, method="nearest")
      sfcPoint = oceData.sel(yh=point[0], yq=point[0], xq=point[1], xh=point[1], method="nearest")
    else:
      validPoint = oceData.sel(yq=slice(area[0], area[2]), yh=slice(area[0], area[2]), xq=slice(area[1], area[3]), xh=slice(area[1], area[3]), z_l=level)
      sfcPoint = oceData.sel(yq=slice(area[0], area[2]), yh=slice(area[0], area[2]), xq=slice(area[1], area[3]), xh=slice(area[1], area[3]))

    oce[valid][level] = {}
    oce[valid][level]["T"] = np.mean(validPoint.temp.data) + 273.15 #This is potential temperature. Converted from C to K.
    oce[valid][level]["s"] = np.mean(validPoint.so.data)
    oce[valid][level]["u"] = np.mean(validPoint.uo.data)
    oce[valid][level]["v"] = np.mean(validPoint.vo.data)

    oce[valid][level]["sst"] = np.mean(sfcPoint.SST.data) + 273.15
    oce[valid][level]["ssh"] = np.mean(sfcPoint.SSH.data)
    oce[valid][level]["shf"] = np.mean(sfcPoint.sensible.data)
    oce[valid][level]["lhf"] = np.mean(sfcPoint.latent.data)
    #I think this is wrong
    oce[valid][level]["tau"] = float(seawater.dens0(np.mean(sfcPoint.SSS.data), np.mean(sfcPoint.SST.data)) * (np.mean(sfcPoint.SSU.data)**2 + np.mean(sfcPoint.SSV.data)**2))

#Figures

Goal: Correlate atmospheric sensible and latent heat flues and momentum fluxes with intensity forecasts and error across runs.  
Particularly, look at the wind-caused and current-caused momentum flux.

For 1-m ocean temp

MAKE THIS SUCK LESS

In [None]:
fluxFig = plt.figure(figsize=(8,5))
tAxshf = fluxFig.add_axes([0.1,0.65,0.8,0.25])
shfAx = tAxshf.twinx()

tDiff = [oce[valid][1.0]["T"] - atm[valid][1000.0]["T"] for valid in atm.keys()]
shf = [atm[valid][1000.0]["shf"] for valid in atm.keys()]
tAxshf.plot(atm.keys(), tDiff, label="T(1000 hPa) - T(1 m)", color="blue")
shfAx.plot(atm.keys(), shf, label="Sensible Heat Flux", color="darkorange")

tAxshf.grid(alpha=0.5)
tAxshf.set_xticklabels([])
tAxshf.set_title(f"Correlation = {str(round(np.corrcoef(tDiff, shf)[0,1], 2))}")

tAxlhf = fluxFig.add_axes([0.1,0.35,0.8,0.25])
lhfAx = tAxlhf.twinx()

lhf = [atm[valid][1000.0]["lhf"] for valid in atm.keys()]
tAxlhf.plot(atm.keys(), tDiff, color="blue")
lhfAx.plot(atm.keys(), lhf, label="Latent Heat Flux", color="green")

tAxlhf.grid(alpha=0.5)
tAxlhf.set_xticklabels([])
tAxlhf.set_title(f"Correlation = {str(round(np.corrcoef(tDiff, lhf)[0,1], 2))}")

tAxFlux = fluxFig.add_axes([0.1,0.05,0.8,0.25])
fluxAx = tAxFlux.twinx()

flux = [atm[valid][1000.0]["lhf"]+atm[valid][1000.0]["shf"] for valid in atm.keys()]
tAxFlux.plot(atm.keys(), tDiff, color="blue")
fluxAx.plot(atm.keys(), flux, label="Latent + Sensible Heat Flux", color="magenta")

tAxFlux.grid(alpha=0.5)
tAxFlux.set_title(f"Correlation = {str(round(np.corrcoef(tDiff, flux)[0,1], 2))}")

fluxFig.legend()
fluxFig.supylabel("Atmosphere - Ocean $\Delta$T (K)")
fluxFig.supylabel("Heat Flux (W/m" + "$^2$" + ")", ha="right")

For SST

In [None]:
fluxFig = plt.figure(figsize=(8,5))
tAxshf = fluxFig.add_axes([0.1,0.65,0.8,0.25])
shfAx = tAxshf.twinx()

tDiff = [atm[valid][1000.0]["sst"] - atm[valid][1000.0]["T"] for valid in atm.keys()]
shf = [atm[valid][1000.0]["shf"] for valid in atm.keys()]
tAxshf.plot(atm.keys(), tDiff, label="T(1000 hPa) - SST", color="blue")
shfAx.plot(atm.keys(), shf, label="Sensible Heat Flux", color="darkorange")

tAxshf.grid(alpha=0.5)
tAxshf.set_xticklabels([])
tAxshf.set_title(f"Correlation = {str(round(np.corrcoef(tDiff, shf)[0,1], 2))}")

tAxlhf = fluxFig.add_axes([0.1,0.35,0.8,0.25])
lhfAx = tAxlhf.twinx()

lhf = [atm[valid][1000.0]["lhf"] for valid in atm.keys()]
tAxlhf.plot(atm.keys(), tDiff, color="blue")
lhfAx.plot(atm.keys(), lhf, label="Latent Heat Flux", color="green")

tAxlhf.grid(alpha=0.5)
tAxlhf.set_xticklabels([])
tAxlhf.set_title(f"Correlation = {str(round(np.corrcoef(tDiff, lhf)[0,1], 2))}")

tAxFlux = fluxFig.add_axes([0.1,0.05,0.8,0.25])
fluxAx = tAxFlux.twinx()

flux = [atm[valid][1000.0]["lhf"]+atm[valid][1000.0]["shf"] for valid in atm.keys()]
tAxFlux.plot(atm.keys(), tDiff, color="blue")
fluxAx.plot(atm.keys(), flux, label="Latent + Sensible Heat Flux", color="magenta")

tAxFlux.grid(alpha=0.5)
tAxFlux.set_title(f"Correlation = {str(round(np.corrcoef(tDiff, flux)[0,1], 2))}")

fluxFig.legend()
fluxFig.supylabel("Atmosphere - Sea Surface $\Delta$T (K)", ha="left")
fluxFig.supylabel("Heat Flux (W/m" + "$^2$" + ")", va="bottom")

Profiles and SST

In [None]:
times = [key for key in atm.keys() if key.hour%6 == 0]
xwidth = (0.60/len(times))
xcorners = np.arange(0.1, 0.9, 0.8/len(times))

In [None]:
profFig = plt.figure(figsize=(14,4))
profAxes = [profFig.add_axes([xcorner,0.1,xwidth,0.8]) for xcorner in xcorners]

for ax, valid in zip(profAxes, times):
  if valid.hour%6 != 0 : continue
  t = [atm[valid][level]["T"] for level in atm[valid].keys()]
  p = atm[valid].keys()
  ax.plot(t, p, color="blue", label="Atm. Temp.")
  ax.scatter(atm[valid][1000.0]["sst"], 1000, color="red", s=50, label="SST")

  ax.set_xlim(290, 310)
  ax.invert_yaxis()
  grid = ax.grid(alpha=0.5)
  ax.set_title(valid.strftime("%m-%d %HUTC"))

  if valid == fcastTimes[0]: ax.legend(ncols=2, loc=(0, -0.15))
  else: ax.set_yticklabels([])

profFig.supxlabel("Temperature (K)")
profFig.supylabel("Pressure (hPa)")

Look at evolution of temperature and flux

In [None]:
fluxFig = plt.figure(figsize=(8,5))

t = list(atm.keys())
atmT = [atm[valid][1000.0]["T"] - atm[past][1000.0]["T"] for (past, valid) in zip(t[:-1], t[1:])]
sst = [atm[valid][1000.0]["sst"] - atm[past][1000.0]["sst"] for (past, valid) in zip(t[:-1], t[1:])]
oceT = [oce[valid][1.0]["T"] - oce[past][1.0]["T"] for (past, valid) in zip(t[:-1], t[1:])]
shf = [atm[valid][1000.0]["shf"] for valid in t[1:]]
lhf = [atm[valid][1000.0]["lhf"] for valid in t[1:]]
flux = [atm[valid][1000.0]["lhf"]+atm[valid][1000.0]["shf"] for valid in t[1:]]

tAxshf = fluxFig.add_axes([0.1,0.65,0.8,0.25])
shfAx = tAxshf.twinx()
tAxshf.plot(t[1:], atmT, label="$\Delta$T(1000 hPa)", color="red")
tAxshf.plot(t[1:], oceT, label="$\Delta$T(1 m)", color="blue")
tAxshf.plot(t[1:], sst, label="$\Delta$SST", color="purple")
shfAx.plot(t[1:], shf, label="Sensible Heat Flux", color="darkorange")

tAxshf.grid(alpha=0.5)
tAxshf.set_xticklabels([])
#tAxshf.set_title(f"Correlation = {str(round(np.corrcoef(tDiff, shf)[0,1], 2))}")

tAxlhf = fluxFig.add_axes([0.1,0.35,0.8,0.25])
lhfAx = tAxlhf.twinx()
tAxlhf.plot(t[1:], atmT, color="red")
tAxlhf.plot(t[1:], oceT, color="blue")
tAxlhf.plot(t[1:], sst, color="purple")
lhfAx.plot(t[1:], lhf, label="Latent Heat Flux", color="green")

tAxlhf.grid(alpha=0.5)
tAxlhf.set_xticklabels([])
#tAxlhf.set_title(f"Correlation = {str(round(np.corrcoef(tDiff, lhf)[0,1], 2))}")

tAxFlux = fluxFig.add_axes([0.1,0.05,0.8,0.25])
fluxAx = tAxFlux.twinx()

tAxFlux.plot(t[1:], atmT, color="red")
tAxFlux.plot(t[1:], oceT, color="blue")
tAxFlux.plot(t[1:], sst, color="purple")
fluxAx.plot(t[1:], flux, label="Latent + Sensible Heat Flux", color="magenta")

tAxFlux.grid(alpha=0.5)
#tAxFlux.set_title(f"Correlation = {str(round(np.corrcoef(tDiff, flux)[0,1], 2))}")

fluxFig.legend()
fluxFig.supylabel("Atmosphere - Ocean $\Delta$T (K)", ha="left")
fluxFig.supylabel("Heat Flux (W/m" + "$^2$" + ")", ha="right")

In [None]:
np.corrcoef([atmT, oceT, sst, flux])

Fluxes with intensity

In [None]:
# Create a figure with a subplot for each run time
fig, axes = plt.subplots(3, 1, figsize=(8, 6), sharex=True)


shf = [atm[time][1000.0]["shf"] for time in atm.keys()]
lhf = [atm[time][1000.0]["lhf"] for time in atm.keys()]
flux = [atm[time][1000.0]["lhf"]+atm[time][1000.0]["shf"] for time in atm.keys()]
intensity = [vmax[time] for time in atm.keys()] # Get intensity values


ax0 = axes[0].twinx()
ax1 = axes[1].twinx()
ax2 = axes[2].twinx()

axes[0].plot(atm.keys(), shf, label='Sensible Heat Flux', color="darkorange")
axes[1].plot(atm.keys(), lhf, label='Latent Heat Flux', color="green")
axes[2].plot(atm.keys(), flux, label='Latent + Sensible Heat Flux', color="magenta")
ax0.plot(atm.keys(), intensity, label='Intensity', color="black")
ax1.plot(atm.keys(), intensity, color="black")
ax2.plot(atm.keys(), intensity, color="black")

ax0.set_ylabel("Intensity (kt)")
ax1.set_ylabel("Intensity (kt)")
ax2.set_ylabel("Intensity (kt)")
axes[0].grid(alpha=0.5)
axes[1].grid(alpha=0.5)
axes[2].grid(alpha=0.5)
ax0.set_title("Sensible Heat Flux")
ax1.set_title("Latent Heat Flux")
ax2.set_title("Latent + Sensible Heat Flux")

# Set common x-label for the last subplot
axes[-1].set_xlabel("Time")
fig.supylabel("Heat Flux (W/m" + "$^2$" + ")")
fig.legend(loc=[0.05, 0.96], ncols=4)
plt.xticks(rotation=45)
plt.tight_layout()
plt.show()

Fluxes with $\Delta$Intensity

In [None]:
# Create a figure with a subplot for each run time
fig, axes = plt.subplots(4, 1, figsize=(8, 6), sharex=True)

validTimes = list(atm.keys())[1:]
pastTimes = list(atm.keys())[:-1]

shf = [atm[time][1000.0]["shf"] for time in validTimes]
lhf = [atm[time][1000.0]["lhf"] for time in validTimes]
flux = [atm[time][1000.0]["lhf"]+atm[time][1000.0]["shf"] for time in validTimes]
mf = [atm[time][1000.0]["tau"] for time in validTimes]
dIntensity = [vmax[validTime]-vmax[pastTime] for validTime, pastTime in zip(validTimes, pastTimes)] # Get intensity values

labels = ["Sensible Heat Flux", "Latent Heat Flux", "Latent + Sensible Heat Flux", "Wind Momentum Flux"]
colors = ["darkorange", "green", "magenta", "red"]

for ax, flux, label, color in zip(axes, [shf, lhf, flux, mf], labels, colors):

  intAx = ax.twinx()
  ax.plot(validTimes, flux, label=label, color=color)
  intAx.plot(validTimes, dIntensity, color="black")

  intAx.set_ylabel("Intensity (kt)")
  ax.set_ylabel("Heat Flux (W/m" + "$^2$" + ")")
  ax.grid(alpha=0.5)
  ax.set_title(f"Correlation = {str(round(np.corrcoef(dIntensity, flux)[0,1], 2))}")

# Set common x-label for the last subplot
axes[-1].set_xlabel("Time")
axes[-1].set_ylabel("Mom. Flux (N/m" + "$^2$" + ")")


fig.legend(loc=[0.05, 0.96], ncols=4)
plt.xticks(rotation=45)
plt.tight_layout()
plt.show()

Momentum flux is less corellated in this quadrant