In [1]:
import os
import sys

os.chdir("..")
sys.path.append("..")

In [10]:
import yaml

import matplotlib.pyplot as plt
import numpy as np
import torch

from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler
from statsmodels.tsa.seasonal import STL
from tqdm import tqdm


from src.models.utils import get_model
from src.utils.data_loading import load_features, load_test_data, load_score
from src.utils.features import decomps_and_features
from src.utils.transformations import manipulate_trend_component, manipulate_seasonal_determination

In [19]:
models = ["feedforward", "seq2seq", "nbeats_g", "tcn", "transformer"]
datasets = ["electricity_nips", "traffic_nips", "m4_hourly", "m4_daily", "m4_weekly", "m4_monthly", "m4_quarterly", "m4_yearly"]

In [22]:
num_params = {}

for dataset in datasets:
    datadir = f"data/{dataset}"
    num_params[dataset] = {}
    for model_name in models:
        experiment_dir = f"experiments/{dataset}/{model_name}"
        
        with open(os.path.join(experiment_dir, "config.yaml"), "r") as f:
            config = yaml.load(f, Loader=yaml.FullLoader)
            
        device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        model = get_model(config["model_name"])(**config["model_args"], device=device, path=config["path"])
        model.load_state_dict(torch.load(os.path.join(config["path"], "model.pth")))
        
        num_params[dataset][model_name] = sum(p.numel() for p in model.parameters() if p.requires_grad)

In [24]:
for dataset in num_params:
    print(dataset)
    for model, num in num_params[dataset].items():
        print(f"\t{model}: {num}")

electricity_nips
	feedforward: 29424
	seq2seq: 411579
	nbeats_g: 29189760
	tcn: 8107
	transformer: 23015
traffic_nips
	feedforward: 29424
	seq2seq: 414544
	nbeats_g: 29189760
	tcn: 11072
	transformer: 25980
m4_hourly
	feedforward: 24648
	seq2seq: 411799
	nbeats_g: 27345120
	tcn: 7527
	transformer: 23235
m4_daily
	feedforward: 14414
	seq2seq: 430864
	nbeats_g: 24730860
	tcn: 24992
	transformer: 42300
m4_weekly
	feedforward: 14113
	seq2seq: 411524
	nbeats_g: 24653970
	tcn: 5652
	transformer: 22960
m4_monthly
	feedforward: 15618
	seq2seq: 649729
	nbeats_g: 25038420
	tcn: 244657
	transformer: 261165
m4_quarterly
	feedforward: 12608
	seq2seq: 529729
	nbeats_g: 24269520
	tcn: 123057
	transformer: 141165
m4_yearly
	feedforward: 12006
	seq2seq: 524729
	nbeats_g: 24115740
	tcn: 118057
	transformer: 136165
