In [1]:
# %%
%load_ext autoreload
%autoreload 2

import torch
from src.dataset import load_and_preprocess_data
from src.models import EALSTM
from src.training import train_epoch, evaluate

# %%
# 1. Configuration
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
EPOCHS = 20
HIDDEN_DIM = 256
LEARNING_RATE = 1e-3

print(f"Using device: {DEVICE}")

# %%
# 2. Load Data
# This takes a minute as it creates sliding windows
train_loader, test_loader, station_ids = load_and_preprocess_data(sequence_length=365)

# %%
# 3. Initialize Model
# Dynamic Features: Precip, Tmax, Tmin (3)
# Static Features: Area, Glacier% (2)
model = EALSTM(input_dim_dyn=3, 
               input_dim_stat=2, 
               hidden_dim=HIDDEN_DIM).to(DEVICE)

optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)

print(model)

# %%
# 4. Training Loop
for epoch in range(EPOCHS):
    train_loss = train_epoch(model, train_loader, optimizer, DEVICE)
    test_nse = evaluate(model, test_loader, DEVICE)
    
    print(f"Epoch {epoch+1}/{EPOCHS} | Train Loss: {train_loss:.4f} | Test NSE: {test_nse:.4f}")

# %%
# 5. Save Model
torch.save(model.state_dict(), "../models/ealstm_run1.pth")
print("Model saved.")

Using device: cpu
⏳ Loading datasets...
   Aligning dates and stations...
   Common Period: 1980-01-01 to 2022-12-31
   Common Stations: 132
   Normalizing features...


ValueError: all input arrays must have the same shape

In [2]:
from src.config import CLIMATE_OUTPUT_DIR
import pandas as pd

precip = pd.read_csv(CLIMATE_OUTPUT_DIR / "daily_precipitation.csv", index_col=0, parse_dates=True)

In [3]:
precip.head()

Unnamed: 0_level_0,05AA004,05AA008,05AA022,05AA027,05AA028,05AB005,05AB029,05AD003,05AD035,05AE005,...,07JC001,07JD002,07JD003,07JF002,07KE001,07OA001,07OB003,07OB004,07OB006,07OC001
datetime,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
1980-01-01,0.053606,0.232041,0.289149,0.06342,0.325748,2.1e-05,0.000505,0.294787,0.0,0.0017,...,0.256983,0.291562,0.286798,0.304178,0.392809,0.204253,0.227905,0.576973,0.527004,0.166458
1980-01-02,0.538679,1.546677,1.204349,1.178653,1.196,0.529095,0.34853,0.871475,0.159057,0.290331,...,0.458887,0.281282,0.902249,1.393473,1.181054,2.244732,2.325565,2.020966,1.933861,2.018898
1980-01-03,0.656272,1.329921,1.280606,0.723043,1.460468,0.587772,0.575559,1.53485,0.745324,0.288815,...,3.655144,3.306906,2.916684,3.817195,1.616598,4.219817,3.562104,2.722252,2.868103,3.963051
1980-01-04,0.034958,0.198072,0.176594,0.077786,0.189435,0.003662,0.001656,0.146171,0.0,0.002382,...,1.88757,2.166312,2.898976,1.858584,3.171192,1.7014,1.713192,1.737582,1.960607,1.615711
1980-01-05,5.758264,4.711472,5.823813,4.281838,6.383781,4.125207,4.206116,7.228878,5.601971,6.614492,...,0.179671,0.618457,0.301169,0.201826,0.431614,0.259102,0.239105,0.322188,0.365159,0.197012
