Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion climanet/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,14 +41,15 @@ def __init__(

# Reshape daily → (M, T=31, H, W), monthly → (M, H, W),
# and get padded_days_mask → (M, T=31)
daily_mt, monthly_m, padded_days_mask = add_month_day_dims(
daily_mt, monthly_m, padded_days_mask, daily_timef = add_month_day_dims(
daily_da, monthly_da, time_dim=time_dim
)

# Convert to numpy once — all __getitem__ calls use these
self.daily_np = daily_mt.to_numpy().copy() # (M, T=31, H, W) float
self.monthly_np = monthly_m.to_numpy().copy() # (M, H, W) float
self.padded_mask_np = padded_days_mask.to_numpy().copy() # (M, T=31) bool
self.daily_timef_np = daily_timef.to_numpy().copy() # (M,T=31, 4)

# Store coordinate arrays
self.lat_coords = daily_da[spatial_dims[0]].to_numpy().copy()
Expand Down Expand Up @@ -123,6 +124,7 @@ def __getitem__(self, idx):
daily_nan_mask = self.daily_nan_mask[
:, :, i : i + ph, j : j + pw
] # (M, T, H, W)
daily_timef_patch = self.daily_timef_np # (M,T,4)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
daily_timef_patch = self.daily_timef_np # (M,T,4)


if self.land_mask_np is not None:
land_patch = self.land_mask_np[i : i + ph, j : j + pw] # (H, W)
Expand All @@ -137,6 +139,8 @@ def __getitem__(self, idx):
monthly_tensor = torch.from_numpy(monthly_patch).float()
# (1, M, T, H, W)
daily_nan_mask = torch.from_numpy(daily_nan_mask).unsqueeze(0)
# ( M, T, 4)
daily_timef_tensor = torch.from_numpy(daily_timef_patch).float()
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
daily_timef_tensor = torch.from_numpy(daily_timef_patch).float()
daily_timef_tensor = torch.from_numpy(self.daily_timef_np).float()


# daily_mask: NaN locations that are NOT land
# Reshape land_tensor for broadcasting: (H, W) → (1, 1, 1, H, W)
Expand All @@ -154,6 +158,7 @@ def __getitem__(self, idx):
"monthly_patch": monthly_tensor, # (M, H, W)
"daily_mask_patch": daily_mask_tensor, # (C=1, M, T=31, H, W)
"land_mask_patch": land_tensor, # (H,W) True=Land
"daily_timef_patch": daily_timef_tensor, #(M, T=31, 4)
"padded_days_mask": self.padded_days_tensor, # (M, T=31) True=padded
"coords": (i, j),
"lat_patch": lat_patch, # (H,)
Expand Down
1 change: 1 addition & 0 deletions climanet/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ def predict_monthly_var(
predictions = model(
batch["daily_patch"].to(device, non_blocking=use_cuda),
batch["daily_mask_patch"].to(device, non_blocking=use_cuda),
batch["daily_timef_patch"].to(device,non_blocking=use_cuda),
batch["land_mask_patch"].to(device, non_blocking=use_cuda),
batch["padded_days_mask"].to(device, non_blocking=use_cuda),
)
Expand Down
96 changes: 90 additions & 6 deletions climanet/st_encoder_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,79 @@ def forward(self, x, mask):
x = self.drop(x)
return x # (B, N_patches, embed_dim)

class CyclicTimeEmbedding(nn.Module):
"""Cyclical Temporal encoding using day-of-year and hour-of-day values in combination
sine and cosine functions

This module generates fixed (non-learnable) sinusoidal temporal encodings for the temporal dimension
using the sinusoidally encoded day-of-year and hour-of-day values extracted from the datetime associated with the input.
This represents a natural positional encoding on the temporal cycle related to the solar year and the
diurnal cycle.

The module uses fixed Fourier frequencies to expand the cyclic encoding to the embedding dimension
The returned encodings are intended to be added emddings of the input data by the caller. The module does
not perform the additon
Comment on lines +83 to +92
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
sine and cosine functions
This module generates fixed (non-learnable) sinusoidal temporal encodings for the temporal dimension
using the sinusoidally encoded day-of-year and hour-of-day values extracted from the datetime associated with the input.
This represents a natural positional encoding on the temporal cycle related to the solar year and the
diurnal cycle.
The module uses fixed Fourier frequencies to expand the cyclic encoding to the embedding dimension
The returned encodings are intended to be added emddings of the input data by the caller. The module does
not perform the additon
sine and cosine functions
This module generates fixed (non-learnable) sinusoidal temporal encodings
for the temporal dimension using the sinusoidally encoded day-of-year and
hour-of-day values extracted from the datetime associated with the input.
This represents a natural positional encoding on the temporal cycle related
to the solar year and the diurnal cycle.
The module uses fixed Fourier frequencies to expand the cyclic encoding to
the embedding dimension The returned encodings are intended to be added
emddings of the input data by the caller. The module does not perform the
additon

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixing long length lines.

"""

def __init__(self, embed_dim=128, base_dim=4):
"""
Initialize temporal encodings

Args:
embed_dim: Dimension of the embedding.The default is 128.
Many vision transformers use embedding dimensions that are multiples
of 64 (e.g., 64, 128, 256). This can be tuned.
base_dim: Dimension of the input cyclical encodings of doy and hod. This is 4 by default (sin/cos(doy) sin/cos(hod))
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
base_dim: Dimension of the input cyclical encodings of doy and hod. This is 4 by default (sin/cos(doy) sin/cos(hod))
base_dim: Dimension of the input cyclical encodings of doy and hod.
This is 4 by default (sin/cos(doy) sin/cos(hod))

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixing long length lines.

"""

super().__init__()

self.embed_dim = embed_dim
self.base_dim = base_dim
Comment on lines +108 to +109
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
self.embed_dim = embed_dim
self.base_dim = base_dim

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unused variables.


#Determine number of frequencies for Fourier expansion in line with embedding dimension
if (embed_dim % (2*base_dim)==0):
num_frequencies = int(embed_dim/(2*base_dim))
self.num_freqencies = num_frequencies
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
self.num_freqencies = num_frequencies

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unused variables

freqs = torch.linspace(1.0, num_frequencies, num_frequencies)
self.register_buffer("freqs", freqs)
else:
raise ValueError(
f"embed_dim must be an even multiple of 2*base_dim for fixed encoding."
f"Got embed_dim: {embed_dim} and base_dim: {base_dim}."
)

def forward(self, time_features):
"""
create encodings in of size embedding dimension

Args:
time_features: (B, M, T, D) ; D is base_dim

Returns:
emb_encode : (B,M,T, embed_dim)
"""
B, M, T, D = time_features.shape

#(B, M, T, D, 1)
x= time_features.unsqueeze(-1)

#(1,1,1,1,F)
freqs = self.freqs.view(1,1,1,1,-1)

#apply frequencies
x = x * freqs # (B, M, T, D, F)

sinx = torch.sin(x)
cosx = torch.cos(x)
Comment on lines +144 to +145
Copy link
Copy Markdown
Member

@SarahAlidoost SarahAlidoost Apr 23, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The day_timef is already cyclic [doy_sin, doy_cos, hod_sin, hod_cos], but here it applies sin/cos again i.e sin(k * already_sin_cos_value) and cos(k * already_sin_cos_value), this is not correct. For standard Fourier time features, raw phase should be used, so in the function add_month_day_dims, the related code can be fixed from:

    doy_sin = np.sin(2*np.pi*doy/doy_period)
    doy_cos = np.cos(2*np.pi*doy/doy_period)
    hod_sin = np.sin(2*np.pi*hod/hod_period)
    hod_cos = np.cos(2*np.pi*hod/hod_period)

to:

    doy_phase = (2*np.pi*doy/doy_period)
    hod_phase = (2*np.pi*hod/hod_period)

After this change, the dimension should be changed from 4 to 2 as well.

Copy link
Copy Markdown
Member

@SarahAlidoost SarahAlidoost Apr 23, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alternatively, if you want to keep the add_month_day_dims to return the sin/cos instead of raw phase, then class CyclicTimeEmbedding should change to use a nn.Sequential to project encoded features to embed_dim and remove sin/cos inside the forward.


emb_encode = torch.cat([sinx,cosx],dim=-1) # (B,M,T,D, 2F)

emb_encode = emb_encode.view(B,M,T,-1) # flatten

return emb_encode



class TemporalPositionalEncoding(nn.Module):
"""Temporal Positional Encoding using sine and cosine functions.
Expand Down Expand Up @@ -153,8 +226,10 @@ def __init__(self, embed_dim=128, max_days=31, max_months=12):
"""
super().__init__()

self.time_embed = CyclicTimeEmbedding(embed_dim=embed_dim)

# Positional encodings for days and months
self.pos_days = TemporalPositionalEncoding(embed_dim, max_len=max_days)
#self.pos_days = TemporalPositionalEncoding(embed_dim, max_len=max_days) REMOVE THIS AND REPLACE WITH time_embed
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
#self.pos_days = TemporalPositionalEncoding(embed_dim, max_len=max_days) REMOVE THIS AND REPLACE WITH time_embed

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

given time_embed already implicitly captures day-within-month, then we can remove self.pos_days

self.pos_months = TemporalPositionalEncoding(embed_dim, max_len=max_months)

# Day scorer (within each month)
Expand Down Expand Up @@ -182,14 +257,15 @@ def __init__(self, embed_dim=128, max_days=31, max_months=12):
nn.Linear(4 * embed_dim, embed_dim),
)

def forward(self, x, M, T, H, W, padded_days_mask=None):
def forward(self, x, M, T, H, W, time_features, padded_days_mask=None):
"""
Args:
x: (B, M, T, H, W, C) containing spatio-temporal tokens, where C is the embedding dimension.
M: number of months
T: number of temporal tokens per month after temporal patching (Tp)
H: spatial height after spatial patching
W: spatial width after spatial patching
time_features: (B,M,T,4) containing cyclically encoded DOY and HOD
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

change the dimension 4 to 2 after fixing the time_features

padded_days_mask: Optional boolean tensor of shape (B, M, T), bool,
True indicating which day tokens are padded (because some months
have fewer days). This is used to mask out padded tokens in attention computation.
Expand All @@ -201,10 +277,13 @@ def forward(self, x, M, T, H, W, padded_days_mask=None):
# Reshape to (B, Hp*Wp, M, Tp, C) for temporal processing
seq = x.permute(0, 3, 4, 1, 2, 5).reshape(B, Hp * Wp, M, Tp, C)

pe_days = self.pos_days(T).to(seq.device).to(seq.dtype) # (T, C)
temp_emb = self.time_embed(time_features) # (B,M,T,emd_dim)
#expand spatially
temp_emb = temp_emb[:, None, :, :, :] #[B, 1, M, T, C]
temp_emb = temp_emb.expand(-1, H*W, -1, -1, -1)
pe_months = self.pos_months(M).to(seq.device).to(seq.dtype) # (M, C)

seq = seq + pe_days[None, None, None, :, :] # add day PE
seq = seq + temp_emb # add temporal embeddings
seq = seq + pe_months[None, None, :, None, :] # add month PE

# Day attention per month
Expand Down Expand Up @@ -554,13 +633,15 @@ def __init__(
)
self.patch_size = patch_size

def forward(self, daily_data, daily_mask, land_mask_patch, padded_days_mask=None):
def forward(self, daily_data, daily_mask, daily_timef, land_mask_patch, padded_days_mask=None,):
"""Forward pass of the Spatio-Temporal model.

Args:
daily_data: Tensor of shape (B, C, M, T, H, W) containing daily
data, where C is the number of channels (e.g., 1 for SST)
daily_mask: Boolean tensor of same shape as daily_data indicating missing values
daily_timef: Tensor of shape (B, M, T, 4) containing the cyclically encoded day-of-year
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

change the dimension 4 to 2 after fixing the time_features

and hour-of-day information for the daily data
land_mask_patch: Boolean tensor of shape (B, H, W) to mask land areas in the output
padded_days_mask: Optional boolean tensor of shape (B, M, T) indicating which day tokens are padded
(True for padded tokens). Used to mask out padded tokens in temporal attention.
Expand All @@ -582,6 +663,9 @@ def forward(self, daily_data, daily_mask, land_mask_patch, padded_days_mask=None
)
assert T % self.patch_size[0] == 0, "T must be divisible by patch size"

if self.patch_size[0] > 1:
daily_timef = daily_timef.view(B, M, Tp, self.patch_size[0], 4).mean(dim=3) # -> (B,M, Tp, 4)

if padded_days_mask is not None and self.patch_size[0] > 1:
B, M, T_days = padded_days_mask.shape
if T_days % self.patch_size[0] != 0:
Expand Down Expand Up @@ -611,7 +695,7 @@ def forward(self, daily_data, daily_mask, land_mask_patch, padded_days_mask=None
latent = latent.view(B, M, Tp, Hp, Wp, embed_dim)

agg_latent = self.temporal(
latent, M, Tp, Hp, Wp, padded_days_mask=padded_days_mask
latent, M, Tp, Hp, Wp, daily_timef, padded_days_mask=padded_days_mask
) # (B, M, Hp*Wp, embed_dim)

# Step 3: Add spatial positional encodings and mix spatial features
Expand Down
1 change: 1 addition & 0 deletions climanet/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ def train_monthly_model(
pred = model(
batch["daily_patch"],
batch["daily_mask_patch"],
batch["daily_timef_patch"],
batch["land_mask_patch"],
batch["padded_days_mask"],
) # (B, M, H, W)
Expand Down
39 changes: 38 additions & 1 deletion climanet/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ def add_month_day_dims(
daily_m : xr.DataArray - dims: (M, T, H, W)
monthly_m : xr.DataArray - dims: (M, H, W)
padded_days_mask : xr.DataArray - dims: (M, T=31), bool, True where day is padded
time_features : xr.DataArray - dims: (M, T, 4)
"""
# Month key as integer YYYYMM
dkey = daily_ts[time_dim].dt.year * 100 + daily_ts[time_dim].dt.month
Expand Down Expand Up @@ -126,7 +127,43 @@ def add_month_day_dims(
.sel(M=month_keys)
)

return daily_indexed, monthly_m, padded_days_mask
#-----------------------------------------
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
#-----------------------------------------

# Build aligned datetime array (M,T)
time_da = daily_ts[time_dim]

#time_indexed is (M,T) with NaT for padded days
time_indexed = (
time_da.assign_coords(M=(time_dim, dkey.values),
T=(time_dim, time_da.dt.day.values))
.set_index({time_dim: ("M", "T")})
.unstack(time_dim)
.reindex(T=np.arange(1,32), M=month_keys)
)
#-------------------------------------------
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
#-------------------------------------------


#determine day-of-year (doy) [and hour-of-day (hod) if applicable], fill NaT with 0 inplace
doy_period = 365.0
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

leap years will be slightly mis-phased here. We can add a comment about it here, or we can calculate it as:
days_in_year = xr.where(time_indexed.dt.is_leap_year, 366.0, 365.0).fillna(365.0)

hod_period = 24.0

doy = time_indexed.dt.dayofyear.fillna(0)

if "hour" in dir(time_indexed.dt):
hod = time_indexed.dt.hour.fillna(0)
else:
hod = xr.zeros_like(doy)

#Create cyclic encodings
doy_sin = np.sin(2*np.pi*doy/doy_period)
doy_cos = np.cos(2*np.pi*doy/doy_period)
hod_sin = np.sin(2*np.pi*hod/hod_period)
hod_cos = np.cos(2*np.pi*hod/hod_period)
Comment on lines +156 to +159
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if class CyclicTimeEmbedding applies sin/cos, these should be changed to raw phase as:

 doy_phase = (2*np.pi*doy/doy_period)
 hod_phase = (2*np.pi*hod/hod_period)


#Stack cyclic encodings into time_features (M,T,4)
time_features = xr.concat([doy_sin,doy_cos,hod_sin,hod_cos],
dim="feature"
).transpose("M","T","feature")

return daily_indexed, monthly_m, padded_days_mask, time_features


def pred_to_numpy(pred, orig_H=None, orig_W=None, land_mask=None):
Expand Down
Loading
Loading