In [1]:
import os
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import requests
from io import StringIO
import xarray as xr
from bs4 import BeautifulSoup

In [2]:
def get_latest_cora_url(base_url):
    resp = requests.get(base_url); resp.raise_for_status()
    soup = BeautifulSoup(resp.text, 'html.parser')
    dirs = [a['href'] for a in soup.find_all('a', href=True) if a['href'].endswith('/')]
    latest_dir = sorted(dirs)[-1]
    dir_url = base_url + latest_dir
    resp2 = requests.get(dir_url); resp2.raise_for_status()
    soup2 = BeautifulSoup(resp2.text, 'html.parser')
    files = [a['href'] for a in soup2.find_all('a', href=True) if a['href'].endswith('.nc')]
    latest_file = sorted(files)[-1]
    return dir_url + latest_file

In [3]:
def load_cora_patch(s3_base_url, start_date, end_date, lat, lon, window=3):
    url = get_latest_cora_url(s3_base_url)
    ds = xr.open_dataset(url)
    ds = ds.sel(time=slice(start_date, end_date))
    lat_i = np.argmin(np.abs(ds.lat - lat))
    lon_i = np.argmin(np.abs(ds.lon - lon))
    half = window // 2
    patch = ds['sea_surface_height'][:,
               lat_i-half:lat_i+half+1,
               lon_i-half:lon_i+half+1]
    monthly = patch.resample(time='M').mean()
    return monthly.values.astype(np.float32)

In [4]:
def load_tide_patch(nc_url, time_var, level_var, start_date, end_date, lat, lon, window=3):
    ds = xr.open_dataset(nc_url)
    ds = ds.sel({time_var: slice(start_date, end_date)})
    lat_i = np.argmin(np.abs(ds.lat - lat))
    lon_i = np.argmin(np.abs(ds.lon - lon))
    half = window // 2
    patch = ds[level_var][:,
               lat_i-half:lat_i+half+1,
               lon_i-half:lon_i+half+1]
    monthly = patch.resample({time_var: 'M'}).mean()
    return monthly.values.astype(np.float32)

In [5]:
def load_grace_patch(nc_path, lat, lon, window=3):
    ds = xr.open_dataset(nc_path)
    lat_i = np.argmin(np.abs(ds.lat - lat))
    lon_i = np.argmin(np.abs(ds.lon - lon))
    half = window // 2
    patch = ds['water_equivalent_thickness'][:,
               lat_i-half:lat_i+half+1,
               lon_i-half:lon_i+half+1]
    monthly = patch.resample(time='M').mean()
    return monthly.values.astype(np.float32)

In [6]:
class SpatialTimeDataset(Dataset):
    def __init__(self, tide, cora, grace, seq_len, pred_len):
        self.data = [tide, cora, grace]
        self.seq_len = seq_len
        self.pred_len = pred_len
        self.T = tide.shape[0]

    def __len__(self):
        return self.T - self.seq_len - self.pred_len + 1

    def __getitem__(self, idx):
        seqs = []
        for arr in self.data:
            seqs.append(arr[idx:idx+self.seq_len])
        x = np.stack(seqs, axis=1) 
        target = self.data[0][idx+self.seq_len:idx+self.seq_len+self.pred_len]
        y = target.mean(axis=(0,1,2))  
        return torch.from_numpy(x), torch.tensor(y, dtype=torch.float32)


