## Notebook 4. (Must be run as python script)

file: `4_acc=ddp-16bit-precision-2-gpu.ipynb`
### ddp - 16bit-precision -- 2-gpu

<!-- tune-batchsize -- tune_lr -- include a batch_size tuning step, then a learning_rate tuning step (both using only 'dp'), before training (using 'ddp') then testing -->

Created by: Jacob A Rose  
Created on: Wednesday July 7th, 2021

### Scaling model training series

A collection of notebooks meant to demonstrate minimal-complexity examples for:
* Integrating new training methods for scaling up experiments to large numbers in parallel &
* Making maximum use of hardware resources

1. 16bit precision, single gpu, train -> test
2. 16bit precision, single gpu, batch_size tune -> train -> test
3. 16bit precision, single gpu, batch_size tune -> lr tune -> train -> test
4. ddp, 16bit precision, 2x gpus, batch_size tune -> lr tune -> train -> test

In [None]:
import wandb
wandb.__version__

In [None]:
# available_datasets = {'1':1,
#                      '2':2}

# class TestClass:
    
# #     @property
# #     def available_datasets(self):
# #         return available_datasets




#     @property
#     def available_datasets(self):
#         """
#         Subclasses must define this property
#         Must return a dict mapping dataset key names to their absolute paths on disk.
#         """
#         return available_datasets
        
#     @available_datasets.setter
#     def available_datasets(self, new):
#         """
#         Subclasses must define this property
#         Must return a dict mapping dataset key names to their absolute paths on disk.
#         """
#         try:
#             available_datasets.update(new)
#         except:
#             raise Exception

# obj = TestClass()

# print(obj.available_datasets)

# obj.available_datasets['3'] = 3

# print(obj.available_datasets)

# obj.available_datasets = {'4':4,
#                           "5":5}

# print(obj.available_datasets)

In [None]:
from typing import Any, List
from pytorch_lightning.metrics.classification import Accuracy

import shutil
import os
import torch
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from torchvision import transforms
import pytorch_lightning as pl
import matplotlib.pyplot as plt
import gc
import rich
# from rich import pretty
# rich.pretty.install()

if 'TOY_DATA_DIR' not in os.environ: 
    os.environ['TOY_DATA_DIR'] = "/media/data_cifs/projects/prj_fossils/data/toy_data"
        
default_root_dir = os.environ['TOY_DATA_DIR']

In [None]:
    
    
def main(config_path = "/media/data/jacob/GitHub/lightning-hydra-classifiers/configs/experiment/3_16bit_precision-single_gpu--tune_batchsize--tune_lr.yaml"):
    
    
    config = read_hydra_config(config_dir = str(Path(config_path).parent),
                               job_name="test_app",
                               config_name=Path(config_path).stem)
    template_utils.extras(config)
    
    if "seed" in config:
        pl.seed_everything(config.seed)
    
#     import pdb; pdb.set_trace()
    
#     from IPython.core import debugger
#     debugger.set_trace()
    
#############################################################################
    if os.path.isfile(config.hparams_log_path):
        config.hparams.update(OmegaConf.load(config.hparams_log_path))
        if isinstance(config.hparams.batch_size, int):
            config.tuner.scale_batch_size.tuned = True
        if isinstance(config.hparams.lr, float):
            config.tuner.lr_find.tuned = True
        os.makedirs(config.log_dir, exist_ok=True)
        shutil.copyfile(config.hparams_log_path, Path(config.log_dir, Path(config.hparams_log_path).name))
    else:
        os.makedirs(Path(config.hparams_log_path).parent, exist_ok=True)
#############################################################################
#     if not config.tuner.scale_batch_size.tuned:
#         config.hparams.batch_size = config.tuner.scale_batch_size.kwargs.init_val
        
#     if not config.tuner.lr_find.tuned:
#         config.hparams.lr = config.tuner.lr_find.kwargs.min_lr
        

        
    
    for k,v in config.hparams.items():
        print(k, v, type(v))

#############################################################################
#     if config.hparams.batch_size is None:
    if not config.tuner.scale_batch_size.tuned:
        log.info(f"<----------Initiating auto scale_batch_size tuning----------->")
        log.info(f"Tuning kwargs:")
        log.info(OmegaConf.to_yaml(config.tuner.scale_batch_size.kwargs))
        # TODO: create hparams dataclass for checkpointing these tuned parameters
        
        config.hparams.batch_size = config.tuner.scale_batch_size.kwargs.init_val        
        config.hparams.lr = config.tuner.lr_find.kwargs.min_lr
        datamodule, config = configure_datamodule(config)    
        template_utils.print_config(config, resolve=True)
        model = configure_model(config)
        tuner = configure_tuner(config)

        
        best_bsz = tuner.scale_batch_size(model,
                                          datamodule=datamodule,
                                          **config.tuner.scale_batch_size.kwargs)
        config.hparams.batch_size = best_bsz
        
        log.info("<----------scale_batch_size-Results----------->")
        log.info(f'Best batch_size={best_bsz}')
        
        del datamodule, model, tuner        
        gc.collect()
        torch.cuda.empty_cache()

