# FuseMedML - Hello world
Welcome!\
In this tutorial we'll cover the basics in our FuseMedML open soruce library through an hands-on notebook.\

Goals:
* 


------------------
## FuseMedML
[![Github repo](https://img.shields.io/static/v1?label=GitHub&message=FuseMedML&color=brightgreen)](https://github.com/IBM/fuse-med-ml)

[![PyPI version](https://badge.fury.io/py/fuse-med-ml.svg)](https://badge.fury.io/py/fuse-med-ml)

[![Slack channel](https://img.shields.io/badge/support-slack-slack.svg?logo=slack)](https://join.slack.com/t/fusemedml/shared_invite/zt-xr1jaj29-h7IMsSc0Lq4qpVNxW97Phw)

[![Open Source](https://badges.frapsoft.com/os/v1/open-source.svg)](https://github.com/IBM/fuse-med-ml)


FuseMedML is an open-source python-based framework designed to enhance collaboration and accelerate discoveries in Fused Medical data through advanced Machine Learning technologies. 

Initial version is PyTorch-based and focuses on deep learning on medical imaging.


## **FuseMedML Key Concepts in a Nutshell**
### Share and Reuse

A common generic implementation, you can reuse, is provided for most components in the pipeline. 

The naming convention for the common implementation is `Fuse***Default` 

FuseMedML comes with a large collection of components that grow with each new project. Some of them are entirely generic and the others are domain specific.


Don't forget to **contribute** back and **share** them. 

### Decoupling
The decoupling is achieved by the fact that, in most cases, the objects do not interact directly. Instead, the information and data are routed between components using *namespaces* (examples below). 

Meaning, each object extracts its input from and saves its output into a dictionary named `batch_dict`. 

`batch_dict` aggregates the outputs of all the objects through a single batch. 

<br />

**Example of the decoupling approach:**
```python
FuseMetricAUC(pred_name='model.output.classification', target_name='data.gt.classification')  
```

`FuseMetricAUC` will read the required tensors to compute AUC from `batch_dict`. The relevant dictionary keys are `pred_name` and `target_name`. 

This approach allows writing a generic metric which is completely independent of the model and data extractor. 

In addition, it allows to easily re-use this object in a plug & play manner without adding extra code. 

Such an approach also allows us to use it several times in case we have multiple heads/tasks.

<br />


When a batch is completed, only the required key-value pairs from `batch_dict`, such as the loss values, will be collected in another dictionary named `epoch_results`. 

Both `batch_dict` and `epoch_results` are nested dictionaries. To easily access the data stored in those dictionaries, use `FuseUtilsHierarchicalDict`:

```python
FuseUtilsHierarchicalDict.get(batch_dict, ‘model.output.classification’)
``` 

will return `batch_dict[‘model’][‘output’][‘classification’]`

### Manager API
The manager is the main API while using Fuse - it resposible for the Train and Infer functionallity.

Possible workflows are listed in the FuseMangerDefault's documentation. Here are two examples:

##### **Train: Init -> set objects -> train**
```python
manager = FuseManagerDefault(output_model_dir=paths['model_dir'], force_reset=paths['force_reset_model_dir'])

manager.set_objects(net=model,
                    optimizer=optimizer,
                    losses=losses,
                    metrics=metrics,
                    best_epoch_source=train_common_params['manager.best_epoch_source'],
                    lr_scheduler=scheduler,
                    callbacks=callbacks,
                    train_params=train_common_params['manager.train_params'],
                    output_model_dir=paths['model_dir'])

manager.train(train_dataloader=train_dataloader,
                validation_dataloader=validation_dataloader)
```

##### **Infer: Init -> infer**
```python
manager = FuseManagerDefault()

output_columns = ['model.output.classification', 'data.label']

manager.infer(data_loader=validation_dataloader,
                input_model_dir=paths['model_dir'],
                checkpoint=infer_common_params['checkpoint'],
                output_columns=output_columns,
                output_file_name=os.path.join(paths["inference_dir"], infer_common_params["infer_filename"]))
```



### Use PyTorch directly and alternative frameworks

FuseMedML uses and extends PyTorch only when required by the user. 
You can mix FuseMedML with PyTorch code, components from alternative frameworks and other popular GitHub projects. 



------------
## **Installation Details - Google Colab**

#### **Enable GPU Support**

To use GPU through Google Colab, change the runtime mode to GPU:

From the "Runtime" menu select "Change Runtime Type", choose "GPU" from the drop-down menu and click "SAVE"
When asked, reboot the system.

#### **Install FuseMedML**

In [None]:
# !git clone https://github.com/IBM/fuse-med-ml.git
# %cd fuse-med-ml
# !pip install -e .

## **Setup environment**

##### **Imports**

In [None]:
import logging
import os

import torch
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import torchvision.models as models
from torch.utils.data.dataloader import DataLoader
from torchvision import transforms

from fuse.analyzer.analyzer_default import FuseAnalyzerDefault
from fuse.data.dataset.dataset_wrapper import FuseDatasetWrapper
from fuse.data.sampler.sampler_balanced_batch import FuseSamplerBalancedBatch
from fuse.losses.loss_default import FuseLossDefault
from fuse.managers.callbacks.callback_metric_statistics import FuseMetricStatisticsCallback
from fuse.managers.callbacks.callback_tensorboard import FuseTensorboardCallback
from fuse.managers.callbacks.callback_time_statistics import FuseTimeStatisticsCallback
from fuse.managers.manager_default import FuseManagerDefault
from fuse.metrics.classification.metric_accuracy import FuseMetricAccuracy
from fuse.metrics.classification.metric_auc import FuseMetricAUC
from fuse.metrics.classification.metric_roc_curve import FuseMetricROCCurve
from fuse.models.model_wrapper import FuseModelWrapper
from fuse.utils.utils_gpu import FuseUtilsGPU
from fuse.utils.utils_debug import FuseUtilsDebug
from fuse.utils.utils_logger import fuse_logger_start

##### **Setup debugger**
The supported modes are: 'default', 'fast', 'debug', 'verbose', 'user'.

More details in FuseUtilsDebug.

In [None]:
mode = 'default'
debug = FuseUtilsDebug(mode)

##### **Output paths**
The user is able to customize the output directory by changing ROOT as following below.

In [None]:
ROOT = 'examples' # TODO: fill path here
PATHS = {'model_dir': os.path.join(ROOT, 'mnist/model_dir'),
         'force_reset_model_dir': True,  # If True will reset model dir automatically - otherwise will prompt 'are you sure' message.
         'cache_dir': os.path.join(ROOT, 'mnist/cache_dir'),
         'inference_dir': os.path.join(ROOT, 'mnist/infer_dir'),
         'analyze_dir': os.path.join(ROOT, 'mnist/analyze_dir')}

paths = PATHS

##### **Training Parameters**
* Data - define parameters for the preproccesing.
* Manager - define parameters using 

In [None]:
TRAIN_COMMON_PARAMS = {}
TRAIN_COMMON_PARAMS['data.batch_size'] = 30
TRAIN_COMMON_PARAMS['data.train_num_workers'] = 8
TRAIN_COMMON_PARAMS['data.validation_num_workers'] = 8

TRAIN_COMMON_PARAMS['manager.train_params'] = {
    'device': 'cuda', 
    'num_epochs': 5,
    'virtual_batch_size': 1,  # number of batches in one virtual batch
    'start_saving_epochs': 10,  # first epoch to start saving checkpoints from
    'gap_between_saving_epochs': 5,  # number of epochs between saved checkpoint
}
TRAIN_COMMON_PARAMS['manager.best_epoch_source'] = {
    'source': 'metrics.accuracy',  # can be any key from 'epoch_results'
    'optimization': 'max',  # can be either min/max
    'on_equal_values': 'better',
    # can be either better/worse - whether to consider best epoch when values are equal
}
TRAIN_COMMON_PARAMS['manager.learning_rate'] = 1e-4
TRAIN_COMMON_PARAMS['manager.weight_decay'] = 0.001
TRAIN_COMMON_PARAMS['manager.resume_checkpoint_filename'] = None  # if not None, will try to load the checkpoint

train_params = TRAIN_COMMON_PARAMS

##### **Allocate GPUs**
Look for a free GPUs and allocate accordingly.



In [None]:
# To use cpu - set NUM_GPUS to 0
NUM_GPUS = 0
if NUM_GPUS == 0:
    TRAIN_COMMON_PARAMS['manager.train_params']['device'] = 'cpu' 
# uncomment if you want to use specific gpus instead of automatically looking for free ones
force_gpus = None  # [0]
FuseUtilsGPU.choose_and_enable_multiple_gpus(NUM_GPUS, force_gpus=force_gpus)

##### **Define helper function**

In [None]:
def perform_softmax(output):
    if isinstance(output, torch.Tensor):  # validation
        logits = output
    else:  # train
        logits = output.logits
    cls_preds = F.softmax(logits, dim=1)
    return logits, cls_preds

##### **Init logger**
The logger does two things:
- Output log automatically to three destinations:
    1. Console
    2. File - a copy of the console.
    3. Verboes file - used for debug.

<p></p>

- Save a copy of the template file.

In [None]:
fuse_logger_start(output_path=paths['model_dir'], console_verbose_level=logging.INFO)
lgr = logging.getLogger('Fuse')
lgr.info('Fuse Train', {'attrs': ['bold', 'underline']})
lgr.info(f'model_dir={paths["model_dir"]}', {'color': 'magenta'})
lgr.info(f'cache_dir={paths["cache_dir"]}', {'color': 'magenta'})

## **Training the model**

##### **Data**

In [None]:
# Train Data
lgr.info(f'Train Data:', {'attrs': 'bold'})

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])
# Create dataset
torch_train_dataset = torchvision.datasets.MNIST(paths['cache_dir'], download=True, train=True, transform=transform)
# wrapping torch dataset
# FIXME: support also using torch dataset directly
train_dataset = FuseDatasetWrapper(name='train', dataset=torch_train_dataset, mapping=('image', 'label'))
train_dataset.create()
lgr.info(f'- Create sampler:')
sampler = FuseSamplerBalancedBatch(dataset=train_dataset,
                                balanced_class_name='data.label',
                                num_balanced_classes=10,
                                batch_size=train_params['data.batch_size'],
                                balanced_class_weights=None)
lgr.info(f'- Create sampler: Done')

# Create dataloader
train_dataloader = DataLoader(dataset=train_dataset, batch_sampler=sampler, num_workers=train_params['data.train_num_workers'])
lgr.info(f'Train Data: Done', {'attrs': 'bold'})

## Validation data
lgr.info(f'Validation Data:', {'attrs': 'bold'})
# Create dataset
torch_validation_dataset = torchvision.datasets.MNIST(paths['cache_dir'], download=True, train=False, transform=transform)
# wrapping torch dataset
validation_dataset = FuseDatasetWrapper(name='validation', dataset=torch_validation_dataset, mapping=('image', 'label'))
validation_dataset.create()

# dataloader
validation_dataloader = DataLoader(dataset=validation_dataset, batch_size=train_params['data.batch_size'],
                                num_workers=train_params['data.validation_num_workers'])
lgr.info(f'Validation Data: Done', {'attrs': 'bold'})

##### **Model**
Building the model using PyTorch's API and then wrapping it with ours. 

The model outputs will be aggregated in batch_dict['model.*'].


Another option to implement a model is to use o

In [None]:
lgr.info('Model:', {'attrs': 'bold'})

torch_model = models.resnet18(num_classes=10)
# modify conv1 to support single channel image
torch_model.conv1 = torch.nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
# use adaptive avg pooling to support mnist low resolution images
torch_model.avgpool = torch.nn.AdaptiveAvgPool2d(1)

model = FuseModelWrapper(model=torch_model,
                        model_inputs=['data.image'],
                        post_forward_processing_function=perform_softmax,
                        model_outputs=['logits.classification', 'output.classification']
                        )

lgr.info('Model: Done', {'attrs': 'bold'})

##### **Loss function**
Dictionary of loss elements. each element is a sub-class of FuseLossBase.

The total loss will be the weighted sum of all the elements.


In [None]:
losses = {
    'cls_loss': FuseLossDefault(pred_name='model.logits.classification', target_name='data.label', callable=F.cross_entropy, weight=1.0),
}

##### **Metrics**
Dictionary of metric elements. Each element is a sub-class of FuseMetricBase.

The metrics will be calculated per epoch for both the validation and train.

The 'best_epoch_source', used to save the best model could be based on one of these metrics.

In [None]:
metrics = {
    'accuracy': FuseMetricAccuracy(pred_name='model.output.classification', target_name='data.label')
}

##### **Callbacks**
Callbacks are sub-classes of FuseCallbackBase.

A callback is an object that can preform actions at various stages of training.

In each stage it allows to manipulate either the data, batch_dict or epoch_results.


In [None]:
callbacks = [
    # default callbacks
    FuseTensorboardCallback(model_dir=paths['model_dir']),  # save statistics for tensorboard
    FuseMetricStatisticsCallback(output_path=paths['model_dir'] + "/metrics.csv"),  # save statistics a csv file
    FuseTimeStatisticsCallback(num_epochs=train_params['manager.train_params']['num_epochs'], load_expected_part=0.1)  # time profiler
]

##### **Train**

In [None]:
lgr.info('Train:', {'attrs': 'bold'})

# Create optimizer
optimizer = optim.Adam(model.parameters(), lr=train_params['manager.learning_rate'], weight_decay=train_params['manager.weight_decay'])

# create scheduler
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer)

# train from scratch
manager = FuseManagerDefault(output_model_dir=paths['model_dir'], force_reset=paths['force_reset_model_dir'])

# Providing the objects required for the training process.
manager.set_objects(net=model,
                    optimizer=optimizer,
                    losses=losses,
                    metrics=metrics,
                    best_epoch_source=train_params['manager.best_epoch_source'],
                    lr_scheduler=scheduler,
                    callbacks=callbacks,
                    train_params=train_params['manager.train_params'])

## Continue training
if train_params['manager.resume_checkpoint_filename'] is not None:
    # Loading the checkpoint including model weights, learning rate, and epoch_index.
    manager.load_checkpoint(checkpoint=train_params['manager.resume_checkpoint_filename'], mode='train')

# Start training
manager.train(train_dataloader=train_dataloader, validation_dataloader=validation_dataloader)

lgr.info('Train: Done', {'attrs': 'bold'})

## **Infer**

##### **Define Infer Common Params**


In [None]:
INFER_COMMON_PARAMS = {}
INFER_COMMON_PARAMS['infer_filename'] = 'validation_set_infer.gz'
INFER_COMMON_PARAMS['checkpoint'] = 'best'  # Fuse TIP: possible values are 'best', 'last' or epoch_index.

infer_common_params = INFER_COMMON_PARAMS

##### **Logging**

In [None]:
lgr.info('Fuse Inference', {'attrs': ['bold', 'underline']})
lgr.info(f'infer_filename={os.path.join(paths["inference_dir"], infer_common_params["infer_filename"])}', {'color': 'magenta'})

In [None]:
## Data
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])
# Create dataset
torch_validation_dataset = torchvision.datasets.MNIST(paths['cache_dir'], download=True, train=False, transform=transform)
# wrapping torch dataset
validation_dataset = FuseDatasetWrapper(name='validation', dataset=torch_validation_dataset, mapping=('image', 'label'))
validation_dataset.create()
# dataloader
validation_dataloader = DataLoader(dataset=validation_dataset, collate_fn=validation_dataset.collate_fn, batch_size=2, num_workers=2)

## Manager for inference
manager = FuseManagerDefault()
output_columns = ['model.output.classification', 'data.label']
manager.infer(data_loader=validation_dataloader,
                input_model_dir=paths['model_dir'],
                checkpoint=infer_common_params['checkpoint'],
                output_columns=output_columns,
                output_file_name=os.path.join(paths["inference_dir"], infer_common_params["infer_filename"]))

## **Analyze**

##### **Analyze Infer Common Params**


In [None]:
ANALYZE_COMMON_PARAMS = {}
ANALYZE_COMMON_PARAMS['infer_filename'] = INFER_COMMON_PARAMS['infer_filename']
ANALYZE_COMMON_PARAMS['output_filename'] = os.path.join(PATHS['analyze_dir'], 'all_metrics')
analyze_common_params = ANALYZE_COMMON_PARAMS

lgr.info('Fuse Analyze', {'attrs': ['bold', 'underline']})

# metrics
metrics = {
    'accuracy': FuseMetricAccuracy(pred_name='model.output.classification', target_name='data.label'),
    'roc': FuseMetricROCCurve(pred_name='model.output.classification', target_name='data.label', output_filename=os.path.join(paths['inference_dir'], 'roc_curve.png')),
    'auc': FuseMetricAUC(pred_name='model.output.classification', target_name='data.label')
}

##### **Analyze**

In [None]:
# create analyzer
analyzer = FuseAnalyzerDefault()

# run
# FIXME: simplify analyze interface for this case
results = analyzer.analyze(gt_processors={},
                data_pickle_filename=os.path.join(paths["inference_dir"], analyze_common_params["infer_filename"]),
                metrics=metrics,
                output_filename=analyze_common_params['output_filename'])