In [7]:
class SpatialTemporalModel(nn.Module):
    def __init__(self, config):
        super().__init__()
        C = 3
        self.cnn = nn.Sequential(
            nn.Conv2d(C, config['cnn_filters'], 3, padding=1), nn.ReLU(),
            nn.Conv2d(config['cnn_filters'], config['cnn_filters'], 3, padding=1), nn.ReLU(),
            nn.AdaptiveAvgPool2d((1,1))
        )
        d_model = config['d_model']
        self.proj = nn.Linear(config['cnn_filters'], d_model)
        encoder = nn.TransformerEncoderLayer(
            d_model=d_model,
            nhead=config['nhead'],
            dim_feedforward=config['dim_feedforward'],
            dropout=config['dropout'],
            batch_first=True
        )
        self.transformer = nn.TransformerEncoder(encoder, num_layers=config['num_layers'])
        self.head = nn.Linear(d_model, 1)

    def forward(self, x):
        # x: (batch, seq, C, H, W)
        b, t, C, H, W = x.shape
        x = x.view(b*t, C, H, W)
        f = self.cnn(x).view(b, t, -1) 
        p = self.proj(f)                   
        e = self.transformer(p) 
        out = e[:, -1, :] 
        return self.head(out).squeeze(-1)  

In [8]:
def prepare_loaders(tide, cora, grace, config):
    ds = SpatialTimeDataset(tide, cora, grace, config['seq_len'], config['pred_len'])
    n_train = int(len(ds) * config['train_split'])
    train_ds, val_ds = torch.utils.data.random_split(ds, [n_train, len(ds)-n_train])
    return DataLoader(train_ds, config['batch_size'], shuffle=True), DataLoader(val_ds, config['batch_size'])

In [9]:
def train_model(config):
    tide = load_tide_patch(
        config['tide_nc_url'], config['tide_time_var'], config['tide_level_var'],
        config['start_date'], config['end_date'], config['lat'], config['lon'], config['window']
    )
    cora = load_cora_patch(
        config['s3_base_url'], config['start_date'], config['end_date'], config['lat'], config['lon'], config['window']
    )
    grace = load_grace_patch(
        config['grace_path'], config['lat'], config['lon'], config['window']
    )
    train_loader, val_loader = prepare_loaders(tide, cora, grace, config)
    model = SpatialTemporalModel(config).to(config['device'])
    opt = torch.optim.Adam(model.parameters(), lr=config['lr'])
    loss_fn = nn.MSELoss()
    for epoch in range(1, config['epochs']+1):
        model.train()
        train_losses = []
        for x, y in train_loader:
            x, y = x.to(config['device']), y.to(config['device'])
            opt.zero_grad()
            loss = loss_fn(model(x), y)
            loss.backward()
            opt.step()
            train_losses.append(loss.item())
        model.eval()
        val_losses = [loss_fn(model(x.to(config['device'])), y.to(config['device'])).item() for x, y in val_loader]
        print(f"Epoch {epoch}/{config['epochs']} | Train: {np.mean(train_losses):.4f} | Val: {np.mean(val_losses):.4f}")
    return model

In [10]:
if __name__ == '__main__':
    config = {
        's3_base_url': 'https://noaa-nos-cora-pds.s3.amazonaws.com/V1.1/assimilated/native_grid/',
        'tide_nc_url': None, 
        'tide_time_var': 'time',
        'tide_level_var': 'zeta',
        'grace_path': '/data/GRACE_DA_NLDAS_0.25deg.nc',
        'start_date': '2010-01-01', 'end_date': '2023-12-31',
        'lat': 37.7749, 'lon': -122.4194, 'window': 3,
        'seq_len': 12, 'pred_len': 1,
        'cnn_filters': 16, 'd_model': 64, 'nhead': 4, 'num_layers': 2,
        'dim_feedforward': 128, 'dropout': 0.1,
        'batch_size': 32, 'train_split': 0.8, 'lr': 1e-3, 'epochs': 20,
        'device': 'cuda' if torch.cuda.is_available() else 'cpu'
    }
    if config['tide_nc_url'] is None:
        config['tide_nc_url'] = get_latest_cora_url(config['s3_base_url'])
    model = train_model(config)
    torch.save(model.state_dict(), 'spatial_temporal_cnn_transformer.pth')


HTTPError: 404 Client Error: Not Found for url: https://noaa-nos-cora-pds.s3.amazonaws.com/V1.1/assimilated/native_grid/