# Training a neural network in PyTorch
This notebook demonstrates training a classifier in PyTorch.

In [1]:
import zarr
import os
import dask
import dask.array 
import torch
import numpy as np
import pathlib

import pytorch_lightning as pl
print(pl.__version__)

import mlflow.pytorch
from mlflow.tracking import MlflowClient
# import dask.distributed
from dask.diagnostics import Profiler, ResourceProfiler, CacheProfiler

1.6.4


In [2]:
root_data_directory = pathlib.Path(os.environ['SCRATCH']) / 'cbh_data'

dev_data_path = root_data_directory / 'analysis_ready' / 'dev.zarr' 
training_data_path = root_data_directory / 'analysis_ready' / 'train.zarr'

In [3]:
# load in the data
def load_data_from_zarr(path):
    
    store = zarr.DirectoryStore(training_data_path)
    zarr_group = zarr.group(store=store)
    print('Loaded zarr, file information:\n', zarr_group.info, '\n')
    
    x = dask.array.from_zarr(zarr_group['humidity_temp_pressure_x.zarr'])
    y_lab = dask.array.from_zarr(zarr_group['onehot_cloud_base_height_y.zarr'])
    y_cont = dask.array.from_zarr(zarr_group['cloud_volume_fraction_y.zarr'])
    
    return x, y_lab, y_cont

In [4]:
train_input, train_labels, train_cloud_volume = load_data_from_zarr(training_data_path)
dev_input, dev_labels, dev_cloud_volume = load_data_from_zarr(dev_data_path)

Loaded zarr, file information:
 Name        : /
Type        : zarr.hierarchy.Group
Read-only   : False
Store type  : zarr.storage.DirectoryStore
No. members : 3
No. arrays  : 3
No. groups  : 0
Arrays      : cloud_volume_fraction_y.zarr, humidity_temp_pressure_x.zarr,
            : onehot_cloud_base_height_y.zarr
 

Loaded zarr, file information:
 Name        : /
Type        : zarr.hierarchy.Group
Read-only   : False
Store type  : zarr.storage.DirectoryStore
No. members : 3
No. arrays  : 3
No. groups  : 0
Arrays      : cloud_volume_fraction_y.zarr, humidity_temp_pressure_x.zarr,
            : onehot_cloud_base_height_y.zarr
 



## Define the network

