# Nixtla Client

In [None]:
#| default_exp nixtla_client

In [None]:
#| hide 
%load_ext autoreload
%autoreload 2

In [None]:
#| export
import logging
import math
import os
import warnings
from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import (
    TYPE_CHECKING,
    Any,
    Callable,
    Dict,
    List,
    Literal,
    Optional,
    Tuple,
    TypeVar,
    Union,
)

import httpcore
import httpx
import numpy as np
import orjson
import pandas as pd
import utilsforecast.processing as ufp
from fastcore.basics import patch
from pydantic import NonNegativeInt, PositiveInt
from tenacity import (
    RetryCallState,
    retry,
    retry_if_exception,
    stop_after_attempt,
    stop_after_delay,
    wait_fixed,
)
from utilsforecast.compat import DFType, DataFrame, pl_DataFrame
from utilsforecast.feature_engineering import _add_time_features, time_features
from utilsforecast.validation import ensure_time_dtype, validate_format
if TYPE_CHECKING:
    try:
        from fugue import AnyDataFrame
    except ModuleNotFoundError:
        pass
    try:
        import matplotlib.pyplot as plt
    except ModuleNotFoundError:
        pass
    try:
        import plotly
    except ModuleNotFoundError:
        pass
    try:
        import triad
    except ModuleNotFoundError:
        pass
    try:
        from polars import DataFrame as PolarsDataFrame
    except ModuleNotFoundError:
        pass
    try:
        from dask.dataframe import DataFrame as DaskDataFrame
    except ModuleNotFoundError:
        pass
    try:
        from pyspark.sql import DataFrame as SparkDataFrame
    except ModuleNotFoundError:
        pass
    try:
        from ray.data import Dataset as RayDataset
    except ModuleNotFoundError:
        pass

In [None]:
#| exporti
AnyDFType = TypeVar(
    "AnyDFType",
    "DaskDataFrame",
    pd.DataFrame,
    "PolarsDataFrame",
    "RayDataset",
    "SparkDataFrame",
)
DistributedDFType = TypeVar(
    "DistributedDFType",
    "DaskDataFrame",
    "RayDataset",
    "SparkDataFrame",
)
logging.basicConfig(level=logging.INFO)
logging.getLogger('httpx').setLevel(logging.ERROR)
logger = logging.getLogger(__name__)

In [None]:
#| hide
from contextlib import contextmanager
from itertools import product
from time import time, sleep

from dotenv import load_dotenv
from fastcore.test import test_eq, test_fail
from utilsforecast.data import generate_series

from nixtla.date_features import SpecialDates

In [None]:
#| hide
load_dotenv(override=True)

In [None]:
#| exporti
_Loss = Literal["default", "mae", "mse", "rmse", "mape", "smape"]
_Model = Literal["azureai", "timegpt-1", "timegpt-1-long-horizon"]

_date_features_by_freq = {
    # Daily frequencies
    'B': ['year', 'month', 'day', 'weekday'],
    'C': ['year', 'month', 'day', 'weekday'],
    'D': ['year', 'month', 'day', 'weekday'],
    # Weekly
    'W': ['year', 'week', 'weekday'],
    # Monthly
    'M': ['year', 'month'],
    'SM': ['year', 'month', 'day'],
    'BM': ['year', 'month'],
    'CBM': ['year', 'month'],
    'MS': ['year', 'month'],
    'SMS': ['year', 'month', 'day'],
    'BMS': ['year', 'month'],
    'CBMS': ['year', 'month'],
    # Quarterly
    'Q': ['year', 'quarter'],
    'BQ': ['year', 'quarter'],
    'QS': ['year', 'quarter'],
    'BQS': ['year', 'quarter'],
    # Yearly
    'A': ['year'],
    'Y': ['year'],
    'BA': ['year'],
    'BY': ['year'],
    'AS': ['year'],
    'YS': ['year'],
    'BAS': ['year'],
    'BYS': ['year'],
    # Hourly
    'BH': ['year', 'month', 'day', 'hour', 'weekday'],
    'H': ['year', 'month', 'day', 'hour'],
    # Minutely
    'T': ['year', 'month', 'day', 'hour', 'minute'],
    'min': ['year', 'month', 'day', 'hour', 'minute'],
    # Secondly
    'S': ['year', 'month', 'day', 'hour', 'minute', 'second'],
    # Milliseconds
    'L': ['year', 'month', 'day', 'hour', 'minute', 'second', 'millisecond'],
    'ms': ['year', 'month', 'day', 'hour', 'minute', 'second', 'millisecond'],
    # Microseconds
    'U': ['year', 'month', 'day', 'hour', 'minute', 'second', 'microsecond'],
    'us': ['year', 'month', 'day', 'hour', 'minute', 'second', 'microsecond'],
    # Nanoseconds
    'N': []
}

def _retry_strategy(max_retries: int, retry_interval: int, max_wait_time: int):
    def should_retry(exc: Exception) -> bool:
        retriable_exceptions = (
            ConnectionResetError,
            httpcore.ConnectError,
            httpcore.RemoteProtocolError,
            httpx.ConnectTimeout,
            httpx.ReadError,
            httpx.RemoteProtocolError,
            httpx.ReadTimeout,
            httpx.PoolTimeout,
            httpx.WriteError,
            httpx.WriteTimeout,
        )
        retriable_codes = [408, 409, 429, 502, 503, 504]
        return (
            isinstance(exc, retriable_exceptions)
            or (isinstance(exc, ApiError) and exc.status_code in retriable_codes)
        )

    def after_retry(retry_state: RetryCallState) -> None:
        error = retry_state.outcome.exception()
        logger.error(
            f"Attempt {retry_state.attempt_number} failed with error: {error}"
        )

    return retry(
        retry=retry_if_exception(should_retry),
        wait=wait_fixed(retry_interval),
        after=after_retry,
        stop=stop_after_attempt(max_retries) | stop_after_delay(max_wait_time),
        reraise=True,
    )

def _maybe_infer_freq(
    df: DataFrame,
    freq: Optional[str],
    id_col: str,
    time_col: str,
) -> str:
    if freq is not None and freq not in ['W', 'M', 'Q', 'Y', 'A']:
        return freq
    if isinstance(df, pl_DataFrame):
        raise ValueError(
            "Cannot infer frequency for a polars DataFrame, please set the "
            "`freq` argument to a valid polars offset.\nYou can find them at "
            "https://pola-rs.github.io/polars/py-polars/html/reference/expressions/api/polars.Expr.dt.offset_by.html"
        )
    assert isinstance(df, pd.DataFrame)
    sizes = df[id_col].value_counts(sort=True)
    times = df.loc[df[id_col] == sizes.index[0], time_col].sort_values()
    if times.dt.tz is not None:
        times = times.dt.tz_convert('UTC').dt.tz_localize(None)
    inferred_freq = pd.infer_freq(times.values)
    if inferred_freq is None:
        raise RuntimeError(
            'Could not infer the frequency of the time column. This could be due '
            'to inconsistent intervals. Please check your data for missing, '
            'duplicated or irregular timestamps'
        )
    if freq is not None:
        # check we have the same base frequency
        # except when we have yearly frequency (A, and Y means the same)
        if (
            (freq[0] != inferred_freq[0] and freq[0] not in ('A', 'Y'))
            or (freq[0] in ('A', 'Y') and inferred_freq[0] not in ('A', 'Y'))
        ):
            raise RuntimeError(f'Failed to infer special date, inferred freq {inferred_freq}')
    logger.info(f'Inferred freq: {inferred_freq}')
    return inferred_freq

def _standardize_freq(freq: str) -> str:
    return freq.replace('mo', 'MS')

def _array_tails(
    x: np.ndarray,
    indptr: np.ndarray,
    out_sizes: np.ndarray,
) -> np.ndarray:
    if (out_sizes > np.diff(indptr)).any():
        raise ValueError('out_sizes must be at most the original sizes.')
    idxs = np.hstack(
        [
            np.arange(end - size, end)
            for end, size in zip(indptr[1:], out_sizes)
        ]
    )
    return x[idxs]

def _tail(proc: ufp.ProcessedDF, n: int) -> ufp.ProcessedDF:
    new_sizes = np.minimum(np.diff(proc.indptr), n)
    new_indptr = np.append(0, new_sizes.cumsum())
    new_data = _array_tails(proc.data, proc.indptr, new_sizes)
    return ufp.ProcessedDF(
        uids=proc.uids,
        last_times=proc.last_times,
        data=new_data,
        indptr=new_indptr,
        sort_idxs=None,
    )

def _partition_series(
    payload: Dict[str, Any], n_part: int, h: int
) -> List[Dict[str, Any]]:
    parts = []
    series = payload.pop("series")
    n_series = len(series["sizes"])
    n_part = min(n_part, n_series)
    series_per_part = math.ceil(n_series / n_part)
    prev_size = 0
    for i in range(0, n_series, series_per_part):
        sizes = series["sizes"][i : i + series_per_part]
        curr_size = sum(sizes)
        part_idxs = slice(prev_size, prev_size + curr_size)
        prev_size += curr_size
        part_series = {
            "y": series["y"][part_idxs],
            "sizes": sizes,
        }
        if series["X"] is None:
            part_series["X"] = None
            if h > 0:
                part_series["X_future"] = None
        else:
            part_series["X"] = [x[part_idxs] for x in series["X"]]
            if h > 0:
                part_series["X_future"] = [
                    x[i * h : (i + series_per_part) * h] for x in series["X_future"]
                ]
        parts.append({"series": part_series, **payload})
    return parts

def _maybe_add_date_features(
    df: DFType,
    X_df: Optional[DFType],
    features: Union[bool, List[Union[str, Callable]]],
    one_hot: Union[bool, List[str]],
    freq: str,
    h: int,
    id_col: str,
    time_col: str,
    target_col: str,
) -> Tuple[DFType, Optional[DFType]]:
    if not features:
        return df, X_df
    if isinstance(features, list):
        date_features = features
    else:
        date_features = _date_features_by_freq.get(freq, [])
        if not date_features:
            warnings.warn(
                f'Non default date features for {freq} '
                'please provide a list of date features'
            )
    # add features
    if X_df is None:
        df, X_df = time_features(
            df=df,
            freq=freq,
            features=date_features,
            h=h,
            id_col=id_col,
            time_col=time_col,
        )
    else:
        df = _add_time_features(df, features=date_features, time_col=time_col)
        X_df = _add_time_features(X_df, features=date_features,time_col=time_col)
    # one hot
    if isinstance(one_hot, list):
        features_one_hot = one_hot
    elif one_hot:
        features_one_hot = [f for f in date_features if not callable(f)]
    else:
        features_one_hot = []
    if features_one_hot:
        X_df = ufp.assign_columns(X_df, target_col, 0)
        full_df = ufp.vertical_concat([df, X_df])
        if isinstance(full_df, pd.DataFrame):
            full_df = pd.get_dummies(
                full_df, columns=features_one_hot, dtype='float32'
            )
        else:
            full_df = full_df.to_dummies(columns=features_one_hot)
        df = ufp.take_rows(full_df, slice(0, df.shape[0]))
        X_df = ufp.take_rows(full_df, slice(df.shape[0], full_df.shape[0]))
        X_df = ufp.drop_columns(X_df, target_col)
        X_df = ufp.drop_index_if_pandas(X_df)
    if h == 0:
        # time_features returns an empty df, we use it as None here
        X_df = None
    return df, X_df

def _validate_exog(
    df: DFType,
    X_df: Optional[DFType],
    id_col: str,
    time_col: str,
    target_col: str,
) -> Tuple[DFType, Optional[DFType]]:

    exog_list = [c for c in df.columns if c not in (id_col, time_col, target_col)]

    if X_df is None:
        df = df[[id_col, time_col, target_col, *exog_list]]
        return df, None

    futr_exog_list = [c for c in X_df.columns if c not in (id_col, time_col)]
    hist_exog_list = list(set(exog_list) - set(futr_exog_list))

    # Capture case where future exogenous are provided in X_df that are not in df
    missing_futr = set(futr_exog_list) - set(exog_list)
    if missing_futr:
        raise ValueError(
            "The following exogenous features are present in `X_df` "
            f"but not in `df`: {missing_futr}."
        )

    # Make sure df and X_df are in right order
    df = df[[id_col, time_col, target_col, *futr_exog_list, *hist_exog_list]]
    X_df = X_df[[id_col, time_col, *futr_exog_list]]

    return df, X_df

def _validate_input_size(
    df: DataFrame,
    id_col: str,
    model_input_size: int,
    model_horizon: int,
) -> None:
    min_size = ufp.counts_by_id(df, id_col)['counts'].min()
    if min_size < model_input_size + model_horizon:
        raise ValueError(
            'Your time series data is too short '
            'Please make sure that your each serie contains '
            f'at least {model_input_size + model_horizon} observations.'
        )

def _prepare_level_and_quantiles(
    level: Optional[List[Union[int, float]]], 
    quantiles: Optional[List[float]],
) -> Tuple[List[Union[int, float]], Optional[List[float]]]:
    if level is not None and quantiles is not None:
        raise ValueError(
            "You should provide `level` or `quantiles`, but not both."
        )
    if quantiles is None:
        return level, quantiles
    # we recover level from quantiles
    if not all(0 < q < 1 for q in quantiles):
        raise ValueError("`quantiles` should be floats between 0 and 1.")
    level = [abs(int(100 - 200 * q)) for q in quantiles]
    return level, quantiles

def _maybe_convert_level_to_quantiles(
    df: DFType,
    quantiles: Optional[List[float]],
) -> DFType:
    if quantiles is None:
        return df
    out_cols = [c for c in df.columns if '-lo-' not in c and '-hi-' not in c]
    df = ufp.copy_if_pandas(df, deep=False)
    for q in sorted(quantiles):
        if q == 0.5:
            col = 'TimeGPT'
        else:
            lv = int(100 - 200 * q)
            hi_or_lo = 'lo' if lv > 0 else 'hi'
            lv = abs(lv)
            col = f"TimeGPT-{hi_or_lo}-{lv}"
        q_col = f"TimeGPT-q-{int(q * 100)}"
        df = ufp.assign_columns(df, q_col, df[col])
        out_cols.append(q_col)
    return df[out_cols]

