<details>
<summary>Variables</summary>
<br>
Data point

<ul>
<li>longitude - horizontal coordinates: 0.0658</li>
<li>latitude - vertical coordinates: 0.0341</li>
<li>time - datetime as timestamp</li>

Pressure level (from upper troposphere, primary data)
<li>d - divergence: 0.0208</li>
<li>cc - fraction of cloud cover - %: 0.1779</li>
<li>z - geopotential - m/s?: 0.5817</li>
<li>o3 - ozone mass mixing ratio - kg: -0.3365</li>
<li>pv - potential vorticity: 0.0709</li>
<li>r - relative humidity - %: 0.2829</li>
<li>ciwc - cloud ice water content - kg: 0.4201</li>

This is what I have to predict
<li>q - specific humidity - kg</li>
<li>cswc - cloud snow water content - kg: 0.4252</li>
<li>t - temperature - Kelvin: 0.6219</li>
<li>u - eastward wind - m/s: -0.0637</li>
<li>v - northward wind - m/s: 0.0742</li>
<li>w - vertical velocity - Pa/s: -0.2006</li>
<li>vo - vorticity: -0.0122</li>

Surface level (secondary data)
<li>u10 - 10m eastward wind - m/s: -0.0416</li>
<li>v10 - 10m northward wind - m/s: -0.0398</li>
<li>d2m - 2m dewpoint temperature - Kelvin: 0.5117</li>
<li>fal - Forecast albedo - 0-1: -0.4423</li>
<li>lai_hv - Leaf area index high vegetation: 0.4315</li>
<li>lai_lv - Leaf area index low vegetation: 0.2420</li>
<li>pev - Potential evaporation - m: -0.1723</li>
<li>ro - Runoff - m: 0.3770</li>
<li>skt - Skin temperature - Kelvin: 0.4425</li>
<li>ssr - Surface net solar radiation - Joules/m^2: 0.1212</li>
<li>sp - Surface pressure - Pa: 0.2366</li>
<li>e - Total evaporation - m: -0.3412</li>
<li>tp - Total precipitation - m: 0.3490</li>
</ul>
</details>

In [None]:
#setup
!pip install pycontrails

import pandas as pd
import xarray as xr
import numpy as np
import pycontrails
import seaborn as sns
import matplotlib.pyplot as plt
import importlib.util
import gc

from pycontrails import Flight
from pycontrails.models.cocip import Cocip
from pycontrails import MetDataset, MetDataArray, MetVariable
from pycontrails.models.sac import SAC
from pycontrails.models.issr import ISSR

from sklearn.model_selection import train_test_split, cross_val_score
from sklearn.metrics import mean_squared_error, mean_absolute_error
from sklearn.inspection import permutation_importance
from sklearn.ensemble import RandomForestRegressor

from google.colab import drive
drive.mount('/content/gdrive/', force_remount=True)
dataFrame = None

