Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions config/moirai2.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
{
"repo": "Salesforce/moirai-2.0-R-small",
"config": {
"context_len": 128,
"horizon_len": 64,
"num_layers": 100,
"model_type": "moirai2",
"model_size": "small"
}
}
12 changes: 12 additions & 0 deletions config/timesfm_2p5.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
{
"repo": "google/timesfm-2.5-200m-pytorch",
"config": {
"max_context": 1024,
"max_horizon": 256,
"normalize_inputs": true,
"use_continuous_quantile_head": true,
"force_flip_invariance": true,
"infer_is_positive": true,
"fix_quantile_crossing": true
}
}
25 changes: 6 additions & 19 deletions example/chronosbolt.ipynb

Large diffs are not rendered by default.

80 changes: 49 additions & 31 deletions example/moirai.ipynb

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions example/moment_anomaly_detection.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
"from samay.model import MomentModel\n",
"from samay.dataset import MomentDataset\n",
"from samay.utils import load_args\n",
"from samay.models.momentfm.utils.anomaly_detection_metrics import adjbestf1\n",
"\n",
"repo = \"AutonLab/MOMENT-1-large\"\n",
"config = {\n",
Expand Down
182 changes: 182 additions & 0 deletions example/timesfm_2p5.ipynb

Large diffs are not rendered by default.

237 changes: 237 additions & 0 deletions example/trial.ipynb

Large diffs are not rendered by default.

56 changes: 44 additions & 12 deletions leaderboard.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@
import pandas as pd
import torch

src_path = os.path.abspath(os.path.join("src"))
if src_path not in sys.path:
sys.path.insert(0, src_path)

from samay.model import TimesfmModel, MomentModel, ChronosModel, ChronosBoltModel, TinyTimeMixerModel, MoiraiTSModel, LPTMModel, TimeMoEModel
from samay.dataset import TimesfmDataset, MomentDataset, ChronosDataset, ChronosBoltDataset, TinyTimeMixerDataset, MoiraiDataset, LPTMDataset, TimeMoEDataset
Expand Down Expand Up @@ -39,7 +42,7 @@
# }


SERIES = "gifteval" # "monash" or "gifteval"
SERIES = "monash" # "monash" or "gifteval"

print("Loading datasets...")
# start = time.time()
Expand Down Expand Up @@ -120,22 +123,22 @@
}

