# Hands-On: FuseMedML and MultiModality

Welcome!

This notebook will guide you through the hands-on session.

Open and run this notebook in Google Colab (instructions can be found in 'Installation Details - Google Colab' section):

https://colab.research.google.com/github/IBM/fuse-med-ml/blob/master/fuse_examples/multimodality/image_clinical/multimodality_image_clinical.ipynb

## Session take-away
* Introduction to FuseMedML framework
* Introduction to multi-modality data and tasks
* Train multimodality based deep-learning model: demonstration of the integration of imaging and clinical data in skin lesion classification task


------------------
## FuseMedML
See https://github.com/IBM/fuse-med-ml

---------
## Multimodality
Radiologists never diagnose based solely on a single modality. The decision is made by combining information from various sources. Therefore, it is important to include such information in machine learning algorithms. 

Radiologists take into account clinical information such as the reason the scan was ordered. If needed, they can also examine other clinical information from the electronic health records of the hospital.  

Prior images are another type of data that is routinely used in radiology reading. Radiologists will often compare a current study with imaging or other tests done in the past to assess change.

It is also a common practice to consider findings from several imaging modalities when making a diagnosis. Each reveals different aspects and attributes of the suspicious finding.

In this session, we will demonstrate two simple yet effective methods to integrate clinical data.
In all cases and in general, the clinical data should first be pre-processed, normalized, etc.

<img src="arch.png" alt="drawing" width="100%"/>

* **Imaging only implementation**

* **Imaging and Tabular data - concatenate tabular data after image feature extraction**

    The tabular data is integrated after feature extraction, done by a convolutional network followed by a pooling layer that extracts non-spatial features from the image.


* **Imaging and Tabular data - concatenate directly with the image**

    The tabular data is integrated at the beginning of the network by adding more channels to the input image. Each channel represents a single bit in the one-hot vector. 

    This method of integrating clinical features provides the network with the ability to extract better features using the backbone, in contrast to the standard way of integrating this data only after feature extraction.




-------------
## Task - ISIC 2019 challenge to classify dermoscopic images and clinical data among nine different diagnostic categories.

This task was chosen for demonstration since the data is simple and public, which will make the session more effective.

