Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 65 additions & 21 deletions example/timesfm_2p5.ipynb

Large diffs are not rendered by default.

4 changes: 3 additions & 1 deletion leaderboard/monash_moirai.csv
Original file line number Diff line number Diff line change
Expand Up @@ -27,4 +27,6 @@ kaggle_web_traffic_weekly (W-SUN),104.91,43.93m,677892827.0147512,1691.418769624
temperature_rain (D),106.72,193.2m,731.8126906485586,5.994333053214229,0.4581582907670293,0.249426455236986,27.05203671904499,0.0036437877732579,0.7951260719877393,0.0656396716800804,0.694280646566257,131.2446799187715,2.7762765877611284
solar_4_seconds (4s),181.73,265.02m,3.56522408907416e-10,1.5385409490183596e-05,1.5385409490183597,1.5385409490183597,1.888180099745297e-05,1.8881800997452969,1.0573307614461165,1.5385409490183597,1.5385409490183597,0.0174618302899513,0.0159100078192055
wind_4_seconds (4s),184.15,295.86m,3.108434090470448,1.6458506820258083,12.77600689413882,0.082432241288121,1.763075180039254,0.5342635870919336,0.0786722471393581,0.6037483766474667,0.0811029831305508,5.187621701262427,1.674276021405027
kaggle_web_traffic (D),639.07,1270.6m,2452766448416.5938,20043.06505466295,0.29456241553874846,2.4090639022504874,1566131.044458475,0.0005082301009184855,0.5116392842774276,7.789624846643337,0.45522290224502393,2109413088817.9578,14272.143354708753
kaggle_web_traffic (D),639.07,1270.6m,2452766448416.5938,20043.06505466295,0.2945624155387484,2.4090639022504874,1566131.044458475,0.0005082301009184,0.5116392842774276,7.789624846643337,0.4552229022450239,2109413088817.9573,14272.143354708753
wind_farms_minutely (min),835.99,3349.23m,16.206971220492438,0.7136967852770562,0.01016127294468,0.0144428278254881,4.0257882731823385,0.0061153134085223,0.0171983034046891,0.0055462004428248,0.0083765532898652,12.73310756721593,0.5549177864411421
weather (D),1180.74,2492.3m,23.087079294192066,0.43981675735280407,0.11220441599089351,0.09457813572814078,4.804901590479462,0.00866736117753027,0.08763642437944673,0.023134484880568202,0.10417793350460773,2.903852911979907,0.18787836223668095
29 changes: 21 additions & 8 deletions src/samay/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2056,6 +2056,7 @@ def __init__(
stride=10,
context_len=512,
horizon_len=96,
normalize=True,
**kwargs,
):
super().__init__(
Expand All @@ -2071,6 +2072,7 @@ def __init__(

self.stride = stride
self.boundaries = boundaries
self.normalize = normalize

self.pad = False
self._read_data()
Expand All @@ -2093,15 +2095,15 @@ def _read_data(self):
if self.boundaries[2] == 0:
self.boundaries[2] = int(len(self.df) - 1)

scaler = StandardScaler()
self.scaler = StandardScaler()
if self.boundaries == [-1, -1, -1]:
# use all data for training
self.boundaries = [0, 0, len(self.df) - 1]

scaler = scaler.fit(self.df)
self.scaler = self.scaler.fit(self.df)
else:
# fit the scaler on the training data
scaler = scaler.fit(self.df[slice(0, self.boundaries[0]), :])
self.scaler = self.scaler.fit(self.df[slice(0, self.boundaries[0]), :])


if self.mode == "train":
Expand All @@ -2110,7 +2112,7 @@ def _read_data(self):
elif self.mode == "test":
self.data = self.df[slice(self.boundaries[1], self.boundaries[2]), :]

self.data = scaler.transform(self.data)
self.data = self.scaler.transform(self.data)

self.length_timeseries = self.data.shape[0]
self.required_len = self.context_len + self.horizon_len
Expand Down Expand Up @@ -2150,7 +2152,7 @@ def __getitem__(self, index):
return input_seq, forecast_seq

elif self.task_name == "finetune":
pred_end = seq_end + 1
pred_end = seq_end + self.horizon_len
if pred_end > self.length_timeseries:
pred_end = self.length_timeseries
seq_end = pred_end - 1
Expand All @@ -2160,8 +2162,8 @@ def __getitem__(self, index):
seq_start:seq_end, channel_idx
] # shape: (context_len, )
forecast_seq = self.data[seq_end:pred_end, channel_idx]
loss_mask = np.ones(input_seq.shape[0])
return input_seq, forecast_seq, loss_mask
# loss_mask = np.ones(input_seq.shape[0])
return input_seq, forecast_seq

def __len__(self):
if self.length_timeseries < self.context_len + self.horizon_len:
Expand All @@ -2173,4 +2175,15 @@ def get_data_loader(self):
return DataLoader(self, shuffle=True, batch_size=self.batchsize)
else:
return DataLoader(self, shuffle=False, batch_size=self.batchsize)
# shape: (batch_size, n_channels, seq_len)
# shape: (batch_size, n_channels, seq_len)

def _denormalize_data(self, data: np.ndarray):
data = np.asarray(data)
if self.normalize:
data = data[:, : self.n_channels, :]
data_flatten = np.transpose(data, (0, 2, 1)).reshape(-1, self.n_channels)
return self.scaler.inverse_transform(data_flatten).reshape(
data.shape[0], data.shape[1], data.shape[2]
)
else:
return data
162 changes: 42 additions & 120 deletions src/samay/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,10 @@ def evaluate(self, dataset, **kwargs):
trues = dataset._denormalize_data(trues)
preds = dataset._denormalize_data(preds)
histories = dataset._denormalize_data(histories)
quantiles = dataset._denormalize_data(q for q in quantiles)
new_quantiles = []
for i in range(quantiles.shape[0]):
new_quantiles.append(dataset._denormalize_data(quantiles[i]))
quantiles = np.array(new_quantiles)

mse = MSE(trues, preds)
mae = MAE(trues, preds)
Expand Down Expand Up @@ -2594,14 +2597,15 @@ class TimesFM_2p5_Model(Basemodel):

def __init__(self, config=None, repo=None, **kwargs):
super().__init__(config=config, repo=repo)
self.model = TimesFM_2p5_200M_torch(device=self.device)

if repo:
self.model.load_checkpoint(hf_repo_id=repo)
self.model = TimesFM_2p5_200M_torch.from_pretrained(repo, device=self.device)
else:
self.model.load_checkpoint()
self.model = TimesFM_2p5_200M_torch(device=self.device)

self.config = ForecastConfig(**config)
self.model.compile(self.config)
self.quantiles = self.model.model.config.quantiles

def evaluate(self, dataset, **kwargs):
"""Evaluate the model on the given dataset.
Expand All @@ -2614,6 +2618,7 @@ def evaluate(self, dataset, **kwargs):
"""
# self.model.to(self.device)
self.model.model.eval()
self.model.model.to(self.device)
inference_loader = dataset.get_data_loader()
horizon = dataset.horizon_len

Expand All @@ -2633,6 +2638,7 @@ def evaluate(self, dataset, **kwargs):
input_seq,
mask_seq,
)
quantile_forecast = quantile_forecast[..., 1:].transpose(2, 0, 1)