def _preprocess(
    df: DFType,
    X_df: Optional[DFType],
    h: int,
    freq: str,
    date_features: Union[bool, List[Union[str, Callable]]],
    date_features_to_one_hot: Union[bool, List[str]],
    id_col: str,
    time_col: str,
    target_col: str,
) -> Tuple[ufp.ProcessedDF, Optional[DFType], List[str]]:
    df, X_df = _maybe_add_date_features(
        df=df,
        X_df=X_df,
        features=date_features,
        one_hot=date_features_to_one_hot,
        freq=freq,
        h=h,
        id_col=id_col,
        time_col=time_col,
        target_col=target_col,
    )
    processed = ufp.process_df(
        df=df, id_col=id_col, time_col=time_col, target_col=target_col
    )
    if X_df is not None:
        X_df = ensure_time_dtype(X_df, time_col=time_col)
        processed_X = ufp.process_df(
            df=X_df, id_col=id_col, time_col=time_col, target_col=None,
        )
        X_future = processed_X.data.T
        futr_cols = [c for c in X_df.columns if c not in (id_col, time_col)]
    else:
        X_future = None
        futr_cols = None
    x_cols = [c for c in df.columns if c not in (id_col, time_col, target_col)]
    return processed, X_future, x_cols, futr_cols

def _forecast_payload_to_in_sample(payload):
    in_sample_payload = {
        k: v
        for k, v in payload.items()
        if k not in ('h', 'finetune_steps', 'finetune_loss')
    }
    del in_sample_payload['series']['X_future']
    return in_sample_payload

def _maybe_add_intervals(
    df: DFType,
    intervals: Optional[Dict[str, list[float]]],
) -> DFType:
    if intervals is None:
        return df
    intervals_df = type(df)(
        {f'TimeGPT-{k}': intervals[k] for k in sorted(intervals.keys())}
    )
    return ufp.horizontal_concat([df, intervals_df])

def _maybe_drop_id(df: DFType, id_col: str, drop: bool) -> DFType:
    if drop:
        df = ufp.drop_columns(df, id_col)
    return df

def _parse_in_sample_output(
    in_sample_output: Dict[str, Union[list[float], Dict[str, list[float]]]],
    df: DataFrame,
    processed: ufp.ProcessedDF,
    id_col: str,
    time_col: str,
    target_col: str,
) -> DataFrame:
    times = df[time_col].to_numpy()
    targets = df[target_col].to_numpy()
    if processed.sort_idxs is not None:
        times = times[processed.sort_idxs]
        targets = targets[processed.sort_idxs]
    times = _array_tails(
        times, processed.indptr, in_sample_output['sizes']
    )
    targets = _array_tails(
        targets, processed.indptr, in_sample_output['sizes']
    )
    uids = ufp.repeat(processed.uids, in_sample_output['sizes'])
    out = type(df)(
        {
            id_col: uids,
            time_col: times,
            target_col: targets,
            'TimeGPT': in_sample_output['mean'],
        }
    )
    return _maybe_add_intervals(out, in_sample_output['intervals'])

def _restrict_input_samples(level, input_size, model_horizon, h) -> int:
    if level is not None:
        # add sufficient info to compute
        # conformal interval
        # @AzulGarza
        #  this is an old opinionated decision
        #  about reducing the data sent to the api
        #  to reduce latency when
        #  a user passes level. since currently the model
        #  uses conformal prediction, we can change a minimum
        #  amount of data if the series are too large
        new_input_size = 3 * input_size + max(model_horizon, h)
    else:
        # we only want to forecast
        new_input_size = input_size
    return new_input_size

In [None]:
#| export
class ApiError(Exception):
    status_code: Optional[int]
    body: Any

    def __init__(self, *, status_code: Optional[int] = None, body: Optional[Any] = None):
        self.status_code = status_code
        self.body = body

    def __str__(self) -> str:
        return f"status_code: {self.status_code}, body: {self.body}"

