# Training a neural network in PyTorch

This notebook demonstrates training a classifier in PyTorch.

In [1]:
# file handling
import os
import pathlib
import sys

from pytorch_lightning.loggers import MLFlowLogger

import dask
import dask.array

# math operators
import numpy as np
import pytorch_lightning as pl

# ml
import torch
import zarr

import datetime
from tempfile import TemporaryDirectory

# training helpers
import mlflow.pytorch
from dask.diagnostics import CacheProfiler, Profiler, ResourceProfiler, visualize
from mlflow.tracking import MlflowClient
from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning.callbacks import (
    RichProgressBar,
)  # this progress bar works through jupyterHub on spice

# defined in directory (model related definitions)
import cbh_data_definitions
cbh_data_definitions.register_cache()
import cbh_torch_lstm
import cbh_torch_MLP

print("pl ver:", pl.__version__)
print("mlflow ver:", mlflow.__version__)
print("torch ver:", torch.__version__)
print("Python ver:", sys.version_info)

pl ver: 1.7.7
mlflow ver: 1.30.0
torch ver: 1.12.1
Python ver: sys.version_info(major=3, minor=10, micro=6, releaselevel='final', serial=0)


In [2]:
RELOAD_PACKAGES = True
if RELOAD_PACKAGES:
    import importlib
    importlib.reload(cbh_torch_lstm)
    importlib.reload(cbh_torch_MLP)
    importlib.reload(cbh_data_definitions)

In [3]:
root_data_directory = pathlib.Path(os.environ["SCRATCH"]) / "cbh_data"

dev_data_path = root_data_directory / "analysis_ready" / "dev.zarr"
training_data_path = root_data_directory / "analysis_ready" / "train.zarr"

mlflow_command_line_run = """
    mlflow server --port 5001 --backend-store-uri sqlite:///mlflowSQLserver.db  --default-artifact-root ./mlflow_artifacts/
"""
mlflow_server_address = 'vld425'
mlflow_server_port = 5001
mlflow_server_uri = f'http://{mlflow_server_address}:{mlflow_server_port:d}'
mlflow_artifact_root = pathlib.Path('./mlflow_artifacts/')

hparams_for_mlflow = {}

CPU_COUNT = 8
RAM_GB = 64
hparams_for_mlflow['CPU Count'] = CPU_COUNT
hparams_for_mlflow['Compute Memory'] = RAM_GB

In [4]:
(
    train_input,
    train_labels,
    _,
) = cbh_data_definitions.load_data_from_zarr(training_data_path)

(
    dev_input, 
    dev_labels, 
    _
) = cbh_data_definitions.load_data_from_zarr(dev_data_path)

# the cloud volume is not needed for the task, so isn't saved on the load
#show a chunk, used to inform dask cache size
train_input

Loaded zarr, file information:
 Name              : /
Type              : zarr.hierarchy.Group
Read-only         : False
Synchronizer type : zarr.sync.ThreadSynchronizer
Store type        : zarr.storage.DirectoryStore
No. members       : 3
No. arrays        : 3
No. groups        : 0
Arrays            : cloud_base_label_y.zarr, cloud_volume_fraction_y.zarr,
                  : humidity_temp_pressure_x.zarr
 

Loaded zarr, file information:
 Name              : /
Type              : zarr.hierarchy.Group
Read-only         : False
Synchronizer type : zarr.sync.ThreadSynchronizer
Store type        : zarr.storage.DirectoryStore
No. members       : 3
No. arrays        : 3
No. groups        : 0
Arrays            : cloud_base_label_y.zarr, cloud_volume_fraction_y.zarr,
                  : humidity_temp_pressure_x.zarr
 



Unnamed: 0,Array,Chunk
Bytes,87.48 GiB,373.24 MiB
Shape,"(111820800, 70, 3)","(465920, 70, 3)"
Count,2 Graph Layers,240 Chunks
Type,float32,numpy.ndarray
"Array Chunk Bytes 87.48 GiB 373.24 MiB Shape (111820800, 70, 3) (465920, 70, 3) Count 2 Graph Layers 240 Chunks Type float32 numpy.ndarray",3  70  111820800,

Unnamed: 0,Array,Chunk
Bytes,87.48 GiB,373.24 MiB
Shape,"(111820800, 70, 3)","(465920, 70, 3)"
Count,2 Graph Layers,240 Chunks
Type,float32,numpy.ndarray


