In [2]:
import argparse
import data.ts_datasets as my
import pandas as pd
from models import DumbMLP4TS, ARv2
import torch
import numpy as np
import pickle as pkl
import time 
import sys

In [2]:
DEVICE = torch.device("cpu")
dataset="ETTh1_96"
batch_size=64
multi_channel=False

file_dict = {
    "ETTh1": "ETT-small/ETTh1.csv",
    "ETTh2": "ETT-small/ETTh2.csv",
    "ETTm1": "ETT-small/ETTm1.csv",
    "ETTm2": "ETT-small/ETTm2.csv",
    "ECL": "electricity/electricity.csv",
    "ER": "exchange_rate/exchange_rate.csv",
    "ILI": "illness/national_illness.csv",
    "Traffic": "traffic/traffic.csv",
    "Weather": "weather/weather.csv"
}


params = dataset.split("_")
prefix = params[0]
horizon = int(params[1])
assert prefix in file_dict, f"Invalid dataset {dataset} from possible {list(file_dict.keys())}"
input_length = 96 if prefix == "ILI" else 512
train_loader, val_loader, test_loader = my.get_timeseries_dataloaders(
    f"./data/ts_datasets/all_six_datasets/{file_dict[prefix]}", 
    batch_size=batch_size, 
    seq_len=input_length,
    forecast_horizon=horizon,
    multi_channel=multi_channel
)
train_data, val_data, test_data = train_loader.dataset, val_loader.dataset, test_loader.dataset
train_df, val_df, test_df = pd.DataFrame(train_data.data), pd.DataFrame(val_data.data), pd.DataFrame(test_data.data)

do_diff = True
start_time = time.perf_counter()
ar = ARv2(input_length, horizon)
ar.fit_raw(train_df, val_df, DEVICE, 10, do_diff, use_ols=True, first_lag=400)
print(f"AR search for {dataset} {horizon} finished in {time.perf_counter() - start_time} seconds)")
ar.fit_preset(pd.concat((train_df, val_df)), ar.best_lags, do_diff, torch.device("cpu"), use_ols=True)
mse, mae = ar.test_loss_acc_df(test_df, torch.device("cpu"))
print(f"Final AR metrics for {dataset} {horizon}: MSE {np.mean(mse)} | MAE {np.mean(mae)}")

10 hours limit


  0%|          | 0/112 [00:00<?, ?it/s]

400


  1%|          | 1/112 [00:00<01:02,  1.76it/s]

401


  2%|▏         | 2/112 [00:01<00:56,  1.94it/s]

402


  3%|▎         | 3/112 [00:01<00:55,  1.97it/s]

403


  4%|▎         | 4/112 [00:02<00:55,  1.96it/s]

404


  4%|▍         | 5/112 [00:02<00:54,  1.96it/s]

405


  5%|▌         | 6/112 [00:03<00:54,  1.94it/s]

406


  6%|▋         | 7/112 [00:03<00:52,  1.99it/s]

407


  7%|▋         | 8/112 [00:04<00:51,  2.03it/s]

408


  8%|▊         | 9/112 [00:04<00:50,  2.05it/s]

409


  9%|▉         | 10/112 [00:04<00:49,  2.07it/s]

410


 10%|▉         | 11/112 [00:05<00:49,  2.05it/s]

411


 11%|█         | 12/112 [00:05<00:48,  2.07it/s]

412


 12%|█▏        | 13/112 [00:06<00:47,  2.07it/s]

413


 12%|█▎        | 14/112 [00:06<00:49,  2.00it/s]

414


 13%|█▎        | 15/112 [00:07<00:48,  1.98it/s]

415


 14%|█▍        | 16/112 [00:07<00:47,  2.01it/s]

416


 15%|█▌        | 17/112 [00:08<00:47,  2.01it/s]

417


 16%|█▌        | 18/112 [00:08<00:46,  2.02it/s]

418


 17%|█▋        | 19/112 [00:09<00:46,  2.02it/s]

419


 18%|█▊        | 20/112 [00:09<00:45,  2.03it/s]

420


 19%|█▉        | 21/112 [00:10<00:44,  2.03it/s]

421


 20%|█▉        | 22/112 [00:10<00:45,  2.00it/s]

422


 21%|██        | 23/112 [00:11<00:44,  2.01it/s]

423


 21%|██▏       | 24/112 [00:11<00:43,  2.02it/s]