In [None]:
#| export
class NixtlaClient:

    def __init__(
        self,
        api_key: Optional[str] = None,
        base_url: Optional[str] = None,
        timeout: int = 60,
        max_retries: int = 6,
        retry_interval: int = 10,
        max_wait_time: int = 6 * 60,
    ):
        """
        Client to interact with the Nixtla API.

        Parameters
        ----------
        api_key : str, optional (default=None)
            The authorization api_key interacts with the Nixtla API.
            If not provided, will use the NIXTLA_API_KEY environment variable.
        base_url : str, optional (default=None)
            Custom base_url.
            If not provided, will use the NIXTLA_BASE_URL environment variable.
        timeout : int, optional (default=60)
            Request timeout in seconds. Set this to `None` to disable it.
        max_retries : int (default=6)
            The maximum number of attempts to make when calling the API before giving up. 
            It defines how many times the client will retry the API call if it fails. 
            Default value is 6, indicating the client will attempt the API call up to 6 times in total
        retry_interval : int (default=10)
            The interval in seconds between consecutive retry attempts. 
            This is the waiting period before the client tries to call the API again after a failed attempt. 
            Default value is 10 seconds, meaning the client waits for 10 seconds between retries.
        max_wait_time : int (default=360)
            The maximum total time in seconds that the client will spend on all retry attempts before giving up. 
            This sets an upper limit on the cumulative waiting time for all retry attempts. 
            If this time is exceeded, the client will stop retrying and raise an exception. 
            Default value is 360 seconds, meaning the client will cease retrying if the total time 
            spent on retries exceeds 360 seconds. 
            The client throws a ReadTimeout error after 60 seconds of inactivity. If you want to 
            catch these errors, use max_wait_time >> 60. 
        """
        if api_key is None:
            api_key = os.environ['NIXTLA_API_KEY']
        if base_url is None:
            base_url = os.getenv('NIXTLA_BASE_URL', 'https://api.nixtla.io')
        self._client_kwargs = {
            'base_url': base_url,
            'headers': {
                'Authorization': f'Bearer {api_key}',
                'Content-Type': 'application/json',
            },
            'timeout': timeout,
        }
        self._retry_strategy = _retry_strategy(
            max_retries=max_retries, retry_interval=retry_interval, max_wait_time=max_wait_time
        )
        self._model_params: Dict[Tuple[str, str], Tuple[int, int]] = {}
        self._is_azure = 'ai.azure' in base_url
        if self._is_azure:
            self.supported_models = ['azureai']
        else:
            self.supported_models = ['timegpt-1', 'timegpt-1-long-horizon']

    def _make_request(self, client: httpx.Client, endpoint: str, payload: Dict[str, Any]) -> Dict[str, Any]:
        def ensure_contiguous_arrays(d: Dict[str, Any]) -> None:
            for k, v in d.items():
                if isinstance(v, np.ndarray):
                    if np.issubdtype(v.dtype, np.floating):
                        v_cont = np.ascontiguousarray(v, dtype=np.float32)
                        d[k] = np.nan_to_num(v_cont, 
                            nan=np.nan, 
                            posinf=np.finfo(np.float32).max, 
                            neginf=np.finfo(np.float32).min,
                            copy=False,
                        )
                    else:
                        d[k] = np.ascontiguousarray(v)
                elif isinstance(v, dict):
                    ensure_contiguous_arrays(v) 

        ensure_contiguous_arrays(payload)
        content = orjson.dumps(payload, option=orjson.OPT_SERIALIZE_NUMPY)
        content_size_mb = len(content) / (1024*1024)
        if content_size_mb > 200:
            raise ValueError(f'The payload is too large. Set num_partitions={math.ceil(content_size_mb / 200)}')
        resp = client.post(url=endpoint, content=content)
        try:
            resp_body = orjson.loads(resp.content)
        except orjson.JSONDecodeError:
            raise ApiError(
                status_code=resp.status_code,
                body=f'Could not parse JSON: {resp.content}',
            )
        if resp.status_code != 200:
            raise ApiError(status_code=resp.status_code, body=resp_body)
        if 'data' in resp_body:
            resp_body = resp_body['data']
        return resp_body

    def _make_request_with_retries(
        self,
        client: httpx.Client,
        endpoint: str,
        payload: Dict[str, Any],
    ) -> Dict[str, Any]:
        return self._retry_strategy(self._make_request)(
            client=client,
            endpoint=endpoint,
            payload=payload,
        )

    def _make_partitioned_requests(
        self,
        client: httpx.Client,
        endpoint: str,
        payloads: List[Dict[str, Any]],
    ) -> Dict[str, Any]:
        from tqdm.auto import tqdm

        num_partitions = len(payloads)
        results = num_partitions * [None]
        max_workers = min(10, num_partitions)
        with ThreadPoolExecutor(max_workers) as executor:
            future2pos = {
                executor.submit(
                    self._make_request_with_retries, client, endpoint, payload
                ): i
                for i, payload in enumerate(payloads)
            }
            for future in tqdm(as_completed(future2pos), total=len(future2pos)):
                pos = future2pos[future]
                results[pos] = future.result()
        resp = {"mean": np.hstack([res["mean"] for res in results])}
        first_res = results[0]
        for k in ('sizes', 'anomaly'):
            if k in first_res:
                resp[k] = np.hstack([res[k] for res in results])
        if 'idxs' in first_res:
            offsets = [0] + [sum(p['series']['sizes']) for p in payloads[:-1]]
            resp['idxs'] = np.hstack(
                [
                    np.array(res['idxs'], dtype=np.int64) + offset
                    for res, offset in zip(results, offsets)
                ]
            )
        if first_res["intervals"] is None:
            resp["intervals"] = None
        else:
            resp["intervals"] = {}
            for k in first_res["intervals"].keys():
                resp["intervals"][k] = np.hstack(
                    [res["intervals"][k] for res in results]
                )
        if "weights_x" not in first_res or first_res["weights_x"] is None:
            resp["weights_x"] = None
        else:
            resp["weights_x"] = [res["weights_x"] for res in results]
        if "feature_contributions" not in first_res or first_res["feature_contributions"] is None:
            resp["feature_contributions"] = None
        else:
            resp["feature_contributions"] = np.vstack([
                np.stack(res["feature_contributions"], axis=1) for res in results
            ]).T
        return resp

    def _maybe_override_model(self, model: str) -> str:
        if self._is_azure:
            model = 'azureai'
        return model

    def _get_model_params(self, model: str, freq: str) -> Tuple[int, int]:
        key = (model, freq)
        if key not in self._model_params:
            logger.info('Querying model metadata...')
            payload = {'model': model, 'freq': freq}
            with httpx.Client(**self._client_kwargs) as client:
                params = self._make_request_with_retries(
                    client, 'model_params', payload
                )['detail']
            self._model_params[key] = (params['input_size'], params['horizon'])
        return self._model_params[key]

    def _maybe_assign_weights(
        self,
        weights: Optional[Union[List[float], List[List[float]]]],
        df: DataFrame,
        x_cols: List[str],
    ) -> None:
        if weights is None:
            return
        if isinstance(weights[0], list):
            self.weights_x = [
                type(df)({'features': x_cols, 'weights': w}) for w in weights
            ]
        else:
            self.weights_x = type(df)(
                {'features': x_cols, 'weights': weights}
            )

    def _maybe_assign_feature_contributions(
        self,
        expected_contributions: bool,
        resp: Dict[str, Any],
        x_cols: List[str],
        out_df: DataFrame,
        insample_feat_contributions: Optional[List[List[float]]],
    ) -> None:
        if not expected_contributions:
            return
        if 'feature_contributions' not in resp:
            if self._is_azure:
                warnings.warn(
                    "feature_contributions aren't implemented in Azure yet."
                )
                return
            else:
                raise RuntimeError(
                    'feature_contributions expected in response but not found'
                )
        feature_contributions = resp['feature_contributions']
        if feature_contributions is None:
            return     
        shap_cols = x_cols + ["base_value"]
        shap_df = type(out_df)(dict(zip(shap_cols, feature_contributions)))
        if insample_feat_contributions is not None:
            insample_shap_df = type(out_df)(
                dict(zip(shap_cols, insample_feat_contributions))
            )
            shap_df = ufp.vertical_concat([insample_shap_df, shap_df])
        self.feature_contributions = ufp.horizontal_concat([out_df, shap_df])

    def _run_validations(
        self,
        df: DFType,
        X_df: Optional[DFType],
        id_col: str,
        time_col: str,
        target_col: str,
        model: str,
        validate_api_key: bool,
    ) -> Tuple[DFType, Optional[DFType], bool]:
        if validate_api_key and not self.validate_api_key(log=False):
            raise Exception('API Key not valid, please email ops@nixtla.io')
        if model not in self.supported_models:
            raise ValueError(
                f'unsupported model: {model}. supported models: {self.supported_models}'
            )
        drop_id = id_col not in df.columns
        if drop_id:
            df = ufp.copy_if_pandas(df, deep=False)
            df = ufp.assign_columns(df, id_col, 0)
            if X_df is not None:
                X_df = ufp.copy_if_pandas(X_df, deep=False)
                X_df = ufp.assign_columns(X_df, id_col, 0)
        if (
            isinstance(df, pd.DataFrame)
            and time_col not in df
            and pd.api.types.is_datetime64_any_dtype(df.index)
        ):
            df.index.name = time_col
            df = df.reset_index()
        df = ensure_time_dtype(df, time_col=time_col)
        validate_format(df=df, id_col=id_col, time_col=time_col, target_col=target_col)
        if ufp.is_nan_or_none(df[target_col]).any():
            raise ValueError(f'Target column ({target_col}) cannot contain missing values.')
        return df, X_df, drop_id

    def validate_api_key(self, log: bool = True) -> bool:
        """Returns True if your api_key is valid."""
        try:
            with httpx.Client(**self._client_kwargs) as client:
                validation = self._make_request_with_retries(
                    client, 'validate_token', {}
                )
        except:
            validation = {}
        if 'support' in validation and log:
            logger.info(f'Happy Forecasting! :), {validation["support"]}')
        return (
            validation.get('message', '') == 'success'
            or 'Forecasting! :)' in validation.get('detail', '')
        )

    def forecast(
        self,
        df: AnyDFType,
        h: PositiveInt,
        freq: Optional[str] = None,    
        id_col: str = 'unique_id',
        time_col: str = 'ds',
        target_col: str = 'y',
        X_df: Optional[AnyDFType] = None,
        level: Optional[List[Union[int, float]]] = None,
        quantiles: Optional[List[float]] = None,
        finetune_steps: NonNegativeInt = 0,
        finetune_loss: _Loss = 'default',
        clean_ex_first: bool = True,
        validate_api_key: bool = False,
        add_history: bool = False,
        date_features: Union[bool, List[Union[str, Callable]]] = False,
        date_features_to_one_hot: Union[bool, List[str]] = False,
        model: _Model = 'timegpt-1',
        num_partitions: Optional[PositiveInt] = None,
        feature_contributions: bool = False
    ) -> AnyDFType:
        """Forecast your time series using TimeGPT.

        Parameters
        ----------
        df : pandas or polars DataFrame
            The DataFrame on which the function will operate. Expected to contain at least the following columns:
            - time_col:
                Column name in `df` that contains the time indices of the time series. This is typically a datetime
                column with regular intervals, e.g., hourly, daily, monthly data points.
            - target_col:
                Column name in `df` that contains the target variable of the time series, i.e., the variable we 
                wish to predict or analyze.
            Additionally, you can pass multiple time series (stacked in the dataframe) considering an additional column:
            - id_col:
                Column name in `df` that identifies unique time series. Each unique value in this column
                corresponds to a unique time series.
        h : int
            Forecast horizon.
        freq : str
            Frequency of the data. By default, the freq will be inferred automatically.
            See [pandas' available frequencies](https://pandas.pydata.org/pandas-docs/stable/user_guide/timeseries.html#offset-aliases).
        id_col : str (default='unique_id')
            Column that identifies each serie.
        time_col : str (default='ds')
            Column that identifies each timestep, its values can be timestamps or integers.
        target_col : str (default='y')
            Column that contains the target.
        X_df : pandas or polars DataFrame, optional (default=None)
            DataFrame with [`unique_id`, `ds`] columns and `df`'s future exogenous.
        level : List[float], optional (default=None)
            Confidence levels between 0 and 100 for prediction intervals.
        quantiles : List[float], optional (default=None)
            Quantiles to forecast, list between (0, 1).
            `level` and `quantiles` should not be used simultaneously.
            The output dataframe will have the quantile columns
            formatted as TimeGPT-q-(100 * q) for each q.
            100 * q represents percentiles but we choose this notation
            to avoid having dots in column names.
        finetune_steps : int (default=0)
            Number of steps used to finetune learning TimeGPT in the
            new data.
        finetune_loss : str (default='default')
            Loss function to use for finetuning. Options are: `default`, `mae`, `mse`, `rmse`, `mape`, and `smape`.
        clean_ex_first : bool (default=True)
            Clean exogenous signal before making forecasts using TimeGPT.
        validate_api_key : bool (default=False)
            If True, validates api_key before sending requests.
        add_history : bool (default=False)
            Return fitted values of the model.
        date_features : bool or list of str or callable, optional (default=False)
            Features computed from the dates. 
            Can be pandas date attributes or functions that will take the dates as input.
            If True automatically adds most used date features for the 
            frequency of `df`.
        date_features_to_one_hot : bool or list of str (default=False)
            Apply one-hot encoding to these date features.
            If `date_features=True`, then all date features are
            one-hot encoded by default.
        model : str (default='timegpt-1')
            Model to use as a string. Options are: `timegpt-1`, and `timegpt-1-long-horizon`. 
            We recommend using `timegpt-1-long-horizon` for forecasting 
            if you want to predict more than one seasonal 
            period given the frequency of your data.
        num_partitions : int (default=None)
            Number of partitions to use.
            If None, the number of partitions will be equal
            to the available parallel resources in distributed environments.
        feature_contributions: bool (default=False)
            Compute SHAP values
            Gives access to computed SHAP values to explain the impact
            of features on the final predictions.
        
        Returns
        -------
        pandas, polars, dask or spark DataFrame or ray Dataset.
            DataFrame with TimeGPT forecasts for point predictions and probabilistic
            predictions (if level is not None).
        """
        if not isinstance(df, (pd.DataFrame, pl_DataFrame)):
            return self._distributed_forecast(
                df=df,
                h=h,
                freq=freq,
                id_col=id_col,
                time_col=time_col,
                target_col=target_col,
                X_df=X_df,
                level=level,
                quantiles=quantiles,
                finetune_steps=finetune_steps,
                finetune_loss=finetune_loss,
                clean_ex_first=clean_ex_first,
                validate_api_key=validate_api_key,
                add_history=add_history,
                date_features=date_features,
                date_features_to_one_hot=date_features_to_one_hot,
                model=model,
                num_partitions=num_partitions,
                feature_contributions=feature_contributions,
            )
        self.__dict__.pop('weights_x', None)
        self.__dict__.pop('feature_contributions', None)
        model = self._maybe_override_model(model)
        logger.info('Validating inputs...')
        df, X_df, drop_id = self._run_validations(
            df=df,
            X_df=X_df,
            id_col=id_col,
            time_col=time_col,
            target_col=target_col,
            validate_api_key=validate_api_key,
            model=model,
        )
        df, X_df = _validate_exog(
            df, X_df, id_col=id_col, time_col=time_col, target_col=target_col
        )
        level, quantiles = _prepare_level_and_quantiles(level, quantiles)
        freq = _maybe_infer_freq(df, freq=freq, id_col=id_col, time_col=time_col)
        standard_freq = _standardize_freq(freq)
        model_input_size, model_horizon = self._get_model_params(model, standard_freq)
        if finetune_steps > 0 or level is not None or add_history:
            _validate_input_size(df, id_col, model_input_size, model_horizon)
        if h > model_horizon:
            logger.warning(
                'The specified horizon "h" exceeds the model horizon. '
                'This may lead to less accurate forecasts. '
                'Please consider using a smaller horizon.'  
            )

        logger.info('Preprocessing dataframes...')
        processed, X_future, x_cols, futr_cols = _preprocess(
            df=df,
            X_df=X_df,
            h=h,
            freq=standard_freq,
            date_features=date_features,
            date_features_to_one_hot=date_features_to_one_hot,
            id_col=id_col,
            time_col=time_col,
            target_col=target_col,
        )
        restrict_input = finetune_steps == 0 and not x_cols and not add_history
        if restrict_input:
            logger.info('Restricting input...')
            new_input_size = _restrict_input_samples(
                level=level,
                input_size=model_input_size,
                model_horizon=model_horizon,
                h=h,
            )
            processed = _tail(processed, new_input_size)
        if processed.data.shape[1] > 1:
            X = processed.data[:, 1:].T
            if futr_cols is not None:
                hist_exog_set= set(x_cols) - set(futr_cols)
                if hist_exog_set:
                    logger.info(f'Using historical exogenous features: {list(hist_exog_set)}')
                logger.info(f'Using future exogenous features: {futr_cols}')
            else:
                logger.info(f'Using historical exogenous features: {x_cols}')
        else:
            X = None

        logger.info('Calling Forecast Endpoint...')
        payload = {
            'series': {
                'y': processed.data[:, 0],
                'sizes': np.diff(processed.indptr),
                'X': X,
                'X_future': X_future,
            },
            'model': model,
            'h': h,
            'freq': standard_freq,
            'clean_ex_first': clean_ex_first,
            'level': level,
            'finetune_steps': finetune_steps,
            'finetune_loss': finetune_loss,
            'feature_contributions': feature_contributions and X is not None,
        }
        with httpx.Client(**self._client_kwargs) as client:
            insample_feat_contributions = None
            if num_partitions is None:
                resp = self._make_request_with_retries(client, 'v2/forecast', payload)
                if add_history:
                    in_sample_payload = _forecast_payload_to_in_sample(payload)
                    logger.info('Calling Historical Forecast Endpoint...')
                    in_sample_resp = self._make_request_with_retries(
                        client, 'v2/historic_forecast', in_sample_payload
                    )
                    insample_feat_contributions = in_sample_resp.get(
                        'feature_contributions', None
                    )
            else:
                payloads = _partition_series(payload, num_partitions, h)
                resp = self._make_partitioned_requests(client, 'v2/forecast', payloads)
                if add_history:
                    in_sample_payloads = [
                        _forecast_payload_to_in_sample(p) for p in payloads
                    ]
                    logger.info('Calling Historical Forecast Endpoint...')
                    in_sample_resp = self._make_partitioned_requests(
                        client, 'v2/historic_forecast', in_sample_payloads
                    )
                    insample_feat_contributions = in_sample_resp.get(
                        'feature_contributions', None
                    )

        # assemble result
        out = ufp.make_future_dataframe(
            uids=processed.uids,
            last_times=type(processed.uids)(processed.last_times),
            freq=freq,
            h=h,
            id_col=id_col,
            time_col=time_col,
        )
        out = ufp.assign_columns(out, 'TimeGPT', resp['mean'])
        out = _maybe_add_intervals(out, resp['intervals'])
        out = _maybe_convert_level_to_quantiles(out, quantiles)
        if add_history:
            in_sample_df = _parse_in_sample_output(
                in_sample_output=in_sample_resp,
                df=df,
                processed=processed,
                id_col=id_col,
                time_col=time_col,
                target_col=target_col,
            )
            in_sample_df = ufp.drop_columns(in_sample_df, target_col)
            out = ufp.vertical_concat([in_sample_df, out])
        self._maybe_assign_feature_contributions(
            expected_contributions=feature_contributions,
            resp=resp,
            x_cols=x_cols,
            out_df=out[[id_col, time_col, 'TimeGPT']],
            insample_feat_contributions=insample_feat_contributions,
        )
        if add_history:
            sort_idxs = ufp.maybe_compute_sort_indices(out, id_col=id_col, time_col=time_col)
            if sort_idxs is not None:
                out = ufp.take_rows(out, sort_idxs)
                if hasattr(self, 'feature_contributions'):
                    self.feature_contributions = ufp.take_rows(self.feature_contributions, sort_idxs)
        out = _maybe_drop_id(df=out, id_col=id_col, drop=drop_id)
        self._maybe_assign_weights(weights=resp['weights_x'], df=df, x_cols=x_cols)
        return out

    def detect_anomalies(
        self,
        df: AnyDFType,
        freq: Optional[str] = None,    
        id_col: str = 'unique_id',
        time_col: str = 'ds',
        target_col: str = 'y',
        level: Union[int, float] = 99,
        clean_ex_first: bool = True,
        validate_api_key: bool = False,
        date_features: Union[bool, List[str]] = False,
        date_features_to_one_hot: Union[bool, List[str]] = False,
        model: _Model = 'timegpt-1',
        num_partitions: Optional[PositiveInt] = None,
    ) -> AnyDFType:
        """Detect anomalies in your time series using TimeGPT.

        Parameters
        ----------
        df : pandas or polars DataFrame
            The DataFrame on which the function will operate. Expected to contain at least the following columns:
            - time_col:
                Column name in `df` that contains the time indices of the time series. This is typically a datetime
                column with regular intervals, e.g., hourly, daily, monthly data points.
            - target_col:
                Column name in `df` that contains the target variable of the time series, i.e., the variable we 
                wish to predict or analyze.
            Additionally, you can pass multiple time series (stacked in the dataframe) considering an additional column:
            - id_col:
                Column name in `df` that identifies unique time series. Each unique value in this column
                corresponds to a unique time series.
        freq : str
            Frequency of the data. By default, the freq will be inferred automatically.
            See [pandas' available frequencies](https://pandas.pydata.org/pandas-docs/stable/user_guide/timeseries.html#offset-aliases).
        id_col : str (default='unique_id')
            Column that identifies each serie.
        time_col : str (default='ds')
            Column that identifies each timestep, its values can be timestamps or integers.
        target_col : str (default='y')
            Column that contains the target.
        level : float (default=99)
            Confidence level between 0 and 100 for detecting the anomalies.
        clean_ex_first : bool (default=True)
            Clean exogenous signal before making forecasts
            using TimeGPT.
        validate_api_key : bool (default=False)
            If True, validates api_key before sending requests.
        date_features : bool or list of str or callable, optional (default=False)
            Features computed from the dates. 
            Can be pandas date attributes or functions that will take the dates as input.
            If True automatically adds most used date features for the 
            frequency of `df`.
        date_features_to_one_hot : bool or list of str (default=False)
            Apply one-hot encoding to these date features.
            If `date_features=True`, then all date features are
            one-hot encoded by default.
        model : str (default='timegpt-1')
            Model to use as a string. Options are: `timegpt-1`, and `timegpt-1-long-horizon`. 
            We recommend using `timegpt-1-long-horizon` for forecasting 
            if you want to predict more than one seasonal 
            period given the frequency of your data.
        num_partitions : int (default=None)
            Number of partitions to use.
            If None, the number of partitions will be equal
            to the available parallel resources in distributed environments.
        
        Returns
        -------
        pandas, polars, dask or spark DataFrame or ray Dataset.
            DataFrame with anomalies flagged by TimeGPT.
        """
        if not isinstance(df, (pd.DataFrame, pl_DataFrame)):
            return self._distributed_detect_anomalies(
                df=df,
                freq=freq,
                id_col=id_col,
                time_col=time_col,
                target_col=target_col,
                level=level,
                clean_ex_first=clean_ex_first,
                validate_api_key=validate_api_key,
                date_features=date_features,
                date_features_to_one_hot=date_features_to_one_hot,
                model=model,
                num_partitions=num_partitions,
            )
        self.__dict__.pop('weights_x', None)
        model = self._maybe_override_model(model)
        logger.info('Validating inputs...')
        df, _, drop_id = self._run_validations(
            df=df,
            X_df=None,
            id_col=id_col,
            time_col=time_col,
            target_col=target_col,
            validate_api_key=validate_api_key,
            model=model,
        )
        freq = _maybe_infer_freq(df, freq=freq, id_col=id_col, time_col=time_col)
        standard_freq = _standardize_freq(freq)
        model_input_size, model_horizon = self._get_model_params(model, standard_freq)

        logger.info('Preprocessing dataframes...')
        processed, _, x_cols, _ = _preprocess(
            df=df,
            X_df=None,
            h=0,
            freq=standard_freq,
            date_features=date_features,
            date_features_to_one_hot=date_features_to_one_hot,
            id_col=id_col,
            time_col=time_col,
            target_col=target_col,
        )
        if processed.data.shape[1] > 1:
            X = processed.data[:, 1:].T
            logger.info(f'Using the following exogenous features: {x_cols}')
        else:
            X = None

        logger.info('Calling Anomaly Detector Endpoint...')
        payload = {
            'series': {
                'y': processed.data[:, 0],
                'sizes': np.diff(processed.indptr),
                'X': X,
            },
            'model': model,
            'freq': standard_freq,
            'clean_ex_first': clean_ex_first,
            'level': level,
        }
        with httpx.Client(**self._client_kwargs) as client:
            if num_partitions is None:
                resp = self._make_request_with_retries(
                    client, 'v2/anomaly_detection', payload
                )
            else:
                payloads = _partition_series(payload, num_partitions, h=0)
                resp = self._make_partitioned_requests(client, 'v2/anomaly_detection', payloads)

        # assemble result
        out = _parse_in_sample_output(
            in_sample_output=resp,
            df=df,
            processed=processed,
            id_col=id_col,
            time_col=time_col,
            target_col=target_col,
        )
        out = ufp.assign_columns(out, 'anomaly', resp['anomaly'])
        out = _maybe_drop_id(df=out, id_col=id_col, drop=drop_id)
        self._maybe_assign_weights(weights=resp['weights_x'], df=df, x_cols=x_cols)
        return out

    def cross_validation(
        self,
        df: AnyDFType,
        h: PositiveInt,
        freq: Optional[str] = None,
        id_col: str = "unique_id",
        time_col: str = "ds",
        target_col: str = "y",
        level: Optional[List[Union[int, float]]] = None,
        quantiles: Optional[List[float]] = None,
        validate_api_key: bool = False,
        n_windows: PositiveInt = 1,
        step_size: Optional[PositiveInt] = None,
        finetune_steps: NonNegativeInt = 0,
        finetune_loss: str = 'default',
        clean_ex_first: bool = True,
        date_features: Union[bool, List[str]] = False,
        date_features_to_one_hot: Union[bool, List[str]] = False,
        model: str = 'timegpt-1',
        num_partitions: Optional[PositiveInt] = None,
    ) -> AnyDFType:
        """Perform cross validation in your time series using TimeGPT.
        
        Parameters
        ----------
        df : pandas or polars DataFrame
            The DataFrame on which the function will operate. Expected to contain at least the following columns:
            - time_col:
                Column name in `df` that contains the time indices of the time series. This is typically a datetime
                column with regular intervals, e.g., hourly, daily, monthly data points.
            - target_col:
                Column name in `df` that contains the target variable of the time series, i.e., the variable we
                wish to predict or analyze.
            Additionally, you can pass multiple time series (stacked in the dataframe) considering an additional column:
            - id_col:
                Column name in `df` that identifies unique time series. Each unique value in this column
                corresponds to a unique time series.
        h : int
            Forecast horizon.
        freq : str
            Frequency of the data. By default, the freq will be inferred automatically.
            See [pandas' available frequencies](https://pandas.pydata.org/pandas-docs/stable/user_guide/timeseries.html#offset-aliases).
        id_col : str (default='unique_id')
            Column that identifies each serie.
        time_col : str (default='ds')
            Column that identifies each timestep, its values can be timestamps or integers.
        target_col : str (default='y')
            Column that contains the target.
        level : float (default=99)
            Confidence level between 0 and 100 for prediction intervals.
        quantiles : List[float], optional (default=None)
            Quantiles to forecast, list between (0, 1).
            `level` and `quantiles` should not be used simultaneously.
            The output dataframe will have the quantile columns
            formatted as TimeGPT-q-(100 * q) for each q.
            100 * q represents percentiles but we choose this notation
            to avoid having dots in column names.
        validate_api_key : bool (default=False)
            If True, validates api_key before sending requests.
        n_windows : int (defaul=1)
            Number of windows to evaluate.
        step_size : int, optional (default=None)
            Step size between each cross validation window. If None it will be equal to `h`.
        finetune_steps : int (default=0)
            Number of steps used to finetune TimeGPT in the
            new data.
        finetune_loss : str (default='default')
            Loss function to use for finetuning. Options are: `default`, `mae`, `mse`, `rmse`, `mape`, and `smape`.
        clean_ex_first : bool (default=True)
            Clean exogenous signal before making forecasts
            using TimeGPT.
        date_features : bool or list of str or callable, optional (default=False)
            Features computed from the dates.
            Can be pandas date attributes or functions that will take the dates as input.
            If True automatically adds most used date features for the
            frequency of `df`.
        date_features_to_one_hot : bool or list of str (default=False)
            Apply one-hot encoding to these date features.
            If `date_features=True`, then all date features are
            one-hot encoded by default.
        model : str (default='timegpt-1')
            Model to use as a string. Options are: `timegpt-1`, and `timegpt-1-long-horizon`. 
            We recommend using `timegpt-1-long-horizon` for forecasting 
            if you want to predict more than one seasonal 
            period given the frequency of your data.
        num_partitions : int (default=None)
            Number of partitions to use.
            If None, the number of partitions will be equal
            to the available parallel resources in distributed environments.
        
        Returns
        -------
        pandas, polars, dask or spark DataFrame or ray Dataset.
            DataFrame with cross validation forecasts.
        """
        if not isinstance(df, (pd.DataFrame, pl_DataFrame)):
            return self._distributed_cross_validation(
                df=df,
                h=h,
                freq=freq,
                id_col=id_col,
                time_col=time_col,
                target_col=target_col,
                level=level,
                quantiles=quantiles,
                n_windows=n_windows,
                step_size=step_size,
                validate_api_key=validate_api_key,
                finetune_steps=finetune_steps,
                finetune_loss=finetune_loss,
                clean_ex_first=clean_ex_first,
                date_features=date_features,
                date_features_to_one_hot=date_features_to_one_hot,
                model=model,
                num_partitions=num_partitions,
            )
        model = self._maybe_override_model(model)
        logger.info('Validating inputs...')
        df, _, drop_id = self._run_validations(
            df=df,
            X_df=None,
            id_col=id_col,
            time_col=time_col,
            target_col=target_col,
            validate_api_key=validate_api_key,
            model=model,
        )
        freq = _maybe_infer_freq(df, freq=freq, id_col=id_col, time_col=time_col)
        standard_freq = _standardize_freq(freq)
        level, quantiles = _prepare_level_and_quantiles(level, quantiles)
        model_input_size, model_horizon = self._get_model_params(model, standard_freq)
        if step_size is None:
            step_size = h

        logger.info('Preprocessing dataframes...')
        processed, _, x_cols, _ = _preprocess(
            df=df,
            X_df=None,
            h=0,
            freq=standard_freq,
            date_features=date_features,
            date_features_to_one_hot=date_features_to_one_hot,
            id_col=id_col,
            time_col=time_col,
            target_col=target_col,
        )
        if isinstance(df, pd.DataFrame):
            # in pandas<2.2 to_numpy can lead to an object array if
            # the type is a pandas nullable type, e.g. pd.Float64Dtype
            # we thus use the dtype's type as the target dtype
            target_dtype = df.dtypes[target_col].type
            targets = df[target_col].to_numpy(dtype=target_dtype)
        else:
            targets = df[target_col].to_numpy()
        times = df[time_col].to_numpy()
        if processed.sort_idxs is not None:
            targets = targets[processed.sort_idxs]
            times = times[processed.sort_idxs]
        restrict_input = finetune_steps == 0 and not x_cols
        if restrict_input:
            logger.info('Restricting input...')
            new_input_size = _restrict_input_samples(
                level=level,
                input_size=model_input_size,
                model_horizon=model_horizon,
                h=h,
            )
            new_input_size += h + step_size * (n_windows - 1)
            orig_indptr = processed.indptr
            processed = _tail(processed, new_input_size)
            times = _array_tails(times, orig_indptr, np.diff(processed.indptr))
            targets = _array_tails(targets, orig_indptr, np.diff(processed.indptr))
        if processed.data.shape[1] > 1:
            X = processed.data[:, 1:].T
            logger.info(f'Using the following exogenous features: {x_cols}')
        else:
            X = None

        logger.info('Calling Cross Validation Endpoint...')
        payload = {
            'series': {
                'y': targets,
                'sizes': np.diff(processed.indptr),
                'X': X,
            },
            'model': model,
            'h': h,
            'n_windows': n_windows,
            'step_size': step_size,
            'freq': standard_freq,
            'clean_ex_first': clean_ex_first,
            'level': level,
            'finetune_steps': finetune_steps,
            'finetune_loss': finetune_loss,
        }
        with httpx.Client(**self._client_kwargs) as client:
            if num_partitions is None:
                resp = self._make_request_with_retries(
                    client, 'v2/cross_validation', payload
                )
            else:
                payloads = _partition_series(payload, num_partitions, h=0)
                resp = self._make_partitioned_requests(client, 'v2/cross_validation', payloads)

        # assemble result
        idxs = np.array(resp['idxs'], dtype=np.int64)
        sizes = np.array(resp['sizes'], dtype=np.int64)
        window_starts = np.arange(0, sizes.sum(), h)
        cutoff_idxs = np.repeat(idxs[window_starts] - 1, h)
        out = type(df)(
            {
                id_col: ufp.repeat(processed.uids, sizes),
                time_col: times[idxs],
                'cutoff': times[cutoff_idxs],
                target_col: targets[idxs],
            }
        )
        out = ufp.assign_columns(out, 'TimeGPT', resp['mean'])
        out = _maybe_add_intervals(out, resp['intervals'])
        out = _maybe_drop_id(df=out, id_col=id_col, drop=drop_id)
        return _maybe_convert_level_to_quantiles(out, quantiles)

    def plot(
        self,
        df: Optional[DataFrame] = None,
        forecasts_df: Optional[DataFrame] = None,
        id_col: str = 'unique_id',
        time_col: str = 'ds',
        target_col: str = 'y',
        unique_ids: Union[Optional[List[str]], np.ndarray] = None,
        plot_random: bool = True,
        max_ids: int = 8,
        models: Optional[List[str]] = None,
        level: Optional[List[float]] = None,
        max_insample_length: Optional[int] = None,
        plot_anomalies: bool = False,
        engine: Literal['matplotlib', 'plotly', 'plotly-resampler'] = 'matplotlib',
        resampler_kwargs: Optional[Dict] = None,
        ax: Optional[Union["plt.Axes", np.ndarray, "plotly.graph_objects.Figure"]] = None,
    ):
        """Plot forecasts and insample values.

        Parameters
        ----------
        df : pandas or polars DataFrame, optional (default=None)
            The DataFrame on which the function will operate. Expected to contain at least the following columns:
            - time_col:
                Column name in `df` that contains the time indices of the time series. This is typically a datetime
                column with regular intervals, e.g., hourly, daily, monthly data points.
            - target_col:
                Column name in `df` that contains the target variable of the time series, i.e., the variable we 
                wish to predict or analyze.
            Additionally, you can pass multiple time series (stacked in the dataframe) considering an additional column:
            - id_col:
                Column name in `df` that identifies unique time series. Each unique value in this column
                corresponds to a unique time series.
        forecasts_df : pandas or polars DataFrame, optional (default=None)
            DataFrame with columns [`unique_id`, `ds`] and models.
        id_col : str (default='unique_id')
            Column that identifies each serie.
        time_col : str (default='ds')
            Column that identifies each timestep, its values can be timestamps or integers.
        target_col : str (default='y')
            Column that contains the target.
        unique_ids : List[str], optional (default=None)
            Time Series to plot.
            If None, time series are selected randomly.
        plot_random : bool (default=True)
            Select time series to plot randomly.
        max_ids : int (default=8)
            Maximum number of ids to plot.
        models : List[str], optional (default=None)
            List of models to plot.
        level : List[float], optional (default=None)
            List of prediction intervals to plot if paseed.
        max_insample_length : int, optional (default=None)
            Max number of train/insample observations to be plotted.
        plot_anomalies : bool (default=False)
            Plot anomalies for each prediction interval.
        engine : str (default='matplotlib')
            Library used to plot. 'matplotlib', 'plotly' or 'plotly-resampler'.
        resampler_kwargs : dict
            Kwargs to be passed to plotly-resampler constructor.
            For further custumization ("show_dash") call the method,
            store the plotting object and add the extra arguments to
            its `show_dash` method.
        ax : matplotlib axes, array of matplotlib axes or plotly Figure, optional (default=None)
            Object where plots will be added.
        """
        try:
            from utilsforecast.plotting import plot_series
        except ModuleNotFoundError:
            raise Exception(
                'You have to install additional dependencies to use this method, '
                'please install them using `pip install "nixtla[plotting]"`'
            )
        if df is not None and id_col not in df.columns:
            df = ufp.copy_if_pandas(df, deep=False)
            df = ufp.assign_columns(df, id_col, 'ts_0')
        df = ensure_time_dtype(df, time_col=time_col)
        if forecasts_df is not None:
            if id_col not in forecasts_df.columns:
                forecasts_df = ufp.copy_if_pandas(forecasts_df, deep=False)
                forecasts_df = ufp.assign_columns(forecasts_df, id_col, 'ts_0')
            forecasts_df = ensure_time_dtype(forecasts_df, time_col=time_col)
            if 'anomaly' in forecasts_df.columns:
                # special case to plot outputs
                # from detect_anomalies
                df = None
                forecasts_df = ufp.drop_columns(forecasts_df, 'anomaly')
                cols = [c for c in forecasts_df.columns if 'TimeGPT-lo-' in c]
                level = [c.replace('TimeGPT-lo-', '') for c in cols][0]
                level = float(level) if '.' in level else int(level)
                level = [level]
                plot_anomalies = True
                models = ['TimeGPT']
        return plot_series(
            df=df,
            forecasts_df=forecasts_df,
            ids=unique_ids,
            plot_random=plot_random,
            max_ids=max_ids,
            models=models,
            level=level,
            max_insample_length=max_insample_length,
            plot_anomalies=plot_anomalies,
            engine=engine,
            resampler_kwargs=resampler_kwargs,
            palette="tab20b",
            id_col=id_col,
            time_col=time_col,
            target_col=target_col,
            ax=ax,
        )

