In [1]:
import xarray as xr
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.colors as mcol
import numpy as np
import datetime as dt
import sys
import os
from data_processing import prepare_data
from train import train_model
from train import make_snapshot_data
from evaluate import evaluate_model
import torch
from pathlib import Path

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def load_ctd_data(fpath,start_year, end_year):
    ds=xr.open_dataset(fpath)
    #start_date = pd.Timestamp(f"{start_year}-01-01")
    #end_date = pd.Timestamp(f"{end_year+1}-01-01")
    #ds = ds.where((ds.time >= start_date) & (ds.time < end_date))
    return ds
    
def load_model_data(fpath, start_year, end_year):
    ds=xr.open_dataset(fpath)
    #start_date = pd.Timestamp(f"{start_year}-01-01")
    #end_date = pd.Timestamp(f"{end_year+1}-01-01")
    #ds = ds.where((ds.time >= start_date) & (ds.time < end_date))
    return ds, ds['bathy'].values

def normalize_dataset(ds, var_methods=None):
    """
    Normalize selected variables in an xarray.Dataset for ML.
    Returns:
      - normalized dataset
      - dictionary of scaling parameters for rescaling later
    """

    ds_norm = ds.copy(deep=True)
    scale_params = {}

    # Default normalization methods (can override with var_methods)
    default_methods = {
        "Temperature": "zscore",
        "t_pot": "zscore",
        "Salinity": "minmax",
        "Oxygen": "zscore",
        "Bathymetry": "minmax",
        "Depth": "minmax",
        "Latitude": None,
        "Longitude": None,
        "DOY" : None
    }

    if var_methods is None:
        var_methods = default_methods

    for var in ds.data_vars:
        method = var_methods.get(var, None)
        data = ds[var]

        if method == "zscore":
            mean_val = float(data.mean(skipna=True))
            std_val = float(data.std(skipna=True))
            ds_norm[var] = (data - mean_val) / std_val

            scale_params[var] = {
                "method": "zscore",
                "mean": mean_val,
                "std": std_val
            }

        elif method == "minmax":
            min_val = float(data.min(skipna=True))
            max_val = float(data.max(skipna=True))
            ds_norm[var] = (data - min_val) / (max_val - min_val)

            scale_params[var] = {
                "method": "minmax",
                "min": min_val,
                "max": max_val
            }

        else:
            # Variable not normalized (e.g., coordinates)
            scale_params[var] = {"method": None}
            continue

        print(f"Normalized {var} using {method}")

    return ds_norm, scale_params

def apply_normalization(ds, scale_params):
    """Apply precomputed normalization parameters to a dataset."""
    ds_norm = ds.copy(deep=True)
    for var, params in scale_params.items():
        if params["method"] == "zscore":
            mean_val = params["mean"]
            std_val = params["std"]
            ds_norm[var] = (ds[var] - mean_val) / std_val

        elif params["method"] == "minmax":
            min_val = params["min"]
            max_val = params["max"]
            ds_norm[var] = (ds[var] - min_val) / (max_val - min_val)
        # else: leave unchanged
    return ds_norm

def reshape_to_tcsd(ds_input: xr.DataArray, ds_target: xr.DataArray):    ##NEW
    ds_input = xr.concat([ds_input[var] for var in list(ds_input.data_vars)], dim = 'channels')
    ds_target = xr.concat([ds_target[var] for var in list(ds_target.data_vars)], dim = 'channels')
    mask = (~np.isnan(ds_target)).astype(int)
    return (ds_input.fillna(0).to_numpy(), ds_target.fillna(0).to_numpy(), mask.to_numpy())

def haversine(la0,lo0,la1,lo1):
    """ haversine formula with numpy array handling
    Calculates spherical distance between points on Earth in meters
    Compares elements of (la0,lo0) with (la1,lo1)
    Shapes must be compatible with numpy array broadcasting
    args: lats and lons in decimal degrees
    returns: distance on sphere with volumetric mean Earth radius in meters
    """
    rEarth=6371*1e3 # 
    # convert to radians
    la0=np.radians(la0)
    la1=np.radians(la1)
    lo0=np.radians(lo0)
    lo1=np.radians(lo1)
    theta=2*np.arcsin(np.sqrt(np.sin((la0-la1)/2)**2+np.cos(la0)*np.cos(la1)*np.sin((lo0-lo1)/2)**2))
    d=rEarth*theta
    return d

In [3]:
    work_dir='/space/hall5/sitestore/eccc/crd/ccrn/users/reo000/StatDownOc/output/work'
    data_dir='/space/hall5/sitestore/eccc/crd/ccrn/users/reo000/StatDownOc/griddedFiles'
    year_range= [1990, 1992]
    target_variable= "t_pot"
    train_ratio = 0.7,  ##Changed
    val_ratio = 0.15   ##Changed

In [4]:
    start_year, end_year = year_range

