diff --git a/climanet/dataset.py b/climanet/dataset.py index 7976297..ba22197 100644 --- a/climanet/dataset.py +++ b/climanet/dataset.py @@ -41,7 +41,7 @@ def __init__( # Reshape daily → (M, T=31, H, W), monthly → (M, H, W), # and get padded_days_mask → (M, T=31) - daily_mt, monthly_m, padded_days_mask = add_month_day_dims( + daily_mt, monthly_m, padded_days_mask, daily_timef = add_month_day_dims( daily_da, monthly_da, time_dim=time_dim ) @@ -49,6 +49,7 @@ def __init__( self.daily_np = daily_mt.to_numpy().copy() # (M, T=31, H, W) float self.monthly_np = monthly_m.to_numpy().copy() # (M, H, W) float self.padded_mask_np = padded_days_mask.to_numpy().copy() # (M, T=31) bool + self.daily_timef_np = daily_timef.to_numpy().copy() # (M,T=31, 4) # Store coordinate arrays self.lat_coords = daily_da[spatial_dims[0]].to_numpy().copy() @@ -123,6 +124,7 @@ def __getitem__(self, idx): daily_nan_mask = self.daily_nan_mask[ :, :, i : i + ph, j : j + pw ] # (M, T, H, W) + daily_timef_patch = self.daily_timef_np # (M,T,4) if self.land_mask_np is not None: land_patch = self.land_mask_np[i : i + ph, j : j + pw] # (H, W) @@ -137,6 +139,8 @@ def __getitem__(self, idx): monthly_tensor = torch.from_numpy(monthly_patch).float() # (1, M, T, H, W) daily_nan_mask = torch.from_numpy(daily_nan_mask).unsqueeze(0) + # ( M, T, 4) + daily_timef_tensor = torch.from_numpy(daily_timef_patch).float() # daily_mask: NaN locations that are NOT land # Reshape land_tensor for broadcasting: (H, W) → (1, 1, 1, H, W) @@ -154,6 +158,7 @@ def __getitem__(self, idx): "monthly_patch": monthly_tensor, # (M, H, W) "daily_mask_patch": daily_mask_tensor, # (C=1, M, T=31, H, W) "land_mask_patch": land_tensor, # (H,W) True=Land + "daily_timef_patch": daily_timef_tensor, #(M, T=31, 4) "padded_days_mask": self.padded_days_tensor, # (M, T=31) True=padded "coords": (i, j), "lat_patch": lat_patch, # (H,) diff --git a/climanet/predict.py b/climanet/predict.py index bb716d9..ebf1ded 100644 --- a/climanet/predict.py +++ b/climanet/predict.py @@ -109,6 +109,7 @@ def predict_monthly_var( predictions = model( batch["daily_patch"].to(device, non_blocking=use_cuda), batch["daily_mask_patch"].to(device, non_blocking=use_cuda), + batch["daily_timef_patch"].to(device,non_blocking=use_cuda), batch["land_mask_patch"].to(device, non_blocking=use_cuda), batch["padded_days_mask"].to(device, non_blocking=use_cuda), ) diff --git a/climanet/st_encoder_decoder.py b/climanet/st_encoder_decoder.py index e09171c..cb446a8 100644 --- a/climanet/st_encoder_decoder.py +++ b/climanet/st_encoder_decoder.py @@ -78,6 +78,79 @@ def forward(self, x, mask): x = self.drop(x) return x # (B, N_patches, embed_dim) +class CyclicTimeEmbedding(nn.Module): + """Cyclical Temporal encoding using day-of-year and hour-of-day values in combination + sine and cosine functions + + This module generates fixed (non-learnable) sinusoidal temporal encodings for the temporal dimension + using the sinusoidally encoded day-of-year and hour-of-day values extracted from the datetime associated with the input. + This represents a natural positional encoding on the temporal cycle related to the solar year and the + diurnal cycle. + + The module uses fixed Fourier frequencies to expand the cyclic encoding to the embedding dimension + The returned encodings are intended to be added emddings of the input data by the caller. The module does + not perform the additon + """ + + def __init__(self, embed_dim=128, base_dim=4): + """ + Initialize temporal encodings + + Args: + embed_dim: Dimension of the embedding.The default is 128. + Many vision transformers use embedding dimensions that are multiples + of 64 (e.g., 64, 128, 256). This can be tuned. + base_dim: Dimension of the input cyclical encodings of doy and hod. This is 4 by default (sin/cos(doy) sin/cos(hod)) + """ + + super().__init__() + + self.embed_dim = embed_dim + self.base_dim = base_dim + + #Determine number of frequencies for Fourier expansion in line with embedding dimension + if (embed_dim % (2*base_dim)==0): + num_frequencies = int(embed_dim/(2*base_dim)) + self.num_freqencies = num_frequencies + freqs = torch.linspace(1.0, num_frequencies, num_frequencies) + self.register_buffer("freqs", freqs) + else: + raise ValueError( + f"embed_dim must be an even multiple of 2*base_dim for fixed encoding." + f"Got embed_dim: {embed_dim} and base_dim: {base_dim}." + ) + + def forward(self, time_features): + """ + create encodings in of size embedding dimension + + Args: + time_features: (B, M, T, D) ; D is base_dim + + Returns: + emb_encode : (B,M,T, embed_dim) + """ + B, M, T, D = time_features.shape + + #(B, M, T, D, 1) + x= time_features.unsqueeze(-1) + + #(1,1,1,1,F) + freqs = self.freqs.view(1,1,1,1,-1) + + #apply frequencies + x = x * freqs # (B, M, T, D, F) + + sinx = torch.sin(x) + cosx = torch.cos(x) + + emb_encode = torch.cat([sinx,cosx],dim=-1) # (B,M,T,D, 2F) + + emb_encode = emb_encode.view(B,M,T,-1) # flatten + + return emb_encode + + class TemporalPositionalEncoding(nn.Module): """Temporal Positional Encoding using sine and cosine functions. @@ -153,8 +226,10 @@ def __init__(self, embed_dim=128, max_days=31, max_months=12): """ super().__init__() + self.time_embed = CyclicTimeEmbedding(embed_dim=embed_dim) + # Positional encodings for days and months - self.pos_days = TemporalPositionalEncoding(embed_dim, max_len=max_days) + #self.pos_days = TemporalPositionalEncoding(embed_dim, max_len=max_days) REMOVE THIS AND REPLACE WITH time_embed self.pos_months = TemporalPositionalEncoding(embed_dim, max_len=max_months) # Day scorer (within each month) @@ -182,7 +257,7 @@ def __init__(self, embed_dim=128, max_days=31, max_months=12): nn.Linear(4 * embed_dim, embed_dim), ) - def forward(self, x, M, T, H, W, padded_days_mask=None): + def forward(self, x, M, T, H, W, time_features, padded_days_mask=None): """ Args: x: (B, M, T, H, W, C) containing spatio-temporal tokens, where C is the embedding dimension. @@ -190,6 +265,7 @@ def forward(self, x, M, T, H, W, padded_days_mask=None): T: number of temporal tokens per month after temporal patching (Tp) H: spatial height after spatial patching W: spatial width after spatial patching + time_features: (B,M,T,4) containing cyclically encoded DOY and HOD padded_days_mask: Optional boolean tensor of shape (B, M, T), bool, True indicating which day tokens are padded (because some months have fewer days). This is used to mask out padded tokens in attention computation. @@ -201,10 +277,13 @@ def forward(self, x, M, T, H, W, padded_days_mask=None): # Reshape to (B, Hp*Wp, M, Tp, C) for temporal processing seq = x.permute(0, 3, 4, 1, 2, 5).reshape(B, Hp * Wp, M, Tp, C) - pe_days = self.pos_days(T).to(seq.device).to(seq.dtype) # (T, C) + temp_emb = self.time_embed(time_features) # (B,M,T,emd_dim) + #expand spatially + temp_emb = temp_emb[:, None, :, :, :] #[B, 1, M, T, C] + temp_emb = temp_emb.expand(-1, H*W, -1, -1, -1) pe_months = self.pos_months(M).to(seq.device).to(seq.dtype) # (M, C) - seq = seq + pe_days[None, None, None, :, :] # add day PE + seq = seq + temp_emb # add temporal embeddings seq = seq + pe_months[None, None, :, None, :] # add month PE # Day attention per month @@ -554,13 +633,15 @@ def __init__( ) self.patch_size = patch_size - def forward(self, daily_data, daily_mask, land_mask_patch, padded_days_mask=None): + def forward(self, daily_data, daily_mask, daily_timef, land_mask_patch, padded_days_mask=None,): """Forward pass of the Spatio-Temporal model. Args: daily_data: Tensor of shape (B, C, M, T, H, W) containing daily data, where C is the number of channels (e.g., 1 for SST) daily_mask: Boolean tensor of same shape as daily_data indicating missing values + daily_timef: Tensor of shape (B, M, T, 4) containing the cyclically encoded day-of-year + and hour-of-day information for the daily data land_mask_patch: Boolean tensor of shape (B, H, W) to mask land areas in the output padded_days_mask: Optional boolean tensor of shape (B, M, T) indicating which day tokens are padded (True for padded tokens). Used to mask out padded tokens in temporal attention. @@ -582,6 +663,9 @@ def forward(self, daily_data, daily_mask, land_mask_patch, padded_days_mask=None ) assert T % self.patch_size[0] == 0, "T must be divisible by patch size" + if self.patch_size[0] > 1: + daily_timef = daily_timef.view(B, M, Tp, self.patch_size[0], 4).mean(dim=3) # -> (B,M, Tp, 4) + if padded_days_mask is not None and self.patch_size[0] > 1: B, M, T_days = padded_days_mask.shape if T_days % self.patch_size[0] != 0: @@ -611,7 +695,7 @@ def forward(self, daily_data, daily_mask, land_mask_patch, padded_days_mask=None latent = latent.view(B, M, Tp, Hp, Wp, embed_dim) agg_latent = self.temporal( - latent, M, Tp, Hp, Wp, padded_days_mask=padded_days_mask + latent, M, Tp, Hp, Wp, daily_timef, padded_days_mask=padded_days_mask ) # (B, M, Hp*Wp, embed_dim) # Step 3: Add spatial positional encodings and mix spatial features diff --git a/climanet/train.py b/climanet/train.py index 4344d12..504329d 100644 --- a/climanet/train.py +++ b/climanet/train.py @@ -111,6 +111,7 @@ def train_monthly_model( pred = model( batch["daily_patch"], batch["daily_mask_patch"], + batch["daily_timef_patch"], batch["land_mask_patch"], batch["padded_days_mask"], ) # (B, M, H, W) diff --git a/climanet/utils.py b/climanet/utils.py index 0d5688b..bb019fb 100644 --- a/climanet/utils.py +++ b/climanet/utils.py @@ -93,6 +93,7 @@ def add_month_day_dims( daily_m : xr.DataArray - dims: (M, T, H, W) monthly_m : xr.DataArray - dims: (M, H, W) padded_days_mask : xr.DataArray - dims: (M, T=31), bool, True where day is padded + time_features : xr.DataArray - dims: (M, T, 4) """ # Month key as integer YYYYMM dkey = daily_ts[time_dim].dt.year * 100 + daily_ts[time_dim].dt.month @@ -126,7 +127,43 @@ def add_month_day_dims( .sel(M=month_keys) ) - return daily_indexed, monthly_m, padded_days_mask + #----------------------------------------- + # Build aligned datetime array (M,T) + time_da = daily_ts[time_dim] + + #time_indexed is (M,T) with NaT for padded days + time_indexed = ( + time_da.assign_coords(M=(time_dim, dkey.values), + T=(time_dim, time_da.dt.day.values)) + .set_index({time_dim: ("M", "T")}) + .unstack(time_dim) + .reindex(T=np.arange(1,32), M=month_keys) + ) + #------------------------------------------- + + #determine day-of-year (doy) [and hour-of-day (hod) if applicable], fill NaT with 0 inplace + doy_period = 365.0 + hod_period = 24.0 + + doy = time_indexed.dt.dayofyear.fillna(0) + + if "hour" in dir(time_indexed.dt): + hod = time_indexed.dt.hour.fillna(0) + else: + hod = xr.zeros_like(doy) + + #Create cyclic encodings + doy_sin = np.sin(2*np.pi*doy/doy_period) + doy_cos = np.cos(2*np.pi*doy/doy_period) + hod_sin = np.sin(2*np.pi*hod/hod_period) + hod_cos = np.cos(2*np.pi*hod/hod_period) + + #Stack cyclic encodings into time_features (M,T,4) + time_features = xr.concat([doy_sin,doy_cos,hod_sin,hod_cos], + dim="feature" + ).transpose("M","T","feature") + + return daily_indexed, monthly_m, padded_days_mask, time_features def pred_to_numpy(pred, orig_H=None, orig_W=None, land_mask=None): diff --git a/notebooks/example.ipynb b/notebooks/example.ipynb index 4143170..0c4f78c 100644 --- a/notebooks/example.ipynb +++ b/notebooks/example.ipynb @@ -29,12 +29,12 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 2, "id": "13a3b0c8-1d92-460d-84a4-a3a59ca081af", "metadata": {}, "outputs": [], "source": [ - "data_folder = Path(\"./eso4clima\")\n", + "data_folder = Path(\"../../data/output\")\n", "\n", "file_names = [data_folder / \"202001_day_ERA5_masked_ts.nc\", data_folder / \"202002_day_ERA5_masked_ts.nc\"]\n", "daily_data = xr.open_mfdataset(file_names)\n", @@ -99,7 +99,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 4, "id": "bcc04777-5235-4ef3-81bd-2bdcafd8baaa", "metadata": {}, "outputs": [], @@ -112,7 +112,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 5, "id": "09eeabbe-36ef-46a4-ad39-b82559a2da2e", "metadata": {}, "outputs": [], @@ -136,7 +136,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 6, "id": "e84304f1-deb2-4f7c-b026-9ee4bbb38272", "metadata": {}, "outputs": [ @@ -144,33 +144,33 @@ "name": "stdout", "output_type": "stream", "text": [ - "Epoch 0: best_loss = 1.064811\n", - "Epoch 20: best_loss = 0.848913\n", - "Epoch 40: best_loss = 0.742154\n", - "Epoch 60: best_loss = 0.590368\n", - "Epoch 80: best_loss = 0.474284\n", - "Epoch 100: best_loss = 0.373293\n", - "Epoch 120: best_loss = 0.313690\n", - "Epoch 140: best_loss = 0.262067\n", - "Epoch 160: best_loss = 0.233068\n", - "Epoch 180: best_loss = 0.198232\n", - "Epoch 200: best_loss = 0.174762\n", - "Epoch 220: best_loss = 0.157859\n", - "Epoch 240: best_loss = 0.144779\n", - "Epoch 260: best_loss = 0.134052\n", - "Epoch 280: best_loss = 0.127401\n", - "Epoch 300: best_loss = 0.122035\n", - "Epoch 320: best_loss = 0.117025\n", - "Epoch 340: best_loss = 0.112626\n", - "Epoch 360: best_loss = 0.108094\n", - "Epoch 380: best_loss = 0.105740\n", - "Epoch 400: best_loss = 0.103568\n", - "Epoch 420: best_loss = 0.101469\n", - "Epoch 440: best_loss = 0.099437\n", - "Epoch 460: best_loss = 0.097428\n", - "Epoch 480: best_loss = 0.095427\n", - "Epoch 500: best_loss = 0.093429\n", - "Training complete. Best loss: 0.093429\n", + "Epoch 0: best_loss = 1.036764\n", + "Epoch 20: best_loss = 0.821099\n", + "Epoch 40: best_loss = 0.671408\n", + "Epoch 60: best_loss = 0.518432\n", + "Epoch 80: best_loss = 0.403573\n", + "Epoch 100: best_loss = 0.332223\n", + "Epoch 120: best_loss = 0.286744\n", + "Epoch 140: best_loss = 0.246473\n", + "Epoch 160: best_loss = 0.203978\n", + "Epoch 180: best_loss = 0.166832\n", + "Epoch 200: best_loss = 0.144872\n", + "Epoch 220: best_loss = 0.126880\n", + "Epoch 240: best_loss = 0.115007\n", + "Epoch 260: best_loss = 0.105522\n", + "Epoch 280: best_loss = 0.098541\n", + "Epoch 300: best_loss = 0.092489\n", + "Epoch 320: best_loss = 0.084668\n", + "Epoch 340: best_loss = 0.081160\n", + "Epoch 360: best_loss = 0.077812\n", + "Epoch 380: best_loss = 0.074545\n", + "Epoch 400: best_loss = 0.072011\n", + "Epoch 420: best_loss = 0.069377\n", + "Epoch 440: best_loss = 0.066885\n", + "Epoch 460: best_loss = 0.065280\n", + "Epoch 480: best_loss = 0.063793\n", + "Epoch 500: best_loss = 0.062360\n", + "Training complete. Best loss: 0.062360\n", "Model saved to runs/best_model.pth\n" ] } @@ -192,7 +192,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 7, "id": "bda9f068", "metadata": {}, "outputs": [], @@ -210,7 +210,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 8, "id": "7c2deb40-bee8-4973-80f0-9d9485eabf0c", "metadata": {}, "outputs": [ @@ -232,7 +232,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 9, "id": "012e01ac-caf1-47f7-a3e2-7b06bb260bfd", "metadata": {}, "outputs": [ @@ -788,8 +788,8 @@ " * lat (lat) float32 640B -29.88 -29.62 -29.38 ... 9.375 9.625 9.875\n", " * lon (lon) float32 640B -49.88 -49.62 -49.38 ... -10.38 -10.12\n", "Data variables:\n", - " predictions (time, lat, lon) float32 205kB 0.0 298.3 298.1 ... 0.0 0.0 0.0