In [None]:
#| hide
@contextmanager
def delete_env_var(key):
    original_value = os.environ.get(key)
    rm = False
    if key in os.environ:
        del os.environ[key]
        rm = True
    try:
        yield
    finally:
        if rm:
            os.environ[key] = original_value
# test api_key fail
with delete_env_var('NIXTLA_API_KEY'), delete_env_var('TIMEGPT_TOKEN'):
    test_fail(
        lambda: NixtlaClient(),
        contains='NIXTLA_API_KEY',
    )

In [None]:
#| hide
nixtla_client = NixtlaClient()

In [None]:
#| hide
nixtla_client.validate_api_key()

In [None]:
#| hide
_nixtla_client = NixtlaClient(api_key="invalid")
test_eq(_nixtla_client.validate_api_key(), False)

In [None]:
#| hide
_nixtla_client = NixtlaClient(
    api_key=os.environ['NIXTLA_API_KEY_CUSTOM'], 
    base_url=os.environ['NIXTLA_BASE_URL_CUSTOM'],
)
_nixtla_client.validate_api_key()

In [None]:
#| hide
test_fail(
    lambda: NixtlaClient(api_key='transphobic').forecast(df=pd.DataFrame(), h=None, validate_api_key=True),
    contains='nixtla'
)

In [None]:
#| hide
# test input_size
test_eq(
    nixtla_client._get_model_params(model='timegpt-1', freq='D'),
    (28, 7),
)

