# FuseMedML - Hello World
[![Github repo](https://img.shields.io/static/v1?label=GitHub&message=FuseMedML&color=brightgreen)](https://github.com/IBM/fuse-med-ml)

[![PyPI version](https://badge.fury.io/py/fuse-med-ml.svg)](https://badge.fury.io/py/fuse-med-ml)

[![Slack channel](https://img.shields.io/badge/support-slack-slack.svg?logo=slack)](https://join.slack.com/t/fusemedml/shared_invite/zt-xr1jaj29-h7IMsSc0Lq4qpVNxW97Phw)

[![Open Source](https://badges.frapsoft.com/os/v1/open-source.svg)](https://github.com/IBM/fuse-med-ml)


**Welcome to Fuse's 'hello world' hands-on notebook!**

FuseMedML is an open-source python-based framework designed to enhance collaboration and accelerate discoveries in Fused Medical data through advanced Machine Learning technologies. 

Initial version is PyTorch-based and focuses on deep learning on medical imaging.

By the end of the session we hope you'll be familiar with basic Fuse's workflow and acknowledge it's potential.

ENJOY

------------
## **Installation Details - Google Colab**


#### **Install FuseMedML**

In [6]:
# !git clone https://github.com/IBM/fuse-med-ml.git
# %cd fuse-med-ml
# !pip install -e .


## **Setup environment**

##### **Imports**

In [7]:
import os
from typing import OrderedDict

import torch.nn.functional as F
import torch.optim as optim
import torchvision
from torch.utils.data.dataloader import DataLoader
from torchvision import transforms

from fuse.eval.evaluator import EvaluatorDefault
from fuse.data.dataset.dataset_wrapper import FuseDatasetWrapper
from fuse.data.sampler.sampler_balanced_batch import FuseSamplerBalancedBatch
from fuse.losses.loss_default import FuseLossDefault
from fuse.managers.callbacks.callback_tensorboard import FuseTensorboardCallback
from fuse.managers.manager_default import FuseManagerDefault
from fuse.eval.metrics.classification.metrics_classification_common import MetricAccuracy, MetricAUCROC, MetricROCCurve
from fuse.eval.metrics.classification.metrics_thresholding_common import MetricApplyThresholds
from fuse.models.model_wrapper import FuseModelWrapper
from fuse_examples.classification.mnist.lenet import LeNet
from fuse_examples.tutorials.hello_world.hello_world_utils import perform_softmax

##### **Output paths**
The user is able to customize the output directories.

In [8]:
ROOT = 'examples' # TODO: fill path here
PATHS = {'model_dir': os.path.join(ROOT, 'mnist/model_dir'),
         'force_reset_model_dir': True,  # If True will reset model dir automatically - otherwise will prompt 'are you sure' message.
         'cache_dir': os.path.join(ROOT, 'mnist/cache_dir'),
         'inference_dir': os.path.join(ROOT, 'mnist/infer_dir'),
         'eval_dir': os.path.join(ROOT, 'mnist/eval_dir')}

paths = PATHS

##### **Training Parameters**
* Model - which model we are using.
* Data - define parameters for the data preproccesing.
* Manager - define parameters for the training session.

In [9]:
TRAIN_COMMON_PARAMS = {}

### Model ###
TRAIN_COMMON_PARAMS['model'] = 'lenet'

### Data ###
TRAIN_COMMON_PARAMS['data.batch_size'] = 100
TRAIN_COMMON_PARAMS['data.train_num_workers'] = 8
TRAIN_COMMON_PARAMS['data.validation_num_workers'] = 8

### Manager ###
TRAIN_COMMON_PARAMS['manager.train_params'] = {
    'device': 'cuda', 
    'num_epochs': 5,
    'virtual_batch_size': 1,  # number of batches in one virtual batch
    'start_saving_epochs': 10,  # first epoch to start saving checkpoints from
    'gap_between_saving_epochs': 5,  # number of epochs between saved checkpoint
}
TRAIN_COMMON_PARAMS['manager.best_epoch_source'] = {
    'source': 'metrics.accuracy',  # can be any key from 'epoch_results'
    'optimization': 'max',  # can be either min/max
    'on_equal_values': 'better',
    # can be either better/worse - whether to consider best epoch when values are equal
}
TRAIN_COMMON_PARAMS['manager.learning_rate'] = 1e-4
TRAIN_COMMON_PARAMS['manager.weight_decay'] = 0.001
TRAIN_COMMON_PARAMS['manager.resume_checkpoint_filename'] = None  # if not None, will try to load the checkpoint

TRAIN_COMMON_PARAMS['manager.train_params']['device'] = 'cpu'

train_params = TRAIN_COMMON_PARAMS

## **Training the model**

##### **Data**
Downloading the MNIST dataset and building dataloaders (torch.utils.data.DataLoader) for both train and validation using Fuse components:
1. Wrapper - **FuseDatasetWrapper**:

    Wraps PyTorch dataset such that each sample is being converted to dictionary according to the provided mapping.
2. Sampler - **FuseSamplerBalancedBatch**:

    Implementing 'torch.utils.data.sampler'.
    
    The sampler retrieves list of samples to use for each batch.


Other Fuse components for preprocessing data are:
* FuseDataSourceBase
* FuseProcessorBase
* FuseDatasetBase
* FuseAugmentorBase
* FuseVisualizerBase

    

In [10]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

# Create dataset
torch_train_dataset = torchvision.datasets.MNIST(paths['cache_dir'], download=True, train=True, transform=transform)

# wrapping torch dataset
train_dataset = FuseDatasetWrapper(name='train', dataset=torch_train_dataset, mapping=('image', 'label'))
train_dataset.create()

sampler = FuseSamplerBalancedBatch(dataset=train_dataset,
                                balanced_class_name='data.label',
                                num_balanced_classes=10,
                                batch_size=train_params['data.batch_size'],
                                balanced_class_weights=None)

# Create dataloader
train_dataloader = DataLoader(dataset=train_dataset, batch_sampler=sampler, num_workers=train_params['data.train_num_workers'])

## Validation data
# Create dataset
torch_validation_dataset = torchvision.datasets.MNIST(paths['cache_dir'], download=True, train=False, transform=transform)
# wrapping torch dataset
validation_dataset = FuseDatasetWrapper(name='validation', dataset=torch_validation_dataset, mapping=('image', 'label'))
validation_dataset.create()

# dataloader
validation_dataloader = DataLoader(dataset=validation_dataset, batch_size=train_params['data.batch_size'],
                                num_workers=train_params['data.validation_num_workers'])

##### **Model**
Building the LeNet model using PyTorch's API and then wrapping it. 

The model outputs will be aggregated in batch_dict['model.*'].


Another option to implement a model is to use Fuse components such as:
* FuseModelDefault
* FuseBackbone
* FuseHead
* FuseModelEnsemble
* FuseModelMultistream

In [11]:
torch_model = LeNet()

model = FuseModelWrapper(model=torch_model,
                        model_inputs=['data.image'],
                        post_forward_processing_function=perform_softmax,
                        model_outputs=['logits.classification', 'output.classification']
                        )

##### **Loss function**
Dictionary of loss elements. each element is a sub-class of FuseLossBase.

The total loss will be the weighted sum of all the elements.

#TODO: Elaborate what the it does: extract the label ....

In [12]:
losses = {
    'cls_loss': FuseLossDefault(pred_name='model.logits.classification', target_name='data.label', callable=F.cross_entropy, weight=1.0),
}

##### **Metrics**
Dictionary of metric elements. Each element is a sub-class of FuseMetricBase.

The metrics will be calculated per epoch for both the validation and train.

The 'best_epoch_source', used to save the best model could be based on one of these metrics.

In [13]:
metrics = OrderedDict([
    ('operation_point', MetricApplyThresholds(pred='model.output.classification')), # will apply argmax
    ('accuracy', MetricAccuracy(pred='results:metrics.operation_point.cls_pred', target='data.label'))
])

##### **Callbacks**
Callbacks are sub-classes of FuseCallbackBase.

A callback is an object that can preform actions at various stages of training.

In each stage it allows to manipulate either the data, batch_dict or epoch_results.


In [14]:
callbacks = [
    FuseTensorboardCallback(model_dir=paths['model_dir']),  # save statistics for tensorboard
]

##### **Train**
Building Fuse's manager and supplying it PyTorch's optimizer and scheduler.

The manager is the main API while using Fuse - it resposible for the Train and Infer functionallity.

Possible workflows are listed in the FuseMangerDefault's documentation.

Note that the manger is using the training paremeter that we've set above.

In [15]:
# Create optimizer
optimizer = optim.Adam(model.parameters(), lr=train_params['manager.learning_rate'], weight_decay=train_params['manager.weight_decay'])

# create scheduler
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer)

# train from scratch
manager = FuseManagerDefault(output_model_dir=paths['model_dir'], force_reset=paths['force_reset_model_dir'])

# Providing the objects required for the training process.
manager.set_objects(net=model,
                    optimizer=optimizer,
                    losses=losses,
                    metrics=metrics,
                    best_epoch_source=train_params['manager.best_epoch_source'],
                    lr_scheduler=scheduler,
                    callbacks=callbacks,
                    train_params=train_params['manager.train_params'])

# Start training
manager.train(train_dataloader=train_dataloader, validation_dataloader=validation_dataloader)

100%|██████████| 100/100 [00:01<00:00, 58.15it/s]
100%|██████████| 675/675 [00:25<00:00, 26.61it/s]
100%|██████████| 100/100 [00:01<00:00, 66.50it/s]
100%|██████████| 675/675 [00:28<00:00, 23.55it/s]
100%|██████████| 100/100 [00:01<00:00, 71.73it/s]
100%|██████████| 675/675 [00:26<00:00, 25.41it/s]
100%|██████████| 100/100 [00:01<00:00, 56.66it/s]
100%|██████████| 675/675 [00:30<00:00, 22.02it/s]
100%|██████████| 100/100 [00:01<00:00, 73.79it/s]


## **Infer**

##### **Define Infer Common Params**


In [16]:
INFER_COMMON_PARAMS = {}
INFER_COMMON_PARAMS['infer_filename'] = 'validation_set_infer.gz'
INFER_COMMON_PARAMS['checkpoint'] = 'best' 

infer_common_params = INFER_COMMON_PARAMS

##### **Infer**

In [17]:
validation_dataloader = DataLoader(dataset=validation_dataset, collate_fn=validation_dataset.collate_fn, batch_size=2, num_workers=2)

## Manager for inference
manager = FuseManagerDefault()
output_columns = ['model.output.classification', 'data.label']
manager.infer(data_loader=validation_dataloader,
                input_model_dir=paths['model_dir'],
                checkpoint=infer_common_params['checkpoint'],
                output_columns=output_columns,
                output_file_name=os.path.join(paths["inference_dir"], infer_common_params["infer_filename"]))

100%|██████████| 5000/5000 [00:18<00:00, 267.43it/s]


Unnamed: 0,descriptor,id,model.output.classification,data.label
0,"(validation, 0)","(validation, 0)","[1.0180833e-07, 6.6291017e-07, 6.315717e-06, 1...",7
1,"(validation, 1)","(validation, 1)","[9.043416e-06, 0.00047580755, 0.9995017, 6.418...",2
2,"(validation, 2)","(validation, 2)","[5.333834e-06, 0.999655, 3.9554485e-05, 1.2428...",1
3,"(validation, 3)","(validation, 3)","[0.999801, 4.07752e-06, 1.1369637e-05, 4.63483...",0
4,"(validation, 4)","(validation, 4)","[5.7314537e-07, 8.599522e-06, 1.8101055e-07, 6...",4
...,...,...,...,...
9995,"(validation, 9995)","(validation, 9995)","[1.3568248e-07, 2.0020354e-05, 0.9999684, 7.08...",2
9996,"(validation, 9996)","(validation, 9996)","[8.5210326e-07, 7.1025497e-06, 3.578722e-05, 0...",3
9997,"(validation, 9997)","(validation, 9997)","[3.583553e-09, 2.403053e-06, 1.9147938e-08, 2....",4
9998,"(validation, 9998)","(validation, 9998)","[2.1659675e-06, 2.8251183e-08, 1.9621085e-08, ...",5


## **Evaluation**

##### **Define EVAL Common Params**


In [18]:
EVAL_COMMON_PARAMS = {}
EVAL_COMMON_PARAMS['infer_filename'] = INFER_COMMON_PARAMS['infer_filename']
eval_common_params = EVAL_COMMON_PARAMS

##### **Calculate metrics**

In [19]:
class_names = [str(i) for i in range(10)]

# metrics
metrics = OrderedDict([
    ('operation_point', MetricApplyThresholds(pred='model.output.classification')), # will apply argmax
    ('accuracy', MetricAccuracy(pred='results:metrics.operation_point.cls_pred', target='data.label')),
    ('roc', MetricROCCurve(pred='model.output.classification', target='data.label', class_names=class_names, output_filename=os.path.join(paths['inference_dir'], 'roc_curve.png'))),
    ('auc', MetricAUCROC(pred='model.output.classification', target='data.label', class_names=class_names)),
])

##### **Evaluate**

In [20]:
# create evaluator
evaluator = EvaluatorDefault()

# run
results = evaluator.eval(ids=None,
                    data=os.path.join(paths["inference_dir"], eval_common_params["infer_filename"]),
                    metrics=metrics,
                    output_dir=paths['eval_dir'])

Results:

Metric operation_point:
------------------------------------------------
cls_pred:
<fuse.eval.metrics.utils.PerSampleData object at 0x7fd163f9e8d0>

Metric accuracy:
------------------------------------------------
0.9855

Metric roc:
------------------------------------------------
0.fpr:
[0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00
 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00
 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00
 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00
 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00
 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00
 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00
 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00
 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00
 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00
 0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00
 0.00000000e+00 0.00000000e+