In [5]:
LIMIT_DATA = False
LIMIT_DATA_INT = -1
if LIMIT_DATA:
    LIMIT_DATA_INT = 10024
    train_input = train_input[:LIMIT_DATA_INT]
    train_labels = train_labels[:LIMIT_DATA_INT]
    # train_cloud_volume = train_cloud_volume[:LIMIT_DATA_INT]
    dev_input = dev_input[:LIMIT_DATA_INT]
    dev_labels = dev_labels[:LIMIT_DATA_INT]
    # dev_cloud_volume = dev_cloud_volume[:LIMIT_DATA_INT]
hparams_for_mlflow['Limited sample number'] =  LIMIT_DATA_INT  

## Define the network

In [6]:
# enforce reproducibility
seed_everything_int = 42
seed_everything(seed_everything_int)
hparams_for_mlflow['Random seed'] = seed_everything_int

Global seed set to 42


## Perform the network initialization and training

In [7]:
# define model and hyperparameters
model_hyperparameter_dictionary = {
    "LSTM": {
        "input_size": train_input.shape[2],  # input size is the cell input (feat dim)
        "lstm_layers": 1,
        "lstm_hidden_size": 8,
        "output_size": 1,  # for each height layer, predict one value for cloud base prob
        "height_dimension": train_input.shape[1],
        "embed_size": 1,
        "BILSTM": False,
        "batch_first": True,
        "lr": 0.003,
    },
    "MLP": {
        "input_size": train_input.shape[2] * train_input.shape[1],
        "ff_nodes": 32,
        "output_size": train_input.shape[1],
        "lr": 1.0e-4,
    },
}

model_definition_dictionary = {
    "LSTM": cbh_torch_lstm.CloudBaseLSTM(**model_hyperparameter_dictionary["LSTM"]),
    "MLP": cbh_torch_MLP.CloudBaseMLP(**model_hyperparameter_dictionary["MLP"]),
}
model_picked = "MLP"
model = model_definition_dictionary[model_picked]  # pick a model
hparams_for_mlflow["Model defined hparams"] = model_hyperparameter_dictionary[model_picked]


# define training related hyperparameters

epochs = 1
hparams_for_mlflow["Max epochs"] = epochs

# after training parameters defined, load datasets into dataloaders (enforce 0 as workers on sys to prevent multiple packages
# trying to parallelise while not communicating
collate_fn = cbh_data_definitions.dataloader_collate_with_dask
print("Data chunk size:", train_input.chunksize[0])
print("Factors of chunk: ", [n for n in range(1, train_input.chunksize[0] + 1) if train_input.chunksize[0] % n == 0])
batch_size = 64
hparams_for_mlflow["Batch size"] = batch_size

Data chunk size: 465920
Factors of chunk:  [1, 2, 4, 5, 7, 8, 10, 13, 14, 16, 20, 26, 28, 32, 35, 40, 52, 56, 64, 65, 70, 80, 91, 104, 112, 128, 130, 140, 160, 182, 208, 224, 256, 260, 280, 320, 364, 416, 448, 455, 512, 520, 560, 640, 728, 832, 896, 910, 1024, 1040, 1120, 1280, 1456, 1664, 1792, 1820, 2080, 2240, 2560, 2912, 3328, 3584, 3640, 4160, 4480, 5120, 5824, 6656, 7168, 7280, 8320, 8960, 11648, 13312, 14560, 16640, 17920, 23296, 29120, 33280, 35840, 46592, 58240, 66560, 93184, 116480, 232960, 465920]


In [8]:
train_loader, val_loader = None, None

single_proc_workers = False # else crashes
if single_proc_workers:
    WORKERS_CPU_COUNT=0
else:
    WORKERS_CPU_COUNT = CPU_COUNT

data_loader_hparam_dict = {
    'batch_size':batch_size,
    'num_workers':WORKERS_CPU_COUNT,
    'pin_memory':False,
    'collate_fn':collate_fn,
    'thread_count_for_dask':CPU_COUNT
}
shuffle_training_data=False,

INTO_MEMORY = False
if INTO_MEMORY:
    train_loader = cbh_data_definitions.define_data_get_loader_into_memory(
        train_input,
        train_labels,
        shuffle=shuffle_training_data,
        **data_loader_hparam_dict
    )
    val_loader = cbh_data_definitions.define_data_get_loader_into_memory(
        dev_input,
        dev_labels,
        shuffle=False,
        **data_loader_hparam_dict
    )