424


 22%|██▏       | 25/112 [00:12<00:42,  2.03it/s]

425


 23%|██▎       | 26/112 [00:12<00:43,  1.97it/s]

426


 24%|██▍       | 27/112 [00:13<00:42,  1.99it/s]

427


 25%|██▌       | 28/112 [00:13<00:42,  2.00it/s]

428


 26%|██▌       | 29/112 [00:14<00:41,  2.01it/s]

429


 27%|██▋       | 30/112 [00:14<00:41,  1.99it/s]

430


 28%|██▊       | 31/112 [00:15<00:40,  1.99it/s]

431


 29%|██▊       | 32/112 [00:16<00:52,  1.52it/s]

432


 29%|██▉       | 33/112 [00:17<01:00,  1.31it/s]

433


 30%|███       | 34/112 [00:18<01:04,  1.21it/s]

434


 31%|███▏      | 35/112 [00:19<00:57,  1.34it/s]

435


 32%|███▏      | 36/112 [00:19<00:51,  1.48it/s]

436


 33%|███▎      | 37/112 [00:20<00:47,  1.58it/s]

437


 34%|███▍      | 38/112 [00:20<00:44,  1.66it/s]

438


 35%|███▍      | 39/112 [00:21<00:42,  1.72it/s]

439


 36%|███▌      | 40/112 [00:21<00:40,  1.77it/s]

440


 37%|███▋      | 41/112 [00:22<00:39,  1.80it/s]

441


 38%|███▊      | 42/112 [00:22<00:38,  1.81it/s]

442


 38%|███▊      | 43/112 [00:23<00:37,  1.82it/s]

443


 39%|███▉      | 44/112 [00:23<00:37,  1.83it/s]

444


 40%|████      | 45/112 [00:24<00:36,  1.83it/s]

445


 41%|████      | 46/112 [00:24<00:36,  1.82it/s]

446


 42%|████▏     | 47/112 [00:25<00:36,  1.80it/s]

447


 43%|████▎     | 48/112 [00:26<00:35,  1.79it/s]

448


 44%|████▍     | 49/112 [00:26<00:35,  1.79it/s]

449


 45%|████▍     | 50/112 [00:27<00:41,  1.48it/s]

450


 46%|████▌     | 51/112 [00:28<00:45,  1.35it/s]

451


 46%|████▋     | 52/112 [00:29<00:46,  1.28it/s]

452


 47%|████▋     | 53/112 [00:29<00:42,  1.39it/s]

453


 48%|████▊     | 54/112 [00:30<00:39,  1.47it/s]

454


 49%|████▉     | 55/112 [00:31<00:37,  1.54it/s]

455


 50%|█████     | 56/112 [00:31<00:35,  1.58it/s]

456


 51%|█████     | 57/112 [00:32<00:34,  1.62it/s]

457


 52%|█████▏    | 58/112 [00:32<00:33,  1.63it/s]

458


 53%|█████▎    | 59/112 [00:33<00:33,  1.60it/s]

459


 54%|█████▎    | 60/112 [00:34<00:31,  1.63it/s]

460


 54%|█████▍    | 61/112 [00:34<00:31,  1.63it/s]

461


 55%|█████▌    | 62/112 [00:35<00:32,  1.53it/s]

462


 56%|█████▋    | 63/112 [00:36<00:31,  1.56it/s]

463


 57%|█████▋    | 64/112 [00:36<00:30,  1.58it/s]

464


 58%|█████▊    | 65/112 [00:37<00:31,  1.47it/s]

465


 59%|█████▉    | 66/112 [00:38<00:30,  1.51it/s]

466


 60%|█████▉    | 67/112 [00:38<00:29,  1.52it/s]

467


 61%|██████    | 68/112 [00:39<00:29,  1.49it/s]

468


 62%|██████▏   | 69/112 [00:40<00:28,  1.51it/s]

469


 62%|██████▎   | 70/112 [00:40<00:27,  1.51it/s]

470


 63%|██████▎   | 71/112 [00:42<00:39,  1.04it/s]

471


 64%|██████▍   | 72/112 [00:45<01:06,  1.66s/it]

472


 65%|██████▌   | 73/112 [00:48<01:12,  1.86s/it]

473


 66%|██████▌   | 74/112 [00:48<00:56,  1.50s/it]

474


 67%|██████▋   | 75/112 [00:49<00:46,  1.24s/it]