MONASH_SETTINGS = {
# "weather": 30,
"weather": 30,
"tourism_yearly": 4,
"tourism_quarterly": 8,
"tourism_monthly": 24,
"cif_2016": 12,
# "london_smart_meters": 60,
"london_smart_meters": 60,
"australian_electricity_demand": 60,
# "wind_farms_minutely": 60,
"wind_farms_minutely": 60,
"bitcoin": 30,
"pedestrian_counts": 48,
"vehicle_trips": 30,
"kdd_cup_2018": 48,
"nn5_daily": 56,
"nn5_weekly": 8,
# "kaggle_web_traffic": 59,
# "kaggle_web_traffic_weekly": 8,
"kaggle_web_traffic": 59,
"kaggle_web_traffic_weekly": 8,
"solar_10_minutes": 60,
"solar_weekly": 5,
"car_parts": 12,
Expand All @@ -162,7 +165,7 @@
NAMES = get_gifteval_datasets("data/gifteval")
elif SERIES == "monash":
# Load the datasets from the Monash dataset
NAMES = get_monash_datasets("data/monash", MONASH_NAMES, MONASH_SETTINGS)
NAMES = get_monash_datasets("data/monash")

end = time.time()
print(NAMES)
Expand Down Expand Up @@ -202,7 +205,7 @@ def calc_pred_and_context_len(freq):

if __name__ == "__main__":

for model_name in ["timemoe"]:
for model_name in ["moirai2"]:
print(f"Evaluating model: {model_name}")
# create csv file for leaderboard if not already created
csv_path = f"leaderboard/{model_name}.csv"
Expand Down Expand Up @@ -246,6 +249,9 @@ def calc_pred_and_context_len(freq):
elif model_name == "moirai":
arg_path = "config/moirai.json"
args = load_args(arg_path)
elif model_name == "moirai2":
arg_path = "config/moirai2.json"
args = load_args(arg_path)
elif model_name == "lptm":
arg_path = "config/lptm.json"
args = load_args(arg_path)
Expand All @@ -255,7 +261,8 @@ def calc_pred_and_context_len(freq):

mod_start = time.time()
mod_timestamp = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
for fname, freq, fs in filesizes:
for fpath, (freq, fs) in NAMES.items():
fname = fpath.split("/")[2]
print(f"Model eval started at: {mod_timestamp}")
print(
f"Evaluating {fname} ({freq}) started at: {datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}"
Expand All @@ -275,7 +282,7 @@ def calc_pred_and_context_len(freq):
elif model_name == "ttm":
args["config"]["horizon_len"] = pred_len
args["config"]["context_len"] = context_len
elif model_name == "moirai":
elif model_name == "moirai" or model_name == "moirai2":
args["config"]["horizon_len"] = pred_len
args["config"]["context_len"] = context_len

Expand Down Expand Up @@ -451,6 +458,31 @@ def calc_pred_and_context_len(freq):
del dataset
torch.cuda.empty_cache()
gc.collect()

elif model_name == "moirai2":
model = MoiraiTSModel(**args)
dataset = MoiraiDataset(
name=fname,
datetime_col="timestamp",
freq=freq,
path=dataset_path,
mode="test",
context_len=context_len,
horizon_len=pred_len,
)

start = time.time()
metrics = model.evaluate(dataset, leaderboard=True)
end = time.time()
print(f"Size of dataset: {fs:.2f} MB")
print(
f"Time taken for evaluation of {fname}: {end - start:.2f} seconds"
)

del model
del dataset
torch.cuda.empty_cache()
gc.collect()

elif model_name == "lptm":
args["config"]["task_name"] = "forecasting2"
Expand Down Expand Up @@ -531,7 +563,7 @@ def calc_pred_and_context_len(freq):
df.to_csv(csv_path, index=False)
mod_end = time.time()
print(f"Time taken for model {model_name}: {mod_end - mod_start:.2f} seconds")
mod_times[model_name] = round(mod_end - mod_start, 2)
mod_timestamp[model_name] = round(mod_end - mod_start, 2)

print("All models evaluated!")
print("Model evaluation times: ", mod_times)
print("Model evaluation times: ", mod_timestamp)
56 changes: 56 additions & 0 deletions leaderboard/moirai2.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
dataset,size_in_MB,eval_time,mse,mae,mase,mape,rmse,nrmse,smape,msis,nd,mwsq,crps
us_births (M),0.0,0.72s,1.5618207475695225,1.114127944727978,1.8662446050970924,0.7233964267224883,1.2497282694928216,0.3878005765718537,1.3142118012513526,0.3509142943015808,0.8388756017713396,0.106674720349532,1.093371706001256
ett1 (W),0.01,0.05s,1.1976832933498784,0.8352266027516095,1.6916772244537828,-1.6767026023162568,1.0943871770766864,0.3188511536006459,1.2240231064721958,0.2521905408895456,-20.62029143884937,0.0943893091452368,1.0167355318211104
ett2 (W),0.01,0.05s,0.7753762452889312,0.6734449925184919,2.971465602726998,1.0962724162106536,0.8805545101178752,0.2413152034722861,1.0035548922160251,0.2369503083491107,-1.0087637456025809,0.0275568384021428,1.193377547065538
saugeenday (M),0.02,0.04s,0.3791637376167401,0.4140662770653676,0.6168052311848707,1.2906715251939034,0.6157627283432638,0.1484457277965084,0.8094952833871165,0.0836234314847085,4.78772491719547,0.1536365231860225,0.3560235590310691
us_births (W),0.02,0.05s,0.2212465707900221,0.3134049087129613,1.278917777028986,0.2657185080629962,0.3824549913652852,0.1310138361261375,0.2991811455553517,0.1252945644915848,0.2212199241431789,0.235366953287976,0.1916985204723228
ett1 (D),0.05,0.05s,0.640002297339701,0.6394201910805829,1.7783068983191364,-0.1283846296712268,0.8000014358360247,0.1865967478014498,1.2284616038297638,0.1958092301787431,-3.404135994329177,0.3253131309167447,0.4772902802706129
ett2 (D),0.05,0.05s,0.9000595886642733,0.7673781522857226,4.675992175860224,-1.8098284507559927,0.9487147035143249,0.2532869927272646,1.323599299713584,0.2814078285479234,-1.0452313073568982,0.2605802238470588,0.9766917659867376
solar (W),0.06,0.5s,2.1906756767310247,1.3389192316233405,1.5478764610426743,-0.8083795570390953,1.4800931311005483,0.4226288831746657,1.7382591926323756,0.5255064725995481,-0.996140032244834,1.3615429995581023,2.320574553629524
saugeenday (W),0.07,0.05s,1.2939860472399958,0.5506614547380495,1.2758437043987887,0.6045803106937465,1.1335463478550842,0.1767842617567758,0.9588178608347544,0.0869638279535065,6.196270263005705,-0.0104231083696542,0.5660171870260812
jena_weather (D),0.08,0.1s,0.46481772862836,0.462728795546598,1.529422709052349,2.3749521954874173,0.6817754238958457,0.0913233740705993,0.8763643926883901,0.1021245251108769,-1.1685434714612388,0.1885007885956199,0.3965992726314405
us_births (D),0.13,0.07s,0.2170142096028851,0.30351963503741,0.344623471018035,-0.2216703152082396,0.4384489429685211,0.1136014547824791,0.3119616319446845,0.074590279576131,0.3457695048954985,0.0634586315364336,0.2189554676275023
hierarchical_sales (W),0.15,0.43s,0.8573508615965433,0.5703766021105302,0.8485540487759559,-0.1190983866276265,0.9259324282022654,0.0552369369864987,1.0105003519129914,0.1150261743659994,-4.018324875729953,0.1940347784611105,0.59078965665434
bizitobs_l2c (H),0.18,0.16s,0.3814205784640619,0.3566320194072934,3.934213870432756,0.0361783463699681,0.5537551001240625,0.128893625292251,0.5882020205848557,0.0800048706540301,-0.0614852850435086,0.1144914814967857,0.2761122069349694
M_DENSE (D),0.21,0.11s,1.5202762105186614,0.7496050409399547,1.1339096401153206,0.5732222420885908,1.2329948136625155,0.1151983237700738,0.7870548509569417,0.1575935549103235,-5.204200135537375,-0.84573408654159,3.164533638548172
covid_deaths (D),0.27,0.93s,38.504567719036274,1.5400394452458483,15.557689624111836,0.1317982441321062,6.205204889367979,0.0531776979228609,0.1913422444387034,0.033424944198528,0.4519661028787954,4.95614705732852,126.64622268402154
bizitobs_application (10s),0.33,0.16s,0.2001495654372555,0.182952270441944,6.2945761379808856,-0.0088061539722238,0.2796514739226038,0.1290526677813379,0.3409897768197215,0.0812678037752585,0.1365893116515497,0.0016719740577447,0.1805862230654025
solar (D),0.35,0.49s,2.0637404228293468,1.078350279808447,1.3019442677516824,-0.9541530841198606,1.436572456519109,0.3385662583308205,1.7257210376606285,0.2834827805154771,-1.290577631546478,1.0609962176724352,2.0995334900309626
hospital (ME),0.35,2.63s,0.9334167007015908,0.7310550844425,0.8544896324791669,26.0512851440814,0.966134928828055,0.0951547615998668,1.184317791889724,0.1278078685610263,3.1664436777448683,0.3551075388135753,0.8498453279393328
saugeenday (D),0.38,0.23s,1.2418786706968648,0.4961153800355255,3.747074390889561,-1451.4983217574352,0.827398556859149,0.2395087764512039,0.8716079347962947,0.158533773606695,0.0557121405729807,-0.0621391201358524,0.5795737029876511
car_parts_with_missing (ME),0.58,9.11s,1.8753214009850416,0.4689303009320326,0.7222610028765081,32.256627428926514,1.3694237477804458,0.0209263715748897,0.5082102445561045,0.0573740827845184,-17.449022252854995,0.0468302897724156,0.8892621459157944
electricity (W),0.66,1.3s,137420812.97130525,458.71490062498543,9.44302862718281,0.8440714891016733,11722.662367026753,0.0330747612568545,0.709304505556285,69.67677654161118,1.0003089780822898,0.2429718747136128,66.0610733532417
hierarchical_sales (D),0.9,1.25s,0.9313689066765386,0.5464736461007927,0.8090745056393108,0.1627537383444756,0.9632634676021296,0.0571299777292587,1.0200595328120816,0.1123750245653353,13.461161462702924,0.0965877316611081,0.5217708737210384
kdd_cup_2018_with_missing (D),1.08,0.93s,1.070217638934671,0.6497289258892555,1.14395724061027,0.6266911940736422,1.0345132376797654,0.0329394540252049,1.1528128745042638,0.1521307666890273,-18.540340250195268,-0.0418543618891689,0.67923946299307
LOOP_SEATTLE (D),1.13,1.14s,0.969241464413956,0.6839508083159247,0.7306472863450527,0.1085332894378542,0.9845006167666712,0.0744202773793354,0.8696419810686549,0.1211121493772571,-2.9040136397151155,0.2327376939811077,0.9448497104278318
SZ_TAXI (H),1.14,0.55s,0.2091646697959583,0.3160283092759068,0.8532816266827321,-0.2983466827545597,0.4573452413614449,0.0678801178051041,0.6155771081765646,0.1055975112782393,-0.5839166765785212,0.1605316962514359,0.2009855887963579
ett1 (H),1.22,0.9s,0.4084598253839612,0.4117874904996299,1.8100245536802064,-1.8204830427552905,0.6168273265663928,0.1038510057804331,0.727088951166426,0.0968809131626775,101.82769192113868,0.0476616118890475,0.6572092625998253
ett2 (H),1.26,0.91s,0.150393013451258,0.2493953824483289,2.2623174678401967,0.0337385578853845,0.3539967914256723,0.1023097691558252,0.6400689179639643,0.0914314296690173,0.5329406964712626,0.0847824799291074,0.1681873401263136
jena_weather (H),1.65,1.35s,0.2561291605928043,0.2861363923875734,3.501378811272285,-0.4489240467029559,0.4948277490852944,0.0825490755315428,0.654368747678897,0.0864938441845049,-2.323662552275527,0.1026014600785833,0.187617565837507
bizitobs_l2c (5T),1.68,1.7s,0.2774132201123239,0.2582235284881951,55.21828304433266,-0.0307860776102781,0.404618183167394,0.1562390447634499,0.4313020529666456,0.1024837961145215,0.0545487131543644,-0.0163956041363768,0.2295854589810427
restaurant (D),1.77,2.79s,26.78373349868724,0.990423541004266,1.1121647444231988,2.425263771833482,5.175300329322661,0.0259509212871583,1.0521503958354017,0.2018534594424923,1.1643380900871894,0.5405322870170337,27.619815157471365
m4_hourly (h),2.43,2.91s,0.114839168929891,0.1557805711950071,0.9713727054400704,-0.0380367601451967,0.3309915870225652,0.0393216207509111,0.3654871225423399,0.0483959948670403,0.5576057573391415,0.027841870720656,0.0920221583234635
bizitobs_service (10s),3.07,2.67s,0.459612539238353,0.3508957145858495,7.225407512838781,0.2372734353640967,0.6481100032671299,0.1260535326715766,0.5616821650357674,0.08925058094901,1.95655655107444,-0.0389959828941912,0.558527321464036
M_DENSE (H),3.7,3.86s,0.2290540528339278,0.25402215034518,0.7394580185261268,-0.1750812521399131,0.4517563410562519,0.069447260657107,0.4682000124345461,0.0744508441583504,-2.352079486426145,0.049078732815286,0.1941743681487937
ett1 (15T),4.4,3.87s,0.3970112528963926,0.3765267557229977,2.7639435877955028,-0.8236410273377459,0.5921374046768627,0.1146154630217933,0.6730127855218513,0.0871713208507931,-0.7166911025807322,0.0662346987405489,0.5340310698329703
ett2 (15T),4.57,3.83s,0.1066165572006297,0.2156378008502821,2.705775335198284,0.658947281658779,0.3051382152895508,0.0987241697681859,0.5921184575632874,0.0805381860472943,-0.611517597539877,0.0466238708852433,0.1353224909646941
SZ_TAXI (15T),4.58,3.24s,0.3201090805988852,0.3892825320744635,0.9393528090877864,12.24086730644436,0.5651520199828886,0.0676203857529825,0.7404812353354995,0.1129100472962687,-0.8305596418280313,0.1314106721693844,0.2900697496786759
electricity (D),4.63,3.84s,1351260.1222181264,33.110530917893165,6.676083436126185,-0.8421084036151313,782.8566459233369,0.0228944895849153,0.542807560955192,18.412028086909185,0.7411831170416928,-62519.94139253908,1442499654.990039
solar (H),5.97,8.64s,0.2197442160244397,0.2144319019566735,1.2277385840529491,-0.196133507626756,0.4424378749907756,0.1368061894031313,0.3243819017965509,0.0738721865696927,-2.045738370587047,0.0443261534333789,0.1327801707690051
bitbrains_rnd (H),6.1,3.43s,43.07305225258606,0.8474669387254162,2.149173289906192,-0.2243800214350195,6.563006342567868,0.0064521179240096,0.5755518603012232,0.0331688495390634,-2.3605612031422614,0.6758162051078175,11.823752201902396
m4_weekly (W-SUN),7.18,6.16s,0.0118760552176095,0.0165799924055716,4.875090039017533,-0.0034710402652487,0.0943505084106502,0.0113352670013461,0.011878821961081,0.0155511218152244,0.0365194408306354,0.0076039590678303,0.0363696959650915
jena_weather (10T),7.18,8.55s,0.1578926971390692,0.1953089946678596,5.177155154654224,0.4781668341732507,0.3734198895876103,0.083502917783025,0.4931976602911049,0.0656340565773135,1.8448059999565471,0.0562227250356211,0.1124969808848997
kdd_cup_2018_with_missing (H),14.28,20.95s,0.8854415544327428,0.5505850301492371,3.385052668597709,-0.1799570222748047,0.8718096321908475,0.0692394521961844,1.0491254278774729,0.1433438316191303,0.8203745590567397,0.0824589888569725,0.5765527347596519
bitbrains_fast_storage (H),15.64,8.58s,15.000319900900791,0.544571573855085,1.320756724442783,-0.2755196533738627,3.8730246450159327,0.0053183301601658,0.7213016666392367,0.1059652514229847,5.057050122726341,0.1585859067694192,4.80040230872521
LOOP_SEATTLE (H),27.05,20.26s,0.5599113734390007,0.423220551104305,1.2129690297625784,0.0382687396271934,0.7282390522529767,0.0568397030935923,0.7478995960144319,0.0944913143934629,1.93249755792969,0.1066675271162883,0.9817790555596028
solar (10T),33.4,58.18s,0.135528298167247,0.1668160437497549,3.47102455989649,-0.6638918974704568,0.3045184058128946,0.1068429887345069,0.3020665701862008,0.0926846920794283,2.639260906254941,0.0565264113088214,0.0886823663431343
m4_yearly (YE-DEC),51.4,79.81s,0.0053045228568788,0.0131120896915413,43.7633672632011,0.0719534957595048,0.0728321553771331,0.0033780184969905,0.0453064278443034,0.0277057577527267,0.0523417783393487,0.0073103659378347,0.0053249508106549
bitbrains_rnd (5T),63.69,63.68s,252.71855612221083,0.4406898078723834,0.8553251210325973,603.7336670121113,8.054605160775605,0.0142031689030316,0.4414359129929475,0.0919010291211265,0.7474936836955323,-2.489079740213258,848.6411586065256
electricity (H),110.58,101.01s,183.4537689942918,0.5391298273655827,1.028040055915683,1491.3586753809727,8.647953813025216,0.0080527491217961,0.3810214084339235,1.101122464067246,0.2611993936088893,-72.84790128736753,277958.2038584184
temperature_rain_with_missing (D),113.99,112.32s,6.364929800835,0.7636738721535776,1.066015280023254,2.535600090034728,2.522881249848078,0.003400277426589,0.7578630937377708,0.0937278317567291,1.8616786534485303,0.0408604270291028,1.64943477270479
bitbrains_fast_storage (5T),160.06,160.32s,7505.603532557848,1.1992279611120773,1.2078396605763542,58.467442244248026,25.847548942551946,0.0147718258510876,0.5233332340770569,0.3424046483653578,-1.5846809087277534,2.261512861421257,7459.360552487737
m4_quarterly (QE-DEC),163.93,84.07s,0.0005428223573765,0.0005831916714262,10.42038125109977,0.0002272885760235,0.0232985483963399,0.0037582143372271,0.0007539949299842,0.010581338498256,0.0027180472802341,0.0002383654295061,0.0005218964130563
m4_daily (D),316.28,306.76s,0.0008733735383475,0.0002450463141843,5.314440321933075,-8.548048707285547e-06,0.0195783559057103,0.0018120517668324,0.0001617314180249,0.0044619037030663,0.0011340822787879,-0.0001024467675414,0.0036900865235162
LOOP_SEATTLE (5T),324.08,300.46s,0.6040522188236592,0.4490252668197635,1.6027602310266642,0.2430395688726573,0.7435173022765758,0.0701184516854129,0.819615631252254,0.107330173029401,-3.3063240559673863,0.2174530420178736,0.8236256033832557
electricity (15T),442.39,501.22s,12.79965959138448,0.3389993187426714,1.4571262314973856,308.1255846202696,2.940649544584549,0.0067247496886965,0.4211286269575287,0.3735479254117926,0.3152344364247284,-1.8706622770018224,2709.352654836476
m4_monthly (ME),1025.34,16.72m,0.0008910824007933492,0.00013815554308538395,3.432069901342716,3.6564446104936877e-05,0.026039563480095278,0.0013936455221847846,8.633480089219661e-05,0.008221830063983306,0.001579898179219596,0.0002981379964391846,0.002655506282765169
Loading