In [5]:
# define RNN
class CloudBaseLSTM(pl.LightningModule):
    def __init__(self, inputSize, lstmLayers, lstmHiddenSize, output_size, height_dimension, embed_size, BILSTM=True, batch_first=False, lr=2e-3, log_boolean=False, do_linear_fit=False):
        super().__init__()
        
        self.LSTM = torch.nn.LSTM(inputSize+embed_size, lstmHiddenSize, lstmLayers, batch_first=batch_first, bidirectional=BILSTM, proj_size=output_size)
        
        
        
        self.batch_first = batch_first
        self.proj_size = output_size
        
        self.relu = torch.nn.ReLU()
        
        self.height_embedding = torch.nn.Embedding(height_dimension, embed_size)
        self.BILSTM = BILSTM
        self.lr = lr
        
        self.loss_fn_vol = torch.nn.MSELoss()
        self.do_linear_fit = do_linear_fit
        if do_linear_fit:
            self.loss_fn_base = torch.nn.CrossEntropyLoss()
            self.linearCap = torch.nn.Linear(height_dimension, height_dimension)
        
        self.log_bool = log_boolean
        
    def forward(self, x, height):
        
        #produce height embeds
        height_embeds = self.height_embedding(height)
        height_embeds = torch.flatten(height_embeds, start_dim=2)
        # print(height_embeds.size())
        
        #concat with feature vector
        x_and_height = torch.cat((x, height_embeds), 2)
        
        #send through LSTM
        lstm_out, _ = self.LSTM(x_and_height)
        # combine backward and forward LSTM outputs for each cell
        if(self.BILSTM):
            lstm_out = lstm_out[:,:,:self.proj_size] + lstm_out[:,:,self.proj_size:]
        # combinedLSTMOut = combinedLSTMOut / 2
        
        # # softmax but check for batch first
        # softmax_dim = 0
        # if self.batch_first:
        #     softmax_dim = 1
            
            
        # flatten seq out
        lstm_out = torch.flatten(lstm_out, start_dim=1)
        
        # #normalization
        # out = torch.nn.functional.log_softmax(nn_out, dim=softmax_dim)
        
        # apply ReLU
        relu_out = self.relu(lstm_out)
        
        nn_out = None # initialize for clarity
        
        if self.do_linear_fit:
            # apply linear layer for base prediction
            nn_out = self.linearCap(relu_out)
            
        return nn_out, relu_out
        
        
        # return both the nn_out and the lstm out for loss calculations
        
    
    def generic_model_step(self, batch, batch_idx, str_of_step_name):
        # print("Start step")
        
         #### #### #### WARNING MAY CAUSE SOME WEIRD OBJECT ORIENTED RELATED BEHAVIOUR I AM UNAWARE ABOUT AND NOT WORK #### #### ####
            
        # print("CHECKING")
        # print(batch)
            
        base_pred, vol_pred = self(batch['x'], batch['height_vector'])
        loss = self.loss_fn_vol(vol_pred, batch['cloud_volume_target'])
        
        if self.do_linear_fit:
            loss_2 = self.loss_fn_base(base_pred, batch['cloud_base_target'])
            loss = (loss*40) + loss_2 # 40 adjusts for differences in numerical values produced by loss function
        
        #log to tensorboard
        if self.log_bool:
            self.log((str_of_step_name + 'loss'), loss)
            self.log(str_of_step_name, 'volume loss component', loss_1)
            self.log(str_of_step_name, 'base height loss component', loss_2)
        
        return loss
        
    
    
    def training_step(self, batch, batch_idx):
        
        return self.generic_model_step(batch, batch_idx, 'training')
    
    def validation_step(self, batch, batch_idx):
        
        return self.generic_model_step(batch, batch_idx, 'validation')
    
    def test_step(self, batch, batch_idx):
        
        return self.generic_model_step(batch, batch_idx, 'test')
    
    def configure_optimizers(self):
        optim = torch.optim.Adam(self.parameters(), self.lr)
        
        return optim

# define torch dataloader
class CBH_Dataset(torch.utils.data.Dataset):
    def __init__(self, data_x, data_y, cloud_base_label):
        
        # print('begin init')
        
        self.temp_humidity_pressure = data_x
        self.cloudbase_target = data_y
        self.cbh_label = cloud_base_label
        
        self.height_layer_number = data_x.shape[1] # take the shape at index 1 as data_x of format sample, height, feature
        
        assert self.height_layer_number == 70
        
        # print('end init')
        
    def __len__(self):
        return len(self.temp_humidity_pressure)

    def __getitem__(self, idx):
        
        # since dask is being used, first compute the values on the index given to the get function, convert the array to tensor for pytorch
        
        # torch.from_numpy(x.compute())
        
        input_features = self.temp_humidity_pressure[idx]
        output_target = self.cloudbase_target[idx]
        # print(output_target.dtype)
        # output_target = output_target.type(torch.FloatTensor)
        cbh_lab = self.cbh_label[idx]
        
        # print('CALL ON GETITEM')
        
        height_vec = torch.from_numpy(np.arange(self.height_layer_number)) # should have produced this vector here, as it is the same every time, but will leave it since sunken cost and maybe it improves performance??? 
        
        item_in_dataset = {'x':input_features, 'cloud_volume_target':output_target, 'cloud_base_target':cbh_lab, 'height_vector':height_vec}
        return item_in_dataset
    

