In [1]:
%env PYTORCH_ENABLE_MPS_FALLBACK=1

from tqdm.notebook import tqdm
import wandb

from WindModel import *

wandb.login()

env: PYTORCH_ENABLE_MPS_FALLBACK=1


[34m[1mwandb[0m: Currently logged in as: [33mbhavye-mathur[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

In [2]:
DEVICE = "cpu"

In [3]:
WindDataset.init(0.1)

train = WindDataset("train")
validation = WindDataset("validation")
test = WindDataset("test")

del WindDataset.data

In [4]:
INPUT_SIZE = 15
OUTPUT_SIZE = 1

LOSS_FUNC = torch.nn.MSELoss

In [5]:
def get_dense_model(input_size: int,
                    hidden_sizes: list[int],
                    output_size: int,
                    activation_func: callable):
    layers = []

    for size in hidden_sizes:
        layers.append(torch.nn.Linear(input_size, size))
        layers.append(activation_func())
        input_size = size

    layers.append(torch.nn.Linear(input_size, output_size))

    return torch.nn.Sequential(*layers)

In [6]:
def evaluate_one_epoch(model, epoch, dl):
    mse = 0
    mae = 0

    with torch.no_grad():
        for inputs, targets in dl:
            prediction = model(inputs).squeeze()

            mse += torch.nn.functional.mse_loss(prediction, targets)
            mae += torch.nn.functional.l1_loss(prediction, targets)

    n = len(dl)
    wandb.log({"val_rmse": ((mse / n) ** 0.5) * stds[VARIABLE]})
    wandb.log({"val_mae": (mae / n) * stds[VARIABLE]})


def train_one_batch(model, optimizer, criterion, batch, batch_idx):
    optimizer.zero_grad()

    inputs, targets = batch

    prediction = model(inputs).squeeze()
    loss = criterion(prediction, targets)

    loss.backward()
    optimizer.step()

    if batch_idx % 100 == 0:
        with torch.no_grad():
            rmse = (torch.nn.functional.mse_loss(prediction, targets) ** 0.5) * stds[VARIABLE]
            wandb.log({"train_loss": loss})
            wandb.log({"train_rmse": rmse})


def train_one_epoch(model, optimizer, criterion, epoch, dl):
    data = iter(dl)
    for i in range(len(dl)):
        train_one_batch(model, optimizer, criterion, next(data), i)


def main(config={}):
    if config["lr_scheduler"] is None and "lr_scheduler_kwargs" in config:
        config.pop("lr_scheduler_kwargs")

    wandb.init(project=f"MERRA2-{VARIABLE}", dir="wandb-local", config=config)

    learning_rate = wandb.config.learning_rate
    batch_size = wandb.config.batch_size
    layers = wandb.config.layers
    epochs = wandb.config.epochs

    activation = wandb.config.activation
    activation = getattr(torch.nn, activation)

    train_dl = DataLoader(train, batch_size=batch_size, shuffle=True, num_workers=10, pin_memory=True)
    validation_dl = DataLoader(validation, batch_size=batch_size, shuffle=False, num_workers=10, pin_memory=True)

    model = get_dense_model(INPUT_SIZE, layers, OUTPUT_SIZE, activation)
    model = model.to(DEVICE)
    print(model)

    criterion = LOSS_FUNC()
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

    if (scheduler := wandb.config.lr_scheduler) is None:
        scheduler = None
    elif scheduler == "StepLR":
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer, **wandb.config.lr_scheduler_kwargs, verbose=True)

    wandb.watch(model, log_freq=100)
    model.train()

    for ep in tqdm(range(epochs)):
        print(ep, end=" ")

        wandb.log({"epoch": ep})

        train_one_epoch(model, optimizer, criterion, ep, train_dl)
        evaluate_one_epoch(model, ep, validation_dl)

        if scheduler:
            scheduler.step()
            wandb.log({"learning_rate": scheduler.get_last_lr()})

    wandb.finish()

In [7]:
main({
    "batch_size": 65536,
    "learning_rate": 0.0005,
    "lr_scheduler": "StepLR",
    "lr_scheduler_kwargs": {"step_size": 25},
    "layers": [512, 256],
    "activation": "PReLU",
    "estimate_quantile": ESTIMATE_QUANTILE,
    "dataset": DATASET,
    "epochs": 100,
})



Sequential(
  (0): Linear(in_features=15, out_features=512, bias=True)
  (1): ReLU()
  (2): Linear(in_features=512, out_features=256, bias=True)
  (3): ReLU()
  (4): Linear(in_features=256, out_features=128, bias=True)
  (5): ReLU()
  (6): Linear(in_features=128, out_features=1, bias=True)
)
Adjusting learning rate of group 0 to 5.0000e-04.


  0%|          | 0/150 [00:00<?, ?it/s]

0 Epoch 00000: adjusting learning rate of group 0 to 5.0000e-04.
1 



Epoch 00001: adjusting learning rate of group 0 to 5.0000e-04.
2 Epoch 00002: adjusting learning rate of group 0 to 5.0000e-04.
3 Epoch 00003: adjusting learning rate of group 0 to 5.0000e-04.
4 Epoch 00004: adjusting learning rate of group 0 to 5.0000e-04.
5 Epoch 00005: adjusting learning rate of group 0 to 5.0000e-04.
6 Epoch 00006: adjusting learning rate of group 0 to 5.0000e-04.
7 Epoch 00007: adjusting learning rate of group 0 to 5.0000e-04.
8 Epoch 00008: adjusting learning rate of group 0 to 5.0000e-04.
9 Epoch 00009: adjusting learning rate of group 0 to 5.0000e-04.
10 Epoch 00010: adjusting learning rate of group 0 to 5.0000e-05.
11 Epoch 00011: adjusting learning rate of group 0 to 5.0000e-05.
12 Epoch 00012: adjusting learning rate of group 0 to 5.0000e-05.
13 Epoch 00013: adjusting learning rate of group 0 to 5.0000e-05.
14 Epoch 00014: adjusting learning rate of group 0 to 5.0000e-05.
15 Epoch 00015: adjusting learning rate of group 0 to 5.0000e-05.
16 Epoch 00016: adjus

KeyboardInterrupt: 

In [None]:
sweep_configuration = {
    "method": "bayes",
    "name": f"sweep-{DATASET}",
    "metric": {
        "goal": "minimize",
        "name": "val_rmse"
    },
    "parameters": {
        "batch_size": {"value": 65536},
        "learning_rate": {"max": 0.001, "min": 0.00005},
        "lr_scheduler": {"values": [None, "StepLR"]},
        "lr_scheduler_kwargs": {"parameters": {"step_size": {"max": 20, "min": 10},
                                               "gamma": {"max": 0.75, "min": 0.25}}},
        "layers": {"values": [(512, 256), (1024, 512)]},
        "epochs": {"value": 50},
        "activation": {"values": ["ReLU", "PReLU", "LeakyReLU"]},
        "estimate_quantile": {"value": ESTIMATE_QUANTILE},
        "dataset": {"value": DATASET},
    },
    "early_terminate": {
        "type": "hyperband",
        "min_iter": 3,
    }
}

sweep_id = wandb.sweep(sweep=sweep_configuration, project=f"MERRA2-{VARIABLE}")
wandb.agent(sweep_id, function=main)

In [14]:
def test(model, dl):
    model.eval()

    mse = 0
    mae = 0

    with torch.no_grad():
        for inputs, targets in tqdm(dl):
            prediction = model(inputs).squeeze()

            mse += torch.nn.functional.mse_loss(prediction, targets)
            mae += torch.nn.functional.l1_loss(prediction, targets)

    return (mse / len(dl)) ** 0.5 * stds[VARIABLE], (mae / len(dl)) * stds[VARIABLE]


test_dl = DataLoader(test, batch_size=2048, shuffle=False)
test_rmse, test_mae = test(model, test)

print(f"RMSE: {test_rmse} m/s")
print(f"MAE:  {test_mae} m/s")


  0%|          | 0/973 [00:00<?, ?it/s]

RMSE: 0.4102303087711334 m/s
MAE:  0.29246410727500916 m/s
