# FuseMedML - Hello World
[![Github repo](https://img.shields.io/static/v1?label=GitHub&message=FuseMedML&color=brightgreen)](https://github.com/IBM/fuse-med-ml)

[![PyPI version](https://badge.fury.io/py/fuse-med-ml.svg)](https://badge.fury.io/py/fuse-med-ml)

[![Slack channel](https://img.shields.io/badge/support-slack-slack.svg?logo=slack)](https://join.slack.com/t/fusemedml/shared_invite/zt-xr1jaj29-h7IMsSc0Lq4qpVNxW97Phw)

[![Open Source](https://badges.frapsoft.com/os/v1/open-source.svg)](https://github.com/IBM/fuse-med-ml)


**Welcome to Fuse's 'hello world' hands-on notebook!**

In this notebook we'll examine a FuseMedML's basic use case: MNIST multiclass classification - incluing training, inference and evaluation.

FuseMedML is an open-source python-based framework designed to enhance collaboration and accelerate discoveries in Fused Medical data through advanced Machine Learning technologies. 

Initial version is PyTorch-based and focuses on deep learning on medical imaging.

By the end of the session we hope you'll be familiar with basic Fuse's workflow and acknowledge it's potential.

Open and run this notebook in [Google Colab](https://colab.research.google.com/github/IBM/fuse-med-ml/blob/master/fuse_examples/tutorials/hello_world/hello_world.ipynb)

ENJOY

------------
## **Installation Details - Google Colab**


#### **Install FuseMedML**
\- If fuse-med-ml package is already cloned and installed this should be skipped.

In [None]:
!git clone https://github.com/IBM/fuse-med-ml.git
%cd fuse-med-ml
!pip install -e .


## **Setup environment**

##### **Imports**

In [1]:
import os
from typing import OrderedDict

import torch.nn.functional as F
import torch.optim as optim
import torchvision
from torch.utils.data.dataloader import DataLoader
from torchvision import transforms

from fuse.eval.evaluator import EvaluatorDefault
from fuse.data.dataset.dataset_wrapper import FuseDatasetWrapper
from fuse.data.sampler.sampler_balanced_batch import FuseSamplerBalancedBatch
from fuse.losses.loss_default import FuseLossDefault
from fuse.managers.callbacks.callback_tensorboard import FuseTensorboardCallback
from fuse.managers.manager_default import FuseManagerDefault
from fuse.eval.metrics.classification.metrics_classification_common import MetricAccuracy, MetricAUCROC, MetricROCCurve
from fuse.eval.metrics.classification.metrics_thresholding_common import MetricApplyThresholds
from fuse.models.model_wrapper import FuseModelWrapper
from fuse_examples.tutorials.hello_world.hello_world_utils import LeNet, perform_softmax

  from .autonotebook import tqdm as notebook_tqdm


##### **Output paths**
The user is able to customize the output directories.

In [2]:
ROOT = 'examples' # TODO: fill path here
PATHS = {'model_dir': os.path.join(ROOT, 'mnist/model_dir'),
         'force_reset_model_dir': True,  # If True will reset model dir automatically - otherwise will prompt 'are you sure' message.
         'cache_dir': os.path.join(ROOT, 'mnist/cache_dir'),
         'inference_dir': os.path.join(ROOT, 'mnist/infer_dir'),
         'eval_dir': os.path.join(ROOT, 'mnist/eval_dir')}

paths = PATHS

##### **Training Parameters**
* Model - which model we are using.
* Data - define parameters for the data preproccesing.
* Manager - define parameters for the training session.

In [3]:
TRAIN_COMMON_PARAMS = {}

### Model ###
TRAIN_COMMON_PARAMS['model'] = 'lenet'

### Data ###
TRAIN_COMMON_PARAMS['data.batch_size'] = 100
TRAIN_COMMON_PARAMS['data.train_num_workers'] = 8
TRAIN_COMMON_PARAMS['data.validation_num_workers'] = 8

### Manager ###
TRAIN_COMMON_PARAMS['manager.train_params'] = {
    'device': 'cuda', 
    'num_epochs': 5,
    'virtual_batch_size': 1,  # number of batches in one virtual batch
    'start_saving_epochs': 10,  # first epoch to start saving checkpoints from
    'gap_between_saving_epochs': 5,  # number of epochs between saved checkpoint
}
TRAIN_COMMON_PARAMS['manager.best_epoch_source'] = {
    'source': 'metrics.accuracy',  # can be any key from 'epoch_results'
    'optimization': 'max',  # can be either min/max
    'on_equal_values': 'better',
    # can be either better/worse - whether to consider best epoch when values are equal
}
TRAIN_COMMON_PARAMS['manager.learning_rate'] = 1e-4
TRAIN_COMMON_PARAMS['manager.weight_decay'] = 0.001
TRAIN_COMMON_PARAMS['manager.resume_checkpoint_filename'] = None  # if not None, will try to load the checkpoint

TRAIN_COMMON_PARAMS['manager.train_params']['device'] = 'cpu'

train_params = TRAIN_COMMON_PARAMS

## **Training the model**

##### **Data**
Downloading the MNIST dataset and building dataloaders (torch.utils.data.DataLoader) for both train and validation using Fuse components:
1. Wrapper - **FuseDatasetWrapper**:

    Wraps PyTorch dataset such that each sample is being converted to dictionary according to the provided mapping.
2. Sampler - **FuseSamplerBalancedBatch**:

    Implementing 'torch.utils.data.sampler'.
    
    The sampler creates a balanced batch comprised of an equal number of samples per label.

In [4]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

# Create dataset
torch_train_dataset = torchvision.datasets.MNIST(paths['cache_dir'], download=True, train=True, transform=transform)

# wrapping torch dataset
train_dataset = FuseDatasetWrapper(name='train', dataset=torch_train_dataset, mapping=('image', 'label'))
train_dataset.create()

sampler = FuseSamplerBalancedBatch(dataset=train_dataset,
                                balanced_class_name='data.label',
                                num_balanced_classes=10,
                                batch_size=train_params['data.batch_size'],
                                balanced_class_weights=None)

# Create dataloader
train_dataloader = DataLoader(dataset=train_dataset, batch_sampler=sampler, num_workers=train_params['data.train_num_workers'])

## Validation data
# Create dataset
torch_validation_dataset = torchvision.datasets.MNIST(paths['cache_dir'], download=True, train=False, transform=transform)
# wrapping torch dataset
validation_dataset = FuseDatasetWrapper(name='validation', dataset=torch_validation_dataset, mapping=('image', 'label'))
validation_dataset.create()

# dataloader
validation_dataloader = DataLoader(dataset=validation_dataset, batch_size=train_params['data.batch_size'],
                                num_workers=train_params['data.validation_num_workers'])

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to examples/mnist/cache_dir/MNIST/raw/train-images-idx3-ubyte.gz


9913344it [00:00, 11347704.62it/s]                             


Extracting examples/mnist/cache_dir/MNIST/raw/train-images-idx3-ubyte.gz to examples/mnist/cache_dir/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to examples/mnist/cache_dir/MNIST/raw/train-labels-idx1-ubyte.gz


29696it [00:00, 28939138.38it/s]         


Extracting examples/mnist/cache_dir/MNIST/raw/train-labels-idx1-ubyte.gz to examples/mnist/cache_dir/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to examples/mnist/cache_dir/MNIST/raw/t10k-images-idx3-ubyte.gz


1649664it [00:00, 5295144.90it/s]                             


Extracting examples/mnist/cache_dir/MNIST/raw/t10k-images-idx3-ubyte.gz to examples/mnist/cache_dir/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to examples/mnist/cache_dir/MNIST/raw/t10k-labels-idx1-ubyte.gz


5120it [00:00, 7550927.03it/s]          

Extracting examples/mnist/cache_dir/MNIST/raw/t10k-labels-idx1-ubyte.gz to examples/mnist/cache_dir/MNIST/raw






##### **Model**
Building the LeNet model using PyTorch's API and then wrapping it. 

The model outputs will be aggregated in batch_dict['model.*'].

In [5]:
torch_model = LeNet()

model = FuseModelWrapper(model=torch_model,
                        model_inputs=['data.image'],
                        post_forward_processing_function=perform_softmax,
                        model_outputs=['logits.classification', 'output.classification']
                        )

##### **Loss function**
Dictionary of loss elements. each element is a sub-class of FuseLossBase.

The total loss will be the weighted sum of all the elements.

Fuse's loss API extracts the predications and the labels from the dictonary hierarchy and then applies the callable loss function while considering the weight.

In [6]:
losses = {
    'cls_loss': FuseLossDefault(pred_name='model.logits.classification', target_name='data.label', callable=F.cross_entropy, weight=1.0),
}

##### **Metrics**
Dictionary of metric elements. Each element is a sub-class of FuseMetricBase.

The metrics will be calculated per epoch for both the validation and train.

The 'best_epoch_source', used to save the best model could be based on one of these metrics.

In [7]:
metrics = OrderedDict([
    ('operation_point', MetricApplyThresholds(pred='model.output.classification')), # will apply argmax
    ('accuracy', MetricAccuracy(pred='results:metrics.operation_point.cls_pred', target='data.label'))
])

##### **Callbacks**
Callbacks are sub-classes of FuseCallbackBase.

A callback is an object that can preform actions at various stages of training.

In each stage it allows to manipulate either the data, batch_dict or epoch_results.


In [8]:
callbacks = [
    FuseTensorboardCallback(model_dir=paths['model_dir']),  # save statistics for tensorboard
]

##### **Train**
Building Fuse's manager and supplying it PyTorch's optimizer and scheduler.

Possible workflows are listed in the FuseMangerDefault's documentation.

Note that the manger is using the training paremeter that we've set above.

In [9]:
# Create optimizer
optimizer = optim.Adam(model.parameters(), lr=train_params['manager.learning_rate'], weight_decay=train_params['manager.weight_decay'])

# create scheduler
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer)

# train from scratch
manager = FuseManagerDefault(output_model_dir=paths['model_dir'], force_reset=paths['force_reset_model_dir'])

# Providing the objects required for the training process.
manager.set_objects(net=model,
                    optimizer=optimizer,
                    losses=losses,
                    metrics=metrics,
                    best_epoch_source=train_params['manager.best_epoch_source'],
                    lr_scheduler=scheduler,
                    callbacks=callbacks,
                    train_params=train_params['manager.train_params'])

# Start training
manager.train(train_dataloader=train_dataloader, validation_dataloader=validation_dataloader)

100%|██████████| 100/100 [00:02<00:00, 42.66it/s]
100%|██████████| 675/675 [00:36<00:00, 18.43it/s]
100%|██████████| 100/100 [00:01<00:00, 66.65it/s]
100%|██████████| 675/675 [00:34<00:00, 19.74it/s]
100%|██████████| 100/100 [00:01<00:00, 61.11it/s]
100%|██████████| 675/675 [00:34<00:00, 19.69it/s]
100%|██████████| 100/100 [00:01<00:00, 50.91it/s]
100%|██████████| 675/675 [00:34<00:00, 19.42it/s]
100%|██████████| 100/100 [00:01<00:00, 58.35it/s]


## **Infer**

##### **Define Infer Common Params**


In [10]:
INFER_COMMON_PARAMS = {}
INFER_COMMON_PARAMS['infer_filename'] = 'validation_set_infer.gz'
INFER_COMMON_PARAMS['checkpoint'] = 'best' 

infer_common_params = INFER_COMMON_PARAMS

##### **Infer**

In [11]:
validation_dataloader = DataLoader(dataset=validation_dataset, collate_fn=validation_dataset.collate_fn, batch_size=2, num_workers=2)

## Manager for inference
manager = FuseManagerDefault()
output_columns = ['model.output.classification', 'data.label']
manager.infer(data_loader=validation_dataloader,
                input_model_dir=paths['model_dir'],
                checkpoint=infer_common_params['checkpoint'],
                output_columns=output_columns,
                output_file_name=os.path.join(paths["inference_dir"], infer_common_params["infer_filename"]))

100%|██████████| 5000/5000 [00:58<00:00, 85.79it/s] 


Unnamed: 0,descriptor,id,model.output.classification,data.label
0,"(validation, 0)","(validation, 0)","[1.0348453e-07, 1.7178486e-07, 5.3252716e-06, ...",7
1,"(validation, 1)","(validation, 1)","[4.7777103e-05, 4.2194282e-05, 0.9998847, 1.20...",2
2,"(validation, 2)","(validation, 2)","[3.941711e-06, 0.99916375, 0.00014531022, 3.80...",1
3,"(validation, 3)","(validation, 3)","[0.9996038, 6.0095076e-06, 8.758462e-06, 7.953...",0
4,"(validation, 4)","(validation, 4)","[2.5701502e-06, 3.2625428e-06, 3.6159786e-06, ...",4
...,...,...,...,...
9995,"(validation, 9995)","(validation, 9995)","[6.151244e-08, 4.437438e-06, 0.9999622, 1.3406...",2
9996,"(validation, 9996)","(validation, 9996)","[6.680257e-07, 4.749617e-06, 1.2762486e-05, 0....",3
9997,"(validation, 9997)","(validation, 9997)","[2.6505043e-08, 9.910561e-07, 2.8513983e-07, 1...",4
9998,"(validation, 9998)","(validation, 9998)","[2.3059108e-07, 9.940045e-07, 2.3683555e-07, 4...",5


## **Evaluation**
Using the Evaluator from the evaluation package of FuseMedML (fuse.eval) which is a standalone library for evaluating ML models that not necessarily trained with FuseMedML.

More details and examples for the evaluation package can be found [here](https://github.com/IBM/fuse-med-ml/blob/master/fuse/eval/README.md).


##### **Define EVAL Common Params**


In [12]:
EVAL_COMMON_PARAMS = {}
EVAL_COMMON_PARAMS['infer_filename'] = INFER_COMMON_PARAMS['infer_filename']
eval_common_params = EVAL_COMMON_PARAMS

##### **Calculate metrics**

In [13]:
class_names = [str(i) for i in range(10)]

# metrics
metrics = OrderedDict([
    ('operation_point', MetricApplyThresholds(pred='model.output.classification')), # will apply argmax
    ('accuracy', MetricAccuracy(pred='results:metrics.operation_point.cls_pred', target='data.label')),
    ('roc', MetricROCCurve(pred='model.output.classification', target='data.label', class_names=class_names, output_filename=os.path.join(paths['inference_dir'], 'roc_curve.png'))),
    ('auc', MetricAUCROC(pred='model.output.classification', target='data.label', class_names=class_names)),
])

##### **Evaluate**

In [14]:
# create evaluator
evaluator = EvaluatorDefault()

# run
results = evaluator.eval(ids=None,
                data=os.path.join(paths["inference_dir"], eval_common_params["infer_filename"]),
                metrics=metrics,
                output_dir=paths['eval_dir'])

Results:

Metric operation_point:
------------------------------------------------
cls_pred:
<fuse.eval.metrics.utils.PerSampleData object at 0x7f1d269d3e10>

Metric accuracy:
------------------------------------------------
0.9865

Metric roc:
------------------------------------------------
0.fpr:
[0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00
 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00
 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00
 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00
 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00
 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00
 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00
 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00
 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00
 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00
 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00
 0.00000000e+00 0.00000000e+