# Training a neural network in PyTorch

This notebook demonstrates training a classifier in PyTorch.

In [1]:
# file handling
import os
import pathlib

import dask
import dask.array

# math operators
import numpy as np
import pytorch_lightning as pl

# ml
import torch
import zarr

print("pl ver:", pl.__version__)

import datetime

# training helpers
import mlflow.pytorch
from mlflow.tracking import MlflowClient

# import dask.distributed # sometimes breaks things
# from dask.diagnostics import Profiler, ResourceProfiler, CacheProfiler
from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning.callbacks import (
    RichProgressBar,
)  # this progress bar works through jupyterHub on spice

# defined in directory (model related definitions)
import cbh_data_definitions
import cbh_torch_lstm
import cbh_torch_MLP

pl ver: 1.6.4


In [23]:
import importlib
importlib.reload(cbh_torch_lstm)
importlib.reload(cbh_torch_MLP)
importlib.reload(cbh_data_definitions)

<module 'cbh_data_definitions' from '/net/home/h02/hsouth/github_committing/data_science_cop/challenges/2021_CyrilMorcrette_cloudBaseHeight/cbh_data_definitions.py'>

In [3]:
root_data_directory = pathlib.Path(os.environ["SCRATCH"]) / "cbh_data"

dev_data_path = root_data_directory / "analysis_ready" / "dev.zarr"
training_data_path = root_data_directory / "analysis_ready" / "train.zarr"

In [4]:
(
    train_input,
    train_labels,
    train_cloud_volume,
) = cbh_data_definitions.load_data_from_zarr(training_data_path)
dev_input, dev_labels, dev_cloud_volume = cbh_data_definitions.load_data_from_zarr(
    dev_data_path
)

Loaded zarr, file information:
 Name        : /
Type        : zarr.hierarchy.Group
Read-only   : False
Store type  : zarr.storage.DirectoryStore
No. members : 3
No. arrays  : 3
No. groups  : 0
Arrays      : cloud_volume_fraction_y.zarr, humidity_temp_pressure_x.zarr,
            : onehot_cloud_base_height_y.zarr
 

Loaded zarr, file information:
 Name        : /
Type        : zarr.hierarchy.Group
Read-only   : False
Store type  : zarr.storage.DirectoryStore
No. members : 3
No. arrays  : 3
No. groups  : 0
Arrays      : cloud_volume_fraction_y.zarr, humidity_temp_pressure_x.zarr,
            : onehot_cloud_base_height_y.zarr
 



In [5]:
LIMIT_DATA = True
LIMIT_DATA_INT = 1000000
if LIMIT_DATA:
    train_input = train_input[:LIMIT_DATA_INT]
    train_labels = train_labels[:LIMIT_DATA_INT]
    train_cloud_volume = train_cloud_volume[:LIMIT_DATA_INT]
    dev_input = dev_input[:LIMIT_DATA_INT]
    dev_labels = dev_labels[:LIMIT_DATA_INT]
    dev_cloud_volume = dev_cloud_volume[:LIMIT_DATA_INT]

## Define the network

In [6]:
# enforce reproducibility
seed_everything(42)

Global seed set to 42


42

## Perform the network initialization and training

In [7]:
TRAIN_LSTM = False
TRAIN_MLP = True

if TRAIN_LSTM:
    height_dim = train_input.shape[1]

    # define model and hyperparameters
    layers = 1
    input_size = train_input.shape[2]  # input size is the cell input (feat dim)
    output_size = 1  # for each height layer, predict one value for cloud base prob
    hidden_size = 8
    embed_size = 1
    BILSTM = False
    batch_first = True

    learn_rate = 0.003

    model = cbh_torch_lstm.CloudBaseLSTM(
        input_size,
        layers,
        hidden_size,
        output_size,
        height_dim,
        embed_size,
        BILSTM,
        batch_first,
        lr=learn_rate,
    )

elif TRAIN_MLP:
    # input is the flat sample on data
    input_size = train_input.shape[2] * train_input.shape[1]
    ff_nodes = 256
    output_size = train_input.shape[1]
    learn_rate = 1.0e-4
    
    model = cbh_torch_MLP.CloudBaseMLP(
        input_size,
        ff_nodes,
        output_size,
        learn_rate
    )
# define training related hyperparameters

epochs = 1
max_time = "00:02:20:00"  # dd:hh:mm:ss

