# Training a neural network in PyTorch

This notebook demonstrates training a classifier in PyTorch.

In [1]:
# file handling
import os
import pathlib

import dask
import dask.array

# math operators
import numpy as np
import pytorch_lightning as pl

# ml
import torch
import zarr

print("pl ver:", pl.__version__)

import datetime

# training helpers
import mlflow.pytorch
from dask.diagnostics import CacheProfiler, Profiler, ResourceProfiler
from mlflow.tracking import MlflowClient
from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning.callbacks import (
    RichProgressBar,
)  # this progress bar works through jupyterHub on spice

# defined in directory (model related definitions)
import cbh_data_definitions
import cbh_torch_lstm
import cbh_torch_MLP

pl ver: 1.6.4


In [2]:
RELOAD_PACKAGES = True
if RELOAD_PACKAGES:
    import importlib

    importlib.reload(cbh_torch_lstm)
    importlib.reload(cbh_torch_MLP)
    importlib.reload(cbh_data_definitions)

In [3]:
root_data_directory = pathlib.Path(os.environ["SCRATCH"]) / "cbh_data"

dev_data_path = root_data_directory / "analysis_ready" / "dev.zarr"
training_data_path = root_data_directory / "analysis_ready" / "train.zarr"

In [6]:
(
    train_input,
    train_labels,
    train_cloud_volume,
) = cbh_data_definitions.load_data_from_zarr(training_data_path)

(dev_input, dev_labels, dev_cloud_volume) = cbh_data_definitions.load_data_from_zarr(
    dev_data_path
)

Loaded zarr, file information:
 Name        : /
Type        : zarr.hierarchy.Group
Read-only   : False
Store type  : zarr.storage.DirectoryStore
No. members : 3
No. arrays  : 3
No. groups  : 0
Arrays      : cloud_base_label_y.zarr, cloud_volume_fraction_y.zarr,
            : humidity_temp_pressure_x.zarr
 

Loaded zarr, file information:
 Name        : /
Type        : zarr.hierarchy.Group
Read-only   : False
Store type  : zarr.storage.DirectoryStore
No. members : 3
No. arrays  : 3
No. groups  : 0
Arrays      : cloud_base_label_y.zarr, cloud_volume_fraction_y.zarr,
            : humidity_temp_pressure_x.zarr
 



In [7]:
LIMIT_DATA = True
LIMIT_DATA_INT = 5024000
if LIMIT_DATA:
    train_input = train_input[:LIMIT_DATA_INT]
    train_labels = train_labels[:LIMIT_DATA_INT]
    train_cloud_volume = train_cloud_volume[:LIMIT_DATA_INT]
    dev_input = dev_input[:LIMIT_DATA_INT]
    dev_labels = dev_labels[:LIMIT_DATA_INT]
    dev_cloud_volume = dev_cloud_volume[:LIMIT_DATA_INT]

## Define the network

In [8]:
# enforce reproducibility
seed_everything(42)

Global seed set to 42


42

## Perform the network initialization and training

In [9]:
# define model and hyperparameters
model_hyperparameter_dictionary = {
    "LSTM": {
        "input_size": train_input.shape[2],  # input size is the cell input (feat dim)
        "lstm_layers": 1,
        "lstm_hidden_size": 8,
        "output_size": 1,  # for each height layer, predict one value for cloud base prob
        "height_dimension": train_input.shape[1],
        "embed_size": 1,
        "BILSTM": False,
        "batch_first": True,
        "lr": 0.003,
    },
    "MLP": {
        "input_size": train_input.shape[2] * train_input.shape[1],
        "ff_nodes": 256,
        "output_size": train_input.shape[1],
        "lr": 1.0e-4,
    },
}

model_definition_dictionary = {
    "LSTM": cbh_torch_lstm.CloudBaseLSTM(**model_hyperparameter_dictionary["LSTM"]),
    "MLP": cbh_torch_MLP.CloudBaseMLP(**model_hyperparameter_dictionary["MLP"]),
}

model = model_definition_dictionary["LSTM"]  # pick a model

# define training related hyperparameters

epochs = 1
max_time = "00:02:20:00"  # dd:hh:mm:ss

# after training parameters defined, load datasets into dataloaders (enforce 0 as workers on sys to prevent multiple packages
# trying to parallelise while not communicating
workers_on_system = 0
collate_fn = cbh_data_definitions.dataloader_collate_with_dask
batch_size = 1024

In [10]:
train_loader, val_loader = None, None