In [5]:
    obs = load_ctd_data(Path(data_dir,Path('ctd_obs_ds_v2.nc')), start_year, end_year)

In [6]:
obs

In [7]:
    obs=obs[[target_variable]]

In [8]:
    obsmask=~np.isnan(obs[target_variable])

In [9]:
    stations = obs['x']
    depths = obs['z']

In [11]:
    ds_input0, bathymetry = load_model_data(Path(data_dir,Path('griddedROMS.nc')),None,None)

In [12]:
ds_input0

In [13]:
bathymetry

array([[1., 1., 1., ..., 1., 1., 1.],
       [1., 1., 1., ..., 1., 1., 1.],
       [1., 1., 1., ..., 1., 1., 1.],
       ...,
       [1., 1., 1., ..., 0., 0., 0.],
       [1., 1., 1., ..., 0., 0., 0.],
       [1., 1., 1., ..., 0., 0., 0.]])

In [14]:
f=xr.open_dataset(Path(data_dir,Path('griddedROMS.nc')))
f

In [15]:
    ds_input=ds_input0[[target_variable]]
    ds_target=ds_input.copy()

In [16]:

    obs = obs.expand_dims('channels', axis = -3)

In [17]:
    perm=np.random.permutation(len(obs.time))
    perm=np.concatenate((perm,perm),axis=0)

In [19]:
    omask=obsmask.isel(time=perm[:len(ds_input.time)]).values
    ds_input=ds_input*omask

In [20]:
    #ds_input = ds_input.expand_dims('channels', axis = -3)
    ds_input

In [21]:
bathymetry.shape

(800, 38)

In [22]:
ds_input['bathymetry'] = (["z","x"],bathymetry)

In [23]:
omask.shape

(312, 800, 38)

In [24]:
 ds_input['omask']=(["time","z","x"],omask)

In [25]:

    ds_input['sin_yd']=ds_input0['sin_yearday'].broadcast_like(ds_input[target_variable])
    ds_input['cos_yd']=ds_input0['cos_yearday'].broadcast_like(ds_input[target_variable])

In [27]:
ds_input['sin_yd']

In [None]:

    
    # for trgt in [target_variable]:
    #     arr = ds_target[trgt].values.copy() 
    #     arr[...,model_ind_closet_to_obs] = obs[trgt].values
    #     ds_target[trgt] = (ds_target[trgt].dims, arr)

    # Add static variables
    if bathymetry is not None:  
        ds_input['bathymetry'] = bathymetry
    ds_input['omask']=omask
    ds_input['sin_yd']=ds_input0['sin_yd'].broadcast_like(ds_input[target_variable])
    ds_input['cos_yd']=ds_input0['cos_yd'].broadcast_like(ds_input[target_variable])
    
    # === Split Data into train, validation, test ===
    T = ds_input.sizes["time"]
    # split ratios
    # split indices
    train_end = int(train_ratio * T)
    val_end = int((train_ratio + val_ratio) * T)
    
    ds_input_train = ds_input.isel(time=slice(0, train_end))
    ds_input_val   = ds_input.isel(time=slice(train_end, val_end))

    ds_target_train = ds_target.isel(time=slice(0, train_end))
    ds_target_val   = ds_target.isel(time=slice(train_end, val_end))

    if train_ratio + val_ratio < 1:
        ds_input_test  = ds_input.isel(time=slice(val_end, T))
        ds_target_test  = ds_target.isel(time=slice(val_end, T))
    else:
        print('==========================================================\n'+
              'Test split ratio is zero. Test set is the same as validation set! \n' + 
              '==========================================================\n')
        ds_input_test  = ds_input_val.copy()
        ds_target_test  = ds_target_val.copy()

In [None]:
target_variable='t_pot'

In [None]:
    obs=xr.open_dataset('/space/hall5/sitestore/eccc/crd/ccrn/users/reo000/StatDownOc/output/ctdObs/ctd_obs_ds.nc')
    obs=obs[[target_variable]]
    obsmask=~np.isnan(obs[target_variable])

In [None]:
ds_input0=xr.open_dataset('/space/hall5/sitestore/eccc/crd/ccrn/users/reo000/StatDownOc/output/griddedROMS.nc')

In [None]:
bathymetry=ds_input0.bathy.values
ds_input=ds_input0[[target_variable]]
ds_target=ds_input.copy()

In [None]:
perm=np.random.permutation(len(obs.time))
perm=np.concatenate((perm,perm),axis=0)
omask=obsmask.isel(t_ind=perm[:len(ds_input.time)]).values
ds_input=ds_input*omask

In [None]:
ds_input

In [None]:
ds_input.t_pot

In [None]:
ds_target.t_pot.values

In [None]:
bathymetry

In [None]:
bathymetry2 = np.where(bathymetry == 0 , 1,0)

In [None]:
bathymetry2

In [None]:
ds_input