# FuseMedML - Hello world
Welcome!\
In this tutorial we'll cover the basics in our FuseMedML open soruce library through an hands-on notebook.

Goals:
* 


------------------
## FuseMedML
[![Github repo](https://img.shields.io/static/v1?label=GitHub&message=FuseMedML&color=brightgreen)](https://github.com/IBM/fuse-med-ml)

[![PyPI version](https://badge.fury.io/py/fuse-med-ml.svg)](https://badge.fury.io/py/fuse-med-ml)

[![Slack channel](https://img.shields.io/badge/support-slack-slack.svg?logo=slack)](https://join.slack.com/t/fusemedml/shared_invite/zt-xr1jaj29-h7IMsSc0Lq4qpVNxW97Phw)

[![Open Source](https://badges.frapsoft.com/os/v1/open-source.svg)](https://github.com/IBM/fuse-med-ml)


FuseMedML is an open-source python-based framework designed to enhance collaboration and accelerate discoveries in Fused Medical data through advanced Machine Learning technologies. 

Initial version is PyTorch-based and focuses on deep learning on medical imaging.


## **FuseMedML Key Concepts in a Nutshell**
### Share and Reuse

A common generic implementation, you can reuse, is provided for most components in the pipeline. 

The naming convention for the common implementation is `Fuse***Default` 

FuseMedML comes with a large collection of components that grow with each new project. Some of them are entirely generic and the others are domain specific.


Don't forget to **contribute** back and **share** them. 

### Decoupling
The decoupling is achieved by the fact that, in most cases, the objects do not interact directly. Instead, the information and data are routed between components using *namespaces* (examples below). 

Meaning, each object extracts its input from and saves its output into a dictionary named `batch_dict`. 

`batch_dict` aggregates the outputs of all the objects through a single batch. 

<br />

**Example of the decoupling approach:**
```python
FuseMetricAUC(pred_name='model.output.classification', target_name='data.gt.classification')  
```

`FuseMetricAUC` will read the required tensors to compute AUC from `batch_dict`. The relevant dictionary keys are `pred_name` and `target_name`. 

This approach allows writing a generic metric which is completely independent of the model and data extractor. 

In addition, it allows to easily re-use this object in a plug & play manner without adding extra code. 

Such an approach also allows us to use it several times in case we have multiple heads/tasks.

<br />


When a batch is completed, only the required key-value pairs from `batch_dict`, such as the loss values, will be collected in another dictionary named `epoch_results`. 

Both `batch_dict` and `epoch_results` are nested dictionaries. To easily access the data stored in those dictionaries, use `FuseUtilsHierarchicalDict`:

```python
FuseUtilsHierarchicalDict.get(batch_dict, ‘model.output.classification’)
``` 

will return `batch_dict[‘model’][‘output’][‘classification’]`

### Manager API
The manager is the main API while using Fuse - it resposible for the Train and Infer functionallity.

Possible workflows are listed in the FuseMangerDefault's documentation. Here are two examples:

##### **Train: Init -> set objects -> train**
```python
manager = FuseManagerDefault(output_model_dir=paths['model_dir'], force_reset=paths['force_reset_model_dir'])

manager.set_objects(net=model,
                    optimizer=optimizer,
                    losses=losses,
                    metrics=metrics,
                    best_epoch_source=train_common_params['manager.best_epoch_source'],
                    lr_scheduler=scheduler,
                    callbacks=callbacks,
                    train_params=train_common_params['manager.train_params'],
                    output_model_dir=paths['model_dir'])

manager.train(train_dataloader=train_dataloader,
                validation_dataloader=validation_dataloader)
```

##### **Infer: Init -> infer**
```python
manager = FuseManagerDefault()

output_columns = ['model.output.classification', 'data.label']

manager.infer(data_loader=validation_dataloader,
                input_model_dir=paths['model_dir'],
                checkpoint=infer_common_params['checkpoint'],
                output_columns=output_columns,
                output_file_name=os.path.join(paths["inference_dir"], infer_common_params["infer_filename"]))
```



### Use PyTorch directly and alternative frameworks

FuseMedML uses and extends PyTorch only when required by the user. 
You can mix FuseMedML with PyTorch code, components from alternative frameworks and other popular GitHub projects. 



------------
## **Installation Details - Google Colab**


#### **Install FuseMedML**

In [None]:
# !git clone https://github.com/IBM/fuse-med-ml.git
# %cd fuse-med-ml
# !pip install -e .


## **Setup environment**

##### **Imports**

In [None]:
import os
from typing import OrderedDict

import torch
import torch.nn.functional as F
import torch.optim as optim
import torchvision
from torch.utils.data.dataloader import DataLoader
from torchvision import transforms

from fuse.eval.evaluator import EvaluatorDefault
from fuse.data.dataset.dataset_wrapper import FuseDatasetWrapper
from fuse.data.sampler.sampler_balanced_batch import FuseSamplerBalancedBatch
from fuse.losses.loss_default import FuseLossDefault
from fuse.managers.callbacks.callback_tensorboard import FuseTensorboardCallback
from fuse.managers.manager_default import FuseManagerDefault
from fuse.eval.metrics.classification.metrics_classification_common import MetricAccuracy, MetricAUCROC, MetricROCCurve
from fuse.eval.metrics.classification.metrics_thresholding_common import MetricApplyThresholds
from fuse.models.model_wrapper import FuseModelWrapper
from fuse_examples.classification.mnist.lenet import LeNet
from fuse_examples.tutorials.hello_world.hello_world_utils import perform_softmax

##### **Output paths**
The user is able to customize the output directories.

In [None]:
ROOT = 'examples' # TODO: fill path here
PATHS = {'model_dir': os.path.join(ROOT, 'mnist/model_dir'),
         'force_reset_model_dir': True,  # If True will reset model dir automatically - otherwise will prompt 'are you sure' message.
         'cache_dir': os.path.join(ROOT, 'mnist/cache_dir'),
         'inference_dir': os.path.join(ROOT, 'mnist/infer_dir'),
         'eval_dir': os.path.join(ROOT, 'mnist/eval_dir')}

paths = PATHS

##### **Training Parameters**
* Model - which model we are using.
* Data - define parameters for the data preproccesing.
* Manager - define parameters for the training session.

In [None]:
TRAIN_COMMON_PARAMS = {}

### Model ###
TRAIN_COMMON_PARAMS['model'] = 'lenet'

### Data ###
TRAIN_COMMON_PARAMS['data.batch_size'] = 100
TRAIN_COMMON_PARAMS['data.train_num_workers'] = 8
TRAIN_COMMON_PARAMS['data.validation_num_workers'] = 8

### Manager ###
TRAIN_COMMON_PARAMS['manager.train_params'] = {
    'device': 'cuda', 
    'num_epochs': 5,
    'virtual_batch_size': 1,  # number of batches in one virtual batch
    'start_saving_epochs': 10,  # first epoch to start saving checkpoints from
    'gap_between_saving_epochs': 5,  # number of epochs between saved checkpoint
}
TRAIN_COMMON_PARAMS['manager.best_epoch_source'] = {
    'source': 'metrics.accuracy',  # can be any key from 'epoch_results'
    'optimization': 'max',  # can be either min/max
    'on_equal_values': 'better',
    # can be either better/worse - whether to consider best epoch when values are equal
}
TRAIN_COMMON_PARAMS['manager.learning_rate'] = 1e-4
TRAIN_COMMON_PARAMS['manager.weight_decay'] = 0.001
TRAIN_COMMON_PARAMS['manager.resume_checkpoint_filename'] = None  # if not None, will try to load the checkpoint

TRAIN_COMMON_PARAMS['manager.train_params']['device'] = 'cpu'

train_params = TRAIN_COMMON_PARAMS

## **Training the model**

##### **Data**
Downloading the MNIST dataset and building dataloaders (torch.utils.data.DataLoader) for both train and validation using Fuse components:
1. Wrapper - **FuseDatasetWrapper**:

    Wraps PyTorch dataset such that each sample is being converted to dictionary according to the provided mapping.
2. Sampler - **FuseSamplerBalancedBatch**:

    Implementing 'torch.utils.data.sampler'.
    
    The sampler retrieves list of samples to use for each batch.


Other Fuse components for preprocessing data are:
* FuseDataSourceBase
* FuseProcessorBase
* FuseDatasetBase
* FuseAugmentorBase
* FuseVisualizerBase

    

In [None]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

# Create dataset
torch_train_dataset = torchvision.datasets.MNIST(paths['cache_dir'], download=True, train=True, transform=transform)

# wrapping torch dataset
train_dataset = FuseDatasetWrapper(name='train', dataset=torch_train_dataset, mapping=('image', 'label'))
train_dataset.create()

sampler = FuseSamplerBalancedBatch(dataset=train_dataset,
                                balanced_class_name='data.label',
                                num_balanced_classes=10,
                                batch_size=train_params['data.batch_size'],
                                balanced_class_weights=None)

# Create dataloader
train_dataloader = DataLoader(dataset=train_dataset, batch_sampler=sampler, num_workers=train_params['data.train_num_workers'])

## Validation data
# Create dataset
torch_validation_dataset = torchvision.datasets.MNIST(paths['cache_dir'], download=True, train=False, transform=transform)
# wrapping torch dataset
validation_dataset = FuseDatasetWrapper(name='validation', dataset=torch_validation_dataset, mapping=('image', 'label'))
validation_dataset.create()

# dataloader
validation_dataloader = DataLoader(dataset=validation_dataset, batch_size=train_params['data.batch_size'],
                                num_workers=train_params['data.validation_num_workers'])

##### **Model**
Building the LeNet model using PyTorch's API and then wrapping it. 

The model outputs will be aggregated in batch_dict['model.*'].


Another option to implement a model is to use Fuse components such as:
* FuseModelDefault
* FuseBackbone
* FuseHead
* FuseModelEnsemble
* FuseModelMultistream

In [None]:
torch_model = LeNet()

model = FuseModelWrapper(model=torch_model,
                        model_inputs=['data.image'],
                        post_forward_processing_function=perform_softmax,
                        model_outputs=['logits.classification', 'output.classification']
                        )

##### **Loss function**
Dictionary of loss elements. each element is a sub-class of FuseLossBase.

The total loss will be the weighted sum of all the elements.

#TODO: Elaborate what the it does: extract the label ....

In [None]:
losses = {
    'cls_loss': FuseLossDefault(pred_name='model.logits.classification', target_name='data.label', callable=F.cross_entropy, weight=1.0),
}

##### **Metrics**
Dictionary of metric elements. Each element is a sub-class of FuseMetricBase.

The metrics will be calculated per epoch for both the validation and train.

The 'best_epoch_source', used to save the best model could be based on one of these metrics.

In [None]:
metrics = OrderedDict([
    ('operation_point', MetricApplyThresholds(pred='model.output.classification')), # will apply argmax
    ('accuracy', MetricAccuracy(pred='results:metrics.operation_point.cls_pred', target='data.label'))
])

##### **Callbacks**
Callbacks are sub-classes of FuseCallbackBase.

A callback is an object that can preform actions at various stages of training.

In each stage it allows to manipulate either the data, batch_dict or epoch_results.


In [None]:
callbacks = [
    FuseTensorboardCallback(model_dir=paths['model_dir']),  # save statistics for tensorboard
]

##### **Train**
Building Fuse's manager and supplying it PyTorch's optimizer and scheduler.

Note that the manger is using the training paremeter that we've set above.

In [None]:
# Create optimizer
optimizer = optim.Adam(model.parameters(), lr=train_params['manager.learning_rate'], weight_decay=train_params['manager.weight_decay'])

# create scheduler
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer)

# train from scratch
manager = FuseManagerDefault(output_model_dir=paths['model_dir'], force_reset=paths['force_reset_model_dir'])

# Providing the objects required for the training process.
manager.set_objects(net=model,
                    optimizer=optimizer,
                    losses=losses,
                    metrics=metrics,
                    best_epoch_source=train_params['manager.best_epoch_source'],
                    lr_scheduler=scheduler,
                    callbacks=callbacks,
                    train_params=train_params['manager.train_params'])

# Start training
manager.train(train_dataloader=train_dataloader, validation_dataloader=validation_dataloader)

## **Infer**

##### **Define Infer Common Params**


In [None]:
INFER_COMMON_PARAMS = {}
INFER_COMMON_PARAMS['infer_filename'] = 'validation_set_infer.gz'
INFER_COMMON_PARAMS['checkpoint'] = 'best' 

infer_common_params = INFER_COMMON_PARAMS

##### **Infer**

In [None]:
validation_dataloader = DataLoader(dataset=validation_dataset, collate_fn=validation_dataset.collate_fn, batch_size=2, num_workers=2)

## Manager for inference
manager = FuseManagerDefault()
output_columns = ['model.output.classification', 'data.label']
manager.infer(data_loader=validation_dataloader,
                input_model_dir=paths['model_dir'],
                checkpoint=infer_common_params['checkpoint'],
                output_columns=output_columns,
                output_file_name=os.path.join(paths["inference_dir"], infer_common_params["infer_filename"]))

## **Evaluation**

##### **Define EVAL Common Params**


In [None]:
EVAL_COMMON_PARAMS = {}
EVAL_COMMON_PARAMS['infer_filename'] = INFER_COMMON_PARAMS['infer_filename']
eval_common_params = EVAL_COMMON_PARAMS

##### **Calculate metrics**

In [None]:
class_names = [str(i) for i in range(10)]

# metrics
metrics = OrderedDict([
    ('operation_point', MetricApplyThresholds(pred='model.output.classification')), # will apply argmax
    ('accuracy', MetricAccuracy(pred='results:metrics.operation_point.cls_pred', target='data.label')),
    ('roc', MetricROCCurve(pred='model.output.classification', target='data.label', class_names=class_names, output_filename=os.path.join(paths['inference_dir'], 'roc_curve.png'))),
    ('auc', MetricAUCROC(pred='model.output.classification', target='data.label', class_names=class_names)),
])

##### **Evaluate**

In [None]:
# create evaluator
evaluator = EvaluatorDefault()

# run
results = evaluator.eval(ids=None,
                    data=os.path.join(paths["inference_dir"], eval_common_params["infer_filename"]),
                    metrics=metrics,
                    output_dir=paths['eval_dir'])