In [101]:
from enum import Enum
from time import sleep
from io import StringIO
import requests
from requests.compat import urljoin
import pandas as pd
import numpy as np
import os
from typing import List, Dict
from datetime import datetime
import re
from typing import Tuple

class MeasureType(Enum):
  LEVEL = "level"
  FLOW = "flow"
  RAINFALL = "rainfall"
  
  @property
  def units(self):  
     return {
        MeasureType.LEVEL: "i-900-m-qualified",
        MeasureType.FLOW: "i-900-m3s-qualified",
        MeasureType.RAINFALL: "t-900-mm-qualified",
    }.get(self, None)
     
  @property
  def observed_property_name(self):
    return {
      MeasureType.LEVEL: "waterLevel",
      MeasureType.FLOW: "waterFlow",
      MeasureType.RAINFALL: "rainfall",
  }.get(self, None)

class Measure:
  def __init__(
    self,
    station_id: str,
    measure_type: MeasureType,
  ):
    self.station_id = station_id
    self.measure_type = measure_type
    
  def __str__(self):
    return f"{self.station_id}-{self.measure_type.value}-{self.measure_type.units}"
  
  def __repr__(self):
    return f"Measure(station_id={self.station_id}, measure_type={self.measure_type})"
    
  @staticmethod
  def from_string(str: str):
    # Format of string is http://environment.data.gov.uk/hydrology/id/measures/ba3f8598-e654-430d-9bb8-e1652e6ff93d-level-i-900-m-qualified
    repr = str.split("/")[-1]
    station_id_regex = r"[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}"
    possible_measures = {f'{m.value}-{m.units}': m for m in MeasureType}
    
    station_id = re.search(station_id_regex, repr)
    assert station_id, f"Could not find station ID in {repr}"
    station_id = station_id.group(0)
    
    measure_type = next((possible_measures[m] for m in possible_measures if m in repr), None)
    assert measure_type, f"Could not find measure type in {repr}"
    
    return Measure(station_id, MeasureType(measure_type))
    
    
