Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
:arrow_minus_sign: Revert dependencies.
- Loading branch information
Daniel Kaminski de Souza
committed
Jul 3, 2020
1 parent
565ac2e
commit d488955
Showing
4 changed files
with
181 additions
and
7 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
""" | ||
Tests the LSTMTimeSeriesPredictor | ||
""" | ||
from pathlib import Path | ||
|
||
import torch | ||
from skorch.callbacks import EarlyStopping | ||
from src.flights_dataset import FlightsDataset | ||
from src.model import BenchmarkLSTM | ||
from src.oze_dataset import OzeNPZDataset, npz_check | ||
from time_series_predictor import TimeSeriesPredictor | ||
# from tune_sklearn.tune_gridsearch import TuneGridSearchCV | ||
|
||
if __name__ == "__main__": | ||
tsp = TimeSeriesPredictor( | ||
BenchmarkLSTM(), | ||
max_epochs=500, | ||
early_stopping=EarlyStopping(patience=30), | ||
# train_split=None, # default = skorch.dataset.CVSplit(5) | ||
optimizer=torch.optim.Adam | ||
) | ||
dataset = OzeNPZDataset( | ||
dataset_path=npz_check( | ||
Path('datasets'), | ||
'dataset' | ||
) | ||
) | ||
|
||
tsp.fit(dataset) | ||
mean_r2_score = tsp.score(dataset) | ||
assert mean_r2_score > -50 | ||
|
||
# if __name__ == "__main__": | ||
# tsp = TimeSeriesPredictor( | ||
# BenchmarkLSTM(), | ||
# max_epochs=50, | ||
# train_split=None, # default = skorch.dataset.CVSplit(5) | ||
# optimizer=torch.optim.Adam | ||
# ) | ||
# dataset = FlightsDataset() | ||
|
||
# tsp.fit(dataset) | ||
# mean_r2_score = tsp.score(dataset) | ||
# assert mean_r2_score > 0.75 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,125 @@ | ||
# | ||
# This file is autogenerated by pip-compile | ||
# To update, run: | ||
# | ||
# pip-compile --find-links=https://download.pytorch.org/whl/torch_stable.html --generate-hashes --output-file=requirements-lock.txt | ||
# | ||
--find-links https://download.pytorch.org/whl/torch_stable.html | ||
|
||
future==0.18.2 \ | ||
--hash=sha256:b1bead90b70cf6ec3f0710ae53a525360fa360d306a86583adc6bf83a4db537d \ | ||
# via torch | ||
joblib==0.16.0 \ | ||
--hash=sha256:8f52bf24c64b608bf0b2563e0e47d6fcf516abc8cfafe10cfd98ad66d94f92d6 \ | ||
--hash=sha256:d348c5d4ae31496b2aa060d6d9b787864dd204f9480baaa52d18850cb43e9f49 \ | ||
# via scikit-learn | ||
numpy==1.19.0 \ | ||
--hash=sha256:13af0184177469192d80db9bd02619f6fa8b922f9f327e077d6f2a6acb1ce1c0 \ | ||
--hash=sha256:26a45798ca2a4e168d00de75d4a524abf5907949231512f372b217ede3429e98 \ | ||
--hash=sha256:26f509450db547e4dfa3ec739419b31edad646d21fb8d0ed0734188b35ff6b27 \ | ||
--hash=sha256:30a59fb41bb6b8c465ab50d60a1b298d1cd7b85274e71f38af5a75d6c475d2d2 \ | ||
--hash=sha256:33c623ef9ca5e19e05991f127c1be5aeb1ab5cdf30cb1c5cf3960752e58b599b \ | ||
--hash=sha256:356f96c9fbec59974a592452ab6a036cd6f180822a60b529a975c9467fcd5f23 \ | ||
--hash=sha256:3c40c827d36c6d1c3cf413694d7dc843d50997ebffbc7c87d888a203ed6403a7 \ | ||
--hash=sha256:4d054f013a1983551254e2379385e359884e5af105e3efe00418977d02f634a7 \ | ||
--hash=sha256:63d971bb211ad3ca37b2adecdd5365f40f3b741a455beecba70fd0dde8b2a4cb \ | ||
--hash=sha256:658624a11f6e1c252b2cd170d94bf28c8f9410acab9f2fd4369e11e1cd4e1aaf \ | ||
--hash=sha256:76766cc80d6128750075378d3bb7812cf146415bd29b588616f72c943c00d598 \ | ||
--hash=sha256:7b57f26e5e6ee2f14f960db46bd58ffdca25ca06dd997729b1b179fddd35f5a3 \ | ||
--hash=sha256:7b852817800eb02e109ae4a9cef2beda8dd50d98b76b6cfb7b5c0099d27b52d4 \ | ||
--hash=sha256:8cde829f14bd38f6da7b2954be0f2837043e8b8d7a9110ec5e318ae6bf706610 \ | ||
--hash=sha256:a2e3a39f43f0ce95204beb8fe0831199542ccab1e0c6e486a0b4947256215632 \ | ||
--hash=sha256:a86c962e211f37edd61d6e11bb4df7eddc4a519a38a856e20a6498c319efa6b0 \ | ||
--hash=sha256:a8705c5073fe3fcc297fb8e0b31aa794e05af6a329e81b7ca4ffecab7f2b95ef \ | ||
--hash=sha256:b6aaeadf1e4866ca0fdf7bb4eed25e521ae21a7947c59f78154b24fc7abbe1dd \ | ||
--hash=sha256:be62aeff8f2f054eff7725f502f6228298891fd648dc2630e03e44bf63e8cee0 \ | ||
--hash=sha256:c2edbb783c841e36ca0fa159f0ae97a88ce8137fb3a6cd82eae77349ba4b607b \ | ||
--hash=sha256:cbe326f6d364375a8e5a8ccb7e9cd73f4b2f6dc3b2ed205633a0db8243e2a96a \ | ||
--hash=sha256:d34fbb98ad0d6b563b95de852a284074514331e6b9da0a9fc894fb1cdae7a79e \ | ||
--hash=sha256:d97a86937cf9970453c3b62abb55a6475f173347b4cde7f8dcdb48c8e1b9952d \ | ||
--hash=sha256:dd53d7c4a69e766e4900f29db5872f5824a06827d594427cf1a4aa542818b796 \ | ||
--hash=sha256:df1889701e2dfd8ba4dc9b1a010f0a60950077fb5242bb92c8b5c7f1a6f2668a \ | ||
--hash=sha256:fa1fe75b4a9e18b66ae7f0b122543c42debcf800aaafa0212aaff3ad273c2596 \ | ||
# via scikit-learn, scipy, skorch, torch | ||
psutil==5.7.0 \ | ||
--hash=sha256:1413f4158eb50e110777c4f15d7c759521703bd6beb58926f1d562da40180058 \ | ||
--hash=sha256:298af2f14b635c3c7118fd9183843f4e73e681bb6f01e12284d4d70d48a60953 \ | ||
--hash=sha256:60b86f327c198561f101a92be1995f9ae0399736b6eced8f24af41ec64fb88d4 \ | ||
--hash=sha256:685ec16ca14d079455892f25bd124df26ff9137664af445563c1bd36629b5e0e \ | ||
--hash=sha256:73f35ab66c6c7a9ce82ba44b1e9b1050be2a80cd4dcc3352cc108656b115c74f \ | ||
--hash=sha256:75e22717d4dbc7ca529ec5063000b2b294fc9a367f9c9ede1f65846c7955fd38 \ | ||
--hash=sha256:a02f4ac50d4a23253b68233b07e7cdb567bd025b982d5cf0ee78296990c22d9e \ | ||
--hash=sha256:d008ddc00c6906ec80040d26dc2d3e3962109e40ad07fd8a12d0284ce5e0e4f8 \ | ||
--hash=sha256:d84029b190c8a66a946e28b4d3934d2ca1528ec94764b180f7d6ea57b0e75e26 \ | ||
--hash=sha256:e2d0c5b07c6fe5a87fa27b7855017edb0d52ee73b71e6ee368fae268605cc3f5 \ | ||
--hash=sha256:f344ca230dd8e8d5eee16827596f1c22ec0876127c28e800d7ae20ed44c4b310 \ | ||
# via time_series_predictor (setup.py) | ||
scikit-learn==0.23.1 \ | ||
--hash=sha256:04799686060ecbf8992f26a35be1d99e981894c8c7860c1365cda4200f954a16 \ | ||
--hash=sha256:058d213092de4384710137af1300ed0ff030b8c40459a6c6f73c31ccd274cc39 \ | ||
--hash=sha256:0c3464e46ef8bd4f1bfa5c009648c6449412c8f7e9b3fc0c9e3d800139c48827 \ | ||
--hash=sha256:0e7b55f73b35537ecd0d19df29dd39aa9e076dba78f3507b8136c819d84611fd \ | ||
--hash=sha256:16feae4361be6b299d4d08df5a30956b4bfc8eadf173fe9258f6d59630f851d4 \ | ||
--hash=sha256:244ca85d6eba17a1e6e8a66ab2f584be6a7784b5f59297e3d7ff8c7983af627c \ | ||
--hash=sha256:3e6e92b495eee193a8fa12a230c9b7976ea0fc1263719338e35c986ea1e42cff \ | ||
--hash=sha256:5bcea4d6ee431c814261117281363208408aa4e665633655895feb059021aca6 \ | ||
--hash=sha256:93f56abd316d131645559ec0ab4f45e3391c2ccdd4eadaa4912f4c1e0a6f2c96 \ | ||
--hash=sha256:9e04c0811ea92931ee8490d638171b8cb2f21387efcfff526bbc8c2a3da60f1c \ | ||
--hash=sha256:bded94236e16774385202cafd26190ce96db18e4dc21e99473848c61e4fdc400 \ | ||
--hash=sha256:c2fa33d20408b513cf432505c80e6eb4bf4d71434f1ae36680765d4a2c2a16ec \ | ||
--hash=sha256:e3fec1c8831f8f93ad85581ca29ca1bb88e2da377fb097cf8322aa89c21bc9b8 \ | ||
--hash=sha256:e585682e37f2faa81ad6cd4472fff646bf2fd0542147bec93697a905db8e6bd2 \ | ||
--hash=sha256:e9879ba9e64ec3add41bf201e06034162f853652ef4849b361d73b0deb3153ad \ | ||
--hash=sha256:ebe853e6f318f9d8b3b74dd17e553720d35646eff675a69eeaed12fbbbb07daa \ | ||
# via skorch | ||
scipy==1.4.1 \ | ||
--hash=sha256:00af72998a46c25bdb5824d2b729e7dabec0c765f9deb0b504f928591f5ff9d4 \ | ||
--hash=sha256:0902a620a381f101e184a958459b36d3ee50f5effd186db76e131cbefcbb96f7 \ | ||
--hash=sha256:1e3190466d669d658233e8a583b854f6386dd62d655539b77b3fa25bfb2abb70 \ | ||
--hash=sha256:2cce3f9847a1a51019e8c5b47620da93950e58ebc611f13e0d11f4980ca5fecb \ | ||
--hash=sha256:3092857f36b690a321a662fe5496cb816a7f4eecd875e1d36793d92d3f884073 \ | ||
--hash=sha256:386086e2972ed2db17cebf88610aab7d7f6e2c0ca30042dc9a89cf18dcc363fa \ | ||
--hash=sha256:71eb180f22c49066f25d6df16f8709f215723317cc951d99e54dc88020ea57be \ | ||
--hash=sha256:770254a280d741dd3436919d47e35712fb081a6ff8bafc0f319382b954b77802 \ | ||
--hash=sha256:787cc50cab3020a865640aba3485e9fbd161d4d3b0d03a967df1a2881320512d \ | ||
--hash=sha256:8a07760d5c7f3a92e440ad3aedcc98891e915ce857664282ae3c0220f3301eb6 \ | ||
--hash=sha256:8d3bc3993b8e4be7eade6dcc6fd59a412d96d3a33fa42b0fa45dc9e24495ede9 \ | ||
--hash=sha256:9508a7c628a165c2c835f2497837bf6ac80eb25291055f56c129df3c943cbaf8 \ | ||
--hash=sha256:a144811318853a23d32a07bc7fd5561ff0cac5da643d96ed94a4ffe967d89672 \ | ||
--hash=sha256:a1aae70d52d0b074d8121333bc807a485f9f1e6a69742010b33780df2e60cfe0 \ | ||
--hash=sha256:a2d6df9eb074af7f08866598e4ef068a2b310d98f87dc23bd1b90ec7bdcec802 \ | ||
--hash=sha256:bb517872058a1f087c4528e7429b4a44533a902644987e7b2fe35ecc223bc408 \ | ||
--hash=sha256:c5cac0c0387272ee0e789e94a570ac51deb01c796b37fb2aad1fb13f85e2f97d \ | ||
--hash=sha256:cc971a82ea1170e677443108703a2ec9ff0f70752258d0e9f5433d00dda01f59 \ | ||
--hash=sha256:dba8306f6da99e37ea08c08fef6e274b5bf8567bb094d1dbe86a20e532aca088 \ | ||
--hash=sha256:dc60bb302f48acf6da8ca4444cfa17d52c63c5415302a9ee77b3b21618090521 \ | ||
--hash=sha256:dee1bbf3a6c8f73b6b218cb28eed8dd13347ea2f87d572ce19b289d6fd3fbc59 \ | ||
# via scikit-learn, skorch, time_series_predictor (setup.py) | ||
skorch==0.8.0 \ | ||
--hash=sha256:5908fdc3c1c8ae49d16fa3edb1fbdd412c44f2baee02abdd5432b7a47933a7d0 \ | ||
--hash=sha256:f292e9866f65df7fb7cf209f503924e2cb67377d7524a50c3e5dc6ae5a5ecd47 \ | ||
# via time_series_predictor (setup.py) | ||
tabulate==0.8.7 \ | ||
--hash=sha256:ac64cb76d53b1231d364babcd72abbb16855adac7de6665122f97b593f1eb2ba \ | ||
--hash=sha256:db2723a20d04bcda8522165c73eea7c300eda74e0ce852d9022e0159d7895007 \ | ||
# via skorch | ||
threadpoolctl==2.1.0 \ | ||
--hash=sha256:38b74ca20ff3bb42caca8b00055111d74159ee95c4370882bbff2b93d24da725 \ | ||
--hash=sha256:ddc57c96a38beb63db45d6c159b5ab07b6bced12c45a1f07b2b92f272aebfa6b \ | ||
# via scikit-learn | ||
torch==1.5.0+cu92 \ | ||
--hash=sha256:21c6cd3f053b21b0c219963a0403eaeb289d53cb4a8ecf9a099c2e9232293fa4 \ | ||
--hash=sha256:2281b4d9fbec7925f44e96a6e5c753a58161d3812bc99fb5d7b0a555d1d60d7f \ | ||
--hash=sha256:352d2ac173d6b203f10fe719545c36810182badace9c6f9bc9b34dc7d138d90e \ | ||
--hash=sha256:560ce89a62b01f7c6c07f1e4099af3e74f4a7213adcc86abd04a2ae2f8bf7663 \ | ||
--hash=sha256:77586f5deca99bf854dce2bce9e533a90dd97694d190b15bd17c170ef493e2b1 \ | ||
--hash=sha256:8078aeffc481549f63e5fcb0ff295164faedd368a4f01e4a489fafb51d794734 \ | ||
--hash=sha256:8e9df2aa4ec2516476dc8c09c984a6d0e11d52126833e47a6d9f18c177e83de1 \ | ||
--hash=sha256:a903fe0e2c9f2017b664938ed27ec81e45f0405b37c1faabdfd12b167f552510 \ | ||
--hash=sha256:af9d74ed62b716ab3803dfa8774380abb8aa97b84c656ca814487d7d9611ccdd \ | ||
--hash=sha256:eca7d1ab146e75fb672e53382a71c0073caceaf3920a3a7aaf6d07419f8b38f7 \ | ||
# via time_series_predictor (setup.py) | ||
tqdm==4.46.0 \ | ||
--hash=sha256:4733c4a10d0f2a4d098d801464bdaf5240c7dadd2a7fde4ee93b0a0efd9fb25e \ | ||
--hash=sha256:acdafb20f51637ca3954150d0405ff1a7edde0ff19e38fb99a80a66210d2a28f \ | ||
# via skorch, time_series_predictor (setup.py) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters