In [12]:
import sys
from pathlib import Path
from importlib import reload

project_root = Path.cwd().parent.parent
sys.path.insert(0, str(project_root))

import pandas as pd
import src.model
import src.trainer
import src.visualization

reload(src.model)
reload(src.trainer)
reload(src.visualization)

from src.model import TanaForecast
from src.trainer import TimeSeriesDataset, TanaForecastTrainer

train_df = pd.read_csv(project_root / 'src' / 'datasets' / 'delhi' / 'train.csv')
test_df = pd.read_csv(project_root / 'src' / 'datasets' / 'delhi' / 'test.csv')

train_df.describe()

Unnamed: 0,meantemp,humidity,wind_speed,meanpressure
count,1462.0,1462.0,1462.0,1462.0
mean,25.495521,60.771702,6.802209,1011.104548
std,7.348103,16.769652,4.561602,180.231668
min,6.0,13.428571,0.0,-3.041667
25%,18.857143,50.375,3.475,1001.580357
50%,27.714286,62.625,6.221667,1008.563492
75%,31.305804,72.21875,9.238235,1014.944901
max,38.714286,100.0,42.22,7679.333333


In [13]:
import torch

feature_cols = ['meantemp', 'humidity', 'wind_speed', 'meanpressure']
target_cols = ['meantemp']

train_dataset = TimeSeriesDataset(
    df=train_df,
    context_window=90,
    prediction_length=7,
    feature_columns=feature_cols,
    target_columns=target_cols,
    stride=1,
    normalize=True
)

val_dataset = TimeSeriesDataset(
    df=test_df,
    context_window=90,
    prediction_length=7,
    feature_columns=feature_cols,
    target_columns=target_cols,
    stride=1,
    normalize=True
)

print(f"Train dataset size: {len(train_dataset)}")
print(f"Val dataset size: {len(val_dataset)}")

context, target = train_dataset[0]
print(f"Context shape: {context.shape}")
print(f"Target shape: {target.shape}")


Train dataset size: 1366
Val dataset size: 18
Context shape: torch.Size([4, 90])
Target shape: torch.Size([1, 7])


In [14]:
from src.model import TanaForecast
from src.utils import Loss

model = TanaForecast(
    context_window=90,
    prediction_length=7,
)

trainer = TanaForecastTrainer(
    model=model,
    train_dataset=train_dataset,
    val_dataset=val_dataset,
    batch_size=64,
    learning_rate=1e-3,
    num_epochs=3,
    checkpoint_dir=str(project_root / 'checkpoints' / 'delhi'),
    early_stopping_patience=-1,
    loss_fn=Loss.quantile_loss,   # Direct reference
    loss_name='Quantile_0.9',      # Custom name for logging
    loss_kwargs={'q': 0.9}         # Quantile level
)

number_of_parameters = model.number_of_parameters
print(f"Number of parameters: {number_of_parameters}")
history = trainer.train()


Number of parameters: 13828698
Training on cpu
Total epochs: 3
Starting from epoch: 1
Batch size: 64
Train batches: 22
Val batches: 1
------------------------------------------------------------
Epoch 1/3 | Train Loss: 0.113533 | Val Loss: 0.125457 | LR: 7.53e-04 | Time: 10.10s
  → New best model saved (Val Loss: 0.125457)
Epoch 2/3 | Train Loss: 0.064423 | Val Loss: 0.204692 | LR: 2.58e-04 | Time: 11.23s
Epoch 3/3 | Train Loss: 0.056170 | Val Loss: 0.189643 | LR: 1.00e-05 | Time: 11.84s
------------------------------------------------------------
Training completed. Best Val Loss: 0.125457
Training run logged to /Users/amaurydelille/Documents/projects/tana-forecast/src/logs/training_logs.csv


In [15]:
from src.visualization import plot_training_history

fig = plot_training_history(history)
fig.show()

print(f"\nBest Validation Loss: {min(history['val_loss']):.6f}")
print(f"Final Train Loss: {history['train_loss'][-1]:.6f}")



Best Validation Loss: 0.125457
Final Train Loss: 0.056170


In [16]:
from src.visualization import plot_forecast, compute_metrics, print_metrics

context, target = val_dataset[10]
prediction = trainer.predict(context.unsqueeze(0)).squeeze(0)

fig = plot_forecast(
    context=context,
    target=target,
    prediction=prediction,
    dataset=val_dataset,
    title='Delhi Temperature Forecast (7-day ahead)',
    feature_idx=0,
    feature_name='Temperature (°C)'
)
fig.show()

target_denorm = val_dataset.denormalize(target, is_target=True)
prediction_denorm = val_dataset.denormalize(prediction, is_target=True)

metrics = compute_metrics(target_denorm[0], prediction_denorm[0])
print_metrics(metrics)


MSE:  3.6558
MAE:  1.5480
RMSE: 1.9120
MAPE: 0.0495
SMAPE: 0.0514