else:
    train_loader = cbh_data_definitions.define_data_get_loader(
        train_input,
        train_labels,
        shuffle=shuffle_training_data,
        **data_loader_hparam_dict
    )
    val_loader = cbh_data_definitions.define_data_get_loader(
        dev_input,
        dev_labels,
        shuffle=False,
        **data_loader_hparam_dict
    )
data_loader_hparam_dict['shuffle_training_data']=shuffle_training_data
hparams_for_mlflow["Data loaded into memory"] = INTO_MEMORY
hparams_for_mlflow['data loader hparams'] = data_loader_hparam_dict

In [9]:
experiment_name = 'cbh-label-model-runs'
experiment_name = 'test-setup-for-model-runs'

# torch.set_num_threads(CPU_COUNT)

mlflow.set_tracking_uri(mlflow_server_uri)
# make vars global
mlf_exp = None
mlf_exp_id = None
try: 
    print('Creating experiment')
    mlf_exp_id = mlflow.create_experiment(experiment_name)
    mlf_exp = mlflow.get_experiment(mlf_exp_id)
except mlflow.exceptions.RestException:
    mlf_exp = mlflow.get_experiment_by_name(experiment_name)

Creating experiment


In [10]:
class MLFlowLogger(pl.loggers.MLFlowLogger): #overwrite mlflogger
    def __init__(self, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)

    def after_save_checkpoint(self, model_checkpoint: pl.callbacks.ModelCheckpoint) -> None:
        """
        Called after model checkpoint callback saves a new checkpoint.
        """
        best_chkpt = torch.load(capture_var.best_model_path)
        # print(best_chkpt)
        # print(best_chkpt['callbacks'])
        checkpoint_for_mlflow = {
            "val loss": float(best_chkpt['callbacks'][list(key for key in list(best_chkpt['callbacks'].keys()) if "ModelCheckpoint" in key)[0]]['current_score']),
            "train loss at step-1": list(train_loss_metric.value for train_loss_metric in mlf_logger._mlflow_client.get_metric_history(run.info.run_id, "Train loss") if (int(train_loss_metric.step) == int(best_chkpt['global_step']-1)))[0],
            "global_step": best_chkpt['global_step'],
            "model_state_dict": best_chkpt['state_dict'],
            "checkpoint": best_chkpt,

        }
        with TemporaryDirectory() as tmpdirname:
            f_name = os.path.join(tmpdirname, f"{run.info.run_id}-best_model_checkpoint-step_{best_chkpt['global_step']}.pt")
            torch.save(checkpoint_for_mlflow, f_name)
            mlflow.log_artifact(f_name)
        # print(trainer._checkpoint_connector.dump_checkpoint() == chkpt_state_dict)
        # print(trainer._checkpoint_connector.dump_checkpoint()['state_dict'] == chkpt_state_dict)
        # chkpt_state_dict.update(trainer._checkpoint_connector.dump_checkpoint()['state_dict'])
        # # print(chkpt_state_dict)
        # rmlist = ["epoch", "global_step", "pytorch-lightning_version", "state_dict", "loops", "callbacks", "optimizer_states", "lr_schedulers", "hparams_name", "hyper_parameters"]
        # for elem in rmlist:
        #     chkpt_state_dict.pop(elem)
        
        # torch_native_model = cbh_torch_MLP.CloudBaseMLP(**model_hyperparameter_dictionary["MLP"])  # PLEASE GENERALIZE
        # torch_native_model.load_state_dict(chkpt_state_dict)
        # with TemporaryDirectory() as tmpdirname:
        #     f_name = join(tmpdirname, f"{run.info.run_id}-best_model_checkpoint.pt")
        #     torch.save(checkpoint, f_name)
        #     mlflow.log_artifact(f_name)

In [12]:
# import traceback
# import warnings
# import sys
# def warn_with_traceback(message, category, filename, lineno, file=None, line=None):
#     log = file if hasattr(file,'write') else sys.stderr
#     traceback.print_stack(file=log)
#     log.write(warnings.formatwarning(message, category, filename, lineno, line))
# warnings.showwarning = warn_with_traceback
# warnings.simplefilter("always")

# run the training function
# setup mlflow logging

max_time = "00:02:20:00"  # dd:hh:mm:ss

hparams_for_mlflow["Training timeout"] = max_time

timestamp_template = '{dt.year:04d}{dt.month:02d}{dt.day:02d}T{dt.hour:02d}{dt.minute:02d}{dt.second:02d}'
run_name_template = 'cbh_challenge_{network_name}_' + timestamp_template
current_run_name = run_name_template.format(network_name=model.__class__.__name__,
                                                dt=datetime.datetime.now()
                                               )

