From 47d2b439ee4e4550f7b073e6aa48a824a55ffdac Mon Sep 17 00:00:00 2001 From: isaacmg Date: Tue, 7 May 2024 13:45:24 -0400 Subject: [PATCH] re-add changes in crossformer commit without junk --- .circleci/config.yml | 2 +- flood_forecast/transformer_xl/cross_former.py | 6 +++++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/.circleci/config.yml b/.circleci/config.yml index 917c5f0ab..577315f08 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -334,7 +334,7 @@ jobs: coverage run flood_forecast/trainer.py -p tests/dsanet.json echo -e 'running trainer_decoder_test \n' coverage run flood_forecast/trainer.py -p tests/decoder_test.json - echo -e 'running trainer_full_transformer_test \n' + echo -e 'running trainer_full_transformer_test' coverage run flood_forecast/trainer.py -p tests/full_transformer.json - store_test_results: diff --git a/flood_forecast/transformer_xl/cross_former.py b/flood_forecast/transformer_xl/cross_former.py index cba0e7552..50c082b20 100644 --- a/flood_forecast/transformer_xl/cross_former.py +++ b/flood_forecast/transformer_xl/cross_former.py @@ -19,6 +19,7 @@ def __init__( e_layers=3, dropout=0.0, baseline=False, + n_targs=None, device=torch.device("cuda:0"), ): """Crossformer: Transformer Utilizing Cross-Dimension Dependency for Multivariate Time Series Forecasting. @@ -57,6 +58,7 @@ def __init__( self.out_len = forecast_length self.seg_len = seg_len self.merge_win = win_size + self.n_targs = n_time_series if n_targs is None else n_targs self.baseline = baseline @@ -126,7 +128,9 @@ def forward(self, x_seq: torch.Tensor): ) predict_y = self.decoder(dec_in, enc_out) - return base + predict_y[:, : self.out_len, :] + result = base + predict_y[:, : self.out_len, :] + res = result[:, :, :self.n_targs] + return res class SegMerging(nn.Module):