In [6]:
# # define dask specific collate function for dataloader, collate is the step where the dataloader combines all the samples into a singular batch to be enumerated on, 
# # after getting all items 
# from torch._six import string_classes
# import collections
# def temp_and_real_collate_default(batch):
#     elem = batch[0]
#     elem_type = type(elem)
#     print('0')
#     if isinstance(elem, torch.Tensor):
#         print('1')
#         out = None
#         if torch.utils.data.get_worker_info() is not None:
#             print('2')
#             # If we're in a background process, concatenate directly into a
#             # shared memory tensor to avoid an extra copy
#             numel = sum(x.numel() for x in batch)
#             storage = elem.storage()._new_shared(numel, device=elem.device)
#             out = elem.new(storage).resize_(len(batch), *list(elem.size()))
#         return torch.stack(batch, 0, out=out)
#     elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \
#             and elem_type.__name__ != 'string_':
#         print('3')
#         if elem_type.__name__ == 'ndarray' or elem_type.__name__ == 'memmap':
#             print('4')
#             # array of string classes and object
#             if np_str_obj_array_pattern.search(elem.dtype.str) is not None:
#                 print('5')
#                 raise TypeError(default_collate_err_msg_format.format(elem.dtype))

#             return default_collate([torch.as_tensor(b) for b in batch])
#         elif elem.shape == ():  # scalars
#             print('6')
#             return torch.as_tensor(batch)
#     elif isinstance(elem, float):
#         print('7')
#         return torch.tensor(batch, dtype=torch.float64)
#     elif isinstance(elem, int):
#         print('8')
#         return torch.tensor(batch)
#     elif isinstance(elem, string_classes):
#         print('9')
#         return batch
#     elif isinstance(elem, collections.abc.Mapping):
#         print('10')
#         try:
#             print('11')
#             return elem_type({key: default_collate([d[key] for d in batch]) for key in elem})
#         except TypeError:
#             print('012')
#             # The mapping type may not support `__init__(iterable)`.
#             return {key: default_collate([d[key] for d in batch]) for key in elem}
#     elif isinstance(elem, tuple) and hasattr(elem, '_fields'):
#         print('013')# namedtuple
#         return elem_type(*(default_collate(samples) for samples in zip(*batch)))
#     elif isinstance(elem, collections.abc.Sequence):
#         print('014')
#         # check to make sure that the elements in batch have consistent size
#         it = iter(batch)
#         elem_size = len(next(it))
#         if not all(len(elem) == elem_size for elem in it):
#             print('015')
#             raise RuntimeError('each element in list of batch should be of equal size')
#         transposed = list(zip(*batch))  # It may be accessed twice, so we use a list.

#         if isinstance(elem, tuple):
#             print('016')
#             return [default_collate(samples) for samples in transposed]  # Backwards compatibility.
#         else:
#             print('017')
#             try:
#                 print('018')
#                 return elem_type([default_collate(samples) for samples in transposed])
#             except TypeError:
#                 print('019')
#                 # The sequence type may not support `__init__(iterable)` (e.g., `range`).
#                 return [default_collate(samples) for samples in transposed]

#     raise TypeError(default_collate_err_msg_format.format(elem_type))


def dataloader_collate_with_dask(batch):
    # print("call OG collate")
    elem = batch[0]
    elem_type = type(elem)
    
    
    # assert torch.utils.data.get_worker_info() is None # if this assertion fails, there are issues in code and this case needs to be handled see pytorch source of default collate fn

    try:
        return elem_type({key: collate_helper_send_dict_elements_to_tensor([d[key] for d in batch]) for key in elem})
        
    except TypeError:
        # print('Should not have reached here')
        # raise TypeError()
        return {key: collate_helper_send_dict_elements_to_tensor([d[key] for d in batch]) for key in elem}
    
    raise TypeError(default_collate_err_msg_format.format(elem_type))

    