trues.append(target_seq.cpu().numpy())
preds.append(point_forecast)
Expand All @@ -2641,9 +2647,17 @@ def evaluate(self, dataset, **kwargs):

trues = np.concatenate(trues, axis=0).reshape(-1, dataset.n_channels, dataset.horizon_len)
preds = np.concatenate(preds, axis=0).reshape(-1, dataset.n_channels, dataset.horizon_len)
q_preds = np.concatenate(q_preds, axis=0).reshape(q_preds[-1].shape[-1], -1, dataset.n_channels, dataset.horizon_len)
q_preds = np.concatenate(q_preds, axis=1).reshape(q_preds[-1].shape[0], -1, dataset.n_channels, dataset.horizon_len)
histories = np.concatenate(histories, axis=0).reshape(-1, dataset.n_channels, dataset.context_len)

trues = dataset._denormalize_data(trues)
preds = dataset._denormalize_data(preds)
new_q_preds = np.zeros_like(q_preds)
for i in range(q_preds.shape[0]):
new_q_preds[i] = dataset._denormalize_data(q_preds[i])
q_preds = new_q_preds
histories = dataset._denormalize_data(histories)

# Calculate metrics
mse = MSE(trues, preds)
mae = MAE(trues, preds)
Expand All @@ -2654,8 +2668,8 @@ def evaluate(self, dataset, **kwargs):
smape = SMAPE(trues, preds)
msis = MSIS(trues, preds)
nd = ND(trues, preds)
mwsq = MWSQ(trues, preds, q_preds)
crps = CRPS(trues, preds, q_preds)
mwsq = MWSQ(trues, q_preds, quantiles=self.quantiles)
crps = CRPS(trues, q_preds, quantiles=self.quantiles)