We explored the effectiveness of each method in two different tasks:
* Article: [Context in medical imaging: the case of focal liver lesion classification](https://www.spiedigitallibrary.org/conference-proceedings-of-spie/12032/120320O/Context-in-medical-imaging--the-case-of-focal-liver/10.1117/12.2609385.short?SSO=1)
* FuseMedML example on [Duke dataset](https://sites.duke.edu/mazurowski/resources/breast-cancer-mri-dataset/).

Skin cancer is the most common cancer globally, with melanoma being the most deadly form. 

Dermoscopy is a skin imaging modality that has demonstrated improvement for the diagnosis of skin cancer compared to unaided visual inspection. 

The goal for ISIC 2019 is to classify dermoscopic images among eight different diagnostic categories:

* Melanoma
* Melanocytic nevus
* Basal cell carcinoma
* Actinic keratosis
* Benign keratosis (solar lentigo / seborrheic keratosis / lichen planus-like keratosis)
* Dermatofibroma
* Vascular lesion
* Squamous cell carcinoma
* None of the others

25,331 images are available for training across 8 different categories.

Two tasks were available for participation:
* classify dermoscopic images without meta-data,
* classify images with additional available meta-data, including age, gender and anatomic site

[1] Tschandl P., Rosendahl C. & Kittler H. The HAM10000 dataset, a large collection of multi-source dermatoscopic images of common pigmented skin lesions. Sci. Data 5, 180161 doi.10.1038/sdata.2018.161 (2018)

[2] Noel C. F. Codella, David Gutman, M. Emre Celebi, Brian Helba, Michael A. Marchetti, Stephen W. Dusza, Aadi Kalloo, Konstantinos Liopyris, Nabin Mishra, Harald Kittler, Allan Halpern: “Skin Lesion Analysis Toward Melanoma Detection: A Challenge at the 2017 International Symposium on Biomedical Imaging (ISBI), Hosted by the International Skin Imaging Collaboration (ISIC)”, 2017; arXiv:1710.05006.

[3] Marc Combalia, Noel C. F. Codella, Veronica Rotemberg, Brian Helba, Veronica Vilaplana, Ofer Reiter, Allan C. Halpern, Susana Puig, Josep Malvehy: “BCN20000: Dermoscopic Lesions in the Wild”, 2019; arXiv:1908.02288.

------------
## Installation Details - Google Colab (skip if you run it using an already installed FuseMedML)

### **Enable GPU Support**

To use GPU through Google Colab, change the runtime mode to GPU:

From the "Runtime" menu select "Change Runtime Type", choose "GPU" from the drop-down menu and click "SAVE"
When asked, reboot the system.

### **Install FuseMedML**

In [None]:
!git clone https://github.com/IBM/fuse-med-ml.git
%cd fuse-med-ml
!pip install -e .[all,examples]

**Please reboot the session when asked!**
----------------
### Select GPU, start a logger and define dataset size

In [None]:
import os

os.environ["CUDA_VISIBLE_DEVICES"] = "0"
from fuse.utils.utils_logger import fuse_logger_start
import logging

fuse_logger_start(output_path=None, console_verbose_level=logging.INFO)
all_data = False  # use all data or just 400 samples
model_dir = "model_dir"  # path to model dir
cache_dir = "cache_dir"  # path to cache dir
data_dir = "data_dir"
reset_cache = True
reset_split_file = True

----------------
### Data

(Don't forget to first follow the installation instructions listed above)

In [None]:
print("Download data: it might take few miuntes")
from fuseimg.datasets.isic import ISIC

if not all_data:
    from fuse_examples.imaging.classification.isic.golden_members import FULL_GOLDEN_MEMBERS as sample_ids
else:
    sample_ids = None
ISIC.download(data_path=data_dir, sample_ids_to_download=sample_ids)

In [None]:
from fuse_examples.multimodality.image_clinical.dataset import isic_2019_dataloaders

# build the dataloders (and dataset) using FuseMedML data package and processing implementation located in fuseimg.datasets.isic
train_dl, validation_dl = isic_2019_dataloaders(
    data_path=data_dir, cache_path=cache_dir, reset_cache=reset_cache, reset_split_file=reset_split_file
)

The original code can be found [here](https://github.com/IBM/fuse-med-ml/blob/master/fuse_examples/multimodality/image_cliical/dataset.py).

<br/>

Details and instructions about our data package can be found [here](https://github.com/IBM/fuse-med-ml/blob/master/fuse/data/README.md)

In [None]:
sample_index = 10
print(train_dl.dataset[sample_index])

### Imaging Only Implementation

In [None]:
from fuse.dl.models.model_multihead import ModelMultiHead
from fuse.dl.models.heads.head_global_pooling_classifier import HeadGlobalPoolingClassifier
from fuse.dl.models.backbones.backbone_inception_resnet_v2 import BackboneInceptionResnetV2

model = ModelMultiHead(
    conv_inputs=(("data.input.img", 3),),
    backbone=BackboneInceptionResnetV2(input_channels_num=3, pretrained_weights_url=None),
    heads=[
        HeadGlobalPoolingClassifier(
            head_name="classification",
            dropout_rate=0.5,
            conv_inputs=[("model.backbone_features", 384)],
            layers_description=(256,),
            num_classes=8,
            pooling="avg",
        ),
    ],
)

In [None]:
import copy
from collections import OrderedDict
import torch.nn.functional as F
from fuse.dl.losses.loss_default import LossDefault
from fuse.eval.metrics.classification.metrics_classification_common import MetricAUCROC, MetricAccuracy, MetricConfusion
from fuse.eval.metrics.classification.metrics_thresholding_common import MetricApplyThresholds

# ====================================================================================
#  Loss
# ====================================================================================
losses = {
    "cls_loss": LossDefault(
        pred="model.logits.classification", target="data.label", callable=F.cross_entropy, weight=1.0
    )
}

# ====================================================================================
# Metrics
# ====================================================================================
class_names = ["MEL", "NV", "BCC", "AK", "BKL", "DF", "VASC", "SCC"]

train_metrics = OrderedDict(
    [
        ("op", MetricApplyThresholds(pred="model.output.classification")),  # will apply argmax,
        (
            "balanced_acc",
            MetricConfusion(
                pred="results:metrics.op.cls_pred",
                target="data.label",
                metrics=("sensitivity",),
                class_names=class_names,
            ),
        ),
    ]
)
validation_metrics = copy.deepcopy(train_metrics)

best_epoch_source = {
    "monitor": "validation.metrics.balanced_acc.sensitivity.macro_avg",
    "mode": "max",
}

In [None]:
import torch.optim as optim
import pytorch_lightning as pl
from fuse.dl.lightning.pl_module import LightningModuleDefault

# create optimizer
optimizer = optim.Adam(model.parameters(), lr=1e-5, weight_decay=0.001)

# create scheduler
lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer)
lr_sch_config = dict(scheduler=lr_scheduler, monitor="validation.losses.total_loss")
# optimizier and lr sch - see pl.LightningModule.configure_optimizers return value for all options
optimizers_and_lr_schs = dict(optimizer=optimizer, lr_scheduler=lr_sch_config)

# create instance of PL module - FuseMedML generic version
pl_module = LightningModuleDefault(
    model_dir=model_dir,
    model=model,
    losses=losses,
    train_metrics=train_metrics,
    validation_metrics=validation_metrics,
    best_epoch_source=best_epoch_source,
    optimizers_and_lr_schs=optimizers_and_lr_schs,
)

# create lightining trainer.
pl_trainer = pl.Trainer(default_root_dir=model_dir, max_epochs=2, accelerator="gpu", devices=1)

# train
pl_trainer.fit(pl_module, train_dl, validation_dl)

In [None]:
# Load the TensorBoard notebook extension - model_dir is hard-coded (change it if necessary)
%load_ext tensorboard
%tensorboard --logdir model_dir

### Imaging and Tabular data - concatenate tabular data after image feature extraction ###

In [None]:
from fuse.utils.ndict import NDict
import torch
from fuse.data import OpBase
from fuse_examples.multimodality.image_clinical.dataset import ANATOM_SITE_INDEX, SEX_INDEX

### Generate Data
# Create an operation to add to the end of the current processing pipeline
class OpClinicalEncodeing(OpBase):
    """
    Collect clinical data into a single vector and stored it in sample_dict["data.input.clinical.all"]
    """

    def __call__(self, sample_dict: NDict) -> NDict:

        age = sample_dict["data.input.clinical.age_approx"]
        if age > 0 and age < 120:
            age = torch.tensor(age / 120.0).reshape(-1)
        else:
            age = torch.tensor(-1.0).reshape(-1)

        anatom_site = sample_dict["data.input.clinical.anatom_site_general"]
        anatom_site_one_hot = torch.zeros(len(ANATOM_SITE_INDEX))
        if anatom_site in ANATOM_SITE_INDEX:
            anatom_site_one_hot[ANATOM_SITE_INDEX[anatom_site]] = 1

        sex = sample_dict["data.input.clinical.sex"]
        sex_one_hot = torch.zeros(len(SEX_INDEX))
        if sex in SEX_INDEX:
            sex_one_hot[SEX_INDEX[sex]] = 1

        clinical_encoding = torch.cat((age, anatom_site_one_hot, sex_one_hot), dim=0)
        sample_dict["data.input.clinical.all"] = clinical_encoding

        return sample_dict


# add single step to the pipeline. format (<operation>, kwargs for <operation>.__call__() method)
append_dyn_pipeline = [(OpClinicalEncodeing(), dict())]
train_dl, validation_dl = isic_2019_dataloaders(
    data_path=data_dir, cache_path=cache_dir, append_dyn_pipeline=append_dyn_pipeline, sample_ids=sample_ids
)

### Define model - add the new clinical data vector as an additional argument to the classification head
model = ModelMultiHead(
    conv_inputs=(("data.input.img", 3),),
    backbone=BackboneInceptionResnetV2(input_channels_num=3, pretrained_weights_url=None),
    heads=[
        HeadGlobalPoolingClassifier(
            head_name="classification",
            dropout_rate=0.5,
            conv_inputs=[("model.backbone_features", 384)],
            tabular_data_inputs=[("data.input.clinical.all", 13)],
            layers_description=(256,),
            tabular_layers_description=(128,),
            num_classes=8,
            pooling="avg",
        ),
    ],
)


# create optimizer
optimizer = optim.Adam(model.parameters(), lr=1e-5, weight_decay=0.001)

# create scheduler
lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer)
lr_sch_config = dict(scheduler=lr_scheduler, monitor="validation.losses.total_loss")
# optimizier and lr sch - see pl.LightningModule.configure_optimizers return value for all options
optimizers_and_lr_schs = dict(optimizer=optimizer, lr_scheduler=lr_sch_config)

### Strart a training process
# create instance of PL module - FuseMedML generic version
pl_module = LightningModuleDefault(
    model_dir=model_dir,
    model=model,
    losses=losses,
    train_metrics=train_metrics,
    validation_metrics=validation_metrics,
    best_epoch_source=best_epoch_source,
    optimizers_and_lr_schs=optimizers_and_lr_schs,
)

# create lightining trainer.
pl_trainer = pl.Trainer(default_root_dir=model_dir, max_epochs=2, accelerator="gpu", devices=1)


# train
pl_trainer.fit(pl_module, train_dl, validation_dl)

### Imaging and Tabular data - concatenate directly with the image

In [None]:
### Generate Data
# Create an operation to add to the end of the current processing pipeline
class OpClinicalPadToImage(OpBase):
    """append the clinical data directly with the image"""

    def __call__(self, sample_dict: NDict) -> NDict:
        clinical_encoding = sample_dict["data.input.clinical.all"]
        image = sample_dict["data.input.img"]

        clinical_data_spatial = clinical_encoding.reshape((clinical_encoding.shape + (1, 1))).repeat(
            (1,) + image.shape[1:]
        )  # repeat and reshape to [num_features, H, W]
        image = torch.cat((image, clinical_data_spatial), dim=0)  # concat to get [num_features + 3, H, W]

        sample_dict["data.input.img"] = image

        return sample_dict


# add two steps to the pipeline. format (<operation>, kwargs for <operation>.__call__() method)
append_dyn_pipeline = [(OpClinicalEncodeing(), dict()), (OpClinicalPadToImage(), dict())]
train_dl, validation_dl = isic_2019_dataloaders(
    data_path=data_dir, cache_path=cache_dir, append_dyn_pipeline=append_dyn_pipeline, sample_ids=sample_ids
)

### Define model - this time add image with 16 channels (with the clinical data embbeded into the image)
model = ModelMultiHead(
    conv_inputs=(("data.input.img", 16),),
    backbone=BackboneInceptionResnetV2(input_channels_num=16, pretrained_weights_url=None),
    heads=[
        HeadGlobalPoolingClassifier(
            head_name="classification",
            dropout_rate=0.5,
            conv_inputs=[("model.backbone_features", 384)],
            layers_description=(256,),
            num_classes=8,
            pooling="avg",
        ),
    ],
)


# create optimizer
optimizer = optim.Adam(model.parameters(), lr=1e-5, weight_decay=0.001)

# create scheduler
lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer)
lr_sch_config = dict(scheduler=lr_scheduler, monitor="validation.losses.total_loss")
# optimizier and lr sch - see pl.LightningModule.configure_optimizers return value for all options
optimizers_and_lr_schs = dict(optimizer=optimizer, lr_scheduler=lr_sch_config)


# create instance of PL module - FuseMedML generic version
pl_module = LightningModuleDefault(
    model_dir="model_dir",
    model=model,
    losses=losses,
    train_metrics=train_metrics,
    validation_metrics=validation_metrics,
    best_epoch_source=best_epoch_source,
    optimizers_and_lr_schs=optimizers_and_lr_schs,
)


# create lightining trainer.
pl_trainer = pl.Trainer(default_root_dir="model_dir", max_epochs=2, accelerator="gpu", devices=1)
# train
pl_trainer.fit(pl_module, train_dl, validation_dl)