In [394]:
import numpy as np
import polars as pl
import os
from pathlib import Path
from typing import List, Tuple, Union
from datetime import datetime, timedelta
from pydantic import BaseModel
from loguru import logger
from enum import Enum
import random


class EndStatus(Enum):
    DISSIPATED = 0
    MOVE_OUT_OF_RESPONSIBILITY = 1
    MERGED = 2
    NEARLY_STATIONARY = 3


class CycloneCategory(Enum):
    BELOW_TD_OR_UNKNOWN = 0
    TROPICAL_DEPRESSION = 1  # 热带低压 (TD, 10.8-17.1m/s)
    TROPICAL_STORM = 2  # 热带风暴 (TS, 17.2-24.4 m/s)
    SEVERE_TROPICAL_STORM = 3  # 强热带风暴 (STS, 24.5-32.6 m/s)
    TYPHOON = 4  # 台风 (TY, 32.7-41.4 m/s)
    SEVERE_TYPHOON = 5  # 强台风 (STY, 41.5-50.9 m/s)
    SUPER_TYPHOON = 6  # 超强台风 (SuperTY, ≥51.0 m/s)
    EXTRATROPICAL = 9  # 变性 (The change is complete)


class HurricaneHeader(BaseModel):
    data_type: int
    country_code: int
    data_count: int
    hurricane_code: int
    china_hurricane_code: int
    end_status: EndStatus
    time_interval_hr: int
    hurricane_name: str
    dataset_record_time: datetime


class HurricaneEntry(BaseModel):
    date: datetime
    category: CycloneCategory
    latitude: float
    longitude: float
    lowest_pressure: int
    wind_speed: int


class Hurricane(BaseModel):
    header: HurricaneHeader
    entries: List[HurricaneEntry]


script_folder = Path(os.getcwd())
dataset_folder = script_folder / "CMABSTdata"

# https://tcdata.typhoon.org.cn/zjljsjj.html
# example_file = dataset_folder / "CH2022BST.txt"
example_file = dataset_folder / "CH1950BST.txt"
logger.info(f"example_file: {example_file}")


def parse_header(line: str) -> HurricaneHeader:
    entry = line.split()
    data_type = int(entry[0])
    country_code = int(entry[1])
    data_count = int(entry[2])
    hurricane_code = int(entry[3])
    try:
        china_hurricane_code = int(entry[4])
    except ValueError:
        # might be a tuple (a,b)
        codes = entry[4].split(",")
        china_hurricane_code = int(codes[0])
    hurricane_end_enum = int(entry[5])
    end_status = EndStatus(hurricane_end_enum)
    time_interval_hr = int(entry[6])
    hurricane_name = entry[7]
    dataset_record_time = entry[8]
    time_format = "%Y%m%d"
    dataset_record_time = datetime.strptime(dataset_record_time, time_format)
    return HurricaneHeader(data_type=data_type,
                           country_code=country_code,
                           data_count=data_count,
                           hurricane_code=hurricane_code,
                           china_hurricane_code=china_hurricane_code,
                           end_status=end_status,
                           time_interval_hr=time_interval_hr,
                           hurricane_name=hurricane_name,
                           dataset_record_time=dataset_record_time)


def parse_entry(line: str) -> HurricaneEntry:
    entry = line.split()
    date_str = entry[0]
    time_format = "%Y%m%d%H"
    date = datetime.strptime(date_str, time_format)
    category = int(entry[1])
    hurricane_category = CycloneCategory(category)
    latitude = float(int(entry[2])) / 10.0
    longitude = float(int(entry[3])) / 10.0
    # in hPa
    lowest_pressure = int(entry[4])
    # 2分钟平均近中心最大风速(MSW, m/s)
    # WND=9 表示 MSW < 10m/s,
    # WND=0 为缺测
    wind_speed = int(entry[5])
    # not sure about OWD
    return HurricaneEntry(date=date,
                          category=hurricane_category,
                          latitude=latitude,
                          longitude=longitude,
                          lowest_pressure=lowest_pressure,
                          wind_speed=wind_speed)