Now you can start to make forecasts! Let's import an example:

In [None]:
#| hide
df = pd.read_csv(
    'https://raw.githubusercontent.com/Nixtla/transfer-learning-time-series/main/datasets/air_passengers.csv',
    parse_dates=['timestamp'],
)
df.head()

In [None]:
#| hide
# test date_features with multiple series
# and different ends
test_series = generate_series(n_series=2, min_length=5, max_length=20)
h = 12
fcst_test_series = nixtla_client.forecast(test_series, h=12, date_features=['dayofweek'])
uids = test_series['unique_id']
for uid in uids:
    test_eq(
        fcst_test_series.query('unique_id == @uid')['ds'].tolist(),
        pd.date_range(periods=h + 1, start=test_series.query('unique_id == @uid')['ds'].max())[1:].tolist(),
    )

In [None]:
#| hide
# test quantiles
test_fail(
    lambda: nixtla_client.forecast(
        df=df, 
        h=12, 
        time_col='timestamp', 
        target_col='value', 
        level=[80], 
        quantiles=[0.2, 0.3]
    ),
    contains='not both'
)
test_qls = list(np.arange(0.1, 1, 0.1))
exp_q_cols = [f"TimeGPT-q-{int(100 * q)}" for q in test_qls]
def test_method_qls(method, **kwargs):
    df_qls = method(
        df=df, 
        h=12, 
        time_col='timestamp', 
        target_col='value', 
        quantiles=test_qls,
        **kwargs
    )
    assert all(col in df_qls.columns for col in exp_q_cols)
    # test monotonicity of quantiles
    df_qls.apply(lambda x: x.is_monotonic_increasing, axis=1).sum() == len(exp_q_cols)
test_method_qls(nixtla_client.forecast)
test_method_qls(nixtla_client.forecast, add_history=True)
test_method_qls(nixtla_client.cross_validation)

In [None]:
#| hide
# test num partitions
# we need to be sure that we can recover the same results
# using a for loop
# A: be aware that num partitons can produce different results
# when used finetune_steps
def test_num_partitions_same_results(method: Callable, num_partitions: int, **kwargs):
    res_partitioned = method(**kwargs, num_partitions=num_partitions)
    res_no_partitioned = method(**kwargs, num_partitions=1)
    sort_by = ['unique_id', 'ds']
    if 'cutoff' in res_partitioned:
        sort_by.extend(['cutoff'])
    pd.testing.assert_frame_equal(
        res_partitioned.sort_values(sort_by).reset_index(drop=True), 
        res_no_partitioned.sort_values(sort_by).reset_index(drop=True),
        rtol=1e-2,
        atol=1e-2,
    )

freqs = {'D': 7, 'W-THU': 52, 'Q-DEC': 8, '15T': 4 * 24 * 7}
for freq, h in freqs.items():
    df_freq = generate_series(
        10, 
        min_length=500 if freq != '15T' else 1_200, 
        max_length=550 if freq != '15T' else 2_000,
    )
    #df_freq['y'] = df_freq['y'].astype(np.float32)
    df_freq['ds'] = df_freq.groupby('unique_id', observed=True)['ds'].transform(
        lambda x: pd.date_range(periods=len(x), freq=freq, end='2023-01-01')
    )
    min_size = df_freq.groupby('unique_id', observed=True).size().min()
    test_num_partitions_same_results(
        nixtla_client.detect_anomalies,
        level=98,
        df=df_freq,
        num_partitions=2,
    )
    test_num_partitions_same_results(
        nixtla_client.cross_validation,
        h=7,
        n_windows=2,
        df=df_freq,
        num_partitions=2,
    )
    test_num_partitions_same_results(
        nixtla_client.forecast,
        df=df_freq,
        h=7,
        add_history=True,
        num_partitions=2,
    )

In [None]:
#| hide
def test_retry_behavior(side_effect, max_retries=5, retry_interval=5, max_wait_time=40, should_retry=True, sleep_seconds=5):
    mock_nixtla_client = NixtlaClient(
        max_retries=max_retries, 
        retry_interval=retry_interval, 
        max_wait_time=max_wait_time,
    )
    mock_nixtla_client._make_request = side_effect
    init_time = time()
    test_fail(
        lambda: mock_nixtla_client.forecast(df=df, h=12, time_col='timestamp', target_col='value'),
    )
    total_mock_time = time() - init_time
    if should_retry:
        approx_expected_time = min((max_retries - 1) * retry_interval, max_wait_time)
        upper_expected_time = min(max_retries * retry_interval, max_wait_time)
        assert total_mock_time >= approx_expected_time, "It is not retrying as expected"
        # preprocessing time before the first api call should be less than 60 seconds
        assert total_mock_time - upper_expected_time - (max_retries - 1) * sleep_seconds <= sleep_seconds
    else:
        assert total_mock_time <= max_wait_time 

In [None]:
#| hide
# we want the api to retry in these cases
def raise_api_error_with_text(*args, **kwargs):
    raise ApiError(
        status_code=503, 
        body="""
        <html><head>
        <meta http-equiv="content-type" content="text/html;charset=utf-8">
        <title>503 Server Error</title>
        </head>
        <body text=#000000 bgcolor=#ffffff>
        <h1>Error: Server Error</h1>
        <h2>The service you requested is not available at this time.<p>Service error -27.</h2>
        <h2></h2>
        </body></html>
        """)
test_retry_behavior(raise_api_error_with_text)

In [None]:
#| hide
# we want the api to not retry in these cases
# here A is assuming that the endpoint responds always
# with a json
def raise_api_error_with_json(*args, **kwargs):
    raise ApiError(
        status_code=422, 
        body=dict(detail='Please use numbers'),
    )
test_retry_behavior(raise_api_error_with_json, should_retry=False)

In [None]:
#| hide
# test resilience of api calls
def raise_read_timeout_error(*args, **kwargs):
    sleep_seconds = 5
    print(f'raising ReadTimeout error after {sleep_seconds} seconds')
    sleep(sleep_seconds)
    raise httpx.ReadTimeout('Timed out')

def raise_http_error(*args, **kwargs):
    print('raising HTTP error')
    raise ApiError(status_code=503, body='HTTP error')
    
combs = [
    (2, 5, 30),
    (10, 1, 5),
]
side_effects = [raise_read_timeout_error, raise_http_error]

for (max_retries, retry_interval, max_wait_time), side_effect in product(combs, side_effects):
    test_retry_behavior(
        max_retries=max_retries, 
        retry_interval=retry_interval, 
        max_wait_time=max_wait_time, 
        side_effect=side_effect,
    )

In [None]:
#| hide
nixtla_client.plot(df, time_col='timestamp', target_col='value', engine='plotly')

In [None]:
#| hide
# test we recover the same <mean> forecasts
# with and without restricting input
# (add_history)
def test_equal_fcsts_add_history(**kwargs):
    fcst_no_rest_df = nixtla_client.forecast(**kwargs, add_history=True)
    fcst_no_rest_df = fcst_no_rest_df.groupby('unique_id', observed=True).tail(kwargs['h']).reset_index(drop=True)
    fcst_rest_df = nixtla_client.forecast(**kwargs)
    pd.testing.assert_frame_equal(
        fcst_no_rest_df,
        fcst_rest_df,
    )
    return fcst_rest_df

freqs = {'D': 7, 'W-THU': 52, 'Q-DEC': 8, '15T': 4 * 24 * 7}
for freq, h in freqs.items():
    df_freq = generate_series(
        10, 
        min_length=500 if freq != '15T' else 1_200, 
        max_length=550 if freq != '15T' else 2_000,
    )
    df_freq['ds'] = df_freq.groupby('unique_id', observed=True)['ds'].transform(
        lambda x: pd.date_range(periods=len(x), freq=freq, end='2023-01-01')
    )
    kwargs = dict(
        df=df_freq,
        h=h,
    )
    fcst_1_df = test_equal_fcsts_add_history(**{**kwargs, 'model': 'timegpt-1'})
    fcst_2_df = test_equal_fcsts_add_history(**{**kwargs, 'model': 'timegpt-1-long-horizon'})
    test_fail(
        lambda: pd.testing.assert_frame_equal(fcst_1_df, fcst_2_df),
        contains='(column name="TimeGPT") are different',
    )
    # add test num_partitions    

In [None]:
#| hide
#test same results custom url
nixtla_client_custom = NixtlaClient(
    api_key=os.environ['NIXTLA_API_KEY_CUSTOM'], 
    base_url=os.environ['NIXTLA_BASE_URL_CUSTOM'],
)
# forecast method
fcst_kwargs = dict(
    df=df, 
    h=12, 
    level=[90, 95], 
    add_history=True, 
    time_col='timestamp', 
    target_col='value',
)
fcst_df = nixtla_client.forecast(**fcst_kwargs)
fcst_df_custom = nixtla_client_custom.forecast(**fcst_kwargs)
pd.testing.assert_frame_equal(
    fcst_df,
    fcst_df_custom,
)
# anomalies method
anomalies_kwargs = dict(
    df=df, 
    level=99,
    time_col='timestamp', 
    target_col='value',
)
anomalies_df = nixtla_client.detect_anomalies(**anomalies_kwargs)
anomalies_df_custom = nixtla_client_custom.detect_anomalies(**anomalies_kwargs)
pd.testing.assert_frame_equal(
    anomalies_df,
    anomalies_df_custom,
)

In [None]:
#| hide
# test different results for different models
fcst_kwargs['model'] = 'timegpt-1'
fcst_timegpt_1 = nixtla_client.forecast(**fcst_kwargs)
fcst_kwargs['model'] = 'timegpt-1-long-horizon'
fcst_timegpt_long = nixtla_client.forecast(**fcst_kwargs)
test_fail(
    lambda: pd.testing.assert_frame_equal(fcst_timegpt_1[['TimeGPT']], fcst_timegpt_long[['TimeGPT']]),
    contains='(column name="TimeGPT") are different'
)

In [None]:
#| hide
# test different results for different models
# cross validation
cv_kwargs = dict(
    df=df, 
    h=12, 
    level=[90, 95], 
    time_col='timestamp', 
    target_col='value',
)
cv_kwargs['model'] = 'timegpt-1'
cv_timegpt_1 = nixtla_client.cross_validation(**cv_kwargs)
cv_kwargs['model'] = 'timegpt-1-long-horizon'
cv_timegpt_long = nixtla_client.cross_validation(**cv_kwargs)
test_fail(
    lambda: pd.testing.assert_frame_equal(cv_timegpt_1[['TimeGPT']], cv_timegpt_long[['TimeGPT']]),
    contains='(column name="TimeGPT") are different'
)

In [None]:
#| hide
# test different results for different models
# anomalies
anomalies_kwargs['model'] = 'timegpt-1'
anomalies_timegpt_1 = nixtla_client.detect_anomalies(**anomalies_kwargs)
anomalies_kwargs['model'] = 'timegpt-1-long-horizon'
anomalies_timegpt_long = nixtla_client.detect_anomalies(**anomalies_kwargs)
test_fail(
    lambda: pd.testing.assert_frame_equal(anomalies_timegpt_1[['TimeGPT']], anomalies_timegpt_long[['TimeGPT']]),
    contains='(column name="TimeGPT") are different'
)

