In [1]:
# Library imports
import torch
import torch.nn as nn
import torch.optim as optim
from torch.cuda.amp import autocast
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
import numpy as np
from math import sqrt
from tqdm import tqdm
import gc

# Importing model, data loaders and other utilities
from models.neo_phurie import NeoPHURIE
from models.oracle_phurie import OraclePHURIE
from dataLoaders.embedding_loader import EmbeddingDataset
from utils import OneCycleLR

# Constants
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
BATCH_SIZE = 32
DATASET_PATH = "./datasets/embedding_dataset_2015.data"
INTENSITY_STD = 23.12


## Model Training

In [2]:
def train_model(model, 
          num_epochs, 
          learning_rate= (1e-5, 1e-4), 
          regularization_rate = 0, 
          train_years = [],
          verbose = False):
    
    # Memory management
    gc.collect()
    torch.cuda.empty_cache()

    # Training Dataset
    train_loader = DataLoader(EmbeddingDataset(DATASET_PATH, train_years), batch_size=BATCH_SIZE, shuffle=True)

    # Training optimizer, scheduler and loss
    optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate[0], weight_decay=regularization_rate)
    scheduler = OneCycleLR(optimizer, num_steps = num_epochs*len(train_loader), lr_range=learning_rate)
    criterion = nn.MSELoss()

    # Scaler for mixed-precision training
    scaler = torch.cuda.amp.GradScaler()

    # Moving model to compute device
    model.to(device)
    
    # Training loop
    for epoch in range(num_epochs):

        # Training Phase
        model.train()
        train_squared_residuals = 0.0
        
        # Progress bar
        pbar = tqdm(enumerate(train_loader, 0), total = len(train_loader), unit = "pairs", unit_scale=BATCH_SIZE, disable=not verbose)
        
        # Iterate through batches in training dataset
        for i, data in pbar:
            x, y = data

            optimizer.zero_grad()

            with autocast():
                outputs = model(x.to(device))
                loss = criterion(outputs, y.to(device))

            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
            scheduler.step()

            # print statistics
            train_squared_residuals += nn.functional.mse_loss(outputs, y.to(device)).item() * BATCH_SIZE
            pbar.set_description(f"[Epoch {epoch + 1}]: Training RMSE = {sqrt(train_squared_residuals / ((i + 1) * BATCH_SIZE)):.4f}")

        train_mse = train_squared_residuals / len(train_loader.dataset)
                
        if verbose:
            print(f"[Epoch {epoch + 1}]: Training RMSE = {sqrt(train_mse):.4f}")

        # Memory management
        gc.collect()
        torch.cuda.empty_cache()

In [3]:
oracle = OraclePHURIE()
train_model(oracle,
    num_epochs = 5,
    learning_rate = (1e0, 1e1),
    regularization_rate = 1e-5,
    train_years = range(2001, 2015),
    verbose = True)

[Epoch 1]: Training RMSE = 0.1411: 100%|██████████| 52192/52192 [00:28<00:00, 1850.43pairs/s]


[Epoch 1]: Training RMSE = 0.1411


[Epoch 2]: Training RMSE = 0.1279: 100%|██████████| 52192/52192 [00:28<00:00, 1848.28pairs/s]


[Epoch 2]: Training RMSE = 0.1279


[Epoch 3]: Training RMSE = 0.1268: 100%|██████████| 52192/52192 [00:26<00:00, 1963.16pairs/s]


[Epoch 3]: Training RMSE = 0.1268


[Epoch 4]: Training RMSE = 0.1232: 100%|██████████| 52192/52192 [00:27<00:00, 1881.02pairs/s]


[Epoch 4]: Training RMSE = 0.1233


[Epoch 5]: Training RMSE = 0.1163: 100%|██████████| 52192/52192 [00:27<00:00, 1876.24pairs/s]


[Epoch 5]: Training RMSE = 0.1164


## Model Testing

