In [2]:
pip install wandb -q

Note: you may need to restart the kernel to use updated packages.


In [3]:
from dask.distributed import Client

client = Client("tcp://127.0.0.1:35703")

# Import

In [4]:
import numpy as np
import xarray as xr
import torch
import pytorch_lightning as pl
from scipy import ndimage
import itertools
import os
from pytorch_lightning.loggers import WandbLogger
import wandb
from pytorch_lightning.callbacks import EarlyStopping,ModelCheckpoint,Callback
import myParam
import matplotlib.pyplot as plt

In [5]:
import platform
print(platform.platform())

Linux-5.10.133+-x86_64-with-glibc2.35


In [6]:
print(torch.__version__)

1.13.1.post200


In [7]:
import warnings

warnings.simplefilter(action='ignore', category=FutureWarning)
warnings.filterwarnings("ignore", ".*Consider increasing the value of the `num_workers` argument*")

In [8]:
import importlib

In [9]:
importlib.reload(myParam)

<module 'myParam' from '/home/jovyan/oceanDataNotebooks/parametrization_NN/myParam.py'>

## Open data

In [10]:
PERSISTENT_BUCKET = os.environ['PERSISTENT_BUCKET'] 

In [11]:
data_dict=[dict() for i in range(6)]
data_dict[0] = dict({'region' : '1', 'season' : 'fma', 'label' : 'GULFSTR FMA'})
data_dict[1] = dict({'region' : '1', 'season' : 'aso', 'label' : 'GULFSTR ASO'})
data_dict[2] = dict({'region' : '2', 'season' : 'fma', 'label' : 'MIDATL FMA'})
data_dict[3] = dict({'region' : '2', 'season' : 'aso', 'label' : 'MIDATL ASO'})
data_dict[4] = dict({'region' : '3', 'season' : 'fma', 'label' : 'WESTMED FMA'})
data_dict[5] = dict({'region' : '3', 'season' : 'aso', 'label' : 'WESTMED ASO'})

In [12]:
batch_size = 4
height = 45
width = 40

In [13]:
%%time
features_to_add_to_sample = ['temp', 'temp_var', 'rho_ct_ct', 'diff_temp_sqr']
auxiliary_features = ['z_l', 'f', 'e1t', 'e2t']
all_data_3D = myParam.PyLiDataModule(data_dict, '3D', features_to_add_to_sample, auxiliary_features, height, width, batch_size=batch_size)

CPU times: user 16.5 s, sys: 2.36 s, total: 18.8 s
Wall time: 1min 38s


In [14]:
%%time
features_to_add_to_sample = ['temp', 'temp_var', 'rho_ct_ct', 'diff_temp_sqr']
auxiliary_features = ['e1t', 'e2t']
all_data_2D = myParam.PyLiDataModule(data_dict, '2D', features_to_add_to_sample, auxiliary_features, height, width, batch_size=batch_size)

CPU times: user 1.92 s, sys: 262 ms, total: 2.18 s
Wall time: 14.4 s


In [None]:
def run_experiment(config, project) :
    wandb_logger = WandbLogger(name=config['model_label']+'_'+config['version'], \
                               version=config['model_label']+'_'+config['version'],\
                               project=project, config=config, resume=False, log_model=True, offline=False)
    
    torch_model = eval(config['torch_model'])(**config['torch_model_params'])
    pylight_module = myParam.GenericPyLiModule(torch_model, **config['module_params'])
    
    # Callbacks
    checkpoint_callback = ModelCheckpoint(monitor="loss_train", save_last=True)    
    early_stopping_callback = EarlyStopping(monitor="loss_validation", mode="min")
    log_predictions_callback = LogPredictionsCallback(wandb_logger, config['module_params']['output_features'][0])
    
    trainer = pl.Trainer(**config['training_params'], logger=wandb_logger, 
                     callbacks=[early_stopping_callback, checkpoint_callback,log_predictions_callback],
                     accelerator='gpu', devices=(1 if torch.cuda.is_available() else None))  
    trainer.fit(model = pylight_module, datamodule=eval(config['datamodule']))
    #perform tests
    test_datamodule = eval(config['datamodule'])
    test_datamodule.setup(stage='test')
    trainer.predict(model = pylight_module, dataloaders=test_datamodule.test_dataloader())
    test_datamodule.setup(stage='test')
    trainer.test(model = pylight_module, datamodule=test_datamodule)
    wandb.finish()

