In [1]:
import torch
from torch.utils.data import DataLoader, Dataset

import pytorch_lightning as pl
from pytorch_lightning.loggers import WandbLogger

import pandas as pd

import wandb

wandb.login()

[34m[1mwandb[0m: Currently logged in as: [33mbhavye-mathur[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

In [None]:
means = {"U": 5.589, "V": 0.018}
stds = {"U": 9.832, "V": 3.232}

ESTIMATE_QUANTILE = 0.9935
TRAIN_TEST_SPLIT = 0.8
TRAIN_VAL_SPLIT = 0.9


class WindDataset(Dataset):
    def __init__(self, subset: str):
        data = pd.read_feather(f"../raw/subset/UV-NGCT-{ESTIMATE_QUANTILE}-{100000000}.ft")

        if subset == "train":
            self.x = data.iloc[:int(len(data) * TRAIN_TEST_SPLIT)]
            self.x = data.iloc[:int(len(self.x) * TRAIN_VAL_SPLIT)]
        elif subset == "test":
            self.x = data.iloc[int(len(data) * TRAIN_TEST_SPLIT):]
        elif subset == "validation":
            self.x = data.iloc[:int(len(data) * TRAIN_TEST_SPLIT)]
            self.x = data.iloc[int(len(self.x) * TRAIN_VAL_SPLIT):]
        else:
            raise ValueError("Invalid Subset")

        self.y = self.x[["U", "V"]]

        del self.x["U"]
        del self.x["V"]

        self.x = self.x.values.astype("float16")
        self.y = self.y.values.astype("float16")

    def __len__(self):
        return len(self.x)

    def __getitem__(self, i):
        return self.x[i], self.y[i]

In [2]:
BATCH_SIZE = 64

train_data = WindDataset("train")
validation_data = WindDataset("validation")
test_data = WindDataset("test")

In [3]:
def get_dense_model(input_size: int,
                    hidden_sizes: list[int],
                    output_size: int,
                    activation_func: callable):
    layers = []

    for size in hidden_sizes:
        layers.append(torch.nn.Linear(input_size, size))
        layers.append(activation_func())
        input_size = size

    layers.append(torch.nn.Linear(input_size, output_size))

    return torch.nn.Sequential(*layers)


class WindModel(pl.LightningModule):
    def __init__(self,
                 variable: str,
                 learning_rate: float,
                 loss_func: callable,
                 input_size: int,
                 hidden_sizes: list[int],
                 output_size: int,
                 activation_func: callable):

        super().__init__()

        self.learning_rate = learning_rate
        self.loss_func = loss_func

        self.variable = variable
        self.denorm_mult = stds[variable]
        self.denorm_add = means[variable]

        self.model = get_dense_model(input_size, hidden_sizes, output_size, activation_func)
        self.save_hyperparameters()

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        x, y = batch

        pred = self.model(x)
        loss = self.loss_func(pred, y)
        if self.loss_func == torch.nn.functional.mse_loss:
            mse = loss
        else:
            mse = torch.nn.functional.mse_loss(pred, y)

        self.log("train_loss", loss)
        self.log("train_rmse", (mse ** 0.5) * self.denorm_mult, prog_bar=True)

        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch

        pred = self.model(x)
        mse = torch.nn.functional.mse_loss(pred, y)
        mae = torch.nn.functional.l1_loss(pred, y)

        print(pred.shape, y.shape)

        self.log("validation_rmse", (mse ** 0.5) * self.denorm_mult)
        self.log("validation_mae", mae * self.denorm_mult)

        return mse

    def test_step(self, batch, batch_idx):
        x, y = batch

        pred = self.model(x)
        mse = torch.nn.functional.mse_loss(pred, y)
        mae = torch.nn.functional.l1_loss(pred, y)

        print(pred.shape, y.shape)

        self.log("test_rmse", (mse ** 0.5) * self.denorm_mult)
        self.log("test_mae", mae * self.denorm_mult)

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.learning_rate)

    def train_dataloader(self):
        return DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True, num_workers=10)

    def test_dataloader(self):
        return DataLoader(test_data, batch_size=BATCH_SIZE, num_workers=10)

    def val_dataloader(self):
        return DataLoader(validation_data, batch_size=BATCH_SIZE, num_workers=10)

In [4]:
VARIABLE = "U"
INPUT_SIZE = 15
OUTPUT_SIZE = 2
HIDDEN_SIZES = [256, 128]
LEARNING_RATE = 1e-3

ACTIVATION = torch.nn.ReLU
LOSS_FUNC = torch.nn.functional.mse_loss

model = WindModel("U", LEARNING_RATE, LOSS_FUNC, INPUT_SIZE, HIDDEN_SIZES, OUTPUT_SIZE, ACTIVATION)
model


WindModel(
  (model): Sequential(
    (0): Linear(in_features=15, out_features=256, bias=True)
    (1): ReLU()
    (2): Linear(in_features=256, out_features=128, bias=True)
    (3): ReLU()
    (4): Linear(in_features=128, out_features=2, bias=True)
  )
)

In [11]:
wandb_logger = WandbLogger(project="MERRA2-Wind", log_model="all")
wandb_logger.experiment.config["estimate_quantile"] = ESTIMATE_QUANTILE


In [5]:
trainer = pl.Trainer(devices=1,
                     accelerator="cpu",
                     precision=16,

                     limit_train_batches=0.1,
                     limit_val_batches=0.1,
                     val_check_interval=0.25,

                     auto_lr_find=True,
                     max_epochs=40,

                     # logger=wandb_logger
                     )

  rank_zero_warn(
Using bfloat16 Automatic Mixed Precision (AMP)
GPU available: True (mps), used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
  rank_zero_warn(


In [6]:
%matplotlib notebook

lr_finder = trainer.tuner.lr_find(model)
lr_finder.plot(suggest=True, show=True)

Missing logger folder: /Users/bhavyemathur/Desktop/Projects/Spherindrical Fourier Transform/MERRA-2/wind-prediction/models/lightning_logs


torch.Size([64, 2]) torch.Size([64, 2])
torch.Size([64, 2]) torch.Size([64, 2])
torch.Size([64, 2]) torch.Size([64, 2])
torch.Size([64, 2]) torch.Size([64, 2])
torch.Size([64, 2]) torch.Size([64, 2])
torch.Size([64, 2]) torch.Size([64, 2])
torch.Size([64, 2]) torch.Size([64, 2])
torch.Size([64, 2]) torch.Size([64, 2])
torch.Size([64, 2]) torch.Size([64, 2])
torch.Size([64, 2]) torch.Size([64, 2])
torch.Size([64, 2]) torch.Size([64, 2])
torch.Size([64, 2]) torch.Size([64, 2])
torch.Size([64, 2]) torch.Size([64, 2])
torch.Size([64, 2]) torch.Size([64, 2])
torch.Size([64, 2]) torch.Size([64, 2])
torch.Size([64, 2]) torch.Size([64, 2])
torch.Size([64, 2]) torch.Size([64, 2])
torch.Size([64, 2]) torch.Size([64, 2])
torch.Size([64, 2]) torch.Size([64, 2])
torch.Size([64, 2]) torch.Size([64, 2])
torch.Size([64, 2]) torch.Size([64, 2])
torch.Size([64, 2]) torch.Size([64, 2])
torch.Size([64, 2]) torch.Size([64, 2])
torch.Size([64, 2]) torch.Size([64, 2])
torch.Size([64, 2]) torch.Size([64, 2])


Finding best initial lr:   0%|          | 0/100 [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_steps=100` reached.
Restoring states from the checkpoint path at /Users/bhavyemathur/Desktop/Projects/Spherindrical Fourier Transform/MERRA-2/wind-prediction/models/.lr_find_de59601d-95f2-4d1a-b6bc-25ca3ea35041.ckpt
Restored all states from the checkpoint file at /Users/bhavyemathur/Desktop/Projects/Spherindrical Fourier Transform/MERRA-2/wind-prediction/models/.lr_find_de59601d-95f2-4d1a-b6bc-25ca3ea35041.ckpt
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x12d359fc0>
Traceback (most recent call last):
  File "/Users/bhavyemathur/Desktop/Projects/Spherindrical Fourier Transform/MERRA-2/venv/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1466, in __del__
    self._shutdown_workers()
  File "/Users/bhavyemathur/Desktop/Projects/Spherindrical Fourier Transform/MERRA-2/venv/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1430, in _shutdown_workers
    w.join(timeout=_utils.MP_STATUS_CHECK_INTERV

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

In [14]:
trainer.fit(model)