Collecting pycontrails
  Downloading pycontrails-0.50.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.4 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.4/1.4 MB[0m [31m5.8 MB/s[0m eta [36m0:00:00[0m
Collecting overrides>=6.1 (from pycontrails)
  Downloading overrides-7.7.0-py3-none-any.whl (17 kB)
Installing collected packages: overrides, pycontrails
Successfully installed overrides-7.7.0 pycontrails-0.50.0
Mounted at /content/gdrive/


In [None]:
#close vscode CSV viewer before running these functions
def loadCSV(toLoad):
  try:
    if toLoad == 'raw':
      return pd.read_csv('/content/gdrive/MyDrive/Convex/Saved Data/rawDataset.csv')
    elif toLoad == 'processed':
      return pd.read_csv('/content/gdrive/MyDrive/Convex/Saved Data/processedDataset.csv')
    else:
      return pd.read_csv('/content/gdrive/MyDrive/Convex/Saved Data/lastDataset.csv')
  except:
    print('CSV not saved yet')
def updateCSV(toUpdate):
  if toUpdate == 'raw':
    dataFrame.to_csv('/content/gdrive/MyDrive/Convex/Saved Data/rawDataset.csv', index=False)
  elif toUpdate == 'processed':
    dataFrame.to_csv('/content/gdrive/MyDrive/Convex/Saved Data/processedDataset.csv', index=False)
  else:
    dataFrame.to_csv('/content/gdrive/MyDrive/Convex/Saved Data/lastDataset.csv', index=False)


In [None]:
dataFrame = pd.read_csv('/content/gdrive/MyDrive/Convex/Saved Data/1DayProcessed.csv').dropna()

  dataFrame = pd.read_csv('/content/gdrive/MyDrive/Convex/Saved Data/1DayProcessed.csv')


In [None]:
dataFrame

Unnamed: 0,longitude,latitude,time,u10,v10,d2m,fal,lai_hv,lai_lv,pev,...,zRate,o3Rate,pvRate,rRate,qRate,tRate,uRate,vRate,wRate,voRate
0,0.0,53.5,2022-04-01 9:00,-2.077039,-6.908855,271.59850,0.132501,2.461048,2.195530,-0.001338,...,0.025772,-0.220136,-0.359770,2.878050,3.957481,0.008970,2.041899,0.974710,-0.186826,14.396083
1,0.0,53.0,2022-04-01 9:00,-1.773355,-6.521138,271.71902,0.174736,2.435926,1.764787,-0.001188,...,0.025906,-0.378134,-0.487167,3.603209,3.696060,0.000644,1.331047,0.670987,-0.653069,3.212005
2,0.0,52.5,2022-04-01 9:00,-0.991369,-5.800329,270.89062,0.170515,2.475419,1.644545,-0.001030,...,0.026042,-0.504500,-0.712622,4.950027,4.075321,-0.005760,0.991703,0.515950,-0.478556,0.205993
3,0.0,52.0,2022-04-01 9:00,-1.587004,-7.715757,270.03360,0.192731,0.000000,1.739633,-0.001172,...,0.026114,-0.539149,-0.676435,5.928098,4.377079,-0.009129,0.727864,0.397136,-0.564027,0.690254
4,0.0,51.5,2022-04-01 9:00,-0.996201,-6.828105,271.18735,0.115908,2.393771,1.655431,-0.001189,...,0.026202,-0.602562,-0.709907,5.826116,3.886916,-0.012256,0.268029,0.247983,-18.427923,0.227997
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
88643,359.5,-88.0,2022-04-01 9:00,-3.574753,-5.045676,217.31110,0.849998,0.000000,0.000000,0.000002,...,0.014170,-0.475331,-0.367362,2.707272,0.962927,-0.022267,-1.696121,-0.271042,2.071103,-0.921169
88644,359.5,-88.5,2022-04-01 9:00,-4.048914,-4.713771,216.64687,0.849998,0.000000,0.000000,0.000002,...,0.014734,-0.480531,-0.478450,3.207823,1.172515,-0.023157,-1.280058,-0.377176,0.840897,-1.659065
88645,359.5,-89.0,2022-04-01 9:00,-4.304285,-4.858051,216.16791,0.849998,0.000000,0.000000,0.000002,...,0.015057,-0.465895,-0.477099,3.165981,1.111289,-0.023993,-0.552166,-0.463029,-1.615731,-1.525995
88646,359.5,-89.5,2022-04-01 9:00,-3.841857,-5.656048,216.24020,0.849998,0.000000,0.000000,0.000002,...,0.015012,-0.487575,-0.423617,2.861046,0.883419,-0.024935,-0.007557,-0.402917,-0.221665,-1.257325


In [None]:
dataFrame

Unnamed: 0,longitude,latitude,time,u10,v10,d2m,fal,lai_hv,lai_lv,pev,...,zRate,o3Rate,pvRate,rRate,qRate,tRate,uRate,vRate,wRate,voRate
0,0.0,53.5,2022-04-01 09:00:00,-2.077039,-6.908855,271.59850,0.132501,2.461048,2.195530,-0.001338,...,0.007145,0.153645,-0.078552,0.523635,0.364858,-0.003292,0.387559,-0.576608,-0.404786,-1.380797
1,0.0,53.0,2022-04-01 09:00:00,-1.773355,-6.521138,271.71902,0.174736,2.435926,1.764787,-0.001188,...,0.007161,-0.078840,-0.149359,1.352714,1.020086,-0.005582,0.031765,-0.654341,-0.762656,-0.804109
2,0.0,52.5,2022-04-01 09:00:00,-0.991369,-5.800329,270.89062,0.170515,2.475419,1.644545,-0.001030,...,0.007191,-0.262899,-0.211601,2.081790,1.421874,-0.008502,-0.285051,-0.707204,-0.751774,-0.663476
3,0.0,52.0,2022-04-01 09:00:00,-1.587004,-7.715757,270.03360,0.192731,0.000000,1.739633,-0.001172,...,0.007366,-0.387931,-0.252241,2.204835,1.385980,-0.010617,-0.495482,-0.724470,-0.495318,-0.624469
4,0.0,51.5,2022-04-01 09:00:00,-0.996201,-6.828105,271.18735,0.115908,2.393771,1.655431,-0.001189,...,0.007699,-0.440073,-0.253728,1.915031,1.139167,-0.011330,-0.627971,-0.702734,8.268152,-0.565353
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
112429,,,,,,,,,,,...,0.013096,-0.518735,-0.316010,2.833350,1.168783,-0.020003,0.026426,1.456440,-12.288445,-0.363797
112430,,,,,,,,,,,...,0.012716,-0.521924,-0.331074,2.752553,1.026117,-0.021721,0.746422,1.056494,29.709023,-0.786882
112431,,,,,,,,,,,...,0.012264,-0.533832,-0.350177,2.683497,0.890316,-0.023486,1.316703,0.718512,5.695881,-1.055768
112432,,,,,,,,,,,...,0.011746,-0.548515,-0.381646,2.665466,0.819057,-0.024574,1.702679,0.457941,2.115029,-1.290580


In [None]:
#load raw CSV instead of converting nc to csv again
dataFrame = loadCSV('raw')

In [None]:
#load processed CSV to save time
dataFrame = loadCSV('processed')

In [None]:
dataFrame = pd.read_csv('/content/gdrive/MyDrive/Convex/Saved Data/legacyDataFrame.csv')

In [None]:
features = dataFrame.columns.to_list()
# avoid using rates of change of below values
avoidRate = np.array(['ciwc', 'cswc', 'tp', 'ssr', 'ro', 'cc'])
# avoid using below values completely
avoidAll = np.array(['longitude', 'latitude', 'time', 'lai_hv', 'lai_lv'])

In [None]:
rateValues = pd.DataFrame(columns=(np.char.add((np.setdiff1d(features, np.concatenate((avoidRate, avoidAll)), True)), 'Rate')))
newValues = pd.DataFrame(columns=(np.concatenate(((np.char.add((np.setdiff1d(features, avoidAll, True)), 'New')), ['qValid', 'tValid']))))

In [None]:
dataFrame = pd.merge(xr.open_dataset('/content/gdrive/MyDrive/Convex/Datasets/era5Land.nc').to_dataframe().reset_index(drop=False), xr.open_dataset('/content/gdrive/MyDrive/Convex/Datasets/era5Pressure.nc').to_dataframe().reset_index(drop=False), on=['longitude', 'latitude', 'time'], how='inner').dropna()
dataFrame.drop('clwc', axis=1, inplace=True)
dataFrame.drop('crwc', axis=1, inplace=True)
updateCSV('raw')

KeyboardInterrupt: 

In [None]:
# define date range for rate of change, as well as date to predict

# period for rate of change
ratePeriod = 1
# period for prediction (max 1 week)
predPeriod = 1
dataFrame['time'] = pd.to_datetime(dataFrame['time'])
beforeDataFrame = dataFrame.loc[(dataFrame['time'].dt.day == 1) & (dataFrame['time'].dt.hour != 10)]

In [None]:
#parse dataset, do not run unless if making changes
print(beforeDataFrame)
count = 0

#loop through rows of beforeDataFrame, used as reference
for index, row in beforeDataFrame.iterrows():

    count += 1
    print(count)

    queuedRates = []
    queuedNew = []

    #since script is run for each feature, flag is set to avoid triggering twice per row
    flag = True

    #set long and lat variables
    long = row['longitude']
    lat = row['latitude']
    hour = row['time'].hour

    #& (dataFrame['time'].dt.hour == hour) add this later
    #set lag and valid for row
    lag = dataFrame.loc[(dataFrame['longitude'] == long) & (dataFrame['latitude'] == lat) & (dataFrame['time'].dt.day == (1 + ratePeriod))].iloc[0]
    valid = dataFrame.loc[(dataFrame['longitude'] == long) & (dataFrame['latitude'] == lat) & (dataFrame['time'].dt.day == (1 + ratePeriod + predPeriod))].iloc[0]

    #loop through features
    for feature in dataFrame.columns:
        if flag == True and feature not in avoidRate and feature not in avoidAll:
            # check to prevent divide by zero error (0 value one day, non-0 next day)
            if row[feature] == 0 and lag[feature] != 0 or row[feature] != 0 and lag[feature] == 0:
                beforeDataFrame = beforeDataFrame.drop(row.name)
                flag = False
            else:
                if row[feature] == lag[feature]:
                    rateOfChange = 0
                else:
                    rateOfChange = (lag[feature] - row[feature]) / row[feature]
                queuedRates.append(rateOfChange)
        if feature not in avoidAll:
            queuedNew.append(lag[feature])
    if flag == True:
        queuedNew.append(valid['q'])
        queuedNew.append(valid['t'])
        rateValues.loc[len(rateValues)] = queuedRates
        newValues.loc[len(newValues)] = queuedNew

In [None]:
dataFrame = pd.concat([beforeDataFrame.reset_index(drop=True), newValues.reset_index(drop=True), rateValues.reset_index(drop=True)], axis=1)

In [None]:
dataFrame.to_csv('/content/gdrive/MyDrive/Convex/Saved Data/1DayProcessed.csv', index=False)

In [None]:
updateCSV('processed')

In [None]:
results = pd.DataFrame(columns=['humidityPred', 'humidityValid', 'tempPred', 'tempValid', 'longitude', 'latitude'])

In [None]:
dataFrame = dataFrame.head(5695)

In [None]:
#now train model for temperature
X = dataFrame[['q', 'qNew', 'tNew', 'tRate', 'qRate', 'ciwcNew', 'pevNew', 'ccNew', 'zRate', 'pvRate']]
y = dataFrame[['qValid']]
#will return same data points because of same random_state
X_train,X_test, y_train, y_test = train_test_split(X, y, test_size = 0.3, random_state = 100)

In [None]:
rfrHumidity = RandomForestRegressor(n_estimators = 200, max_features = 10, max_depth = 15).fit(X_train, y_train)
y_pred = rfrHumidity.predict(X_test)

# Evaluate the results
rmse = np.sqrt(mean_squared_error(y_test, y_pred))
mae = mean_absolute_error(y_test, y_pred)
print(f"Humidity RMSE: {rmse}")
print(f"Humidity MAE: {mae}")

Humidity RMSE: 4.858077952800099e-05
Humidity MAE: 2.1109033265595093e-05


In [None]:
results['humidityPred'] = y_pred
results['humidityValid'] = y_test['qValid'].values

ValueError: Length of values (26595) does not match length of index (3000)

In [None]:
#now train model for temperature
X = dataFrame[['q', 'qNew', 'tNew', 'tRate', 'qRate', 'ciwcNew', 'pevNew', 'ccNew', 'zRate', 'pvRate']]
y = dataFrame[['tValid']]
#will return same data points because of same random_state
X_train,X_test, y_train, y_test = train_test_split(X, y, test_size = 0.3, random_state = 100)

In [None]:
rfrTemp = RandomForestRegressor(n_estimators = 200, max_features = 10, max_depth = 15).fit(X_train, y_train)
y_pred = rfrTemp.predict(X_test)

# Evaluate the results
rmse = np.sqrt(mean_squared_error(y_test, y_pred))
mae = mean_absolute_error(y_test, y_pred)
print(f"Temperature RMSE: {rmse}")
print(f"Temperature MAE: {mae}")

  rfrTemp = RandomForestRegressor(n_estimators = 200, max_features = 10, max_depth = 15).fit(X_train, y_train)


Temperature RMSE: 0.8110788779842881
Temperature MAE: 0.4890732725285861


In [None]:
results['tempPred'] = y_pred
results['tempValid'] = y_test['tValid'].values

In [None]:
#load processed CSV again, if performing grid operation on specific area (e.g Great Lakes)
dataFrame = loadCSV('processed')

In [None]:
results

Unnamed: 0,humidityPred,humidityValid,tempPred,tempValid,longitude,latitude
78073,0.000642,0.000432,243.222642,243.14485,302.0,-2.5
47531,0.000022,0.000022,216.443082,216.52554,139.5,75.0
59538,0.000020,0.000019,214.878907,215.20800,237.5,-82.0
5903,0.000362,0.000217,242.272254,242.43245,18.5,-3.0
14452,0.000052,0.000056,241.077310,241.03018,38.0,9.0
...,...,...,...,...,...,...
60049,0.000166,0.000162,228.913920,228.97166,240.0,42.5
19162,0.000017,0.000018,217.772024,218.11917,51.5,67.5
31867,0.000033,0.000033,217.845430,217.53717,93.5,61.0
25067,0.000080,0.000069,223.093546,222.73062,72.0,56.0


In [None]:
#grid operation on local area, don't run otherwise
latRange = [39, 50]
longRange = [265, 285]

#remove first statement if running on all timestamps
dataFrame = dataFrame.loc[(dataFrame['time'].dt.hour == 10) & (dataFrame['longitude'] >= longRange[0]) & (dataFrame['longitude'] <= longRange[1]) & (dataFrame['latitude'] >= latRange[0]) & (dataFrame['latitude'] <= latRange[1])].drop(columns=['time'])

localHumidityPred = rfrHumidity.predict(dataFrame[['q', 'd2mNew', 'falNew', 'pevNew', 'roNew', 'sktNew', 'ssrNew', 'spNew', 'eNew', 'tpNew', 'dNew', 'ccNew', 'zNew', 'o3New', 'pvNew', 'rNew', 'ciwcNew', 'qNew', 'cswcNew', 'tNew', 'uNew', 'vNew', 'wNew', 'voNew', 'lai_hv', 'lai_lv', 'eRate', 'dRate', 'zRate', 'o3Rate', 'pvRate', 'rRate', 'qRate', 'tRate', 'uRate', 'vRate', 'wRate', 'voRate']])
localTempPred = rfrTemp.predict(dataFrame[['t', 'd2mNew', 'falNew', 'pevNew', 'roNew', 'sktNew', 'ssrNew', 'spNew', 'eNew', 'tpNew', 'dNew', 'ccNew', 'zNew', 'o3New', 'pvNew', 'rNew', 'ciwcNew', 'qNew', 'cswcNew', 'tNew', 'uNew', 'vNew', 'wNew', 'voNew', 'lai_hv', 'lai_lv', 'eRate', 'dRate', 'zRate', 'o3Rate', 'pvRate', 'rRate', 'qRate', 'tRate', 'uRate', 'vRate', 'wRate', 'voRate']])

results = pd.DataFrame()
results['humidityPred'] = localHumidityPred
results['humidityValid'] = dataFrame['qValid'].values
results['tempPred'] = localTempPred
results['tempValid'] = dataFrame['tValid'].values
results['longitude'] = dataFrame['longitude'].values
results['latitude'] = dataFrame['latitude'].values

In [None]:
sacPred = []
sacValid = []
issrPred = []
issrValid = []
# select size of subset to measure accuracy on (500-1000 typically sufficent)
processedResults = results.head(1000)

In [None]:
# process the data and feed it through SAC in chunks
while True:
  chunk = processedResults.head(500)

  time = np.full(500, np.datetime64("2022-04-04T09"))

  chunk['level'] = 300
  chunk = chunk.assign(time=time)

  uLong = chunk['longitude'].unique()
  uLat = chunk['latitude'].unique()
  uLevel = chunk['level'].unique()
  uTime = chunk['time'].unique()

  columns = ['humidityPred', 'humidityValid', 'tempPred', 'tempValid']
  pivotedArrays = {}

  for column in columns:
    pivoted = chunk.pivot_table(index=['longitude', 'latitude', 'level', 'time'],
                                values=column,
                                aggfunc='first')

    # Reindex the pivot table to ensure all combinations are present
    index = pd.MultiIndex.from_product([uLong, uLat, uLevel, uTime],
                                       names=['longitude', 'latitude', 'level', 'time'])
    pivoted = pivoted.reindex(index)

    # Convert the pivot table to a 4D array and store it
    pivotedArrays[column] = pivoted.values.reshape(len(uLong), len(uLat), len(uLevel), len(uTime))


  datasetPred = xr.Dataset(
    data_vars = {
        'air_temperature': (['longitude', 'latitude', 'level', 'time'], pivotedArrays['tempPred']),
        'specific_humidity': (['longitude', 'latitude', 'level', 'time'], pivotedArrays['humidityPred']),
    },
    coords = {
        'longitude': uLong,
        'latitude': uLat,
        'level': uLevel,
        'time': uTime,
      }
  )
  print(datasetPred)
  datasetValid = xr.Dataset(
    data_vars = {
        'air_temperature': (['longitude', 'latitude', 'level', 'time'], pivotedArrays['tempValid']),
        'specific_humidity': (['longitude', 'latitude', 'level', 'time'], pivotedArrays['humidityValid']),
    },
    coords = {
        'longitude': uLong,
        'latitude': uLat,
        'level': uLevel,
        'time': uTime,
      }
  )
  datasetPred = MetDataset(datasetPred)
  datasetValid = MetDataset(datasetValid)

  # calculate SAC for predicted data
  sacResults = SAC(met=datasetPred).eval()
  sacPred.append(sacResults['sac'])

  # calculate ISSR for predicted data
  issrResults = ISSR(datasetPred).eval()
  issrPred.append(issrResults['issr'])

  #clean up variables to prevent ram overload
  del sacResults, issrResults, datasetPred, pivotedArrays
  gc.collect()

  # calculate SAC & ISSR for validation data
  sacResults = SAC(met=datasetValid).eval()
  sacValid.append(sacResults['sac'])
  issrResults = ISSR(datasetValid).eval()
  issrValid.append(issrResults['issr'])

  #clean up variables again
  del sacResults, issrResults, datasetValid
  gc.collect()



  if len(processedResults) < 501:
    break
  else:
    print('processed chunk')
    # delete first 500 rows to make way for next 500
    processedResults = processedResults.iloc[500:]

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  chunk['level'] = 300


<xarray.Dataset>
Dimensions:            (longitude: 357, latitude: 213, level: 1, time: 1)
Coordinates:
  * longitude          (longitude) float64 302.0 139.5 237.5 ... 5.5 251.5 45.5
  * latitude           (latitude) float64 -2.5 75.0 -82.0 ... -4.5 38.5 45.5
  * level              (level) int64 300
  * time               (time) datetime64[ns] 2022-04-04T09:00:00
Data variables:
    air_temperature    (longitude, latitude, level, time) float64 243.2 ... nan
    specific_humidity  (longitude, latitude, level, time) float64 0.0006418 ....
processed chunk


A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  chunk['level'] = 300


<xarray.Dataset>
Dimensions:            (longitude: 344, latitude: 218, level: 1, time: 1)
Coordinates:
  * longitude          (longitude) float64 219.5 51.5 100.5 ... 277.5 15.5 140.5
  * latitude           (latitude) float64 66.0 -70.5 -82.5 ... 83.5 21.5 77.0
  * level              (level) int64 300
  * time               (time) datetime64[ns] 2022-04-04T09:00:00
Data variables:
    air_temperature    (longitude, latitude, level, time) float64 229.7 ... nan
    specific_humidity  (longitude, latitude, level, time) float64 9.248e-06 ....


In [None]:
#calc SAC accuracy
correct = 0
wrong = 0
resultsList = []
sacPredFrames = []
sacValidFrames = []
issrPredFrames = []
issrValidFrames = []
for sacChunkPred, sacChunkValid, issrChunkPred, issrChunkValid in zip(sacPred, sacValid, issrPred, issrValid):
  sacPredFrames.append(sacChunkPred)
  sacValidFrames.append(sacChunkValid)
  issrPredFrames.append(issrChunkPred)
  issrValidFrames.append(issrChunkValid)

sacPred = pd.concat(sacPredFrames)
sacValid = pd.concat(sacValidFrames)
issrPred = pd.concat(issrPredFrames)
issrValid = pd.concat(issrValidFrames)

#concat and delete time only (for graphing)
SCPG = sacChunkPred.data.isel(time=0).data
SCVG = sacChunkValid.data.isel(time=0).data
ICPG = issrChunkPred.data.isel(time=0).data
ICVG = issrChunkValid.data.isel(time=0).data

loop = sacChunkPred.drop_dims(['time', 'level'])

#concat and flatten all (for accuracy calc)
SCPA = sacChunkPred.data.isel(time=0, level=0).data.flatten().tolist()
SCVA = sacChunkValid.data.isel(time=0, level=0).data.flatten().tolist()
ICPA = issrChunkPred.data.isel(time=0, level=0).data.flatten().tolist()
ICVA = issrChunkValid.data.isel(time=0, level=0).data.flatten().tolist()

correct predictions: 892
wrong predictions: 108
rate: 0.892


In [None]:
for predSAC, validSAC, predISSR, validISSR, longitude, latitude in zip(SCPA, SCVA, ICPA, ICVA, loop.coords['longitude'], loop.coords['latitude'])

In [None]:
for predSAC, validSAC, predISSR, validISSR in zip(SCPA, SCVA, ICPA, ICVA):
  if predSAC == predSAC and predISSR == predISSR: ## won't be equal to itself if nan
    if predSAC == validSAC and predISSR == validISSR:
      resultsList.append(1)
      correct += 1
    else:
      resultsList.append(0)
      wrong += 1


print('correct predictions: ' + str(correct))
print('wrong predictions: ' + str(wrong))
print('rate: ' + str(correct/(correct + wrong)))

In [None]:
accuracyHeatmap = xr.DataArray(resultsList, dims=('longitude', 'latitude'), coords={'longitude': (['x', 'y'], loop.coords['longitude']), 'latitude': (['x', 'y'], loop.coords['latitude'])})

In [None]:
figx = 6
figy = 12

In [None]:
#sac pred plot
SCPG.plot(x="longitude", y='latitude', row='level', cmap='reds', figsize=(figx, figy))

In [None]:
#sac valid plot
SCVG.plot(x="longitude", y='latitude', row='level', cmap='reds', figsize=(figx, figy))

In [None]:
#issr pred plot
ICPG.plot(x="longitude", y='latitude', row='level', cmap='reds', figsize=(figx, figy))

In [None]:
#issr valid plot
ICVG.plot(x="longitude", y='latitude', row='level', cmap='reds', figsize=(figx, figy))

In [None]:
del sacPred, sacValid
gc.collect()

34