-
Notifications
You must be signed in to change notification settings - Fork 323
Expand file tree
/
Copy pathnixtla_client.py
More file actions
3490 lines (3272 loc) · 138 KB
/
nixtla_client.py
File metadata and controls
3490 lines (3272 loc) · 138 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
__all__ = ["ApiError", "NixtlaClient"]
import datetime
from http import HTTPStatus
import logging
import math
import os
import warnings
from collections.abc import Sequence
from concurrent.futures import ThreadPoolExecutor, as_completed
from enum import Enum
from typing import (
TYPE_CHECKING,
Annotated,
Any,
Callable,
Dict,
Literal,
Optional,
TypeVar,
Union,
overload,
)
import annotated_types
import httpcore
import httpx
import numpy as np
import orjson
import pandas as pd
import utilsforecast.processing as ufp
import zstandard as zstd
from pydantic import AfterValidator, BaseModel, TypeAdapter
from tenacity import (
RetryCallState,
retry,
retry_if_exception,
stop_after_attempt,
stop_after_delay,
wait_fixed,
)
from utilsforecast.compat import DataFrame, DFType, pl_DataFrame
from utilsforecast.feature_engineering import _add_time_features, time_features
from utilsforecast.preprocessing import fill_gaps, id_time_grid
from utilsforecast.processing import ensure_sorted
from utilsforecast.validation import ensure_time_dtype, validate_format
if TYPE_CHECKING:
try:
from fugue import AnyDataFrame
except ModuleNotFoundError:
pass
try:
import matplotlib.pyplot as plt
except ModuleNotFoundError:
pass
try:
import plotly
except ModuleNotFoundError:
pass
try:
import triad
except ModuleNotFoundError:
pass
try:
from polars import DataFrame as PolarsDataFrame
except ModuleNotFoundError:
pass
try:
from dask.dataframe import DataFrame as DaskDataFrame
except ModuleNotFoundError:
pass
try:
from pyspark.sql import DataFrame as SparkDataFrame
except ModuleNotFoundError:
pass
try:
from ray.data import Dataset as RayDataset
except ModuleNotFoundError:
pass
AnyDFType = TypeVar(
"AnyDFType",
"DaskDataFrame",
pd.DataFrame,
"PolarsDataFrame",
"RayDataset",
"SparkDataFrame",
)
DistributedDFType = TypeVar(
"DistributedDFType",
"DaskDataFrame",
"RayDataset",
"SparkDataFrame",
)
logging.basicConfig(level=logging.INFO)
logging.getLogger("httpx").setLevel(logging.ERROR)
logger = logging.getLogger(__name__)
def validate_extra_params(value: Optional[Dict[str, Any]]) -> Optional[Dict[str, Any]]:
"""Validate that the dictionary doesn't contain complex structures."""
primitives = (str, int, float, bool, type(None))
if value is None:
return value
for _, v in value.items():
if isinstance(v, dict):
for _, nv in v.items():
# nested structure allowed but they can support primitive values only
if not isinstance(nv, primitives):
raise TypeError(f"Invalid value type: {type(nv).__name__}")
elif isinstance(v, (dict, list, tuple, set)):
for nv in v:
if not isinstance(nv, primitives):
raise TypeError(f"Invalid value type: {type(nv).__name__}")
elif not isinstance(v, primitives):
raise TypeError(f"Invalid value type: {type(v).__name__}")
return value
_PositiveInt = Annotated[int, annotated_types.Gt(0)]
_NonNegativeInt = Annotated[int, annotated_types.Ge(0)]
_ExtraParamDataType = Annotated[
Optional[Dict[str, Any]], AfterValidator(validate_extra_params)
]
extra_param_checker = TypeAdapter(_ExtraParamDataType)
_Loss = Literal["default", "mae", "mse", "rmse", "mape", "smape", "poisson"]
_Model = str
_FinetuneDepth = Literal[1, 2, 3, 4, 5]
_Freq = Union[str, int, pd.offsets.BaseOffset]
_FreqType = TypeVar("_FreqType", str, int, pd.offsets.BaseOffset)
_ThresholdMethod = Literal["univariate", "multivariate"]
class FinetunedModel(BaseModel, extra="allow"): # type: ignore
id: str
created_at: datetime.datetime
created_by: str
base_model_id: str
steps: int
depth: int
loss: _Loss
model: _Model
freq: str
_date_features_by_freq = {
# Daily frequencies
"B": ["year", "month", "day", "weekday"],
"C": ["year", "month", "day", "weekday"],
"D": ["year", "month", "day", "weekday"],
# Weekly
"W": ["year", "week", "weekday"],
# Monthly
"M": ["year", "month"],
"SM": ["year", "month", "day"],
"BM": ["year", "month"],
"CBM": ["year", "month"],
"MS": ["year", "month"],
"SMS": ["year", "month", "day"],
"BMS": ["year", "month"],
"CBMS": ["year", "month"],
# Quarterly
"Q": ["year", "quarter"],
"BQ": ["year", "quarter"],
"QS": ["year", "quarter"],
"BQS": ["year", "quarter"],
# Yearly
"A": ["year"],
"Y": ["year"],
"BA": ["year"],
"BY": ["year"],
"AS": ["year"],
"YS": ["year"],
"BAS": ["year"],
"BYS": ["year"],
# Hourly
"BH": ["year", "month", "day", "hour", "weekday"],
"H": ["year", "month", "day", "hour"],
# Minutely
"T": ["year", "month", "day", "hour", "minute"],
"min": ["year", "month", "day", "hour", "minute"],
# Secondly
"S": ["year", "month", "day", "hour", "minute", "second"],
# Milliseconds
"L": ["year", "month", "day", "hour", "minute", "second", "millisecond"],
"ms": ["year", "month", "day", "hour", "minute", "second", "millisecond"],
# Microseconds
"U": ["year", "month", "day", "hour", "minute", "second", "microsecond"],
"us": ["year", "month", "day", "hour", "minute", "second", "microsecond"],
# Nanoseconds
"N": [],
}
def _retry_strategy(max_retries: int, retry_interval: int, max_wait_time: int):
def should_retry(exc: Exception) -> bool:
retriable_exceptions = (
ConnectionResetError,
httpcore.ConnectError,
httpcore.RemoteProtocolError,
httpx.ConnectTimeout,
httpx.ReadError,
httpx.RemoteProtocolError,
httpx.ReadTimeout,
httpx.PoolTimeout,
httpx.WriteError,
httpx.WriteTimeout,
)
retriable_codes = [
HTTPStatus.REQUEST_TIMEOUT,
HTTPStatus.CONFLICT,
HTTPStatus.TOO_MANY_REQUESTS,
HTTPStatus.BAD_GATEWAY,
HTTPStatus.SERVICE_UNAVAILABLE,
HTTPStatus.GATEWAY_TIMEOUT,
]
return isinstance(exc, retriable_exceptions) or (
isinstance(exc, ApiError) and exc.status_code in retriable_codes
)
def after_retry(retry_state: RetryCallState) -> None:
error = retry_state.outcome.exception()
logger.error(f"Attempt {retry_state.attempt_number} failed with error: {error}")
return retry(
retry=retry_if_exception(should_retry),
wait=wait_fixed(retry_interval),
after=after_retry,
stop=stop_after_attempt(max_retries) | stop_after_delay(max_wait_time),
reraise=True,
)
def _maybe_infer_freq(
df: DataFrame,
freq: Optional[_FreqType],
id_col: str,
time_col: str,
) -> _FreqType:
if freq is not None:
return freq
if isinstance(df, pl_DataFrame):
raise ValueError(
"Cannot infer frequency for a polars DataFrame, please set the "
"`freq` argument to a valid polars offset.\nYou can find them at "
"https://pola-rs.github.io/polars/py-polars/html/reference/expressions/api/polars.Expr.dt.offset_by.html"
)
assert isinstance(df, pd.DataFrame)
sizes = df[id_col].value_counts(sort=True)
times = df.loc[df[id_col] == sizes.index[0], time_col].sort_values()
if times.dt.tz is not None:
times = times.dt.tz_convert("UTC").dt.tz_localize(None)
inferred_freq = pd.infer_freq(times.values)
if inferred_freq is None:
raise RuntimeError(
"Could not infer the frequency of the time column. This could be due "
"to inconsistent intervals. Please check your data for missing, "
"duplicated or irregular timestamps"
)
logger.info(f"Inferred freq: {inferred_freq}")
return inferred_freq
def _standardize_freq(freq: _Freq, processed: ufp.ProcessedDF) -> str:
if isinstance(freq, str):
# polars uses 'mo' for months, all other strings are compatible with pandas
freq = freq.replace("mo", "MS")
elif isinstance(freq, pd.offsets.BaseOffset):
freq = freq.freqstr
elif isinstance(freq, int):
freq = "MS"
else:
raise ValueError(
f"`freq` must be a string, int or pandas offset, got {type(freq).__name__}"
)
return freq
def _array_tails(
x: np.ndarray,
indptr: np.ndarray,
out_sizes: np.ndarray,
) -> np.ndarray:
if (out_sizes > np.diff(indptr)).any():
raise ValueError("out_sizes must be at most the original sizes.")
idxs = np.hstack(
[np.arange(end - size, end) for end, size in zip(indptr[1:], out_sizes)]
)
return x[idxs]
def _tail(proc: ufp.ProcessedDF, n: int) -> ufp.ProcessedDF:
new_sizes = np.minimum(np.diff(proc.indptr), n)
new_indptr = np.append(0, new_sizes.cumsum())
new_data = _array_tails(proc.data, proc.indptr, new_sizes)
return ufp.ProcessedDF(
uids=proc.uids,
last_times=proc.last_times,
data=new_data,
indptr=new_indptr,
sort_idxs=None,
)
def _partition_series(
payload: dict[str, Any], n_part: int, h: int
) -> list[dict[str, Any]]:
parts = []
series = payload.pop("series")
n_series = len(series["sizes"])
n_part = min(n_part, n_series)
series_per_part = math.ceil(n_series / n_part)
prev_size = 0
for i in range(0, n_series, series_per_part):
sizes = series["sizes"][i : i + series_per_part]
curr_size = sum(sizes)
part_idxs = slice(prev_size, prev_size + curr_size)
prev_size += curr_size
part_series = {
"y": series["y"][part_idxs],
"sizes": sizes,
}
if series["X"] is None:
part_series["X"] = None
if h > 0:
part_series["X_future"] = None
else:
part_series["X"] = [x[part_idxs] for x in series["X"]]
if h > 0:
if series["X_future"] is None:
part_series["X_future"] = None
else:
part_series["X_future"] = [
x[i * h : (i + series_per_part) * h] for x in series["X_future"]
]
if "categorical_exog" in series:
part_series["categorical_exog"] = series["categorical_exog"]
parts.append({"series": part_series, **payload})
return parts
def _maybe_add_date_features(
df: DFType,
X_df: Optional[DFType],
features: Union[bool, Sequence[Union[str, Callable]]],
one_hot: Union[bool, list[str]],
freq: _Freq,
h: int,
id_col: str,
time_col: str,
target_col: str,
) -> tuple[DFType, Optional[DFType]]:
if not features or not isinstance(freq, str):
return df, X_df
if isinstance(features, list):
date_features: Sequence[Union[str, Callable]] = features
else:
date_features = _date_features_by_freq.get(freq, [])
if not date_features:
logger.warning(
f"Non default date features for {freq} "
"please provide a list of date features"
)
# add features
if X_df is None:
df, X_df = time_features(
df=df,
freq=freq,
features=date_features,
h=h,
id_col=id_col,
time_col=time_col,
)
else:
df = _add_time_features(df, features=date_features, time_col=time_col)
X_df = _add_time_features(X_df, features=date_features, time_col=time_col)
# one hot
if isinstance(one_hot, list):
features_one_hot = one_hot
elif one_hot:
features_one_hot = [f for f in date_features if not callable(f)]
else:
features_one_hot = []
if features_one_hot:
X_df = ufp.assign_columns(X_df, target_col, 0)
full_df = ufp.vertical_concat([df, X_df])
if isinstance(full_df, pd.DataFrame):
full_df = pd.get_dummies(full_df, columns=features_one_hot, dtype="float32")
else:
full_df = full_df.to_dummies(columns=features_one_hot)
df = ufp.take_rows(full_df, slice(0, df.shape[0]))
X_df = ufp.take_rows(full_df, slice(df.shape[0], full_df.shape[0]))
X_df = ufp.drop_columns(X_df, target_col)
X_df = ufp.drop_index_if_pandas(X_df)
if h == 0:
# time_features returns an empty df, we use it as None here
X_df = None
return df, X_df
def _validate_exog(
df: DFType,
X_df: Optional[DFType],
id_col: str,
time_col: str,
target_col: str,
hist_exog: Optional[list[str]],
) -> tuple[DFType, Optional[DFType]]:
base_cols = {id_col, time_col, target_col}
exogs = [c for c in df.columns if c not in base_cols]
if hist_exog is None:
hist_exog = []
if X_df is None:
# all exogs must be historic
ignored_exogs = [c for c in exogs if c not in hist_exog]
if ignored_exogs:
logger.warning(
f"`df` contains the following exogenous features: {ignored_exogs}, "
"but `X_df` was not provided and they were not declared in `hist_exog_list`. "
"They will be ignored."
)
exogs = [c for c in exogs if c in hist_exog]
df = df[[id_col, time_col, target_col, *exogs]]
return df, None
# exogs in df that weren't declared as historic nor future
futr_exog = [c for c in X_df.columns if c not in base_cols]
declared_exogs = {*hist_exog, *futr_exog}
ignored_exogs = [c for c in exogs if c not in declared_exogs]
if ignored_exogs:
logger.warning(
f"`df` contains the following exogenous features: {ignored_exogs}, "
"but they were not found in `X_df` nor declared in `hist_exog_list`. "
"They will be ignored."
)
# future exogenous are provided in X_df that are not in df
missing_futr = set(futr_exog) - set(exogs)
if missing_futr:
raise ValueError(
"The following exogenous features are present in `X_df` "
f"but not in `df`: {missing_futr}."
)
# features are provided through X_df but declared as historic
futr_and_hist = set(futr_exog) & set(hist_exog)
if futr_and_hist:
logger.warning(
"The following features were declared as historic but found in `X_df`: "
f"{futr_and_hist}, they will be considered as historic."
)
futr_exog = [f for f in futr_exog if f not in hist_exog]
# Make sure df and X_df are in right order
df = df[[id_col, time_col, target_col, *futr_exog, *hist_exog]]
X_df = X_df[[id_col, time_col, *futr_exog]]
return df, X_df
def _extract_categorical_exog(
df: DFType,
categorical_exog_list: Optional[list[str]],
id_col: str,
time_col: str,
target_col: str,
X_df: Optional[DFType] = None,
) -> tuple[
DFType,
Optional[DFType],
dict[str, np.ndarray],
list[str],
list[str],
list[list],
]:
"""Validate, extract, and strip categorical exogenous columns from df/X_df.
Returns:
df: df with all categorical columns removed.
X_df: X_df with future categorical columns removed (unchanged if None).
df_cat_vals: mapping col → raw values array for every col in categorical_exog_list.
futr_cat_cols: cat cols found in X_df (treated as future categoricals).
hist_cat_cols: cat cols not in X_df.
X_df_cat_future: sorted future values per futr_cat_col (empty when X_df is None).
"""
if not categorical_exog_list:
return df, X_df, {}, [], [], []
x_df_exog_cols = (
{c for c in X_df.columns if c not in {id_col, time_col}}
if X_df is not None
else set()
)
df_exog_cols = {c for c in df.columns if c not in {id_col, time_col, target_col}}
invalid_cats = set(categorical_exog_list) - df_exog_cols - x_df_exog_cols
if invalid_cats:
location = "`df` or `X_df`" if X_df is not None else "`df`"
raise ValueError(
"The following columns in `categorical_exog_list` were not "
f"found in {location}: {invalid_cats}."
)
futr_cat_cols = [c for c in categorical_exog_list if c in x_df_exog_cols]
hist_cat_cols = [c for c in categorical_exog_list if c not in futr_cat_cols]
# futr_cat_cols must also exist in df to provide historical context rows for X.
futr_cat_missing_from_df = set(futr_cat_cols) - df_exog_cols
if futr_cat_missing_from_df:
raise ValueError(
"The following columns in `categorical_exog_list` were found in `X_df` but are "
f"missing from `df`: {futr_cat_missing_from_df}. Future categorical features must "
"also be present in `df` to provide historical context."
)
# Extract historical values for all cat cols from df.
# futr_cat_cols appear in both df (history) and X_df (future); hist_cat_cols only in df.
df_cat_vals: dict[str, np.ndarray] = {c: df[c].to_numpy() for c in categorical_exog_list}
X_df_cat_future: list[list] = []
if futr_cat_cols and X_df is not None:
X_df_sorted = ensure_sorted(X_df, id_col=id_col, time_col=time_col)
for c in futr_cat_cols:
X_df_cat_future.append(X_df_sorted[c].tolist())
X_df = X_df[[c for c in X_df.columns if c not in futr_cat_cols]]
df = df[[c for c in df.columns if c not in set(categorical_exog_list)]]
return df, X_df, df_cat_vals, futr_cat_cols, hist_cat_cols, X_df_cat_future
def _validate_input_size(
processed: ufp.ProcessedDF,
model_input_size: int,
model_horizon: int,
) -> None:
min_size = np.diff(processed.indptr).min().item()
if min_size < model_input_size + model_horizon:
raise ValueError(
"Some series are too short. "
"Please make sure that each series contains "
f"at least {model_input_size + model_horizon} observations."
)
def _prepare_level_and_quantiles(
level: Optional[list[Union[int, float]]],
quantiles: Optional[list[float]],
) -> tuple[Optional[list[Union[int, float]]], Optional[list[float]]]:
if level is not None and quantiles is not None:
raise ValueError("You should provide `level` or `quantiles`, but not both.")
if quantiles is None:
return level, quantiles
# we recover level from quantiles
if not all(0 < q < 1 for q in quantiles):
raise ValueError("`quantiles` should be floats between 0 and 1.")
level = [abs(int(100 - 200 * q)) for q in quantiles]
return level, quantiles
def _maybe_convert_level_to_quantiles(
df: DFType,
quantiles: Optional[list[float]],
) -> DFType:
if quantiles is None:
return df
out_cols = [c for c in df.columns if "-lo-" not in c and "-hi-" not in c]
df = ufp.copy_if_pandas(df, deep=False)
for q in sorted(quantiles):
if q == 0.5:
col = "TimeGPT"
else:
lv = int(100 - 200 * q)
hi_or_lo = "lo" if lv > 0 else "hi"
lv = abs(lv)
col = f"TimeGPT-{hi_or_lo}-{lv}"
q_col = f"TimeGPT-q-{int(q * 100)}"
df = ufp.assign_columns(df, q_col, df[col])
out_cols.append(q_col)
return df[out_cols]
def _preprocess(
df: DFType,
X_df: Optional[DFType],
h: int,
freq: str,
date_features: Union[bool, Sequence[Union[str, Callable]]],
date_features_to_one_hot: Union[bool, list[str]],
id_col: str,
time_col: str,
target_col: str,
) -> tuple[ufp.ProcessedDF, Optional[DFType], list[str], Optional[list[str]]]:
df, X_df = _maybe_add_date_features(
df=df,
X_df=X_df,
features=date_features,
one_hot=date_features_to_one_hot,
freq=freq,
h=h,
id_col=id_col,
time_col=time_col,
target_col=target_col,
)
processed = ufp.process_df(
df=df, id_col=id_col, time_col=time_col, target_col=target_col
)
if X_df is not None and X_df.shape[1] > 2:
X_df = ensure_time_dtype(X_df, time_col=time_col)
processed_X = ufp.process_df(
df=X_df,
id_col=id_col,
time_col=time_col,
target_col=None,
)
X_future = processed_X.data.T
futr_cols = [c for c in X_df.columns if c not in (id_col, time_col)]
else:
X_future = None
futr_cols = None
x_cols = [c for c in df.columns if c not in (id_col, time_col, target_col)]
return processed, X_future, x_cols, futr_cols
def _forecast_payload_to_in_sample(payload: dict, h: int, n_windows: int) -> dict:
# No finetuning for in-sample
payload["finetune_steps"] = 0
# historic exogenous features
hist_exog = None
if payload["series"]["X"] is not None:
n_features = len(payload["series"]["X"])
hist_exog = list(range(n_features))
if payload["series"]["X_future"] is not None:
n_futr_exog = len(payload["series"]["X_future"])
hist_exog = hist_exog[n_futr_exog:]
payload["hist_exog"] = hist_exog
del payload["series"]["X_future"]
# in-sample horizon and number of windows
payload["h"] = h
payload["step_size"] = h
payload["n_windows"] = n_windows
return payload
def _get_in_sample_horizon_and_windows(
sizes: np.ndarray,
model_horizon: int,
model_input_size: int,
clean_ex_first: bool,
level: Optional[list[Union[int, float]]],
) -> tuple[int, int]:
# in-sample horizon and number of windows
min_size = min(sizes)
h = min(model_horizon, min_size - 1)
if clean_ex_first:
n_windows = max((min_size - model_input_size) // model_horizon, 1)
else:
n_windows = max((min_size - (model_input_size + model_horizon + 2 * h)) // model_horizon, 1)
# In case of multiple windows, we reduce one to avoid errors when running with level argument
if level is not None and n_windows > 1:
n_windows -= 1
return h, n_windows
def _maybe_add_intervals(
df: DFType,
intervals: Optional[dict[str, list[float]]],
) -> DFType:
if intervals is None:
return df
first_key = next(iter(intervals), None)
if first_key is None or intervals[first_key] is None:
return df
intervals_df = type(df)(
{f"TimeGPT-{k}": intervals[k] for k in sorted(intervals.keys())}
)
return ufp.horizontal_concat([df, intervals_df])
def _maybe_drop_id(df: DFType, id_col: str, drop: bool) -> DFType:
if drop:
df = ufp.drop_columns(df, id_col)
return df
def _parse_in_sample_output(
in_sample_output: dict[str, Union[list[float], dict[str, list[float]]]],
df: DataFrame,
processed: ufp.ProcessedDF,
id_col: str,
time_col: str,
target_col: str,
) -> DataFrame:
times = df[time_col].to_numpy()
targets = df[target_col].to_numpy()
if processed.sort_idxs is not None:
times = times[processed.sort_idxs]
targets = targets[processed.sort_idxs]
times = _array_tails(times, processed.indptr, in_sample_output["sizes"])
targets = _array_tails(targets, processed.indptr, in_sample_output["sizes"])
uids = ufp.repeat(processed.uids, in_sample_output["sizes"])
out = type(df)(
{
id_col: uids,
time_col: times,
target_col: targets,
"TimeGPT": in_sample_output["mean"],
}
)
return _maybe_add_intervals(out, in_sample_output["intervals"]) # type: ignore
def _restrict_input_samples(level, input_size, model_horizon, h) -> int:
if level is not None:
# add sufficient info to compute
# conformal interval
# @AzulGarza
# this is an old opinionated decision
# about reducing the data sent to the api
# to reduce latency when
# a user passes level. since currently the model
# uses conformal prediction, we can change a minimum
# amount of data if the series are too large
new_input_size = 3 * input_size + max(model_horizon, h)
else:
# we only want to forecast
new_input_size = input_size
return new_input_size
def _extract_target_array(df: DataFrame, target_col: str) -> np.ndarray:
# in pandas<2.2 to_numpy can lead to an object array if
# the type is a pandas nullable type, e.g. pd.Float64Dtype
# we thus use the dtype's type as the target dtype
if isinstance(df, pd.DataFrame):
target_dtype = df.dtypes[target_col].type
targets = df[target_col].to_numpy(dtype=target_dtype)
else:
targets = df[target_col].to_numpy()
return targets
def _process_exog_features(
processed_data: np.ndarray,
x_cols: list[str],
hist_exog_list: Optional[list[str]] = None,
) -> tuple[Optional[np.ndarray], Optional[list[int]]]:
X = None
hist_exog = None
if processed_data.shape[1] > 1:
X = processed_data[:, 1:].T
if hist_exog_list is None:
futr_exog = x_cols
else:
missing_hist: set[str] = set(hist_exog_list) - set(x_cols)
if missing_hist:
raise ValueError(
"The following exogenous features were declared as historic "
f"but were not found in `df`: {missing_hist}."
)
futr_exog = [c for c in x_cols if c not in hist_exog_list]
# match the forecast method order [future, historic]
fcst_features_order = futr_exog + hist_exog_list
x_idxs = [x_cols.index(c) for c in fcst_features_order]
X = X[x_idxs]
hist_exog = [fcst_features_order.index(c) for c in hist_exog_list]
if futr_exog and logger:
logger.info(f"Using future exogenous features: {futr_exog}")
if hist_exog_list and logger:
logger.info(f"Using historical exogenous features: {hist_exog_list}")
return X, hist_exog
class AuditDataSeverity(Enum):
"""Enum class to indicate audit data severity levels"""
FAIL = "Fail" # Indicates a critical issue that requires immediate attention
CASE_SPECIFIC = "Case Specific" # Indicates an issue that may be acceptable in specific contexts
PASS = "Pass" # Indicates that the data is acceptable
def _audit_duplicate_rows(
df: AnyDFType,
id_col: str = "unique_id",
time_col: str = "ds",
) -> tuple[AuditDataSeverity, AnyDFType]:
if isinstance(df, pd.DataFrame):
duplicates = df.duplicated(subset=[id_col, time_col], keep=False)
if duplicates.any():
return AuditDataSeverity.FAIL, df[duplicates]
return AuditDataSeverity.PASS, pd.DataFrame()
else:
raise ValueError(f"Dataframe type {type(df)} is not supported yet.")
def _audit_missing_dates(
df: AnyDFType,
freq: _Freq,
id_col: str = "unique_id",
time_col: str = "ds",
start: Union[str, int, datetime.date, datetime.datetime] = "per_serie",
end: Union[str, int, datetime.date, datetime.datetime] = "global",
) -> tuple[AuditDataSeverity, AnyDFType]:
if isinstance(df, pd.DataFrame):
# Fill gaps in data
# Convert time_col to datetime if it's string/object type
df = ensure_time_dtype(df, time_col=time_col)
df_complete = fill_gaps(
df, freq=freq, id_col=id_col, time_col=time_col, start=start, end=end
)
# Find missing dates by comparing df_complete with df
df_missing = pd.merge(
df_complete, df, on=[id_col, time_col], how="outer", indicator=True
)
df_missing = df_missing.query("_merge == 'left_only'")[[id_col, time_col]]
if len(df_missing) > 0:
return AuditDataSeverity.FAIL, df_missing
return AuditDataSeverity.PASS, pd.DataFrame()
else:
raise ValueError(f"Dataframe type {type(df)} is not supported yet.")
def _audit_categorical_variables(
df: AnyDFType,
id_col: str = "unique_id",
time_col: str = "ds",
) -> tuple[AuditDataSeverity, AnyDFType]:
if isinstance(df, pd.DataFrame):
# Check categorical variables in df except id_col and time_col
categorical_cols = (
df.select_dtypes(include=["category", "object"])
.columns.drop([id_col, time_col], errors="ignore")
.tolist()
)
if categorical_cols:
return AuditDataSeverity.FAIL, df[categorical_cols]
return AuditDataSeverity.PASS, pd.DataFrame()
else:
raise ValueError(f"Dataframe type {type(df)} is not supported yet.")
def _audit_leading_zeros(
df: pd.DataFrame,
id_col: str = "unique_id",
time_col: str = "ds",
target_col: str = "y",
) -> tuple[AuditDataSeverity, pd.DataFrame]:
df = ensure_sorted(df, id_col=id_col, time_col=time_col)
if isinstance(df, pd.DataFrame):
group_info = df.groupby(id_col).agg(
first_index=(target_col, lambda s: s.index[0]),
first_nonzero_index=(
target_col,
lambda s: s.ne(0).idxmax() if s.ne(0).any() else s.index[0],
),
)
leading_zeros_df = group_info[
group_info["first_index"] != group_info["first_nonzero_index"]
].reset_index()
if len(leading_zeros_df) > 0:
return AuditDataSeverity.CASE_SPECIFIC, leading_zeros_df
return AuditDataSeverity.PASS, pd.DataFrame()
else:
raise ValueError(f"Dataframe type {type(df)} is not supported yet.")
def _audit_negative_values(
df: AnyDFType,
target_col: str = "y",
) -> tuple[AuditDataSeverity, AnyDFType]:
if isinstance(df, pd.DataFrame):
negative_values = df.loc[df[target_col] < 0]
if len(negative_values) > 0:
return AuditDataSeverity.CASE_SPECIFIC, negative_values
return AuditDataSeverity.PASS, pd.DataFrame()
else:
raise ValueError(f"Dataframe type {type(df)} is not supported yet.")
class ApiError(Exception):
status_code: Optional[int]
body: Any
def __init__(
self, *, status_code: Optional[int] = None, body: Optional[Any] = None
):
self.status_code = status_code
self.body = body
def __str__(self) -> str:
return f"status_code: {self.status_code}, body: {self.body}"
class NixtlaClient:
def __init__(
self,
api_key: Optional[str] = None,
base_url: Optional[str] = None,
timeout: Optional[int] = 60,
max_retries: int = 6,
retry_interval: int = 10,
max_wait_time: int = 6 * 60,
):
"""
Client to interact with the Nixtla API.
Args:
api_key (str, optional): The authorization API key to interact
with the Nixtla API. If not provided, will use the
NIXTLA_API_KEY environment variable.
base_url (str, optional): Custom base URL.
If not provided, will use the NIXTLA_BASE_URL environment
variable.
timeout (int, optional): Request timeout in seconds.
Set to `None` to disable it. Defaults to 60.
max_retries (int, optional): The maximum number of attempts to
make when calling the API before giving up. It defines how
many times the client will retry the API call if it fails.
Default value is 6, indicating the client will attempt the
API call up to 6 times in total. Defaults to 60.
retry_interval (int, optional): The interval in seconds between
consecutive retry attempts. This is the waiting period before
the client tries to call the API again after a failed attempt.
Default value is 10 seconds, meaning the client waits for
10 seconds between retries. Defaults to 10.
max_wait_time (int, optional): The maximum total time in seconds
that the client will spend on all retry attempts before
giving up. This sets an upper limit on the cumulative
waiting time for all retry attempts. If this time is
exceeded, the client will stop retrying and raise an
exception. Default value is 360 seconds, meaning the
client will cease retrying if the total time spent on retries
exceeds 360 seconds. The client throws a ReadTimeout error
after 60 seconds of inactivity. If you want to catch these
errors, use max_wait_time >> 60. Defaults to 360.
"""
if api_key is None:
api_key = os.environ["NIXTLA_API_KEY"]
if base_url is None:
base_url = os.getenv("NIXTLA_BASE_URL") or "https://api.nixtla.io"
self._client_kwargs = {
"base_url": base_url,
"headers": {
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json",
},
"timeout": timeout,
}
self._retry_strategy = _retry_strategy(
max_retries=max_retries,
retry_interval=retry_interval,
max_wait_time=max_wait_time,
)
self._model_params: dict[tuple[str, str], tuple[int, int]] = {}
self._is_azure = "ai.azure" in base_url
def _make_request(
self,
client: httpx.Client,
endpoint: str,
payload: dict[str, Any],
multithreaded_compress: bool,
) -> dict[str, Any]:
def ensure_contiguous_if_array(x):
if not isinstance(x, np.ndarray):
return x
if np.issubdtype(x.dtype, np.floating):
x = np.nan_to_num(
np.ascontiguousarray(x, dtype=np.float32),
nan=np.nan,
posinf=np.finfo(np.float32).max,
neginf=np.finfo(np.float32).min,
copy=False,
)
else:
x = np.ascontiguousarray(x)
return x
def ensure_contiguous_arrays(d: dict[str, Any]) -> None:
for k, v in d.items():
if isinstance(v, np.ndarray):
d[k] = ensure_contiguous_if_array(v)
elif isinstance(v, list):
d[k] = [ensure_contiguous_if_array(x) for x in v]
elif isinstance(v, dict):
ensure_contiguous_arrays(v)
ensure_contiguous_arrays(payload)
content = orjson.dumps(payload, option=orjson.OPT_SERIALIZE_NUMPY)
content_size_mb = len(content) / 2**20
if content_size_mb > 200:
raise ValueError(
f"The payload is too large. Set num_partitions={math.ceil(content_size_mb / 200)}"
)