# Runs

In [21]:
wandb.finish()

In [22]:
project_name = 'tests'

In [23]:
configs = list()

## Linear regression

In [None]:
configs.append(dict({'model_label' : 'LinReg',
                'version' : 'HuberLoss2D_nn',
                'torch_model' : 'myParam.lin_regr_model',
                'datamodule' : 'all_data_2D',
                'torch_model_params' : dict({'data_geometry' : '2D',\
                                            'nb_of_input_features' : 1, \
                                            'nb_of_output_features' : 1}),
                'module_params' : dict({'input_features'  : ['diff_temp_sqr'],
                                        'output_features'  : ['temp_var'],
                                        'output_units' : None,
                                        'loss' : torch.nn.functional.huber_loss,
                                        'optimizer' : torch.optim.SGD,
                                        'learning_rate' : 1e-3,
                                        'loss_normalization' : False}),
                'training_params' : dict({'max_epochs' : 100,
                                          'limit_train_batches' : 1.0})
               }))
run_experiment(configs[-1], project_name)

## FCNN

In [None]:
configs.append(dict({'model_label' : 'FCNN',
                'version' : '3D_patch3-3_MSELoss_inNorm_outnondim',
                'torch_model' : 'myParam.FCNN',
                'datamodule' : 'all_data_3D',
                'torch_model_params' : dict({'data_geometry' : '3D',\
                                            'nb_of_input_features' : 1, \
                                            'nb_of_output_features' : 1, \
                                            'input_patch_size' : 3,
                                            'output_patch_size' : 3, 
                                            'int_layer_width' : 50}),
                'module_params' : dict({'input_features'  : ['temp'],
                                        'output_features'  : ['temp_var'],
                                        'output_units' : ['diff_temp_sqr'],
                                        'input_normalization_features' : ['sqrt_filtered_diff_temp_sqr'],
                                        'loss' : torch.nn.functional.mse_loss,
                                        'optimizer' : torch.optim.Adam,
                                        'learning_rate' : 1e-4,
                                        'loss_normalization' : False}),
                'training_params' : dict({'max_epochs' : 100,
                                          'limit_train_batches' : 1.0})
               }))
run_experiment(configs[-1], project_name)

## CNN

In [None]:
configs.append(dict({'model_label' : 'CNN',
                'version' : 'kernel3_MSELossNorm_inNorm_outnondim',
                'torch_model' : 'myParam.CNN',
                'datamodule' : 'all_data_3D',
                'torch_model_params' : dict({'data_geometry' : '3D',\
                                            'nb_of_input_features' : 1, \
                                            'nb_of_output_features' : 1, \
                                            'kernel_size' : 3,
                                            'int_layer_width' : 64}),
                'module_params' : dict({'input_features'  : ['temp'],
                                        'output_features'  : ['temp_var'],
                                        'output_units' : ['diff_temp_sqr'],
                                        'input_normalization_features' : ['sqrt_filtered_diff_temp_sqr'],
                                        'loss' : torch.nn.functional.mse_loss,
                                        'optimizer' : torch.optim.Adam,
                                        'learning_rate' : 1e-3,
                                        'loss_normalization' : True}),
                'training_params' : dict({'max_epochs' : 100,
                                          'limit_train_batches' : 1.0})
               }))
run_experiment(configs[-1], project_name)

# Check the metrics and tests

The metrics monitored during the training and after-training tests results can be found on https://wandb.ai/anagorb63/tests?workspace=user-anagorb63