Skip to content

Commit

Permalink
re-add changes in crossformer commit without junk
Browse files Browse the repository at this point in the history
  • Loading branch information
isaacmg committed May 7, 2024
1 parent 2dd6d0d commit 47d2b43
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 2 deletions.
2 changes: 1 addition & 1 deletion .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -334,7 +334,7 @@ jobs:
coverage run flood_forecast/trainer.py -p tests/dsanet.json
echo -e 'running trainer_decoder_test \n'
coverage run flood_forecast/trainer.py -p tests/decoder_test.json
echo -e 'running trainer_full_transformer_test \n'
echo -e 'running trainer_full_transformer_test'
coverage run flood_forecast/trainer.py -p tests/full_transformer.json
- store_test_results:
Expand Down
6 changes: 5 additions & 1 deletion flood_forecast/transformer_xl/cross_former.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ def __init__(
e_layers=3,
dropout=0.0,
baseline=False,
n_targs=None,
device=torch.device("cuda:0"),
):
"""Crossformer: Transformer Utilizing Cross-Dimension Dependency for Multivariate Time Series Forecasting.
Expand Down Expand Up @@ -57,6 +58,7 @@ def __init__(
self.out_len = forecast_length
self.seg_len = seg_len
self.merge_win = win_size
self.n_targs = n_time_series if n_targs is None else n_targs

self.baseline = baseline

Expand Down Expand Up @@ -126,7 +128,9 @@ def forward(self, x_seq: torch.Tensor):
)
predict_y = self.decoder(dec_in, enc_out)

return base + predict_y[:, : self.out_len, :]
result = base + predict_y[:, : self.out_len, :]
res = result[:, :, :self.n_targs]
return res


class SegMerging(nn.Module):
Expand Down

0 comments on commit 47d2b43

Please sign in to comment.