def parse_dataset(filename):
    hurricanes: list[Hurricane] = []
    with open(filename, "r") as f:
        try:
            while True:
                # check if the line is empty
                l = f.readline()
                if not l:
                    break
                header = parse_header(l)
                count = header.data_count
                hurricane_entries = []
                for i in range(count):
                    entry = parse_entry(f.readline())
                    hurricane_entries.append(entry)
                hurricane = Hurricane(header=header, entries=hurricane_entries)
                hurricanes.append(hurricane)
        except ValueError as e:
            logger.error(f"ValueError: {e} for {filename}")
        except IndexError as e:
            logger.warning(f"IndexError: {e} for {filename}")
        except EOFError:
            logger.info(f"EOFError for {filename}")
    return hurricanes



[32m2024-04-23 02:44:08.413[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m63[0m - [1mexample_file: /home/crosstyan/code/hurricane_stuff/CMABSTdata/CH1950BST.txt[0m


In [395]:
total_dataset: list[Hurricane] = []

for file in dataset_folder.glob("*.txt"):
    hurricanes = parse_dataset(file)
    total_dataset.extend(hurricanes)

logger.info(f"total_dataset: {len(total_dataset)}")

[32m2024-04-23 02:44:10.182[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m7[0m - [1mtotal_dataset: 2469[0m


In [396]:
class FlatHurricaneEntry(BaseModel):
    sample_id: int
    name: str
    china_hurricane_code: int
    date: datetime
    category: CycloneCategory
    latitude: float
    longitude: float
    lowest_pressure: int
    wind_speed: int


def flat_hurricane_entries(
        hurricanes: list[Hurricane]) -> List[FlatHurricaneEntry]:
    counter = 0
    def flat_one(h: Hurricane, counter: int = counter):
        name = h.header.hurricane_name
        hurricane_code = h.header.hurricane_code
        entries = h.entries
        return [
            FlatHurricaneEntry(sample_id=counter,
                               name=name,
                               china_hurricane_code=hurricane_code,
                               date=e.date,
                               category=e.category,
                               latitude=e.latitude,
                               longitude=e.longitude,
                               lowest_pressure=e.lowest_pressure,
                               wind_speed=e.wind_speed) for e in entries
        ]

    entries = []
    for h in hurricanes:
        entries.extend(flat_one(h, counter))
        counter += 1
    return entries


flatten_entries = [
    e.model_dump() for e in flat_hurricane_entries(total_dataset)
]


def entry_enum_to_number(entry: dict[str, any]) -> dict[str, any]:
    entry['category'] = entry['category'].value
    return entry


flatten_entries_without_enum = [
    entry_enum_to_number(e) for e in flatten_entries
]

In [397]:
df = pl.DataFrame(flatten_entries_without_enum)
df_filtered = df.filter(df["wind_speed"] != 0)
df_filtered.describe()

statistic,sample_id,name,china_hurricane_code,date,category,latitude,longitude,lowest_pressure,wind_speed
str,f64,str,f64,str,f64,f64,f64,f64,f64
"""count""",65796.0,"""65796""",65796.0,"""65796""",65796.0,65796.0,65796.0,65796.0,65796.0
"""null_count""",0.0,"""0""",0.0,"""0""",0.0,0.0,0.0,0.0,0.0
"""mean""",1246.461107,,17.502265,"""1985-09-05 05:…",2.866664,20.73495,133.459558,984.894963,25.891893
"""std""",698.361356,,10.381192,,2.121499,8.752285,16.292624,21.13435,14.21218
"""min""",0.0,"""(nameless)""",1.0,"""1949-01-15 00:…",0.0,0.5,95.0,870.0,8.0
"""25%""",656.0,,9.0,"""1968-06-01 00:…",1.0,14.5,121.2,975.0,15.0
"""50%""",1232.0,,17.0,"""1984-10-31 00:…",2.0,19.4,131.7,992.0,20.0
"""75%""",1843.0,,25.0,"""2003-06-17 18:…",4.0,25.5,143.9,1000.0,35.0
"""max""",2468.0,"""Zola""",53.0,"""2022-12-13 06:…",9.0,70.1,243.9,1016.0,110.0


In [398]:
import math
time = df["date"][6]
assert isinstance(time, datetime)
# use sin/cos to normalize the day in a year and the hour in a day

def sinusoidal_hour_in_day(dt: datetime) -> tuple[float, float]:
    """
    Return sin and cos corresponding to the hour of day from a datetime object.
    """
    # Extract the hour from the datetime object
    hour = dt.hour

    # Calculate the radians for the given hour
    radians_per_hour = 2 * math.pi / 24
    hour_in_radians = hour * radians_per_hour

    # Return the sine and cosine values
    return math.sin(hour_in_radians), math.cos(hour_in_radians)

def sinusoidal_day_in_year(dt: datetime) -> tuple[float, float]:
    """
    Return sin and cos corresponding to the day of year from a datetime object.
    """
    # Extract the day of year from the datetime object
    day_of_year = dt.timetuple().tm_yday

    # Handle leap years
    year_length = 366 if dt.year % 4 == 0 and (dt.year % 100 != 0 or dt.year % 400 == 0) else 365

    # Calculate the radians for the given day of year
    radians_per_day = 2 * math.pi / year_length
    day_in_radians = day_of_year * radians_per_day

    # Return the sine and cosine values
    return math.sin(day_in_radians), math.cos(day_in_radians)

logger.info(f"{time} -> {sinusoidal_hour_in_day(time)} {sinusoidal_day_in_year(time)}")

[32m2024-04-23 02:44:16.048[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m37[0m - [1m2013-01-02 12:00:00 -> (1.2246467991473532e-16, -1.0) (0.03442161162274574, 0.9994074007397048)[0m


In [399]:
from sklearn.preprocessing import MinMaxScaler, StandardScaler, RobustScaler

# longitude and latitude
lat_scaler = RobustScaler()
latitude = lat_scaler.fit_transform(df_filtered["latitude"].to_numpy().reshape(-1, 1)).reshape(-1)
long_scaler = RobustScaler()
longitude = long_scaler.fit_transform(df_filtered["longitude"].to_numpy().reshape(-1, 1)).reshape(-1)

# wind speed
wind_scaler = StandardScaler()
wind_speed = wind_scaler.fit_transform(df_filtered["wind_speed"].to_numpy().reshape(-1, 1)).reshape(-1)

# lowest pressure
lowest_pressure_scaler = StandardScaler()
lowest_pressure = lowest_pressure_scaler.fit_transform(df_filtered["lowest_pressure"].to_numpy().reshape(-1, 1)).reshape(-1)

In [400]:
with_normalized_time = df_filtered.with_columns([
    df_filtered["date"].map_elements(lambda x: sinusoidal_day_in_year(x)[0]).alias("sin_day_in_year"),
    df_filtered["date"].map_elements(lambda x: sinusoidal_day_in_year(x)[1]).alias("cos_day_in_year"),
    df_filtered["date"].map_elements(lambda x: sinusoidal_hour_in_day(x)[0]).alias("sin_hour_in_day"),
    df_filtered["date"].map_elements(lambda x: sinusoidal_hour_in_day(x)[1]).alias("cos_hour_in_day"),
    pl.Series("latitude_norm", latitude),
    pl.Series("longitude_norm", longitude),
    pl.Series("wind_speed_norm", wind_speed),
    pl.Series("lowest_pressure_norm", lowest_pressure),
])

df_features = with_normalized_time.select([
    "sample_id",
    "sin_day_in_year",
    "cos_day_in_year",
    "sin_hour_in_day",
    "cos_hour_in_day",
    "latitude_norm",
    "longitude_norm",
    "wind_speed_norm",
    "lowest_pressure_norm",
])

  df_filtered["date"].map_elements(lambda x: sinusoidal_day_in_year(x)[0]).alias("sin_day_in_year"),
  df_filtered["date"].map_elements(lambda x: sinusoidal_day_in_year(x)[1]).alias("cos_day_in_year"),
  df_filtered["date"].map_elements(lambda x: sinusoidal_hour_in_day(x)[0]).alias("sin_hour_in_day"),
  df_filtered["date"].map_elements(lambda x: sinusoidal_hour_in_day(x)[1]).alias("cos_hour_in_day"),


In [401]:
df_features.height
df_features.describe()

statistic,sample_id,sin_day_in_year,cos_day_in_year,sin_hour_in_day,cos_hour_in_day,latitude_norm,longitude_norm,wind_speed_norm,lowest_pressure_norm
str,f64,f64,f64,f64,f64,f64,f64,f64,f64
"""count""",65796.0,65796.0,65796.0,65796.0,65796.0,65796.0,65796.0,65796.0,65796.0
"""null_count""",0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
"""mean""",1246.461107,-0.498177,-0.264362,0.006157,0.003637,0.121359,0.077514,7.5743e-17,-2.4017e-16
"""std""",698.361356,0.550623,0.615432,0.705854,0.708332,0.795662,0.717737,1.000008,1.000008
"""min""",0.0,-0.999991,-1.0,-1.0,-1.0,-1.718182,-1.61674,-1.258922,-5.43645
"""25%""",656.0,-0.927542,-0.809017,0.0,-0.707107,-0.445455,-0.462555,-0.766383,-0.468197
"""50%""",1232.0,-0.699458,-0.413279,1.2246e-16,6.1232e-17,0.0,0.0,-0.41457,0.336187
"""75%""",1843.0,-0.263665,0.209315,1.0,1.0,0.554545,0.537445,0.640871,0.71472
"""max""",2468.0,0.999991,1.0,1.0,1.0,4.609091,4.942731,5.918075,1.471788


In [402]:
from numpy.typing import NDArray
from functools import reduce
# group by sample_id and iterate over the groups
grouped = df_features.group_by("sample_id")
from typing import Iterable, Iterator, Tuple, Union

EXPECTED_TIMESTAMP_COUNT = 20


def filter_out_short_sequence(id_and_df: tuple[int, pl.DataFrame]) -> bool:
    return id_and_df[1].height >= EXPECTED_TIMESTAMP_COUNT


def pad_or_truncate(
        id_and_df: tuple[int, pl.DataFrame]) -> tuple[pl.Series, pl.DataFrame]:
    group_id, df = id_and_df
    if df.height < EXPECTED_TIMESTAMP_COUNT:
        # pad with zeros
        diff = EXPECTED_TIMESTAMP_COUNT - df.height
        mask = pl.Series("mask", [True] * df.height + [False] * diff)
        zeros = pl.DataFrame({
            "sample_id": [group_id] * diff,
            "sin_day_in_year": [0.0] * diff,
            "cos_day_in_year": [0.0] * diff,
            "sin_hour_in_day": [0.0] * diff,
            "cos_hour_in_day": [0.0] * diff,
            "latitude_norm": [0.0] * diff,
            "longitude_norm": [0.0] * diff,
            "wind_speed_norm": [0.0] * diff,
            "lowest_pressure_norm": [0.0] * diff,
        })
        stacked = df.vstack(zeros)
        # sort by date
        return mask, stacked.sort("date")
    elif df.height >= EXPECTED_TIMESTAMP_COUNT:
        # truncate
        mask = pl.Series("mask", [True] * EXPECTED_TIMESTAMP_COUNT)
        return mask, df.head(EXPECTED_TIMESTAMP_COUNT)
    else:
        mask = pl.Series("mask", [True] * df.height)
        return mask, df


filtered = filter(filter_out_short_sequence, grouped)
padded = map(pad_or_truncate, filtered)


# for some reason, the reduce function is not working
def to_tensor(
        id_and_df: Iterable[tuple[int,
                                  pl.DataFrame]]) -> tuple[NDArray, NDArray]:
    init_mask, init_data = np.empty(
        (0, EXPECTED_TIMESTAMP_COUNT, 1)), np.empty(
            (0, EXPECTED_TIMESTAMP_COUNT, df_features.width))
    for mask, df in id_and_df:
        current_data = df.to_numpy()
        current_mask = np.expand_dims(mask.to_numpy(), axis=-1)
        try:
            new_data = np.vstack(
                (init_data, np.expand_dims(current_data, axis=0)))
            new_mask = np.vstack(
                (init_mask, np.expand_dims(current_mask, axis=0)))
        except ValueError as e:
            logger.error(f"ValueError: {e}")
            logger.info(
                f"init_data: {init_data.shape}, current_data: {current_data.shape}"
            )
            logger.info(
                f"init_mask: {init_mask.shape}, current_mask: {current_mask.shape}"
            )
            logger.info(
                f"init_data: {init_data}, current_data: {current_data}")
            logger.info(
                f"init_mask: {init_mask}, current_mask: {current_mask}")

        init_data, init_mask = new_data, new_mask
    return init_data, init_mask


data_with_id, mask = to_tensor(padded)
# remove the sample_id column
features = data_with_id[:, :, 1:]

  filtered = filter(filter_out_short_sequence, grouped)


In [403]:
import torch
logger.info(f"torch.cuda.is_available(): {torch.cuda.is_available()}")
logger.info(f"torch.cuda.current_device(): {torch.cuda.current_device()}")
logger.info(f"torch.cuda.device_count(): {torch.cuda.device_count()}")

[32m2024-04-23 02:44:19.304[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m2[0m - [1mtorch.cuda.is_available(): True[0m
[32m2024-04-23 02:44:19.306[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m3[0m - [1mtorch.cuda.current_device(): 0[0m
[32m2024-04-23 02:44:19.307[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m4[0m - [1mtorch.cuda.device_count(): 2[0m


In [404]:
diff_features = np.diff(features, axis=1)
X_train = diff_features[:, :10, :]
Y_train = diff_features[:, -10:, :]
display(X_train.shape, Y_train.shape)

(1599, 10, 8)

(1599, 10, 8)

In [405]:
# swap shape to (batch, features, seq) from (batch, seq, features)
X_train = np.swapaxes(X_train, 1, 2)
Y_train = np.swapaxes(Y_train, 1, 2)
display(X_train.shape, Y_train.shape)

(1599, 8, 10)

(1599, 8, 10)

In [406]:
from pytorch_tcn import TCN
from torchsummary import summary
import pytorch_tcn as tcn

num_inputs = X_train.shape[1]
num_outputs = Y_train.shape[1]

In [407]:
from typing import Any
import lightning as L
from lightning import LightningModule
from lightning.pytorch import Trainer
from lightning.pytorch.callbacks import ModelCheckpoint, EarlyStopping, LearningRateMonitor
from lightning.pytorch.loggers import TensorBoardLogger
from torch import optim, nn, utils, Tensor
from lion_pytorch import Lion


class TCNModel(LightningModule):
    num_inputs: int
    num_outputs: int
    tcn: TCN
    linear: torch.nn.Linear

    def __init__(self, num_inputs: int, num_outputs: int):
        super(TCNModel, self).__init__()
        self.num_inputs = num_inputs
        self.num_outputs = num_outputs
        self.tcn = TCN(num_inputs=num_inputs,
                       num_channels=[num_outputs] * num_inputs,
                       kernel_size=8,
                       dropout=0.1)
    
    def summary(self):
        return summary(self.tcn, (self.num_inputs, 10), verbose=0)

    def forward(self, x):
        x = self.tcn(x)
        return x

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = nn.functional.mse_loss(y_hat, y)
        self.log("train_loss",
                 loss,
                 on_step=True,
                 on_epoch=True,
                 prog_bar=True,
                 logger=True)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        val_loss = nn.functional.mse_loss(y_hat, y)
        self.log("val_loss",
                 val_loss,
                 on_step=True,
                 on_epoch=True,
                 prog_bar=True,
                 logger=True)
        return {"val_loss": val_loss}
    
    def predict(self, x):
        return self(x)

    def configure_optimizers(self):
        optimizer = Lion(self.parameters(), lr=3e-4, weight_decay=1e-2, use_triton=True)
        return optimizer


In [408]:
from torch.utils.data import Dataset, DataLoader, TensorDataset, random_split

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = TCNModel(num_inputs, num_outputs)
# display(model.summary())

early_stopping_callback = EarlyStopping(monitor="val_loss", patience=4)
val_ckpt = ModelCheckpoint(monitor="val_loss",
                           dirpath="checkpoints",
                           filename="tcn-{epoch:02d}-{val_loss:.2f}",
                           auto_insert_metric_name=True,
                           save_top_k=3,
                           mode="min")
train_loss_ckpt = ModelCheckpoint(monitor="train_loss",
                                  dirpath="checkpoints",
                                  filename="tcn-{epoch:02d}-{train_loss:.2f}",
                                  auto_insert_metric_name=True,
                                  save_top_k=3,
                                  mode="min")
logger = TensorBoardLogger("logs", name="tcn")
trainer = Trainer(max_epochs=10,
                  logger=logger,
                  callbacks=[val_ckpt, train_loss_ckpt, early_stopping_callback])

Trainer will use only 1 of 2 GPUs because it is running inside an interactive / notebook environment. You may try to set `Trainer(devices=2)` but please note that multi-GPU inside interactive / notebook environments is considered experimental and unstable. Your mileage may vary.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [409]:
X_train_tensor = torch.tensor(X_train, dtype=torch.float32, device=device)
Y_train_tensor = torch.tensor(Y_train, dtype=torch.float32, device=device)
print("X_train_tensor.shape", X_train_tensor.shape)
print("Y_train_tensor.shape", Y_train_tensor.shape)
dataset = TensorDataset(X_train_tensor, Y_train_tensor)
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

BATCH_SIZE = 8192
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)
trainer.fit(model, train_dataset, val_dataset)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]

  | Name | Type | Params
------------------------------
0 | tcn  | TCN  | 8.4 K 
------------------------------
8.4 K     Trainable params
0         Non-trainable params
8.4 K     Total params
0.034     Total estimated model params size (MB)


X_train_tensor.shape torch.Size([1599, 8, 10])
Y_train_tensor.shape torch.Size([1599, 8, 10])
Epoch 9: 100%|██████████| 1279/1279 [00:40<00:00, 31.52it/s, v_num=1, train_loss_step=0.277, val_loss_step=0.329, val_loss_epoch=0.332, train_loss_epoch=0.333] 

`Trainer.fit` stopped: `max_epochs=10` reached.


Epoch 9: 100%|██████████| 1279/1279 [00:40<00:00, 31.50it/s, v_num=1, train_loss_step=0.277, val_loss_step=0.329, val_loss_epoch=0.332, train_loss_epoch=0.333]


In [489]:
# use the model to predict the next 10 time steps
# use the last 10 time steps from the training set
# random pick a sample
random_sample_idx = random.randint(0, X_train.shape[0])
X_sample = features[random_sample_idx, :, :]
X_sample = X_sample.reshape(1, X_sample.shape[0], X_sample.shape[1])
display(X_sample.shape)
wind_sample = wind_scaler.inverse_transform(X_sample[0, :, 6].reshape(-1, 1))
pressure_sample = lowest_pressure_scaler.inverse_transform(X_sample[0, :, 7].reshape(-1, 1))
lat_sample = lat_scaler.inverse_transform(X_sample[0, :, 4].reshape(-1, 1))
long_sample = long_scaler.inverse_transform(X_sample[0, :, 5].reshape(-1, 1))
X_test_example = np.hstack((lat_sample, long_sample, wind_sample, pressure_sample))
display(X_test_example)

(1, 20, 8)

array([[  13. ,  135. ,   10. , 1004. ],
       [  14.2,  134.2,   10. , 1004. ],
       [  14.8,  133.3,   12. , 1002. ],
       [  15.2,  132.3,   15. , 1000. ],
       [  15.6,  131.6,   18. ,  998. ],
       [  16.7,  130.4,   20. ,  995. ],
       [  17.2,  129.5,   25. ,  985. ],
       [  17.6,  128.2,   25. ,  985. ],
       [  18.1,  126.8,   30. ,  980. ],
       [  18.7,  125.4,   35. ,  970. ],
       [  18.9,  123.5,   40. ,  960. ],
       [  19.2,  121.9,   40. ,  960. ],
       [  19.4,  119.8,   45. ,  950. ],
       [  19.9,  117.8,   45. ,  940. ],
       [  20.4,  115.7,   50. ,  935. ],
       [  20.7,  113.7,   50. ,  935. ],
       [  21.2,  111.6,   50. ,  935. ],
       [  21.5,  109.5,   30. ,  960. ],
       [  22. ,  108. ,   15. ,  992. ],
       [  22.2,  105.6,   10. , 1000. ]])

In [490]:
CKPT = Path("checkpoints") / "tcn-epoch=06-val_loss=0.33.ckpt"
model_test = TCNModel.load_from_checkpoint(CKPT, num_inputs=num_inputs, num_outputs=num_outputs)

In [498]:
model_input = X_sample[:, :10, :]
display(model_input.shape)
model_input = np.swapaxes(model_input, 1, 2)
display(model_input.shape)
y_tensor = torch.tensor(model_input, dtype=torch.float32, device=device)
y_pred_tensor = model_test.predict(y_tensor)
y_pred_ = y_pred_tensor.cpu().detach().numpy()
y_pred_ = np.swapaxes(y_pred_, 1, 2)

(1, 10, 8)

(1, 8, 10)

In [499]:
display(y_pred_.shape)
# display(y_pred)
y_pred = y_pred_[:,:,4:]
display(y_pred.shape)
display(y_pred)

(1, 10, 8)

(1, 10, 4)

array([[[0.        , 0.14537445, 0.        , 0.9039872 ],
        [0.        , 0.11013216, 0.        , 0.9039872 ],
        [0.        , 0.07048458, 0.        , 0.8093538 ],
        [0.        , 0.02643172, 0.        , 0.7147204 ],
        [0.        , 0.        , 0.        , 0.620087  ],
        [0.        , 0.        , 0.        , 0.47813696],
        [0.        , 0.        , 0.        , 0.00496999],
        [0.        , 0.        , 0.        , 0.00496999],
        [0.        , 0.        , 0.28905758, 0.        ],
        [0.        , 0.        , 0.64087117, 0.        ]]], dtype=float32)

In [500]:
# reverse the normalization
wind_pred = wind_scaler.inverse_transform(y_pred[:, :, 2])
pressure_pred = lowest_pressure_scaler.inverse_transform(y_pred[:, :, 3])
lat_pred = lat_scaler.inverse_transform(y_pred[:, :, 0])
long_pred = long_scaler.inverse_transform(y_pred[:, :, 1])
pred = np.vstack((lat_pred, long_pred, wind_pred, pressure_pred)).T
display(pred)
expected = X_test_example[-10:, :]
display(expected)

array([[  19.4     ,  135.      ,   25.891893, 1004.      ],
       [  19.4     ,  134.2     ,   25.891893, 1004.      ],
       [  19.4     ,  133.3     ,   25.891893, 1002.      ],
       [  19.4     ,  132.3     ,   25.891893, 1000.      ],
       [  19.4     ,  131.7     ,   25.891893,  998.      ],
       [  19.4     ,  131.7     ,   25.891893,  995.      ],
       [  19.4     ,  131.7     ,   25.891893,  985.      ],
       [  19.4     ,  131.7     ,   25.891893,  985.      ],
       [  19.4     ,  131.7     ,   30.      ,  984.89496 ],
       [  19.4     ,  131.7     ,   35.      ,  984.89496 ]],
      dtype=float32)

array([[  18.9,  123.5,   40. ,  960. ],
       [  19.2,  121.9,   40. ,  960. ],
       [  19.4,  119.8,   45. ,  950. ],
       [  19.9,  117.8,   45. ,  940. ],
       [  20.4,  115.7,   50. ,  935. ],
       [  20.7,  113.7,   50. ,  935. ],
       [  21.2,  111.6,   50. ,  935. ],
       [  21.5,  109.5,   30. ,  960. ],
       [  22. ,  108. ,   15. ,  992. ],
       [  22.2,  105.6,   10. , 1000. ]])

In [None]:
import plotly.graph_objects as go
import plotly.express as px
Y_TIME_STEPS=10
fig = go.Figure()
lat = X_test_example[:, 0]
lon = X_test_example[:, 1]
lat_pred = pred[:, 0]
lon_pred = pred[:, 1]
fig.add_trace(
    go.Scatter(
        x=lon[:Y_TIME_STEPS+1],
        y=lat[:Y_TIME_STEPS+1],
        mode="lines+markers",
        name="Input",
        line=dict(width=2, color="blue"),
        marker=dict(size=10, color="blue"),
    ))
fig.add_trace(
    go.Scatter(
        x=lon[-Y_TIME_STEPS:],
        y=lat[-Y_TIME_STEPS:],
        name="True Value",
        mode="lines+markers",
        line=dict(width=2, color="green"),
        marker=dict(size=10, color="green"),
    ))
fig.add_trace(
    go.Scatter(
        x=lon_pred,
        y=lat_pred,
        name="Predicted Value",
        mode="lines+markers",
        line=dict(width=2, color="red"),
        marker=dict(size=10, color="red"),
    ))
fig.update_layout(title="Hurricane Prediction",
                    xaxis_title="Longitude",
                    yaxis_title="Latitude",
                    width=800,
                    height=800)
fig.show()