# FuseMedML - Hello World
[![Github repo](https://img.shields.io/static/v1?label=GitHub&message=FuseMedML&color=brightgreen)](https://github.com/BiomedSciAI/fuse-med-ml)

[![PyPI version](https://badge.fury.io/py/fuse-med-ml.svg)](https://badge.fury.io/py/fuse-med-ml)

[![Slack channel](https://img.shields.io/badge/support-slack-slack.svg?logo=slack)](https://join.slack.com/t/fusemedml/shared_invite/zt-xr1jaj29-h7IMsSc0Lq4qpVNxW97Phw)

[![Open Source](https://badges.frapsoft.com/os/v1/open-source.svg)](https://github.com/BiomedSciAI/fuse-med-ml)


**Welcome to FuseMedML's 'hello world' hands-on notebook!**

In this notebook we'll examine a FuseMedML's basic use case: MNIST multiclass classification - incluing training, inference and evaluation.

By the end of the session we hope you'll be familiar with basic Fuse's workflow and acknowledge it's potential.

Open and run this notebook in [Google Colab](https://colab.research.google.com/github/BiomedSciAI/fuse-med-ml/blob/master/fuse_examples/imaging/hello_world/hello_world.ipynb)

ENJOY

------------
## **Installation Details - Google Colab**

In [None]:
# @title 1. Install FuseMedML

# @markdown Please choose whether or not to install FuseMedML and execute this cell by pressing the *Play* button on the left.


install_fuse = False  # @param {type:"boolean"}
use_gpu = True  # @param {type:"boolean"}

# @markdown ### **Warning!**
# @markdown If you wish to install FuseMedML -- as a workaround for
# @markdown [this](https://stackoverflow.com/questions/57831187/need-to-restart-runtime-before-import-an-installed-package-in-colab)
# @markdown issue please follow those steps:   <br>
# @markdown 1. Execute this cell by pressing the ▶️ button on the left.
# @markdown 2. Restart runtime
# @markdown 3. Execute it once again
# @markdown 4. Enjoy
if install_fuse:
    !git clone https://github.com/BiomedSciAI/fuse-med-ml.git
    %cd fuse-med-ml
    %pip install -e .[all,examples]


## **Setup environment**

##### **Imports**

In [None]:
# @title 1. Imports

# @markdown Please execute this cell by pressing the *Play* button on the left.

import os
import copy
from typing import OrderedDict

import torch.nn.functional as F
import torch.optim as optim
import pytorch_lightning as pl
from torch.utils.data.dataloader import DataLoader

from fuse.eval.evaluator import EvaluatorDefault
from fuse.dl.losses.loss_default import LossDefault
from fuse.eval.metrics.classification.metrics_classification_common import MetricAccuracy, MetricAUCROC, MetricROCCurve
from fuse.eval.metrics.classification.metrics_thresholding_common import MetricApplyThresholds
from fuse.dl.models.model_wrapper import ModelWrapSeqToDict
from fuse.data.utils.samplers import BatchSamplerDefault
from fuse.data.utils.collates import CollateDefault
from fuse.dl.lightning.pl_module import LightningModuleDefault
from fuse.dl.lightning.pl_funcs import convert_predictions_to_dataframe
from fuse.utils.file_io.file_io import create_dir, save_dataframe
from fuseimg.datasets.mnist import MNIST

from fuse_examples.imaging.hello_world.hello_world_utils import LeNet, perform_softmax

##### **Output paths**
The user is able to easily customize the directories paths.

In [None]:
ROOT = "_examples/mnist"
model_dir = os.path.join(ROOT, "model_dir")
PATHS = {
    "model_dir": model_dir,
    "cache_dir": os.path.join(ROOT, "cache_dir"),
    "inference_dir": os.path.join(model_dir, "infer_dir"),
    "eval_dir": os.path.join(model_dir, "eval_dir"),
}

paths = PATHS

##### **Training Parameters**

In [None]:
TRAIN_COMMON_PARAMS = {}

### Data ###
TRAIN_COMMON_PARAMS["data.batch_size"] = 100
TRAIN_COMMON_PARAMS["data.train_num_workers"] = 8
TRAIN_COMMON_PARAMS["data.validation_num_workers"] = 8

### PL Trainer ###
TRAIN_COMMON_PARAMS["trainer.num_epochs"] = 2
TRAIN_COMMON_PARAMS["trainer.num_devices"] = 1
TRAIN_COMMON_PARAMS["trainer.accelerator"] = "gpu" if use_gpu else "cpu"
TRAIN_COMMON_PARAMS["trainer.ckpt_path"] = None  #  path to the checkpoint you wish continue the training from

### Optimizer ###
TRAIN_COMMON_PARAMS["opt.lr"] = 1e-4
TRAIN_COMMON_PARAMS["opt.weight_decay"] = 0.001

train_params = TRAIN_COMMON_PARAMS

## **Training the model**

##### **Data**
Downloading the MNIST dataset and building dataloaders (torch.utils.data.DataLoader) for both train and validation.


In [None]:
## Training Data
# Create dataset
train_dataset = MNIST.dataset(paths["cache_dir"], train=True)

# Create Fuse's custom sampler
sampler = BatchSamplerDefault(
    dataset=train_dataset,
    balanced_class_name="data.label",
    num_balanced_classes=10,
    batch_size=train_params["data.batch_size"],
    balanced_class_weights=None,
)

# Create dataloader
train_dataloader = DataLoader(
    dataset=train_dataset,
    batch_sampler=sampler,
    collate_fn=CollateDefault(),
    num_workers=train_params["data.train_num_workers"],
)

## Validation data
# Create dataset
validation_dataset = MNIST.dataset(paths["cache_dir"], train=False)

# dataloader
validation_dataloader = DataLoader(
    dataset=validation_dataset,
    batch_size=train_params["data.batch_size"],
    collate_fn=CollateDefault(),
    num_workers=train_params["data.validation_num_workers"],
)

##### **Model**
Building a LeNet model using "pure" PyTorch and wrapping it with Fuse's component. 

In [None]:
def create_model():
    torch_model = LeNet()
    # wrap basic torch model to automatically read inputs from batch_dict and save its outputs to batch_dict
    model = ModelWrapSeqToDict(
        model=torch_model,
        model_inputs=["data.image"],
        post_forward_processing_function=perform_softmax,
        model_outputs=["model.logits.classification", "model.output.classification"],
    )
    return model


model = create_model()

##### **Loss function**
Dictionary of loss elements such that each element is a sub-class of LossBase. The total loss will be the weighted sum of all the elements.

In [None]:
losses = {
    "cls_loss": LossDefault(
        pred="model.logits.classification", target="data.label", callable=F.cross_entropy, weight=1.0
    ),
}

##### **Metrics**
Dictionary of metric elements such that each element is a sub-class of MetricBase.

The metrics will be calculated per epoch for both the validation and train.

In [None]:
train_metrics = OrderedDict(
    [
        ("operation_point", MetricApplyThresholds(pred="model.output.classification")),  # will apply argmax
        ("accuracy", MetricAccuracy(pred="results:metrics.operation_point.cls_pred", target="data.label")),
    ]
)
validation_metrics = copy.deepcopy(train_metrics)  # use the same metrics in validation as well

##### **Best Epoch Source**
Defining what will be considered as 'the best epoch' so the model will be saved according to it.

In [None]:
best_epoch_source = dict(monitor="validation.metrics.accuracy", mode="max")

##### **Train**
Training session using PyTorch Lightning's trainer.

In [None]:
# create optimizer
optimizer = optim.Adam(model.parameters(), lr=train_params["opt.lr"], weight_decay=train_params["opt.weight_decay"])

# create scheduler
lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer)
lr_sch_config = dict(scheduler=lr_scheduler, monitor="validation.losses.total_loss")

# optimizer and lr sch - see pl.LightningModule.configure_optimizers return value for all options
optimizers_and_lr_schs = dict(optimizer=optimizer, lr_scheduler=lr_sch_config)

# create instance of PL module - FuseMedML generic version
pl_module = LightningModuleDefault(
    model_dir=paths["model_dir"],
    model=model,
    losses=losses,
    train_metrics=train_metrics,
    validation_metrics=validation_metrics,
    best_epoch_source=best_epoch_source,
    optimizers_and_lr_schs=optimizers_and_lr_schs,
)

# create lightning trainer
pl_trainer = pl.Trainer(
    default_root_dir=paths["model_dir"],
    max_epochs=train_params["trainer.num_epochs"],
    accelerator=train_params["trainer.accelerator"],
    devices=train_params["trainer.num_devices"],
    auto_select_gpus=True,
)

# train
pl_trainer.fit(pl_module, train_dataloader, validation_dataloader, ckpt_path=train_params["trainer.ckpt_path"])

## **Infer**

##### **Define Infer Common Params**


In [None]:
INFER_COMMON_PARAMS = {}
INFER_COMMON_PARAMS["infer_filename"] = "infer_file.gz"
INFER_COMMON_PARAMS["checkpoint"] = "best_epoch.ckpt"
INFER_COMMON_PARAMS["trainer.num_devices"] = TRAIN_COMMON_PARAMS["trainer.num_devices"]
INFER_COMMON_PARAMS["trainer.accelerator"] = TRAIN_COMMON_PARAMS["trainer.accelerator"]

infer_common_params = INFER_COMMON_PARAMS

##### **Infer**

In [None]:
# setting dir and paths
create_dir(paths["inference_dir"])
infer_file = os.path.join(paths["inference_dir"], infer_common_params["infer_filename"])
checkpoint_file = os.path.join(paths["model_dir"], infer_common_params["checkpoint"])

# creating a dataloader
validation_dataloader = DataLoader(dataset=validation_dataset, collate_fn=CollateDefault(), batch_size=2, num_workers=2)

# load pytorch lightning module
model = create_model()
pl_module = LightningModuleDefault.load_from_checkpoint(
    checkpoint_file, model_dir=paths["model_dir"], model=model, map_location="cpu", strict=True
)

# set the prediction keys to extract (the ones used be the evaluation function).
pl_module.set_predictions_keys(
    ["model.output.classification", "data.label"]
)  # which keys to extract and dump into file

# create a trainer instance
pl_trainer = pl.Trainer(
    default_root_dir=paths["model_dir"],
    accelerator=infer_common_params["trainer.accelerator"],
    devices=infer_common_params["trainer.num_devices"],
    auto_select_gpus=True,
)

# predict
predictions = pl_trainer.predict(pl_module, validation_dataloader, return_predictions=True)

# convert list of batch outputs into a dataframe
infer_df = convert_predictions_to_dataframe(predictions)
save_dataframe(infer_df, infer_file)

## **Evaluation**
Using the `EvaluatorDefault` from the evaluation package of FuseMedML (fuse.eval) which is a standalone library for evaluating ML models that not necessarily trained with FuseMedML.

More details and examples for the evaluation package can be found [here](https://github.com/BiomedSciAI/fuse-med-ml/blob/master/fuse/eval/README.md).


##### **Define EVAL Common Params**


In [None]:
EVAL_COMMON_PARAMS = {}
EVAL_COMMON_PARAMS["infer_filename"] = INFER_COMMON_PARAMS["infer_filename"]

eval_common_params = EVAL_COMMON_PARAMS

##### **Define metrics**

In [None]:
class_names = [str(i) for i in range(10)]

# metrics
metrics = OrderedDict(
    [
        ("operation_point", MetricApplyThresholds(pred="model.output.classification")),  # will apply argmax
        ("accuracy", MetricAccuracy(pred="results:metrics.operation_point.cls_pred", target="data.label")),
        (
            "roc",
            MetricROCCurve(
                pred="model.output.classification",
                target="data.label",
                class_names=class_names,
                output_filename=os.path.join(paths["inference_dir"], "roc_curve.png"),
            ),
        ),
        ("auc", MetricAUCROC(pred="model.output.classification", target="data.label", class_names=class_names)),
    ]
)

##### **Evaluate**

In [None]:
# create evaluator
evaluator = EvaluatorDefault()

# run eval
results = evaluator.eval(
    ids=None,
    data=os.path.join(paths["inference_dir"], eval_common_params["infer_filename"]),
    metrics=metrics,
    output_dir=paths["eval_dir"],
    silent=False,
)

print("Done!")

In [None]:
# For testing purposes
test_result_acc = results["metrics.accuracy"]