cleanup_dataloader(inference_loader)
return {
Expand Down Expand Up @@ -2698,130 +2712,38 @@ def finetune(self, dataset, **kwargs):
epoch = 5 if "epoch" not in kwargs else kwargs["epoch"]

optimizer = torch.optim.AdamW(self.model.model.parameters(), lr=lr)
self.model.model.to(self.device)
# self.model.model.to(self.device)
self.model.model.train()
dataloader = dataset.get_data_loader()
horizon = dataset.horizon_len
# torch.autograd.set_detect_anomaly(True)

for ep in range(epoch):
total_loss = 0
for i, batch in enumerate(dataloader):
for batch in dataloader:
input_seq, target_seq = batch
batch_size = input_seq.shape[0]
input_seq = input_seq.float().to(self.device)
input_seq = input_seq.float()
target_seq = target_seq.float().to(self.device)
mask_seq = torch.zeros_like(input_seq).to(self.device)
mask_seq = torch.zeros_like(input_seq)

optimizer.zero_grad()
# perform auto-regressive forward pass
num_decode_steps = (dataset.horizon_len - 1) // self.model.model.o
num_input_patches = dataset.context_len // self.model.model.p
decode_cache_size = num_input_patches + num_decode_steps * self.model.model.m

# Prefill
patched_inputs = torch.reshape(input_seq, (batch_size, -1, self.p))
patched_masks = torch.reshape(mask_seq, (batch_size, -1, self.p))

# running stats
n = torch.zeros(batch_size, device=input_seq.device)
mu = torch.zeros(batch_size, device=input_seq.device)
sigma = torch.zeros(batch_size, device=input_seq.device)
patch_mu = []
patch_sigma = []
for i in range(num_input_patches):
(n, mu, sigma), _ = util.update_running_stats(
n, mu, sigma, patched_inputs[:, i], patched_masks[:, i]
)
patch_mu.append(mu)
patch_sigma.append(sigma)
last_n, last_mu, last_sigma = n, mu, sigma
context_mu = torch.stack(patch_mu, dim=1)
context_sigma = torch.stack(patch_sigma, dim=1)
decode_caches = [
util.DecodeCache(
next_index=torch.zeros(
batch_size, dtype=torch.int32, device=input_seq.device
),
num_masked=torch.zeros(
batch_size, dtype=torch.int32, device=input_seq.device
),
key=torch.zeros(
batch_size,
decode_cache_size,
self.model.model.h,
self.model.model.hd,
device=input_seq.device,
),
value=torch.zeros(
batch_size,
decode_cache_size,
self.model.model.h,
self.model.model.hd,
device=input_seq.device,
),
)
for _ in range(self.model.model.x)
]

normed_inputs = util.revin(
patched_inputs, context_mu, context_sigma, reverse=False
)
normed_inputs = torch.where(patched_masks, 0.0, normed_inputs)
(_, _, normed_outputs, normed_quantile_spread), decode_caches = self.model.model(
normed_inputs, patched_masks, decode_caches
)
renormed_outputs = torch.reshape(
util.revin(normed_outputs, context_mu, context_sigma, reverse=True),
(batch_size, -1, self.model.model.o, self.model.model.q),
point_forecast, quantile_forecast = self.model.compiled_decode(
horizon,
input_seq,
mask_seq,
train=True
)
renormed_quantile_spread = torch.reshape(
util.revin(
normed_quantile_spread, context_mu, context_sigma, reverse=True
),
(batch_size, -1, self.model.model.os, self.model.model.q),
)[:, -1, ...]

# Autogressive decode
ar_outputs = []
last_renormed_output = renormed_outputs[:, -1, :, self.model.model.aridx]

for _ in range(num_decode_steps):
new_patched_input = torch.reshape(
last_renormed_output, (batch_size, self.model.model.m, self.model.model.p)
)
new_mask = torch.zeros_like(new_patched_input, dtype=torch.bool)

n, mu, sigma = last_n, last_mu, last_sigma
new_mus, new_sigmas = [], []
for i in range(self.model.model.m):
(n, mu, sigma), _ = util.update_running_stats(
n, mu, sigma, new_patched_input[:, i], new_mask[:, i]
)
new_mus.append(mu)
new_sigmas.append(sigma)
last_n, last_mu, last_sigma = n, mu, sigma
new_mu = torch.stack(new_mus, dim=1)
new_sigma = torch.stack(new_sigmas, dim=1)

new_normed_input = util.revin(
new_patched_input, new_mu, new_sigma, reverse=False
)
(_, _, new_normed_output, _), decode_caches = self.model.model(
new_normed_input, new_mask, decode_caches
)

new_renormed_output = torch.reshape(
util.revin(new_normed_output, new_mu, new_sigma, reverse=True),
(batch_size, self.model.model.m, self.model.model.o, self.model.model.q),
)
ar_outputs.append(new_renormed_output[:, -1, ...])
last_renormed_output = new_renormed_output[:, -1, :, self.model.model.aridx]

if num_decode_steps > 0:
ar_renormed_outputs = torch.stack(ar_outputs, dim=1)
else:
ar_renormed_outputs = None

forecast_seq = ar_renormed_outputs if ar_renormed_outputs is not None else renormed_outputs[:, :, :dataset.horizon_len, :]
# quantile_forecast = quantile_forecast[..., 1:]
# point_forecast, quantile_forecast = torch.Tensor(point_forecast), torch.Tensor(quantile_forecast)

loss = nn.functional.mse_loss(point_forecast, target_seq)
quantiles = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]
for i, quantile in enumerate(quantiles):
last_patch_quantile = quantile_forecast[:, :, i + 1]
loss += torch.mean(
quantile_loss(last_patch_quantile, target_seq.squeeze(-1),
quantile))

loss.backward()
optimizer.step()
Expand Down
1 change: 1 addition & 0 deletions src/samay/models/timesfm/timesfm/v2/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ class TransformerConfig:
use_bias: bool
use_rotary_position_embeddings: bool
ff_activation: Literal["relu", "swish", "none"]
fuse_qkv: bool


@dataclasses.dataclass(frozen=True)
Expand Down
1 change: 1 addition & 0 deletions src/samay/models/timesfm/timesfm/v2/timesfm_2p5_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ class TimesFM_2p5_200M_Definition:
use_bias=False,
use_rotary_position_embeddings=True,
ff_activation="swish",
fuse_qkv=True,
),
)
output_projection_point: ResidualBlockConfig = ResidualBlockConfig(
Expand Down
Loading