#     datamodule.batch_size = config.hparams.batch_size
#     model.batch_size = config.hparams.batch_size
    if not config.tuner.lr_find.tuned:
        config.hparams.lr = "default"
    OmegaConf.save(config.hparams, resolve=True, f=config.hparams_log_path)
#############################################################################


#############################################################################
    if not config.tuner.lr_find.tuned:
        config.hparams.lr = config.tuner.lr_find.kwargs.min_lr
        log.info(f"<----------Initiating learning_rate tuning----------->")
        log.info(f"Tuning kwargs:")
        log.info(OmegaConf.to_yaml(config.tuner.lr_find.kwargs))
        
#         config.hparams.lr = config.tuner.lr_find.kwargs.min_lr
        datamodule, config = configure_datamodule(config)    
        template_utils.print_config(config, resolve=True)
        model = configure_model(config)
        tuner = configure_tuner(config)

        
        lr_finder = tuner.lr_find(model=model,
                                  datamodule=datamodule,
                                  **config.tuner.lr_find.kwargs)

        log.info("<----------lr_finder-Results----------->")
        log.info(f"Best lr: {lr_finder.results}")

        fig = lr_finder.plot(suggest=True)
        plt.savefig(Path(config.hparams_log_path).parent / "lr_tuning_plot.png")
        config.hparams.lr = lr_finder.suggestion()
   
        del datamodule, model, tuner        
        gc.collect()
        torch.cuda.empty_cache()
        

#     model.lr = config.hparams.lr
    OmegaConf.save(config.hparams, resolve=True, f=config.hparams_log_path)
    
    
#############################################################################

    datamodule, config = configure_datamodule(config)    
    template_utils.print_config(config, resolve=True)
    model = configure_model(config)
        
    trainer = configure_trainer(config)
    print(f'[START] Training with tuned batch_size = {config.hparams.batch_size}')
    print(f'and tuned learning rate = {config.hparams.lr}')
    
    trainer.fit(model, datamodule=datamodule)    
    test_results = trainer.test(datamodule=datamodule)
    
    return test_results, config.hparams

### Function definitions