In [None]:
#| hide
# test unsupported model
fcst_kwargs['model'] = 'a-model'
test_fail(
    lambda: nixtla_client.forecast(**fcst_kwargs),
    contains='unsupported model',
)

In [None]:
#| hide
# test unsupported model
anomalies_kwargs['model'] = 'my-awesome-model'
test_fail(
    lambda: nixtla_client.detect_anomalies(**anomalies_kwargs),
    contains='unsupported model',
)

In [None]:
#| hide
# test add date features
df_ = df.rename(columns={'timestamp': 'ds', 'value': 'y'})
df_.insert(0, 'unique_id', 'AirPassengers')
date_features = ['year', 'month']
df_date_features, future_df = _maybe_add_date_features(
    df=df_,
    X_df=None,
    h=12, 
    freq='MS', 
    features=date_features,
    one_hot=False,
    id_col='unique_id',
    time_col='ds',
    target_col='y',
)
assert all(col in df_date_features for col in date_features)
assert all(col in future_df for col in date_features)

In [None]:
#| hide
# Test shap values are returned and sum to predictions
h=12
fcst_df = nixtla_client.forecast(df=df_date_features, h=h, X_df=future_df, feature_contributions=True)
shap_values = nixtla_client.feature_contributions
assert len(shap_values) == len(fcst_df)
np.testing.assert_allclose(fcst_df["TimeGPT"].values, shap_values.iloc[:, 3:].sum(axis=1).values)

fcst_hist_df = nixtla_client.forecast(df=df_date_features, h=h, X_df=future_df, add_history=True, feature_contributions=True)
shap_values_hist = nixtla_client.feature_contributions
assert len(shap_values_hist) == len(fcst_hist_df)
np.testing.assert_allclose(fcst_hist_df["TimeGPT"].values, shap_values_hist.iloc[:, 3:].sum(axis=1).values, atol=1e-4)

# test num partitions
_ = nixtla_client.forecast(df=df_date_features, h=h, X_df=future_df, add_history=True, feature_contributions=True, num_partitions=2)
pd.testing.assert_frame_equal(nixtla_client.feature_contributions, shap_values_hist)

In [None]:
#| hide
# cross validation tests
df_copy = df_.copy()
pd.testing.assert_frame_equal(
    df_copy,
    df_,
)
df_test = df_.groupby('unique_id').tail(12)
df_train = df_.drop(df_test.index)
hyps = [
    # finetune steps is unstable due
    # to numerical reasons
    # dict(finetune_steps=2),
    dict(),
    dict(clean_ex_first=False),
    dict(date_features=['month']),
    dict(level=[80, 90]),
    #dict(level=[80, 90], finetune_steps=2),
]

In [None]:
#| hide
# test exogenous variables cv
df_ex_ = df_.copy()
df_ex_['exogenous_var'] = df_ex_['y'] + np.random.normal(size=len(df_ex_))
x_df_test = df_test.drop(columns='y').merge(df_ex_.drop(columns='y'))
for hyp in hyps:
    logger.info(f'Hyperparameters: {hyp}')
    logger.info('\n\nPerforming forecast\n')
    fcst_test = nixtla_client.forecast(
        df_train.merge(df_ex_.drop(columns='y')), h=12, X_df=x_df_test, **hyp
    )
    fcst_test = df_test[['unique_id', 'ds', 'y']].merge(fcst_test)
    fcst_test = fcst_test.sort_values(['unique_id', 'ds']).reset_index(drop=True)
    logger.info('\n\nPerforming Cross validation\n')
    fcst_cv = nixtla_client.cross_validation(df_ex_, h=12, **hyp)
    fcst_cv = fcst_cv.sort_values(['unique_id', 'ds']).reset_index(drop=True)
    logger.info('\n\nVerify difference\n')
    pd.testing.assert_frame_equal(fcst_test, fcst_cv.drop(columns='cutoff'))

In [None]:
#| hide
# test finetune cv
finetune_cv = nixtla_client.cross_validation(
            df=df_,
            h=12,
            n_windows=1,
            finetune_steps=1
        )
test_eq(finetune_cv is not None, True)

In [None]:
#| hide
for hyp in hyps:
    fcst_test = nixtla_client.forecast(df_train, h=12, **hyp)
    fcst_test = df_test[['unique_id', 'ds', 'y']].merge(fcst_test)
    fcst_test = fcst_test.sort_values(['unique_id', 'ds']).reset_index(drop=True)
    fcst_cv = nixtla_client.cross_validation(df_, h=12, **hyp)
    fcst_cv = fcst_cv.sort_values(['unique_id', 'ds']).reset_index(drop=True)
    pd.testing.assert_frame_equal(
        fcst_test,
        fcst_cv.drop(columns='cutoff'),
        rtol=1e-2,
    )

In [None]:
#| hide
for hyp in hyps:
    fcst_test = nixtla_client.forecast(df_train, h=12, **hyp)
    fcst_test.insert(2, 'y', df_test['y'].values)
    fcst_test = fcst_test.sort_values(['unique_id', 'ds']).reset_index(drop=True)
    fcst_cv = nixtla_client.cross_validation(df_, h=12, **hyp)
    fcst_cv = fcst_cv.sort_values(['unique_id', 'ds']).reset_index(drop=True)
    pd.testing.assert_frame_equal(
        fcst_test,
        fcst_cv.drop(columns='cutoff'),
        rtol=1e-2,
    )

In [None]:
#| hide
# test add callables
date_features = [SpecialDates({'first_dates': ['2021-01-1'], 'second_dates': ['2021-01-01']})]
df_daily = df_.copy()
df_daily['ds'] = pd.date_range(end='2021-01-01', periods=len(df_daily))
df_date_features, future_df = _maybe_add_date_features(
    df=df_,
    X_df=None,
    h=12, 
    freq='D', 
    features=date_features,
    one_hot=False,
    id_col='unique_id',
    time_col='ds',
    target_col='y',
)
assert all(col in df_date_features for col in ['first_dates', 'second_dates'])
assert all(col in future_df for col in ['first_dates', 'second_dates'])

In [None]:
#| hide
# test add date features one hot encoded
date_features = ['year', 'month']
date_features_to_one_hot = ['month']
df_date_features, future_df = _maybe_add_date_features(
    df=df_,
    X_df=None,
    h=12, 
    freq='D', 
    features=date_features,
    one_hot=date_features_to_one_hot,
    id_col='unique_id',
    time_col='ds',
    target_col='y',
)

In [None]:
#| hide
# test pass dataframe with index
df_ds_index = df_.set_index('ds')[['unique_id', 'y']]
df_ds_index.index = pd.DatetimeIndex(df_ds_index.index)
fcst_inferred_df_index = nixtla_client.forecast(df_ds_index, h=10)
anom_inferred_df_index = nixtla_client.detect_anomalies(df_ds_index)
fcst_inferred_df = nixtla_client.forecast(df_[['ds', 'unique_id', 'y']], h=10)
anom_inferred_df = nixtla_client.detect_anomalies(df_[['ds', 'unique_id', 'y']])
pd.testing.assert_frame_equal(fcst_inferred_df_index, fcst_inferred_df, atol=1e-3)
pd.testing.assert_frame_equal(anom_inferred_df_index, anom_inferred_df, atol=1e-3)
df_ds_index = df_ds_index.groupby('unique_id').tail(80)
for freq in ['Y', 'W-MON', 'Q-DEC', 'H']:
    df_ds_index.index = np.concatenate(
        df_ds_index['unique_id'].nunique() * [pd.date_range(end='2023-01-01', periods=80, freq=freq)]
    )
    df_ds_index.index.name = 'ds'
    fcst_inferred_df_index = nixtla_client.forecast(df_ds_index, h=10)
    df_test = df_ds_index.reset_index()
    fcst_inferred_df = nixtla_client.forecast(df_test, h=10)
    pd.testing.assert_frame_equal(fcst_inferred_df_index, fcst_inferred_df, atol=1e-3)

In [None]:
#| hide
# test add date features with exogenous variables 
# and multiple series
date_features = ['year', 'month']
df_actual_future = df_.tail(12)[['unique_id', 'ds']]
df_date_features, future_df = _maybe_add_date_features(
    df=df_,
    X_df=df_actual_future,
    h=24, 
    freq='H', 
    features=date_features,
    one_hot=False,
    id_col='unique_id',
    time_col='ds',
    target_col='y',
)
assert all(col in df_date_features for col in date_features)
assert all(col in future_df for col in date_features)
pd.testing.assert_frame_equal(
    df_date_features[df_.columns],
    df_,
)
pd.testing.assert_frame_equal(
    future_df[df_actual_future.columns],
    df_actual_future,
)

In [None]:
#| hide
# test add date features one hot with exogenous variables 
# and multiple series
date_features = ['month', 'day']
df_date_features, future_df = _maybe_add_date_features(
    df=df_,
    X_df=df_actual_future,
    h=24, 
    freq='H', 
    features=date_features,
    one_hot=date_features,
    id_col='unique_id',
    time_col='ds',
    target_col='y',
)
pd.testing.assert_frame_equal(
    df_date_features[df_.columns],
    df_,
)
pd.testing.assert_frame_equal(
    future_df[df_actual_future.columns],
    df_actual_future.reset_index(drop=True),
)

In [None]:
#| hide
# test warning horizon too long
nixtla_client.forecast(df=df.tail(3), h=100, time_col='timestamp', target_col='value')

In [None]:
#| hide 
# test short horizon with add_history
test_fail(
    lambda: nixtla_client.forecast(df=df.tail(3), h=12, time_col='timestamp', target_col='value', add_history=True),
    contains='make sure'
)

In [None]:
#| hide 
# test short horizon with finetunning
test_fail(
    lambda: nixtla_client.forecast(df=df.tail(3), h=12, time_col='timestamp', target_col='value', finetune_steps=10, finetune_loss='mae'),
    contains='make sure'
)

In [None]:
#| hide 
# test short horizon with level
test_fail(
    lambda: nixtla_client.forecast(df=df.tail(3), h=12, time_col='timestamp', target_col='value', level=[80, 90]),
    contains='make sure'
)

In [None]:
#| hide
# test custom url
# same results
_timegpt_fcst_df = _nixtla_client.forecast(df=df, h=12, time_col='timestamp', target_col='value')
timegpt_fcst_df = nixtla_client.forecast(df=df, h=12, time_col='timestamp', target_col='value')
pd.testing.assert_frame_equal(
    _timegpt_fcst_df,
    timegpt_fcst_df,
)

In [None]:
#| hide
# test using index as time_col
# same results
df_test = df.copy()
df_test["timestamp"] = pd.to_datetime(df_test["timestamp"])
df_test.set_index(df_test["timestamp"], inplace=True)
df_test.drop(columns="timestamp", inplace=True)

# Using user_provided time_col and freq
timegpt_anomalies_df_1 = nixtla_client.detect_anomalies(df, time_col='timestamp', target_col='value', freq= 'M')
# Infer time_col and freq from index
timegpt_anomalies_df_2 = nixtla_client.detect_anomalies(df_test, time_col='timestamp', target_col='value')

pd.testing.assert_frame_equal(
    timegpt_anomalies_df_1,
    timegpt_anomalies_df_2 
)

In [None]:
#| hide
# Test large requests raise error and suggest partition number
df = generate_series(20_000, min_length=1_000, max_length=1_000, freq='min')
test_fail(
    lambda: nixtla_client.forecast(df=df, h=1, freq='min', finetune_steps=2),
    contains="num_partitions"
)

## Distributed

In [None]:
#| exporti
def _forecast_wrapper(
    df: pd.DataFrame,
    client: NixtlaClient,
    h: PositiveInt,
    freq: Optional[str],
    id_col: str,
    time_col: str,
    target_col: str,
    level: Optional[List[Union[int, float]]],
    quantiles: Optional[List[float]],
    finetune_steps: NonNegativeInt,
    finetune_loss: _Loss,
    clean_ex_first: bool,
    validate_api_key: bool,
    add_history: bool,
    date_features: Union[bool, List[Union[str, Callable]]],
    date_features_to_one_hot: Union[bool, List[str]],
    model: _Model,
    num_partitions: Optional[PositiveInt],
    feature_contributions: bool,
) -> pd.DataFrame:
    if '_in_sample' in df:
        in_sample_mask = df['_in_sample']
        X_df = df.loc[~in_sample_mask].drop(columns=['_in_sample', target_col])
        df = df.loc[in_sample_mask].drop(columns='_in_sample')
    else:
        X_df = None
    return client.forecast(
        df=df,
        h=h,
        freq=freq,
        id_col=id_col,
        time_col=time_col,
        target_col=target_col,
        X_df=X_df,
        level=level,
        quantiles=quantiles,
        finetune_steps=finetune_steps,
        finetune_loss=finetune_loss,
        clean_ex_first=clean_ex_first,
        validate_api_key=validate_api_key,
        add_history=add_history,
        date_features=date_features,
        date_features_to_one_hot=date_features_to_one_hot,
        model=model,
        num_partitions=num_partitions,
        feature_contributions=feature_contributions,
    )

def _detect_anomalies_wrapper(
    df: pd.DataFrame,
    client: NixtlaClient,
    freq: Optional[str],
    id_col: str,
    time_col: str,
    target_col: str,
    level: Union[int, float],
    clean_ex_first: bool,
    validate_api_key: bool,
    date_features: Union[bool, List[str]],
    date_features_to_one_hot: Union[bool, List[str]],
    model: str,
    num_partitions: Optional[PositiveInt],
) -> pd.DataFrame:
    return client.detect_anomalies(
        df=df,
        freq=freq,
        id_col=id_col,
        time_col=time_col,
        target_col=target_col,
        level=level,
        clean_ex_first=clean_ex_first,
        validate_api_key=validate_api_key,
        date_features=date_features,
        date_features_to_one_hot=date_features_to_one_hot,
        model=model,
        num_partitions=num_partitions,
    )