In [23]:
def test_oracle(oracle,
                decoder,
                testing_years = [],
                max_forecast_steps = 10,
                useMeasuredIntensities = False):

    decoder.to(device)
    decoder.eval()
    oracle.to(device)
    oracle.eval()

    testing_dataset = EmbeddingDataset(DATASET_PATH, testing_years)

    errors_at_offset = {}
    for l in range(1, max_forecast_steps+1):
        errors_at_offset[l] = []

    with torch.no_grad():
        for hurricane in tqdm(testing_dataset.get_hurricanes(), total = len([h for h in testing_dataset.get_hurricanes()])):
            x, y = hurricane
            x, y = x.to(device), y.to(device)
                   
            # Get NeoPHURIE predictions for intensities 
            pred_y = decoder.predict_from_embedding(x)
            
            for start_idx in range(x.shape[0] - 2):
                
                # Embedding with which to start forecasting
                embedding_1 = torch.clone(torch.cat([x[start_idx], 
                                                     y[start_idx] if useMeasuredIntensities else pred_y[start_idx]])).to(device)
                embedding_2 = torch.clone(torch.cat([x[start_idx+1], 
                                                     y[start_idx+1] if useMeasuredIntensities else pred_y[start_idx+1]])).to(device)
                
                for offset in range(1, min(max_forecast_steps+1, x.shape[0] - start_idx - 1)):
                    
                    pred_embedding = oracle(torch.cat([embedding_1, embedding_2]).to(device))
                    pred_intensity = pred_embedding[1152]
                
                    errors_at_offset[offset].append(abs(pred_intensity.item() - y[start_idx + offset + 1].item()) * INTENSITY_STD)
                    
                    embedding_1 = embedding_2
                    embedding_2 = pred_embedding
    
    test_statistics = []
    for l in range(1, max_forecast_steps+1):
        test_statistics.append((np.mean(errors_at_offset[l]),
                                np.quantile(errors_at_offset[l], 0.25),
                                np.quantile(errors_at_offset[l], 0.5),
                                np.quantile(errors_at_offset[l], 0.75)))
        
    return test_statistics

In [29]:
decoder = NeoPHURIE()
decoder.load_state_dict(torch.load("./checkpoints/neo_phurie_2015.pt"))

statistics = test_oracle(oracle,
                         decoder,
                         testing_years = [2015],
                         max_forecast_steps=24,
                         useMeasuredIntensities=False)

for i, stat in enumerate(statistics):
    print(f"Forecasting at {3*i + 3} hours: MAE = {stat[0]:.2f} ; Q1 = {stat[1]:.2f} ; Median = {stat[2]:.2f} ; Q3 = {stat[3]:.2f}")

100%|██████████| 83/83 [01:48<00:00,  1.31s/it]

Forecasting at 3 hours: MAE = 7.12 ; Q1 = 2.25 ; Median = 5.16 ; Q3 = 9.88
Forecasting at 6 hours: MAE = 7.35 ; Q1 = 2.35 ; Median = 5.37 ; Q3 = 10.21
Forecasting at 9 hours: MAE = 7.71 ; Q1 = 2.41 ; Median = 5.60 ; Q3 = 10.63
Forecasting at 12 hours: MAE = 8.23 ; Q1 = 2.61 ; Median = 6.01 ; Q3 = 11.31
Forecasting at 15 hours: MAE = 8.85 ; Q1 = 2.76 ; Median = 6.39 ; Q3 = 12.26
Forecasting at 18 hours: MAE = 9.56 ; Q1 = 2.95 ; Median = 6.86 ; Q3 = 13.34
Forecasting at 21 hours: MAE = 10.31 ; Q1 = 3.10 ; Median = 7.43 ; Q3 = 14.30
Forecasting at 24 hours: MAE = 11.12 ; Q1 = 3.46 ; Median = 7.94 ; Q3 = 15.40
Forecasting at 27 hours: MAE = 11.89 ; Q1 = 3.68 ; Median = 8.38 ; Q3 = 16.61
Forecasting at 30 hours: MAE = 12.71 ; Q1 = 3.85 ; Median = 8.92 ; Q3 = 17.77
Forecasting at 33 hours: MAE = 13.55 ; Q1 = 3.98 ; Median = 9.53 ; Q3 = 18.91
Forecasting at 36 hours: MAE = 14.48 ; Q1 = 4.36 ; Median = 10.39 ; Q3 = 20.28
Forecasting at 39 hours: MAE = 15.39 ; Q1 = 4.62 ; Median = 11.12 ; Q3 = 