1. Configure logger (using python's logging module)
2. Configure experiment Config (using hydra + omegaconf.DictConfig)
3. Configure datamodule (using custom LightningDataModule)
4. Configure model (using custom LightningModule)
5. Configure trainer (using pl.Trainer, as well as pytorch lightning loggers & callbacks)

In [None]:
import sys
import logging
from lightning_hydra_classifiers.utils import template_utils
import hydra
from hydra.experimental import compose, initialize_config_dir
from omegaconf import OmegaConf, DictConfig
from pathlib import Path
from rich.logging import RichHandler

def get_standard_python_logger(name: str='notebook',
                               log_path: str=None):
    """
    Set up the standard python logging module for command line debugging
    """
    if log_path is None:
        log_path = f"./{name}.log"
    else:
        os.makedirs(log_path, exist_ok=True)
        log_path = str(Path(log_path, name)) + '.log'

    logging.basicConfig(
        format='%(asctime)s [%(levelname)s] %(name)s - %(message)s',
        level=logging.INFO,
        datefmt='%Y-%m-%d %H:%M:%S',
        stream=sys.stdout
    )
    log = logging.getLogger(name)
    
#     sys.stdout = open(log_path, 'a')
    
#     fh = logging.FileHandler(log_path)
#     f = logging.Formatter('%(message)s')
#     fh.setFormatter(f)
#     fh.setLevel(logging.INFO)
    rh = RichHandler(rich_tracebacks=True)
    rh.setLevel(logging.INFO)
    
#     log.addHandler(fh)
    log.addHandler(rh)
    return log


# log = get_standard_python_logger(name='notebook_experiment')


def read_hydra_config(config_dir: str,
                      job_name: str="test_app",
                      config_name: str="experiment") -> DictConfig:
    """
    Read a yaml config file from disk using hydra and return as a DictConfig.
    """
    os.chdir(config_dir)
    with initialize_config_dir(config_dir=config_dir, job_name=job_name):
        cfg = compose(config_name=config_name)
        
    if cfg.get("print_config"):
        template_utils.print_config(cfg, resolve=True)        
    return cfg

In [None]:
def configure_datamodule(config: DictConfig) -> pl.LightningDataModule:
    log.info(f"Instantiating datamodule <{config.datamodule._target_}>")
    datamodule: pl.LightningDataModule = hydra.utils.instantiate(config.datamodule)
        
    try:
        datamodule.setup(stage="fit")
        config.hparams.classes = datamodule.classes
        config.hparams.num_classes = len(config.hparams.classes)
    except Exception as e:
        print(e)
        pass
        
    return datamodule, config

In [None]:
def configure_model(config: DictConfig) -> pl.LightningModule:
    log.info(f"Instantiating model <{config.model._target_}>")
    model: pl.LightningModule = hydra.utils.instantiate(config.model)
    
    return model

In [None]:
def configure_tuner(config: DictConfig) -> pl.tuner.tuning.Tuner:
    config = OmegaConf.create(config)
#     trainer_config = OmegaConf.masked_copy(config,['trainer'])
#     tuner_config = OmegaConf.masked_copy(config,['tuner'])
    if 'ddp' in config.trainer.accelerator:
        config.trainer.accelerator = 'dp'
        
    trainer: pl.Trainer = configure_trainer(config) #hydra.utils.instantiate(trainer_config)
    tuner: pl.tuner.tuning.Tuner = hydra.utils.instantiate(config.tuner.instantiate, trainer=trainer)
    
    return tuner

In [None]:
from typing import List
# from pytorch_lightning import LightningModule, LightningDataModule, Callback, Trainer
# from pytorch_lightning.loggers import LightningLoggerBase
# from pytorch_lightning import seed_everything

def configure_trainer(config: DictConfig) -> pl.Trainer:

    # Init Lightning callbacks
    callbacks: List[pl.Callback] = []
    if "callbacks" in config:
        for cb_name, cb_conf in config["callbacks"].items():
            if "_target_" in cb_conf:
                log.info(f"Instantiating callback <{cb_conf._target_}>")
                if cb_name == "wandb":
                    callbacks.append(hydra.utils.instantiate(cb_conf, config=OmegaConf.to_container(config, resolve=True)))
                else:
                    callbacks.append(hydra.utils.instantiate(cb_conf))

    # Init Lightning loggers
    logger: List[pl.loggers.LightningLoggerBase] = []
    if "logger" in config:
        for _, lg_conf in config["logger"].items():
            if "_target_" in lg_conf:
                log.info(f"Instantiating logger <{lg_conf._target_}>")
                logger.append(hydra.utils.instantiate(lg_conf))


    log.info(f"Instantiating trainer <{config.trainer._target_}>")
    trainer: pl.Trainer = hydra.utils.instantiate(config.trainer,
                                                  callbacks=callbacks,
                                                  logger=logger,
                                                  _convert_="partial")
        
    return trainer

# trainer = configure_trainer(config)

In [None]:
config_path = "/media/data/jacob/GitHub/lightning-hydra-classifiers/configs/experiment/3_16bit_precision-single_gpu--tune_batchsize--tune_lr.yaml"

log = get_standard_python_logger(name=Path(config_path).stem,
                                 log_path=Path(config_path).parent / "experiment_logs") # 'notebook_experiment')

### Main

In [None]:
# %reload_ext tensorboard
# %tensorboard --port 0 --logdir lightning_logs/

main()

In [None]:
# %debug

In [None]:
# %pip list
#| grep tornado

## Scratch

In [None]:


# config = read_hydra_config(config_dir = str(Path(config_path).parent),
#                            job_name="test_app",
#                            config_name=Path(config_path).stem)

# template_utils.extras(config)



# log.info(f"<----------Initiating auto scale_batch_size tuning----------->")
# log.info(f"Tuning kwargs:")
# # log.info(f"Tuning kwargs:\n{OmegaConf.to_yaml(config.tuner.scale_batch_size.kwargs, resolve=True)}")
# log.info(OmegaConf.to_container(config.tuner.scale_batch_size.kwargs, resolve=True))
# # OmegaConf.to_container(

# # OmegaConf.to_yaml(

# import rich

# name = Path(config_path).stem
# log_path=Path(config_path).parent / "experiment_logs"


# if log_path is None:
#     log_path = f"./{name}.log"
# else:
#     os.makedirs(log_path, exist_ok=True)
#     log_path = str(Path(log_path, name)) + '.log'

# file = open(log_path, 'w')
# template_utils.print_config(config, resolve=True, file=file)
# file.close()
# rich.print
# rich.print