def _cross_validation_wrapper(
    df: pd.DataFrame,
    client: NixtlaClient,
    h: PositiveInt,
    freq: Optional[str],
    id_col: str,
    time_col: str,
    target_col: str,
    level: Optional[List[Union[int, float]]],
    quantiles: Optional[List[float]],
    validate_api_key: bool,
    n_windows: PositiveInt,
    step_size: Optional[PositiveInt],
    finetune_steps: NonNegativeInt,
    finetune_loss: str,
    clean_ex_first: bool,
    date_features: Union[bool, List[str]],
    date_features_to_one_hot: Union[bool, List[str]],
    model: str,
    num_partitions: Optional[PositiveInt],
) -> pd.DataFrame:
    return client.cross_validation(
        df=df,
        h=h,
        freq=freq,
        id_col=id_col,
        time_col=time_col,
        target_col=target_col,
        level=level,
        quantiles=quantiles,
        validate_api_key=validate_api_key,
        n_windows=n_windows,
        step_size=step_size,
        finetune_steps=finetune_steps,
        finetune_loss=finetune_loss,
        clean_ex_first=clean_ex_first,
        date_features=date_features,
        date_features_to_one_hot=date_features_to_one_hot,
        model=model,
        num_partitions=num_partitions,
    )

def _get_schema(
    df: 'AnyDataFrame',
    method: str,
    id_col: str,
    time_col: str,
    target_col: str,
    level: Optional[List[Union[int, float]]],
    quantiles: Optional[List[float]],
) -> 'triad.Schema':
    import fugue.api as fa

    base_cols = [id_col, time_col]
    if method != 'forecast':
        base_cols.append(target_col)
    schema = fa.get_schema(df).extract(base_cols).copy()
    schema.append('TimeGPT:double')
    if method == 'detect_anomalies':
        schema.append('anomaly:bool')
    elif method == 'cross_validation':
        schema.append(('cutoff', schema[time_col].type))
    if level is not None and quantiles is not None:
        raise ValueError("You should provide `level` or `quantiles` but not both.")
    if level is not None:
        if not isinstance(level, list):
            level = [level]
        level = sorted(level)
        schema.append(",".join(f"TimeGPT-lo-{lv}:double" for lv in reversed(level)))
        schema.append(",".join(f"TimeGPT-hi-{lv}:double" for lv in level))
    if quantiles is not None:
        quantiles = sorted(quantiles)
        q_cols = [f'TimeGPT-q-{int(q * 100)}:double' for q in quantiles]
        schema.append(",".join(q_cols))
    return schema

def _distributed_setup(
    df: 'AnyDataFrame',
    method: str,
    id_col: str,
    time_col: str,
    target_col: str,
    level: Optional[List[Union[int, float]]],
    quantiles: Optional[List[float]],
    num_partitions: Optional[int],
) -> Tuple['triad.Schema', Dict[str, Any]]:
    from fugue.execution import infer_execution_engine

    if infer_execution_engine([df]) is None:
        raise ValueError(
            f'Could not infer execution engine for type {type(df).__name__}. '
            'Expected a spark or dask DataFrame or a ray Dataset.'
        )
    schema = _get_schema(
        df=df,
        method=method,
        id_col=id_col,
        time_col=time_col,
        target_col=target_col,
        level=level,
        quantiles=quantiles,
    )
    partition_config = dict(by=id_col, algo='coarse')
    if num_partitions is not None:
        partition_config['num'] = num_partitions
    return schema, partition_config

@patch
def _distributed_forecast(
    self: NixtlaClient,
    df: DistributedDFType,
    h: PositiveInt,
    freq: Optional[str],
    id_col: str,
    time_col: str,
    target_col: str,
    X_df: Optional[DistributedDFType],
    level: Optional[List[Union[int, float]]],
    quantiles: Optional[List[float]],
    finetune_steps: NonNegativeInt,
    finetune_loss: _Loss,
    clean_ex_first: bool,
    validate_api_key: bool,
    add_history: bool,
    date_features: Union[bool, List[Union[str, Callable]]],
    date_features_to_one_hot: Union[bool, List[str]],
    model: _Model,
    num_partitions: Optional[int],
    feature_contributions: bool,
) -> DistributedDFType:
    import fugue.api as fa

    schema, partition_config = _distributed_setup(
        df=df,
        method='forecast',
        id_col=id_col,
        time_col=time_col,
        target_col=target_col,
        level=level,
        quantiles=quantiles,
        num_partitions=num_partitions,
    )
    if X_df is not None:
        def format_df(df: pd.DataFrame) -> pd.DataFrame:
            return df.assign(_in_sample=True)

        def format_X_df(
            X_df: pd.DataFrame,
            target_col: str,
            df_cols: List[str],
        ) -> pd.DataFrame:
            return X_df.assign(**{'_in_sample': False, target_col: 0.0})[df_cols]

        df = fa.transform(df, format_df, schema='*,_in_sample:bool')
        X_df = fa.transform(
            X_df,
            format_X_df,
            schema=fa.get_schema(df),
            params={'target_col': target_col, 'df_cols': fa.get_column_names(df)},
        )
        df = fa.union(df, X_df)
    result_df = fa.transform(
        df,
        using=_forecast_wrapper,
        schema=schema,
        params=dict(
            client=self,
            h=h,
            freq=freq,
            id_col=id_col,
            time_col=time_col,
            target_col=target_col,
            level=level,
            quantiles=quantiles,
            finetune_steps=finetune_steps,
            finetune_loss=finetune_loss,
            clean_ex_first=clean_ex_first,
            validate_api_key=validate_api_key,
            add_history=add_history,
            date_features=date_features,
            date_features_to_one_hot=date_features_to_one_hot,
            model=model,
            num_partitions=None,
            feature_contributions=feature_contributions,   
        ),
        partition=partition_config,
        as_fugue=True,
    )
    return fa.get_native_as_df(result_df)

@patch
def _distributed_detect_anomalies(
    self: NixtlaClient,
    df: DistributedDFType,
    freq: Optional[str],
    id_col: str,
    time_col: str,
    target_col: str,
    level: Union[int, float],
    clean_ex_first: bool,
    validate_api_key: bool,
    date_features: Union[bool, List[str]],
    date_features_to_one_hot: Union[bool, List[str]],
    model: str,
    num_partitions: Optional[int],
) -> DistributedDFType:
    import fugue.api as fa

    schema, partition_config = _distributed_setup(
        df=df,
        method='detect_anomalies',
        id_col=id_col,
        time_col=time_col,
        target_col=target_col,
        level=level,
        quantiles=None,
        num_partitions=num_partitions,
    )
    result_df = fa.transform(
        df,
        using=_detect_anomalies_wrapper,
        schema=schema,
        params=dict(
            client=self,
            freq=freq,
            id_col=id_col,
            time_col=time_col,
            target_col=target_col,
            level=level,
            clean_ex_first=clean_ex_first,
            validate_api_key=validate_api_key,
            date_features=date_features,
            date_features_to_one_hot=date_features_to_one_hot,
            model=model,
            num_partitions=None,   
        ),
        partition=partition_config,
        as_fugue=True,
    )
    return fa.get_native_as_df(result_df)

@patch
def _distributed_cross_validation(
    self: NixtlaClient,
    df: DistributedDFType,
    h: PositiveInt,
    freq: Optional[str],
    id_col: str,
    time_col: str,
    target_col: str,
    level: Optional[List[Union[int, float]]],
    quantiles: Optional[List[float]],
    validate_api_key: bool,
    n_windows: PositiveInt,
    step_size: Optional[PositiveInt],
    finetune_steps: NonNegativeInt,
    finetune_loss: _Loss,
    clean_ex_first: bool,
    date_features: Union[bool, List[Union[str, Callable]]],
    date_features_to_one_hot: Union[bool, List[str]],
    model: _Model,
    num_partitions: Optional[int],
) -> DistributedDFType:
    import fugue.api as fa

    schema, partition_config = _distributed_setup(
        df=df,
        method='forecast',
        id_col=id_col,
        time_col=time_col,
        target_col=target_col,
        level=level,
        quantiles=quantiles,
        num_partitions=num_partitions,
    )
    result_df = fa.transform(
        df,
        using=_cross_validation_wrapper,
        schema=schema,
        params=dict(
            client=self,
            h=h,
            freq=freq,
            id_col=id_col,
            time_col=time_col,
            target_col=target_col,
            level=level,
            quantiles=quantiles,
            validate_api_key=validate_api_key,
            n_windows=n_windows,
            step_size=step_size,
            finetune_steps=finetune_steps,
            finetune_loss=finetune_loss,
            clean_ex_first=clean_ex_first,
            date_features=date_features,
            date_features_to_one_hot=date_features_to_one_hot,
            model=model,
            num_partitions=None,   
        ),
        partition=partition_config,
        as_fugue=True,
    )
    return fa.get_native_as_df(result_df)

In [None]:
#| hide
#| distributed
import dask.dataframe as dd
import fugue
import fugue.api as fa
import ray
from dask.distributed import Client
from pyspark.sql import SparkSession
from ray.cluster_utils import Cluster

In [None]:
#| hide
#| distributed
def test_forecast(
    df: fugue.AnyDataFrame, 
    horizon: int = 12,
    id_col: str = 'unique_id',
    time_col: str = 'ds',
    **fcst_kwargs,
):
    fcst_df = nixtla_client.forecast(
        df=df, 
        h=horizon,
        id_col=id_col,
        time_col=time_col,
        **fcst_kwargs,
    )
    fcst_df = fa.as_pandas(fcst_df)
    test_eq(n_series * 12, len(fcst_df))
    cols = fcst_df.columns.to_list()
    exp_cols = [id_col, time_col, 'TimeGPT']
    if 'level' in fcst_kwargs:
        level = sorted(fcst_kwargs['level'])
        exp_cols.extend([f'TimeGPT-lo-{lv}' for lv in reversed(level)])
        exp_cols.extend([f'TimeGPT-hi-{lv}' for lv in level])
    test_eq(cols, exp_cols)

def test_forecast_diff_results_diff_models(
    df: fugue.AnyDataFrame, 
    horizon: int = 12, 
    id_col: str = 'unique_id',
    time_col: str = 'ds',
    **fcst_kwargs,
):
    fcst_df = nixtla_client.forecast(
        df=df, 
        h=horizon, 
        num_partitions=1,
        id_col=id_col,
        time_col=time_col,
        model='timegpt-1',
        **fcst_kwargs
    )
    fcst_df = fa.as_pandas(fcst_df)
    fcst_df_2 = nixtla_client.forecast(
        df=df, 
        h=horizon, 
        num_partitions=1,
        id_col=id_col,
        time_col=time_col,
        model='timegpt-1-long-horizon',
        **fcst_kwargs
    )
    fcst_df_2 = fa.as_pandas(fcst_df_2)
    test_fail(
        lambda: pd.testing.assert_frame_equal(
            fcst_df.sort_values([id_col, time_col]).reset_index(drop=True),
            fcst_df_2.sort_values([id_col, time_col]).reset_index(drop=True),
        ),
        contains='(column name="TimeGPT") are different',
    )

def test_forecast_same_results_num_partitions(
    df: fugue.AnyDataFrame, 
    horizon: int = 12, 
    id_col: str = 'unique_id',
    time_col: str = 'ds',
    **fcst_kwargs,
):
    fcst_df = nixtla_client.forecast(
        df=df, 
        h=horizon, 
        num_partitions=1,
        id_col=id_col,
        time_col=time_col,
        **fcst_kwargs
    )
    fcst_df = fa.as_pandas(fcst_df)
    fcst_df_2 = nixtla_client.forecast(
        df=df, 
        h=horizon, 
        num_partitions=2,
        id_col=id_col,
        time_col=time_col,
        **fcst_kwargs
    )
    fcst_df_2 = fa.as_pandas(fcst_df_2)
    pd.testing.assert_frame_equal(
        fcst_df.sort_values([id_col, time_col]).reset_index(drop=True),
        fcst_df_2.sort_values([id_col, time_col]).reset_index(drop=True),
    )

def test_cv_same_results_num_partitions(
    df: fugue.AnyDataFrame, 
    horizon: int = 12, 
    id_col: str = 'unique_id',
    time_col: str = 'ds',
    **fcst_kwargs,
):
    fcst_df = nixtla_client.cross_validation(
        df=df, 
        h=horizon, 
        num_partitions=1,
        id_col=id_col,
        time_col=time_col,
        **fcst_kwargs
    )
    fcst_df = fa.as_pandas(fcst_df)
    fcst_df_2 = nixtla_client.cross_validation(
        df=df, 
        h=horizon, 
        num_partitions=2,
        id_col=id_col,
        time_col=time_col,
        **fcst_kwargs
    )
    fcst_df_2 = fa.as_pandas(fcst_df_2)
    pd.testing.assert_frame_equal(
        fcst_df.sort_values([id_col, time_col]).reset_index(drop=True),
        fcst_df_2.sort_values([id_col, time_col]).reset_index(drop=True),
    )

def test_forecast_dataframe(df: fugue.AnyDataFrame):
    test_cv_same_results_num_partitions(df, n_windows=2, step_size=1)
    test_cv_same_results_num_partitions(df, n_windows=3, step_size=None, horizon=1)
    test_cv_same_results_num_partitions(df, model='timegpt-1-long-horizon', horizon=1)
    test_forecast_diff_results_diff_models(df)
    test_forecast(df, num_partitions=1)
    test_forecast(df, level=[90, 80], num_partitions=1)
    test_forecast_same_results_num_partitions(df)