INTO_MEMORY = True
if INTO_MEMORY:
    train_loader = cbh_data_definitions.define_data_get_loader_into_memory(
        train_input,
        train_cloud_volume,
        train_labels,
        batch_size=batch_size,
        shuffle=True,
        num_workers=workers_on_system,
        collate_fn=collate_fn,
    )
    val_loader = cbh_data_definitions.define_data_get_loader_into_memory(
        dev_input,
        dev_cloud_volume,
        dev_labels,
        batch_size=batch_size,
        shuffle=False,
        num_workers=workers_on_system,
        collate_fn=collate_fn,
    )

else:
    train_loader = cbh_data_definitions.define_data_get_loader(
        train_input,
        train_cloud_volume,
        train_labels,
        batch_size=batch_size,
        shuffle=True,
        num_workers=workers_on_system,
        collate_fn=collate_fn,
    )
    val_loader = cbh_data_definitions.define_data_get_loader(
        dev_input,
        dev_cloud_volume,
        dev_labels,
        batch_size=batch_size,
        shuffle=False,
        num_workers=workers_on_system,
        collate_fn=collate_fn,
    )

Computing x...
Computing y...
Computing lab...
Computing x...
Computing y...
Computing lab...


In [11]:
# define trainer

time_for_checkpoint = datetime.timedelta(minutes=20)
checkpoint_callback = pl.callbacks.ModelCheckpoint(
    train_time_interval=time_for_checkpoint
)
callbacks = [checkpoint_callback, RichProgressBar()]

trainer = pl.Trainer(
    max_epochs=epochs,
    deterministic=True,
    check_val_every_n_epoch=1,
    devices="auto",
    accelerator="auto",
    max_time=max_time,
    enable_checkpointing=True,
    callbacks=callbacks,
)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [12]:
# setup mlflow logging

mlflow.pytorch.autolog()

In [13]:
# run the training function
with mlflow.start_run() as run:
    trainer.fit(model=model, train_dataloaders=train_loader, val_dataloaders=val_loader)



## Display and evaluate results

In [14]:
def print_auto_logged_info(r):

    tags = {k: v for k, v in r.data.tags.items() if not k.startswith("mlflow.")}
    artifacts = [f.path for f in MlflowClient().list_artifacts(r.info.run_id, "model")]
    print("run_id: {}".format(r.info.run_id))
    print("artifacts: {}".format(artifacts))
    print("params: {}".format(r.data.params))
    print("metrics: {}".format(r.data.metrics))
    print("tags: {}".format(tags))

In [15]:
# display mlflow output
print_auto_logged_info(mlflow.get_run(run_id=run.info.run_id))
mlflow.end_run()

run_id: 6dfe03b528b54577ac22e83a85627bfc
artifacts: []
params: {'epochs': '1', 'optimizer_name': 'Adam', 'lr': '0.003', 'betas': '(0.9, 0.999)', 'eps': '1e-08', 'weight_decay': '0', 'amsgrad': 'False'}
metrics: {'training base height loss component': 2.7354509830474854, 'training loss': 2.7354509830474854, 'validation base height loss component': 3.0111770629882812, 'validation loss': 3.0111770629882812}
tags: {'Mode': 'training'}


In [16]:
unique_save_str_element = str(datetime.datetime.now())
save_str = "model_out" + "_" + unique_save_str_element + ".ckpt"
trainer.save_checkpoint(save_str)

In [19]:
# test model functionality
example_batch = next(iter(train_loader))
heights = example_batch["height_vector"]
inputs = example_batch["x"]
print(inputs.shape, "Input shape")
print("Height vector shape:", example_batch["height_vector"].shape)
try:
    preds, _ = model(inputs, heights)
except:
    print(example_batch["x"].shape, "inp pre-flat")
    inputs = torch.flatten(example_batch["x"], start_dim=1)
    preds = model(inputs)
print(preds.shape, "prediction output")
pred_label = np.argmax(preds.detach().numpy(), axis=1)
print(pred_label.shape, "prediction label shape")
targs = example_batch["cloud_base_target"]
targs = np.array(targs)
print(targs.shape, "targ shape")
correct = targs == pred_label
print("Correct samples:", np.count_nonzero(correct))
print("Total samples tested:", len(correct))
print("Accuracy:", (np.count_nonzero(correct) / len(correct) * 100), "%")

torch.Size([1024, 70, 3]) Input shape
Height vector shape: torch.Size([1024, 70])
torch.Size([1024, 70]) prediction output
(1024,) prediction label shape
(1024,) targ shape
Correct samples: 366
Total samples tested: 1024
Accuracy: 35.7421875 %