# after training parameters defined, load datasets into dataloaders (enforce 0 as workers on sys to prevent multiple packages
# trying to parallelise while not communicating
workers_on_system = 0
collate_fn = cbh_data_definitions.dataloader_collate_with_dask
batch_size = 200

In [24]:
train_loader, val_loader = None, None
INTO_MEMORY = True
if INTO_MEMORY:
    train_loader = cbh_data_definitions.define_data_get_loader_into_memory(
        train_input,
        train_cloud_volume,
        train_labels,
        batch_size=batch_size,
        shuffle=True,
        num_workers=workers_on_system,
        collate_fn=collate_fn,
    )
    val_loader = cbh_data_definitions.define_data_get_loader_into_memory(
        dev_input,
        dev_cloud_volume,
        dev_labels,
        batch_size=batch_size,
        shuffle=False,
        num_workers=workers_on_system,
        collate_fn=collate_fn,
    )
else:
    train_loader = cbh_data_definitions.define_data_get_loader(
        train_input,
        train_cloud_volume,
        train_labels,
        batch_size=batch_size,
        shuffle=True,
        num_workers=workers_on_system,
        collate_fn=collate_fn,
    )
    val_loader = cbh_data_definitions.define_data_get_loader(
        dev_input,
        dev_cloud_volume,
        dev_labels,
        batch_size=batch_size,
        shuffle=False,
        num_workers=workers_on_system,
        collate_fn=collate_fn,
    )

In [9]:
# define trainer

time_for_checkpoint = datetime.timedelta(minutes=20)
checkpoint_callback = pl.callbacks.ModelCheckpoint(
    train_time_interval=time_for_checkpoint
)
callbacks = [checkpoint_callback, RichProgressBar()]

trainer = pl.Trainer(
    max_epochs=epochs,
    deterministic=True,
    check_val_every_n_epoch=1,
    devices="auto",
    accelerator="auto",
    max_time=max_time,
    enable_checkpointing=True,
    callbacks=callbacks,
)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [10]:
# setup mlflow logging

mlflow.pytorch.autolog()

In [11]:
# run the training function
with mlflow.start_run() as run:
    trainer.fit(model=model, train_dataloaders=train_loader, val_dataloaders=val_loader)



## Display and evaluate results

In [12]:
def print_auto_logged_info(r):

    tags = {k: v for k, v in r.data.tags.items() if not k.startswith("mlflow.")}
    artifacts = [f.path for f in MlflowClient().list_artifacts(r.info.run_id, "model")]
    print("run_id: {}".format(r.info.run_id))
    print("artifacts: {}".format(artifacts))
    print("params: {}".format(r.data.params))
    print("metrics: {}".format(r.data.metrics))
    print("tags: {}".format(tags))

In [13]:
# display mlflow output
print_auto_logged_info(mlflow.get_run(run_id=run.info.run_id))
mlflow.end_run()

run_id: e285eddb762f4a5786d9ec95fa6b987c
artifacts: []
params: {'epochs': '1', 'optimizer_name': 'Adam', 'lr': '0.0001', 'betas': '(0.9, 0.999)', 'eps': '1e-08', 'weight_decay': '0', 'amsgrad': 'False'}
metrics: {'training loss': 2.511777639389038, 'validation loss': 2.922545909881592}
tags: {'Mode': 'training'}


In [14]:
unique_save_str_element = str(datetime.datetime.now())
save_str = "model_out" + "_" + unique_save_str_element + '.ckpt'
trainer.save_checkpoint(save_str)

In [26]:
# test model functionality
example_batch = next(iter(train_loader))
print(example_batch['x'].shape, 'inp pre-flat')
inputs = torch.flatten(example_batch['x'], start_dim=1)
print(inputs.shape, 'inp')
preds = model(inputs)
print(preds.shape, 'preds')
targs = example_batch['cloud_base_target']
print(targs[0:5], 'targ ex')
print(example_batch['cloud_volume_target'][0:5], 'targ verif')
print(targs.shape)

torch.Size([200, 70, 3]) inp pre-flat
torch.Size([200, 210]) inp
torch.Size([200, 70]) preds
tensor([10,  0,  8, 69,  0]) targ ex
tensor([[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.1406, 0.2969, 0.0312, 0.9219,
         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.5625,
         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.0000, 0.0000, 0.0312, 0.0469, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5781, 0.7812, 0.8750, 0.9375, 0.9844, 1.0000, 0.2500, 0.0000, 0.0000,
         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.