class HydrologyApi:
    API_BASE_URL = "https://environment.data.gov.uk/hydrology/"
    DATA_DIR = "data"
    CACHE_DIR = "cache"
    START_DATE = datetime(2008, 1, 1)
    
    def _batch_request(
      self,
      *args,
      **kwargs,
    ) -> pd.DataFrame:
      """Deal with batch requests from the API. These may be queued for some time before returning data.

      Returns:
          pd.DataFrame: The data returned by the API
      """
      class BatchRequestStatus(Enum):
        PENDING = "pending"
        IN_PROGRESS = "inprogress"
        COMPLETE = "complete"
        FAILED = "failed"
        
        @staticmethod
        def from_string(s: str):
          s = s.lower()
          s = "complete" if s == "completed" else s
          
          assert s in [e.value for e in BatchRequestStatus], f"Unknown response status: {s}"
          return BatchRequestStatus(s)
        
      status = BatchRequestStatus.PENDING
      
      required_headers = {"Accept-Encoding": "gzip"}
      kwargs["headers"] = {**kwargs.get("headers", {}), **required_headers}
      
      
      while status in [BatchRequestStatus.PENDING, BatchRequestStatus.IN_PROGRESS]:
        response = requests.get(*args, **kwargs)
        
        content_type = response.headers.get("Content-Type", "")
        
        if content_type == "text/csv":
          buffer = StringIO(response.text)
          # write buffer to file for debugging
          with open("data.csv", "w") as f:
            f.write(response.text)
          return pd.read_csv(buffer, low_memory=False)
        
        assert "application/json" in content_type, f"Unexpected content type: {content_type}"
        
        response_data: dict = response.json()
        assert "status" in response_data, "No status field in response"
        status = BatchRequestStatus.from_string(response_data["status"])
        
        match status:
          case BatchRequestStatus.PENDING | BatchRequestStatus.IN_PROGRESS:
            eta = response_data.get("eta", 60 * 1000) / 1000
            sleep(max(eta*0.1, 1))
          
          case BatchRequestStatus.COMPLETE:
            keys = ["dataUrl", "url"] # Some responses have dataUrl, some have url
            data_url = next((response_data.get(k) for k in keys if k in response_data), None)
            assert data_url, f"Could not find data URL in response: {response_data}"
            return pd.read_csv(data_url)

          case BatchRequestStatus.FAILED:
            raise Exception(f"Batch request failed: {response_data}")
          
          case _:
            raise Exception(f"Unknown status: {status}")
          
    def _request(
      self,
      *args,
      **kwargs,
    ):
      response = requests.get(*args, **kwargs)
      response.raise_for_status()
      
      response_data = response.json()
      assert "items" in response_data, "No items field in response"
      return pd.DataFrame(response_data["items"])
      
    def get_stations(
      self,
      measures: MeasureType | List[MeasureType] | None = None,
      river: str = None,
      position: Tuple[float, float] = None,
      radius: float = None,
      limit: int = None,
      return_df = False,
    ):
      if isinstance(measures, MeasureType):
        measures = [measures]
        
      lat, long = position if position else (None, None)
        
      result = requests.get(
        urljoin(self.API_BASE_URL, "id/stations"),
        params = {
          "observedProperty": [measure.observed_property_name for measure in measures] if measures else None,
          "riverName": river, 
          "lat": lat,
          "long": long,
          "dist": radius,
          "_limit": limit,
          "status.label": "Active",
        },
      )
      result_json = result.json()
      assert "items" in result_json, f"Unexpected response: {result_json}"
      return (
        pd.DataFrame(result_json["items"]) 
        if return_df 
        else {
          station['notation']: station['label']
          for station in result_json["items"]
        }
      )
      
    def get_measures(
      self,
      measures: List[Measure],
      station_names: Dict[str, str],
      start_date: datetime = START_DATE,
    ):
      # Estimate how many rows we are going to get back
      # Each measure is every 15 mins
      estimated_rows = 4 * 24 * (datetime.now() - start_date).days * len(measures)
      
      params = {
        'measure': [str(m) for m in measures],
        'mineq-date': start_date.strftime("%Y-%m-%d"),
        '_limit': int(estimated_rows * 1.1),
      }
      
      if estimated_rows > 2_000_000:
        # We need to use the batch api
        df = self._batch_request(
          urljoin(self.API_BASE_URL, 'data/batch-readings/batch'),
          params = params
        )
        
      else:
        df = self._request(
          urljoin(self.API_BASE_URL, 'data/readings.json'),
          params = params
        )
        df['measure'] = df['measure'].str.get('@id')
        
      return (
        df
        .loc[lambda x: x['quality'].isin(['Good', 'Unchecked', 'Estimated'])]
        .assign(
          timestamp = lambda x: pd.to_datetime(x['dateTime']),
          value = lambda x: pd.to_numeric(x['value'], errors='coerce').astype(np.float32),
          series_name = lambda x: (
            pd.Categorical(x['measure'])
            .map(Measure.from_string, na_action='None')
            .map(lambda row: f"{station_names[row.station_id]} ({row.measure_type.value})", na_action='None')
          )
        )
        .drop(
          columns=['measure', 'id', 'date', 'completeness', 'quality', 'qcode', 'dateTime'], 
          errors='ignore'
        )
        .pivot(
          index='timestamp',
          columns='series_name',
          values='value'
        )
        .resample('15min')
        .interpolate(
          "time",
          limit_direction='both',
          limit=24 * 4,
          fill_value="extrapolate",
        )
      )
    

## Load Dataset

In [119]:
api = HydrologyApi()
level_stations = api.get_stations(MeasureType.LEVEL, river="River Wear")
rainfall_stations = api.get_stations(MeasureType.RAINFALL, position=(54.774, -1.558), radius=15)

measures = [
  Measure(station_id, MeasureType.LEVEL) for station_id in level_stations
] + [
  Measure(station_id, MeasureType.RAINFALL) for station_id in rainfall_stations
]

stations = {
  **level_stations,
  **rainfall_stations,
}

df = api.get_measures(
  measures, 
  stations, 
  start_date=datetime(2007, 1, 1)
)

df.head()

## Forecasting

In [None]:
from neuralforecast.core import NeuralForecast
from neuralforecast.auto import NHITS
from neuralforecast.losses.pytorch import MQLoss

train_df = (
  df
  .reset_index()
  .rename_axis(None, axis=1)
  .rename(columns={
    "Durham New Elvet Bridge (level)": "y",
    "timestamp": "ds",
  })
  .assign(unique_id = "River Wear")
)

In [None]:
def make_loss():
  return MQLoss(quantiles=[0.5, 0.9, 0.99])

models = [
  NHITS(
    h = 4 * 24, # 1 day
    input_size = 3 * 4 * 24, # 3 days
    hist_exog_list = train_df.columns.drop(["ds", "unique_id", "y"]).to_list(),
    scaler_type = 'robust',
    loss=make_loss(),
    max_steps=1000,
  )
]

nf = NeuralForecast(
  models = models,
  freq = '15min',
)

nf.fit(df=train_df)