475


 68%|██████▊   | 76/112 [00:49<00:38,  1.06s/it]

476


 69%|██████▉   | 77/112 [00:50<00:32,  1.07it/s]

477


 70%|██████▉   | 78/112 [00:51<00:28,  1.18it/s]

478


 71%|███████   | 79/112 [00:51<00:25,  1.27it/s]

479


 71%|███████▏  | 80/112 [00:52<00:23,  1.35it/s]

480


 72%|███████▏  | 81/112 [00:53<00:23,  1.34it/s]

481


 73%|███████▎  | 82/112 [00:53<00:21,  1.38it/s]

482


 74%|███████▍  | 83/112 [00:54<00:20,  1.41it/s]

483


 75%|███████▌  | 84/112 [00:55<00:19,  1.42it/s]

484


 76%|███████▌  | 85/112 [00:56<00:18,  1.43it/s]

485


 77%|███████▋  | 86/112 [00:56<00:18,  1.43it/s]

486


 78%|███████▊  | 87/112 [00:57<00:17,  1.44it/s]

487


 79%|███████▊  | 88/112 [00:58<00:16,  1.43it/s]

488


 79%|███████▉  | 89/112 [00:58<00:16,  1.43it/s]

489


 80%|████████  | 90/112 [00:59<00:15,  1.43it/s]

490


 81%|████████▏ | 91/112 [01:00<00:15,  1.40it/s]

491


 82%|████████▏ | 92/112 [01:00<00:14,  1.39it/s]

492


 83%|████████▎ | 93/112 [01:01<00:13,  1.39it/s]

493


 84%|████████▍ | 94/112 [01:02<00:13,  1.38it/s]

494


 85%|████████▍ | 95/112 [01:03<00:12,  1.36it/s]

495


 86%|████████▌ | 96/112 [01:03<00:11,  1.36it/s]

496


 87%|████████▋ | 97/112 [01:04<00:11,  1.36it/s]

497


 88%|████████▊ | 98/112 [01:05<00:10,  1.37it/s]

498


 88%|████████▊ | 99/112 [01:06<00:09,  1.37it/s]

499


 89%|████████▉ | 100/112 [01:06<00:08,  1.37it/s]

500


 90%|█████████ | 101/112 [01:07<00:08,  1.35it/s]

501


 91%|█████████ | 102/112 [01:08<00:07,  1.36it/s]

502


 92%|█████████▏| 103/112 [01:09<00:06,  1.32it/s]

503


 93%|█████████▎| 104/112 [01:09<00:06,  1.33it/s]

504


 94%|█████████▍| 105/112 [01:10<00:05,  1.34it/s]

505


 95%|█████████▍| 106/112 [01:11<00:04,  1.35it/s]

506


 96%|█████████▌| 107/112 [01:12<00:03,  1.36it/s]

507


 96%|█████████▋| 108/112 [01:12<00:02,  1.35it/s]

508


 97%|█████████▋| 109/112 [01:13<00:02,  1.35it/s]

509


 98%|█████████▊| 110/112 [01:14<00:01,  1.34it/s]

510


 99%|█████████▉| 111/112 [01:15<00:00,  1.34it/s]

511


100%|██████████| 112/112 [01:15<00:00,  1.48it/s]


Best Lags: 503
AR search for ETTh1_96 96 finished in 75.8529167752713 seconds)
Selected Lags: 503
Final AR metrics for ETTh1_96 96: MSE 0.3579343259334564 | MAE 0.38869863748550415


In [19]:
for i, x in enumerate(test_loader):
    if i == 0:
        print(x[1].shape)
    pass

print(i)

torch.Size([64, 1, 96])
304


In [11]:
list1 = np.array([1, 2, 3, 4, 5])
list2 = np.array([5, 6, 7, 8, 9])
y = np.array([1.5, 2.5, 3.5, 4.5, 5.5])

# find alpha such that alpha * list1 + (1 - alpha) * list2 is closest to y
alpha = np.sum(-2*(list2-y)*(list1-list2)) / np.sum(2*(list1-list2)**2)
print(alpha)
print(alpha * list1 + (1 - alpha) * list2)

0.875
[1.5 2.5 3.5 4.5 5.5]


In [4]:
a = np.array([1, 2, 3, 4, 5])
print(np.flip(a, 0))

[5 4 3 2 1]
