In [2]:
# %%
%load_ext autoreload
%autoreload 2

import torch
from src.dataset import load_and_preprocess_data
from src.models import EALSTM
from src.training import train_epoch, evaluate

# %%
# 1. Configuration
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
EPOCHS = 20
HIDDEN_DIM = 256
LEARNING_RATE = 1e-3

print(f"Using device: {DEVICE}")

# %%
# 2. Load Data
# This takes a minute as it creates sliding windows
train_loader, test_loader, station_ids = load_and_preprocess_data(sequence_length=365)

# %%
# 3. Initialize Model
# Dynamic Features: Precip, Tmax, Tmin (3)
# Static Features: Area, Glacier% (2)
model = EALSTM(input_dim_dyn=3, 
               input_dim_stat=2, 
               hidden_dim=HIDDEN_DIM).to(DEVICE)

optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)

print(model)

# %%
# 4. Training Loop
for epoch in range(EPOCHS):
    train_loss = train_epoch(model, train_loader, optimizer, DEVICE)
    test_nse = evaluate(model, test_loader, DEVICE)
    
    print(f"Epoch {epoch+1}/{EPOCHS} | Train Loss: {train_loss:.4f} | Test NSE: {test_nse:.4f}")

# %%
# 5. Save Model
torch.save(model.state_dict(), "../models/ealstm_run1.pth")
print("Model saved.")

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
Using device: cpu
⏳ Loading datasets...
   Aligning dates and stations...
   Common Stations: 132
   Master Index Range: 1979-01-01 to 2022-12-31
   Normalizing features...
   Generating Training Sequences...
   Processing 8401 time steps...
   Generating Testing Sequences...
   Processing 7305 time steps...
✅ Data Ready.
   Train Samples: 1108932
   Test Samples:  964260
EALSTM(
  (input_gate_net): Linear(in_features=2, out_features=256, bias=True)
  (w_f): Linear(in_features=259, out_features=256, bias=True)
  (w_g): Linear(in_features=259, out_features=256, bias=True)
  (w_o): Linear(in_features=259, out_features=256, bias=True)
  (dropout): Dropout(p=0.4, inplace=False)
  (head): Linear(in_features=256, out_features=1, bias=True)
)


Training:   0%|          | 8/4332 [00:11<1:40:01,  1.39s/it]


KeyboardInterrupt: 