# with Profiler() as prof, ResourceProfiler(dt=0.25) as rprof, CacheProfiler() as cprof:
with mlflow.start_run(experiment_id=mlf_exp.experiment_id, run_name=current_run_name) as run:

    mlflow.pytorch.autolog()
    mlf_logger = MLFlowLogger(experiment_name=experiment_name, tracking_uri=mlflow_server_uri, run_id=run.info.run_id)


    # define trainer
    # time_for_checkpoint = datetime.timedelta(minutes=0.1)
    checkpoint_callback = pl.callbacks.ModelCheckpoint(
        # train_time_interval=time_for_checkpoint,
        dirpath=run.info.artifact_uri,
        monitor="Val loss",
        save_on_train_epoch_end=False,
        mode="min"
    )
    callbacks = [checkpoint_callback, RichProgressBar()]
    trainer_hparams = {
        'max_epochs':epochs,
        'deterministic':True,
        'val_check_interval':0.01, # save every percentage of the data
        'devices':"auto",
        'accelerator':"auto",
        'max_time':max_time,
        'enable_checkpointing':True,
        'strategy':None,
        'callbacks':callbacks,
        'logger':mlf_logger,
    }
    hparams_for_mlflow["Trainer hparams"] = trainer_hparams
    mlf_logger.log_hyperparams(hparams_for_mlflow)
    trainer = pl.Trainer(
        **trainer_hparams
    )
    trainer.fit(model=model, train_dataloaders=train_loader, val_dataloaders=val_loader)
print("Ended run", run.info.run_id)
    # print(visualize([prof, rprof, cprof], filename='profile_loop.html', save=True))



Ended run 97a2587bf69049319b3dc98684a73223


## Display and evaluate results

In [13]:
def print_auto_logged_info(r):

    tags = {k: v for k, v in r.data.tags.items() if not k.startswith("mlflow.")}
    artifacts = [f.path for f in MlflowClient().list_artifacts(r.info.run_id, "model")]
    print("run_id: {}".format(r.info.run_id))
    print("artifacts: {}".format(artifacts))
    print("params: {}".format(r.data.params))
    print("metrics: {}".format(r.data.metrics))
    print("tags: {}".format(tags))
print_auto_logged_info(run)

run_id: 97a2587bf69049319b3dc98684a73223
artifacts: []
params: {}
metrics: {}
tags: {}


In [None]:
# unique_save_str_element = str(datetime.datetime.now())
# save_str = "model_out" + "_" + unique_save_str_element + ".ckpt"
# trainer.save_checkpoint(save_str)

In [None]:
# test model functionality
example_batch = next(iter(train_loader))
inputs = example_batch[0]
print(inputs.shape, "Input shape")
try:
    preds, _ = model(inputs, heights)
except:
    print(example_batch[0].shape, "inp pre-flat")
    inputs = torch.flatten(example_batch[0], start_dim=1)
    preds = model(inputs)
print(preds.shape, "prediction output")
pred_label = np.argmax(preds.detach().numpy(), axis=1)
print(pred_label.shape, "prediction label shape")
targs = example_batch[1]
targs = np.array(targs)
print(targs.shape, "targ shape")
correct = targs == pred_label
print("Correct samples:", np.count_nonzero(correct))
print("Total samples tested:", len(correct))
print("Accuracy:", (np.count_nonzero(correct) / len(correct) * 100), "%")
print(
    "Model predictions binned: (Class labels), (Counts):",
    np.unique(pred_label, return_counts=True),
)

eg_batch_metrics = {
    "Correct samples" : np.count_nonzero(correct),
    "Total samples tested" : len(correct),
    "Accuracy" : (np.count_nonzero(correct) / len(correct) * 100),

    "Model predictions binned: (Class labels), (Counts)"
     : str(np.unique(pred_label, return_counts=True)),

}

In [None]:
eg_batch_metrics = {
    "Single batch example validation metrics/Correct samples" : np.count_nonzero(correct),
    "Single batch example validation metrics/Total samples tested" : len(correct),
    "Single batch example validation metrics/Accuracy" : np.count_nonzero(correct) / len(correct) * 100,
}

In [None]:
mlf_logger.log_metrics(eg_batch_metrics)

In [None]:
# display mlflow output
print_auto_logged_info(mlflow.get_run(run_id=run.info.run_id))
mlflow.end_run()