def collate_helper_send_dict_elements_to_tensor(batch):
    # print("call sub collate")
    # assert torch.utils.data.get_worker_info() is None
    
    elem = batch[0]
    
    # print(type(elem))
    # print(batch)
    
    if type(elem) is dask.array.core.Array:
        new_batch = np.stack(batch, 0) # emulate torch stack
        # print("Start compute", len(batch))
        new_batch = new_batch.compute()
        # print("End compute")
        to_return = torch.from_numpy(new_batch)
        
    # elif isinstance(elem, torch.Tensor):
    #     out = None
    #     if torch.utils.data.get_worker_info() is not None:
    #         # If we're in a background process, concatenate directly into a
    #         # shared memory tensor to avoid an extra copy
    #         numel = sum(x.numel() for x in batch)
    #         storage = elem.storage()._new_shared(numel)
    #         out = elem.new(storage).resize_(len(batch), *list(elem.size()))
    #     return torch.stack(batch, 0, out=out)
    
    
    else:
        to_return = torch.stack(batch, 0)
    # print('okay')
    # print(to_return)
    return to_return
    
    

In [7]:
# enforce reproducibility
from pytorch_lightning import Trainer, seed_everything

seed_everything(42)


Global seed set to 42


42

## Perform the network initialization and training

In [8]:
# load into torcg dataset 

collate_fn = dataloader_collate_with_dask

train_cbh_data = CBH_Dataset(train_input, train_cloud_volume, train_labels)
dev_cbh_data = CBH_Dataset(dev_input, dev_cloud_volume, dev_labels)

height_dim = train_input.shape[1]

# define model and hyperparameters
layers = 3
input_size = train_input.shape[2] # input size is the cell input (feat dim)
output_size = 1 # for each height layer, predict one value for cloud base prob
hidden_size = 32
embed_size = 5
BILSTM = False
batch_first = True

learn_rate = 0.002

log_with_pl = False # do not log, as track with mlFlow

model = CloudBaseLSTM(input_size, layers, hidden_size, output_size, height_dim, embed_size, BILSTM, batch_first, lr=learn_rate, log_boolean=log_with_pl)

# define training related hyperparameters

epochs = 10
max_time ="00:12:00:00" #dd:hh:mm:ss

batch_size = 1000

# after training parameters defined, load datasets into dataloaders
import multiprocessing as mp
workers_on_system = np.min((8, mp.cpu_count()))
train_loader = torch.utils.data.DataLoader(train_cbh_data, batch_size=batch_size, shuffle=True, collate_fn=collate_fn, num_workers = workers_on_system)
val_loader = torch.utils.data.DataLoader(dev_cbh_data, batch_size=batch_size, shuffle=False, collate_fn=collate_fn, num_workers = workers_on_system) # don't shuffle in val

# define trainer
trainer = pl.Trainer(max_epochs = epochs, deterministic=True, check_val_every_n_epoch=1, devices="auto", accelerator="auto", max_time=max_time)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [9]:
# setup mlflow logging

mlflow.pytorch.autolog()

In [10]:
# run the training function 
with mlflow.start_run() as run:
    trainer.fit(model=model, train_dataloaders=train_loader, val_dataloaders=val_loader)

Missing logger folder: /net/home/h02/hsouth/github_committing/data_science_cop/challenges/2021_CyrilMorcrette_cloudBaseHeight/lightning_logs

  | Name             | Type      | Params
-----------------------------------------------
0 | LSTM             | LSTM      | 2.5 K 
1 | relu             | ReLU      | 0     
2 | height_embedding | Embedding | 350   
3 | loss_fn_vol      | MSELoss   | 0     
-----------------------------------------------
2.9 K     Trainable params
0         Non-trainable params
2.9 K     Total params
0.012     Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

  rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")


## Display and evaluate results

In [18]:
def print_auto_logged_info(r):

    tags = {k: v for k, v in r.data.tags.items() if not k.startswith("mlflow.")}
    artifacts = [f.path for f in MlflowClient().list_artifacts(r.info.run_id, "model")]
    print("run_id: {}".format(r.info.run_id))
    print("artifacts: {}".format(artifacts))
    print("params: {}".format(r.data.params))
    print("metrics: {}".format(r.data.metrics))
    print("tags: {}".format(tags))

In [19]:
# display mlflow output
print_auto_logged_info(mlflow.get_run(run_id=run.info.run_id))

run_id: 5ebfb65f4e5441de92e1d59798e36e14
artifacts: []
params: {}
metrics: {}
tags: {}


In [None]:
# sample some predictions for understanding