def test_forecast_dataframe_diff_cols(
    df: fugue.AnyDataFrame,
    id_col: str = 'id_col',
    time_col: str = 'time_col',
    target_col: str = 'target_col',
):
    test_forecast(df, id_col=id_col, time_col=time_col, target_col=target_col, num_partitions=1)
    test_forecast(
        df, id_col=id_col, time_col=time_col, target_col=target_col, level=[90, 80], num_partitions=1
    )
    test_forecast_same_results_num_partitions(
        df, id_col=id_col, time_col=time_col, target_col=target_col
    )

def test_forecast_x(
    df: fugue.AnyDataFrame, 
    X_df: fugue.AnyDataFrame,
    horizon: int = 24,
    id_col: str = 'unique_id',
    time_col: str = 'ds',
    target_col: str = 'y',
    **fcst_kwargs,
):
    fcst_df = nixtla_client.forecast(
        df=df, 
        X_df=X_df,
        h=horizon,
        id_col=id_col,
        time_col=time_col,
        target_col=target_col,
        **fcst_kwargs,
    )
    fcst_df = fa.as_pandas(fcst_df)
    n_series = fa.as_pandas(X_df)[id_col].nunique()
    test_eq(n_series * horizon, len(fcst_df))
    cols = fcst_df.columns.to_list()
    exp_cols = [id_col, time_col, 'TimeGPT']
    if 'level' in fcst_kwargs:
        level = sorted(fcst_kwargs['level'])
        exp_cols.extend([f'TimeGPT-lo-{lv}' for lv in reversed(level)])
        exp_cols.extend([f'TimeGPT-hi-{lv}' for lv in level])
    test_eq(cols, exp_cols)
    fcst_df_2 = nixtla_client.forecast(
        df=df,
        h=horizon,
        id_col=id_col,
        time_col=time_col,
        target_col=target_col,
        **fcst_kwargs,
    )
    fcst_df_2 = fa.as_pandas(fcst_df_2)
    equal_arrays = np.array_equal(
        fcst_df.sort_values([id_col, time_col])['TimeGPT'].values,
        fcst_df_2.sort_values([id_col, time_col])['TimeGPT'].values
    )
    assert not equal_arrays, 'Forecasts with and without ex vars are equal'

def test_forecast_x_same_results_num_partitions(
    df: fugue.AnyDataFrame, 
    X_df: fugue.AnyDataFrame,
    horizon: int = 24, 
    id_col: str = 'unique_id',
    time_col: str = 'ds',
    target_col: str = 'y',
    **fcst_kwargs,
):
    fcst_df = nixtla_client.forecast(
        df=df, 
        X_df=X_df,
        h=horizon, 
        num_partitions=1,
        id_col=id_col,
        time_col=time_col,
        target_col=target_col,
        **fcst_kwargs
    )
    fcst_df = fa.as_pandas(fcst_df)
    fcst_df_2 = nixtla_client.forecast(
        df=df,
        h=horizon,
        num_partitions=2,
        id_col=id_col,
        time_col=time_col,
        target_col=target_col,
        **fcst_kwargs
    )
    fcst_df_2 = fa.as_pandas(fcst_df_2)
    equal_arrays = np.array_equal(
        fcst_df.sort_values([id_col, time_col])['TimeGPT'].values,
        fcst_df_2.sort_values([id_col, time_col])['TimeGPT'].values
    )
    assert not equal_arrays, 'Forecasts with and without ex vars are equal'

def test_forecast_x_dataframe(df: fugue.AnyDataFrame, X_df: fugue.AnyDataFrame):
    test_forecast_x(df, X_df, num_partitions=1)
    test_forecast_x(df, X_df, level=[90, 80], num_partitions=1)
    test_forecast_x_same_results_num_partitions(df, X_df)

def test_forecast_x_dataframe_diff_cols(
    df: fugue.AnyDataFrame,
    X_df: fugue.AnyDataFrame,
    id_col: str = 'id_col',
    time_col: str = 'time_col',
    target_col: str = 'target_col'
):
    test_forecast_x(
        df, X_df, id_col=id_col, time_col=time_col, target_col=target_col, num_partitions=1
    )
    test_forecast_x(
        df, X_df, id_col=id_col, time_col=time_col, target_col=target_col, level=[90, 80], num_partitions=1
    )
    test_forecast_x_same_results_num_partitions(
        df, X_df, id_col=id_col, time_col=time_col, target_col=target_col
    )

def test_anomalies(
    df: fugue.AnyDataFrame, 
    id_col: str = 'unique_id',
    time_col: str = 'ds',
    target_col: str = 'y',
    **anomalies_kwargs,
):
    anomalies_df = nixtla_client.detect_anomalies(
        df=df, 
        id_col=id_col,
        time_col=time_col,
        target_col=target_col,
        **anomalies_kwargs,
    )
    anomalies_df = fa.as_pandas(anomalies_df)
    test_eq(fa.as_pandas(df)[id_col].unique(), anomalies_df[id_col].unique())
    cols = anomalies_df.columns.to_list()
    level = anomalies_kwargs.get('level', 99)
    exp_cols = [
        id_col,
        time_col,
        target_col,
        'TimeGPT',
        'anomaly',
        f'TimeGPT-lo-{level}',
        f'TimeGPT-hi-{level}',
    ]
    test_eq(cols, exp_cols)

def test_anomalies_same_results_num_partitions(
    df: fugue.AnyDataFrame, 
    id_col: str = 'unique_id',
    time_col: str = 'ds',
    target_col: str = 'y',
    **anomalies_kwargs,
):
    anomalies_df = nixtla_client.detect_anomalies(
        df=df, 
        num_partitions=1,
        id_col=id_col,
        time_col=time_col,
        target_col=target_col,
        **anomalies_kwargs
    )
    anomalies_df = fa.as_pandas(anomalies_df)
    anomalies_df_2 = nixtla_client.detect_anomalies(
        df=df, 
        num_partitions=2,
        id_col=id_col,
        time_col=time_col,
        target_col=target_col,
        **anomalies_kwargs
    )
    anomalies_df_2 = fa.as_pandas(anomalies_df_2)
    pd.testing.assert_frame_equal(
        anomalies_df.sort_values([id_col, time_col]).reset_index(drop=True),
        anomalies_df_2.sort_values([id_col, time_col]).reset_index(drop=True),
        atol=1e-5,
    )

def test_anomalies_diff_results_diff_models(
    df: fugue.AnyDataFrame, 
    id_col: str = 'unique_id',
    time_col: str = 'ds',
    target_col: str = 'y',
    **anomalies_kwargs,
):
    anomalies_df = nixtla_client.detect_anomalies(
        df=df, 
        num_partitions=1,
        id_col=id_col,
        time_col=time_col,
        target_col=target_col,
        model='timegpt-1',
        **anomalies_kwargs
    )
    anomalies_df = fa.as_pandas(anomalies_df)
    anomalies_df_2 = nixtla_client.detect_anomalies(
        df=df, 
        num_partitions=1,
        id_col=id_col,
        time_col=time_col,
        target_col=target_col,
        model='timegpt-1-long-horizon',
        **anomalies_kwargs
    )
    anomalies_df_2 = fa.as_pandas(anomalies_df_2)
    test_fail(
        lambda: pd.testing.assert_frame_equal(
            anomalies_df.sort_values([id_col, time_col]).reset_index(drop=True),
            anomalies_df_2.sort_values([id_col, time_col]).reset_index(drop=True),
        ),
        contains='(column name="TimeGPT") are different',
    )

def test_anomalies_dataframe(df: fugue.AnyDataFrame):
    test_anomalies(df, num_partitions=1)
    test_anomalies(df, level=90, num_partitions=1)
    test_anomalies_same_results_num_partitions(df)

def test_anomalies_dataframe_diff_cols(
    df: fugue.AnyDataFrame,
    id_col: str = 'id_col',
    time_col: str = 'time_col',
    target_col: str = 'target_col',
):
    test_anomalies(df, id_col=id_col, time_col=time_col, target_col=target_col, num_partitions=1)
    test_anomalies(df, id_col=id_col, time_col=time_col, target_col=target_col, level=90, num_partitions=1)
    test_anomalies_same_results_num_partitions(df, id_col=id_col, time_col=time_col, target_col=target_col)
    # @A: document behavior with exogenous variables in distributed environments.  
    #test_anomalies_same_results_num_partitions(df, id_col=id_col, time_col=time_col, date_features=True, clean_ex_first=False)

def test_quantiles(df: fugue.AnyDataFrame, id_col: str = 'id_col', time_col: str = 'time_col'):
    test_qls = list(np.arange(0.1, 1, 0.1))
    exp_q_cols = [f"TimeGPT-q-{int(q * 100)}" for q in test_qls]
    def test_method_qls(method, **kwargs):
        df_qls = method(
            df=df, 
            h=12, 
            id_col=id_col,
            time_col=time_col, 
            quantiles=test_qls,
            **kwargs
        )
        df_qls = fa.as_pandas(df_qls)
        assert all(col in df_qls.columns for col in exp_q_cols)
        # test monotonicity of quantiles
        df_qls.apply(lambda x: x.is_monotonic_increasing, axis=1).sum() == len(exp_q_cols)
    test_method_qls(nixtla_client.forecast)
    test_method_qls(nixtla_client.forecast, add_history=True)
    test_method_qls(nixtla_client.cross_validation)

In [None]:
#| hide
#| distributed
n_series = 4
horizon = 7

series = generate_series(n_series, min_length=100)
series['unique_id'] = series['unique_id'].astype(str)

series_diff_cols = series.copy()
renamer = {'unique_id': 'id_col', 'ds': 'time_col', 'y': 'target_col'}
series_diff_cols = series_diff_cols.rename(columns=renamer)

# data for exogenous tests
df_x = pd.read_csv(
    'https://raw.githubusercontent.com/Nixtla/transfer-learning-time-series/main/datasets/electricity-short-with-ex-vars.csv',
    parse_dates=['ds'],
)
df_x = df_x.rename(columns=str.lower)
future_ex_vars_df = pd.read_csv(
    'https://raw.githubusercontent.com/Nixtla/transfer-learning-time-series/main/datasets/electricity-short-future-ex-vars.csv',
    parse_dates=['ds'],
)
future_ex_vars_df = future_ex_vars_df.rename(columns=str.lower)

### Spark

In [None]:
#| hide
#| distributed
spark = SparkSession.builder.getOrCreate()
spark_df = spark.createDataFrame(series).repartition(2)
spark_diff_cols_df = spark.createDataFrame(series_diff_cols).repartition(2)

test_quantiles(spark_df, id_col="unique_id", time_col="ds")

test_forecast_dataframe(spark_df)
test_forecast_dataframe_diff_cols(spark_diff_cols_df)
test_anomalies_dataframe(spark_df)
test_anomalies_dataframe_diff_cols(spark_diff_cols_df)
# test exogenous variables
spark_df_x = spark.createDataFrame(df_x).repartition(2)
spark_future_ex_vars_df = spark.createDataFrame(future_ex_vars_df).repartition(2)
test_forecast_x_dataframe(spark_df_x, spark_future_ex_vars_df)
# test x different cols
spark_df_x_diff_cols = spark.createDataFrame(df_x.rename(columns=renamer)).repartition(2)
spark_future_ex_vars_df_diff_cols = spark.createDataFrame(
    future_ex_vars_df.rename(columns=renamer)
).repartition(2)
test_forecast_x_dataframe_diff_cols(spark_df_x_diff_cols, spark_future_ex_vars_df_diff_cols)

spark.stop()

### Dask

In [None]:
#| hide
#| distributed
client = Client()
dask_df = dd.from_pandas(series, npartitions=2)
dask_diff_cols_df = dd.from_pandas(series_diff_cols, npartitions=2)

test_quantiles(dask_df, id_col="unique_id", time_col="ds")


test_forecast_dataframe(dask_df)
test_forecast_dataframe_diff_cols(dask_diff_cols_df)
test_anomalies_dataframe(dask_df)
test_anomalies_dataframe_diff_cols(dask_diff_cols_df)

# test exogenous variables
dask_df_x = dd.from_pandas(df_x, npartitions=2)
dask_future_ex_vars_df = dd.from_pandas(future_ex_vars_df, npartitions=2)
test_forecast_x_dataframe(dask_df_x, dask_future_ex_vars_df)

# test x different cols
dask_df_x_diff_cols = dd.from_pandas(df_x.rename(columns=renamer), npartitions=2)
dask_future_ex_vars_df_diff_cols = dd.from_pandas(future_ex_vars_df.rename(columns=renamer), npartitions=2)
test_forecast_x_dataframe_diff_cols(dask_df_x_diff_cols, dask_future_ex_vars_df_diff_cols)

client.close()

### Ray

In [None]:
#| hide
#| distributed
ray_cluster = Cluster(
    initialize_head=True,
    head_node_args={"num_cpus": 2}
)
ray.init(address=ray_cluster.address, ignore_reinit_error=True)
# add mock node to simulate a cluster
mock_node = ray_cluster.add_node(num_cpus=2)

ray_df = ray.data.from_pandas(series)
ray_diff_cols_df = ray.data.from_pandas(series_diff_cols)

test_quantiles(ray_df, id_col="unique_id", time_col="ds")

test_forecast_dataframe(ray_df)
test_forecast_dataframe_diff_cols(ray_diff_cols_df)
test_anomalies_dataframe(ray_df)
test_anomalies_dataframe_diff_cols(ray_diff_cols_df)

# test exogenous variables
ray_df_x = ray.data.from_pandas(df_x)
ray_future_ex_vars_df = ray.data.from_pandas(future_ex_vars_df)
test_forecast_x_dataframe(ray_df_x, ray_future_ex_vars_df)

# test x different cols
ray_df_x_diff_cols = ray.data.from_pandas(df_x.rename(columns=renamer))
ray_future_ex_vars_df_diff_cols = ray.data.from_pandas(future_ex_vars_df.rename(columns=renamer))
test_forecast_x_dataframe_diff_cols(ray_df_x_diff_cols, ray_future_ex_vars_df_diff_cols)

ray.shutdown()