-
Notifications
You must be signed in to change notification settings - Fork 0
14 annual embed #39
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
14 annual embed #39
Changes from all commits
a94b321
9306575
e8edf08
3bdea7d
83d9ffc
abb2392
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -41,14 +41,15 @@ def __init__( | |||||
|
|
||||||
| # Reshape daily → (M, T=31, H, W), monthly → (M, H, W), | ||||||
| # and get padded_days_mask → (M, T=31) | ||||||
| daily_mt, monthly_m, padded_days_mask = add_month_day_dims( | ||||||
| daily_mt, monthly_m, padded_days_mask, daily_timef = add_month_day_dims( | ||||||
| daily_da, monthly_da, time_dim=time_dim | ||||||
| ) | ||||||
|
|
||||||
| # Convert to numpy once — all __getitem__ calls use these | ||||||
| self.daily_np = daily_mt.to_numpy().copy() # (M, T=31, H, W) float | ||||||
| self.monthly_np = monthly_m.to_numpy().copy() # (M, H, W) float | ||||||
| self.padded_mask_np = padded_days_mask.to_numpy().copy() # (M, T=31) bool | ||||||
| self.daily_timef_np = daily_timef.to_numpy().copy() # (M,T=31, 4) | ||||||
|
|
||||||
| # Store coordinate arrays | ||||||
| self.lat_coords = daily_da[spatial_dims[0]].to_numpy().copy() | ||||||
|
|
@@ -123,6 +124,7 @@ def __getitem__(self, idx): | |||||
| daily_nan_mask = self.daily_nan_mask[ | ||||||
| :, :, i : i + ph, j : j + pw | ||||||
| ] # (M, T, H, W) | ||||||
| daily_timef_patch = self.daily_timef_np # (M,T,4) | ||||||
|
|
||||||
| if self.land_mask_np is not None: | ||||||
| land_patch = self.land_mask_np[i : i + ph, j : j + pw] # (H, W) | ||||||
|
|
@@ -137,6 +139,8 @@ def __getitem__(self, idx): | |||||
| monthly_tensor = torch.from_numpy(monthly_patch).float() | ||||||
| # (1, M, T, H, W) | ||||||
| daily_nan_mask = torch.from_numpy(daily_nan_mask).unsqueeze(0) | ||||||
| # ( M, T, 4) | ||||||
| daily_timef_tensor = torch.from_numpy(daily_timef_patch).float() | ||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
|
|
||||||
| # daily_mask: NaN locations that are NOT land | ||||||
| # Reshape land_tensor for broadcasting: (H, W) → (1, 1, 1, H, W) | ||||||
|
|
@@ -154,6 +158,7 @@ def __getitem__(self, idx): | |||||
| "monthly_patch": monthly_tensor, # (M, H, W) | ||||||
| "daily_mask_patch": daily_mask_tensor, # (C=1, M, T=31, H, W) | ||||||
| "land_mask_patch": land_tensor, # (H,W) True=Land | ||||||
| "daily_timef_patch": daily_timef_tensor, #(M, T=31, 4) | ||||||
| "padded_days_mask": self.padded_days_tensor, # (M, T=31) True=padded | ||||||
| "coords": (i, j), | ||||||
| "lat_patch": lat_patch, # (H,) | ||||||
|
|
||||||
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -78,6 +78,79 @@ def forward(self, x, mask): | |||||||||||||||||||||||||||||||||||||||||||||
| x = self.drop(x) | ||||||||||||||||||||||||||||||||||||||||||||||
| return x # (B, N_patches, embed_dim) | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| class CyclicTimeEmbedding(nn.Module): | ||||||||||||||||||||||||||||||||||||||||||||||
| """Cyclical Temporal encoding using day-of-year and hour-of-day values in combination | ||||||||||||||||||||||||||||||||||||||||||||||
| sine and cosine functions | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| This module generates fixed (non-learnable) sinusoidal temporal encodings for the temporal dimension | ||||||||||||||||||||||||||||||||||||||||||||||
| using the sinusoidally encoded day-of-year and hour-of-day values extracted from the datetime associated with the input. | ||||||||||||||||||||||||||||||||||||||||||||||
| This represents a natural positional encoding on the temporal cycle related to the solar year and the | ||||||||||||||||||||||||||||||||||||||||||||||
| diurnal cycle. | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| The module uses fixed Fourier frequencies to expand the cyclic encoding to the embedding dimension | ||||||||||||||||||||||||||||||||||||||||||||||
| The returned encodings are intended to be added emddings of the input data by the caller. The module does | ||||||||||||||||||||||||||||||||||||||||||||||
| not perform the additon | ||||||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+83
to
+92
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. fixing long length lines. |
||||||||||||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| def __init__(self, embed_dim=128, base_dim=4): | ||||||||||||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||||||||||||
| Initialize temporal encodings | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| Args: | ||||||||||||||||||||||||||||||||||||||||||||||
| embed_dim: Dimension of the embedding.The default is 128. | ||||||||||||||||||||||||||||||||||||||||||||||
| Many vision transformers use embedding dimensions that are multiples | ||||||||||||||||||||||||||||||||||||||||||||||
| of 64 (e.g., 64, 128, 256). This can be tuned. | ||||||||||||||||||||||||||||||||||||||||||||||
| base_dim: Dimension of the input cyclical encodings of doy and hod. This is 4 by default (sin/cos(doy) sin/cos(hod)) | ||||||||||||||||||||||||||||||||||||||||||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. fixing long length lines. |
||||||||||||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| super().__init__() | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| self.embed_dim = embed_dim | ||||||||||||||||||||||||||||||||||||||||||||||
| self.base_dim = base_dim | ||||||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+108
to
+109
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. unused variables. |
||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| #Determine number of frequencies for Fourier expansion in line with embedding dimension | ||||||||||||||||||||||||||||||||||||||||||||||
| if (embed_dim % (2*base_dim)==0): | ||||||||||||||||||||||||||||||||||||||||||||||
| num_frequencies = int(embed_dim/(2*base_dim)) | ||||||||||||||||||||||||||||||||||||||||||||||
| self.num_freqencies = num_frequencies | ||||||||||||||||||||||||||||||||||||||||||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. unused variables |
||||||||||||||||||||||||||||||||||||||||||||||
| freqs = torch.linspace(1.0, num_frequencies, num_frequencies) | ||||||||||||||||||||||||||||||||||||||||||||||
| self.register_buffer("freqs", freqs) | ||||||||||||||||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||||||||||||||||
| raise ValueError( | ||||||||||||||||||||||||||||||||||||||||||||||
| f"embed_dim must be an even multiple of 2*base_dim for fixed encoding." | ||||||||||||||||||||||||||||||||||||||||||||||
| f"Got embed_dim: {embed_dim} and base_dim: {base_dim}." | ||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| def forward(self, time_features): | ||||||||||||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||||||||||||
| create encodings in of size embedding dimension | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| Args: | ||||||||||||||||||||||||||||||||||||||||||||||
| time_features: (B, M, T, D) ; D is base_dim | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| Returns: | ||||||||||||||||||||||||||||||||||||||||||||||
| emb_encode : (B,M,T, embed_dim) | ||||||||||||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||||||||||||
| B, M, T, D = time_features.shape | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| #(B, M, T, D, 1) | ||||||||||||||||||||||||||||||||||||||||||||||
| x= time_features.unsqueeze(-1) | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| #(1,1,1,1,F) | ||||||||||||||||||||||||||||||||||||||||||||||
| freqs = self.freqs.view(1,1,1,1,-1) | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| #apply frequencies | ||||||||||||||||||||||||||||||||||||||||||||||
| x = x * freqs # (B, M, T, D, F) | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| sinx = torch.sin(x) | ||||||||||||||||||||||||||||||||||||||||||||||
| cosx = torch.cos(x) | ||||||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+144
to
+145
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The doy_sin = np.sin(2*np.pi*doy/doy_period)
doy_cos = np.cos(2*np.pi*doy/doy_period)
hod_sin = np.sin(2*np.pi*hod/hod_period)
hod_cos = np.cos(2*np.pi*hod/hod_period)to: doy_phase = (2*np.pi*doy/doy_period)
hod_phase = (2*np.pi*hod/hod_period)After this change, the dimension should be changed from 4 to 2 as well.
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Alternatively, if you want to keep the |
||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| emb_encode = torch.cat([sinx,cosx],dim=-1) # (B,M,T,D, 2F) | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| emb_encode = emb_encode.view(B,M,T,-1) # flatten | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| return emb_encode | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| class TemporalPositionalEncoding(nn.Module): | ||||||||||||||||||||||||||||||||||||||||||||||
| """Temporal Positional Encoding using sine and cosine functions. | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -153,8 +226,10 @@ def __init__(self, embed_dim=128, max_days=31, max_months=12): | |||||||||||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||||||||||||
| super().__init__() | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| self.time_embed = CyclicTimeEmbedding(embed_dim=embed_dim) | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| # Positional encodings for days and months | ||||||||||||||||||||||||||||||||||||||||||||||
| self.pos_days = TemporalPositionalEncoding(embed_dim, max_len=max_days) | ||||||||||||||||||||||||||||||||||||||||||||||
| #self.pos_days = TemporalPositionalEncoding(embed_dim, max_len=max_days) REMOVE THIS AND REPLACE WITH time_embed | ||||||||||||||||||||||||||||||||||||||||||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. given |
||||||||||||||||||||||||||||||||||||||||||||||
| self.pos_months = TemporalPositionalEncoding(embed_dim, max_len=max_months) | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| # Day scorer (within each month) | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -182,14 +257,15 @@ def __init__(self, embed_dim=128, max_days=31, max_months=12): | |||||||||||||||||||||||||||||||||||||||||||||
| nn.Linear(4 * embed_dim, embed_dim), | ||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| def forward(self, x, M, T, H, W, padded_days_mask=None): | ||||||||||||||||||||||||||||||||||||||||||||||
| def forward(self, x, M, T, H, W, time_features, padded_days_mask=None): | ||||||||||||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||||||||||||
| Args: | ||||||||||||||||||||||||||||||||||||||||||||||
| x: (B, M, T, H, W, C) containing spatio-temporal tokens, where C is the embedding dimension. | ||||||||||||||||||||||||||||||||||||||||||||||
| M: number of months | ||||||||||||||||||||||||||||||||||||||||||||||
| T: number of temporal tokens per month after temporal patching (Tp) | ||||||||||||||||||||||||||||||||||||||||||||||
| H: spatial height after spatial patching | ||||||||||||||||||||||||||||||||||||||||||||||
| W: spatial width after spatial patching | ||||||||||||||||||||||||||||||||||||||||||||||
| time_features: (B,M,T,4) containing cyclically encoded DOY and HOD | ||||||||||||||||||||||||||||||||||||||||||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. change the dimension 4 to 2 after fixing the |
||||||||||||||||||||||||||||||||||||||||||||||
| padded_days_mask: Optional boolean tensor of shape (B, M, T), bool, | ||||||||||||||||||||||||||||||||||||||||||||||
| True indicating which day tokens are padded (because some months | ||||||||||||||||||||||||||||||||||||||||||||||
| have fewer days). This is used to mask out padded tokens in attention computation. | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -201,10 +277,13 @@ def forward(self, x, M, T, H, W, padded_days_mask=None): | |||||||||||||||||||||||||||||||||||||||||||||
| # Reshape to (B, Hp*Wp, M, Tp, C) for temporal processing | ||||||||||||||||||||||||||||||||||||||||||||||
| seq = x.permute(0, 3, 4, 1, 2, 5).reshape(B, Hp * Wp, M, Tp, C) | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| pe_days = self.pos_days(T).to(seq.device).to(seq.dtype) # (T, C) | ||||||||||||||||||||||||||||||||||||||||||||||
| temp_emb = self.time_embed(time_features) # (B,M,T,emd_dim) | ||||||||||||||||||||||||||||||||||||||||||||||
| #expand spatially | ||||||||||||||||||||||||||||||||||||||||||||||
| temp_emb = temp_emb[:, None, :, :, :] #[B, 1, M, T, C] | ||||||||||||||||||||||||||||||||||||||||||||||
| temp_emb = temp_emb.expand(-1, H*W, -1, -1, -1) | ||||||||||||||||||||||||||||||||||||||||||||||
| pe_months = self.pos_months(M).to(seq.device).to(seq.dtype) # (M, C) | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| seq = seq + pe_days[None, None, None, :, :] # add day PE | ||||||||||||||||||||||||||||||||||||||||||||||
| seq = seq + temp_emb # add temporal embeddings | ||||||||||||||||||||||||||||||||||||||||||||||
| seq = seq + pe_months[None, None, :, None, :] # add month PE | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| # Day attention per month | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -554,13 +633,15 @@ def __init__( | |||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||
| self.patch_size = patch_size | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| def forward(self, daily_data, daily_mask, land_mask_patch, padded_days_mask=None): | ||||||||||||||||||||||||||||||||||||||||||||||
| def forward(self, daily_data, daily_mask, daily_timef, land_mask_patch, padded_days_mask=None,): | ||||||||||||||||||||||||||||||||||||||||||||||
| """Forward pass of the Spatio-Temporal model. | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| Args: | ||||||||||||||||||||||||||||||||||||||||||||||
| daily_data: Tensor of shape (B, C, M, T, H, W) containing daily | ||||||||||||||||||||||||||||||||||||||||||||||
| data, where C is the number of channels (e.g., 1 for SST) | ||||||||||||||||||||||||||||||||||||||||||||||
| daily_mask: Boolean tensor of same shape as daily_data indicating missing values | ||||||||||||||||||||||||||||||||||||||||||||||
| daily_timef: Tensor of shape (B, M, T, 4) containing the cyclically encoded day-of-year | ||||||||||||||||||||||||||||||||||||||||||||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. change the dimension 4 to 2 after fixing the time_features |
||||||||||||||||||||||||||||||||||||||||||||||
| and hour-of-day information for the daily data | ||||||||||||||||||||||||||||||||||||||||||||||
| land_mask_patch: Boolean tensor of shape (B, H, W) to mask land areas in the output | ||||||||||||||||||||||||||||||||||||||||||||||
| padded_days_mask: Optional boolean tensor of shape (B, M, T) indicating which day tokens are padded | ||||||||||||||||||||||||||||||||||||||||||||||
| (True for padded tokens). Used to mask out padded tokens in temporal attention. | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -582,6 +663,9 @@ def forward(self, daily_data, daily_mask, land_mask_patch, padded_days_mask=None | |||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||
| assert T % self.patch_size[0] == 0, "T must be divisible by patch size" | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| if self.patch_size[0] > 1: | ||||||||||||||||||||||||||||||||||||||||||||||
| daily_timef = daily_timef.view(B, M, Tp, self.patch_size[0], 4).mean(dim=3) # -> (B,M, Tp, 4) | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| if padded_days_mask is not None and self.patch_size[0] > 1: | ||||||||||||||||||||||||||||||||||||||||||||||
| B, M, T_days = padded_days_mask.shape | ||||||||||||||||||||||||||||||||||||||||||||||
| if T_days % self.patch_size[0] != 0: | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -611,7 +695,7 @@ def forward(self, daily_data, daily_mask, land_mask_patch, padded_days_mask=None | |||||||||||||||||||||||||||||||||||||||||||||
| latent = latent.view(B, M, Tp, Hp, Wp, embed_dim) | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| agg_latent = self.temporal( | ||||||||||||||||||||||||||||||||||||||||||||||
| latent, M, Tp, Hp, Wp, padded_days_mask=padded_days_mask | ||||||||||||||||||||||||||||||||||||||||||||||
| latent, M, Tp, Hp, Wp, daily_timef, padded_days_mask=padded_days_mask | ||||||||||||||||||||||||||||||||||||||||||||||
| ) # (B, M, Hp*Wp, embed_dim) | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| # Step 3: Add spatial positional encodings and mix spatial features | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| Original file line number | Diff line number | Diff line change | ||
|---|---|---|---|---|
|
|
@@ -93,6 +93,7 @@ def add_month_day_dims( | |||
| daily_m : xr.DataArray - dims: (M, T, H, W) | ||||
| monthly_m : xr.DataArray - dims: (M, H, W) | ||||
| padded_days_mask : xr.DataArray - dims: (M, T=31), bool, True where day is padded | ||||
| time_features : xr.DataArray - dims: (M, T, 4) | ||||
| """ | ||||
| # Month key as integer YYYYMM | ||||
| dkey = daily_ts[time_dim].dt.year * 100 + daily_ts[time_dim].dt.month | ||||
|
|
@@ -126,7 +127,43 @@ def add_month_day_dims( | |||
| .sel(M=month_keys) | ||||
| ) | ||||
|
|
||||
| return daily_indexed, monthly_m, padded_days_mask | ||||
| #----------------------------------------- | ||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||
| # Build aligned datetime array (M,T) | ||||
| time_da = daily_ts[time_dim] | ||||
|
|
||||
| #time_indexed is (M,T) with NaT for padded days | ||||
| time_indexed = ( | ||||
| time_da.assign_coords(M=(time_dim, dkey.values), | ||||
| T=(time_dim, time_da.dt.day.values)) | ||||
| .set_index({time_dim: ("M", "T")}) | ||||
| .unstack(time_dim) | ||||
| .reindex(T=np.arange(1,32), M=month_keys) | ||||
| ) | ||||
| #------------------------------------------- | ||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||
|
|
||||
| #determine day-of-year (doy) [and hour-of-day (hod) if applicable], fill NaT with 0 inplace | ||||
| doy_period = 365.0 | ||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. leap years will be slightly mis-phased here. We can add a comment about it here, or we can calculate it as: |
||||
| hod_period = 24.0 | ||||
|
|
||||
| doy = time_indexed.dt.dayofyear.fillna(0) | ||||
|
|
||||
| if "hour" in dir(time_indexed.dt): | ||||
| hod = time_indexed.dt.hour.fillna(0) | ||||
| else: | ||||
| hod = xr.zeros_like(doy) | ||||
|
|
||||
| #Create cyclic encodings | ||||
| doy_sin = np.sin(2*np.pi*doy/doy_period) | ||||
| doy_cos = np.cos(2*np.pi*doy/doy_period) | ||||
| hod_sin = np.sin(2*np.pi*hod/hod_period) | ||||
| hod_cos = np.cos(2*np.pi*hod/hod_period) | ||||
|
Comment on lines
+156
to
+159
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. if class CyclicTimeEmbedding applies sin/cos, these should be changed to raw phase as: doy_phase = (2*np.pi*doy/doy_period)
hod_phase = (2*np.pi*hod/hod_period) |
||||
|
|
||||
| #Stack cyclic encodings into time_features (M,T,4) | ||||
| time_features = xr.concat([doy_sin,doy_cos,hod_sin,hod_cos], | ||||
| dim="feature" | ||||
| ).transpose("M","T","feature") | ||||
|
|
||||
| return daily_indexed, monthly_m, padded_days_mask, time_features | ||||
|
|
||||
|
|
||||
| def pred_to_numpy(pred, orig_H=None, orig_W=None, land_mask=None): | ||||
|
|
||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.