------------
## Installation Details - Google Colab

### **Enable GPU Support**

To use GPU through Google Colab, change the runtime mode to GPU:

From the "Runtime" menu select "Change Runtime Type", choose "GPU" from the drop-down menu and click "SAVE"
When asked, reboot the system.

### **Install FuseMedML**

In [41]:
# !git clone https://github.com/IBM/fuse-med-ml.git
# %cd fuse-med-ml
# !pip install -e .

### **Setup imports**

In [43]:
import logging
import os

import torch
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import torchvision.models as models
from torch.utils.data.dataloader import DataLoader
from torchvision import transforms

from fuse.analyzer.analyzer_default import FuseAnalyzerDefault
from fuse.data.dataset.dataset_wrapper import FuseDatasetWrapper
from fuse.data.sampler.sampler_balanced_batch import FuseSamplerBalancedBatch
from fuse.losses.loss_default import FuseLossDefault
from fuse.managers.callbacks.callback_metric_statistics import FuseMetricStatisticsCallback
from fuse.managers.callbacks.callback_tensorboard import FuseTensorboardCallback
from fuse.managers.callbacks.callback_time_statistics import FuseTimeStatisticsCallback
from fuse.managers.manager_default import FuseManagerDefault
from fuse.metrics.classification.metric_accuracy import FuseMetricAccuracy
from fuse.metrics.classification.metric_auc import FuseMetricAUC
from fuse.metrics.classification.metric_roc_curve import FuseMetricROCCurve
from fuse.models.model_wrapper import FuseModelWrapper
from fuse.utils.utils_debug import FuseUtilsDebug
from fuse.utils.utils_gpu import FuseUtilsGPU
from fuse.utils.utils_logger import fuse_logger_start

## **Setup environment**


##### **Debugger**


In [44]:
mode = 'default'  # Options: 'default', 'fast', 'debug', 'verbose', 'user'. See details in FuseUtilsDebug
debug = FuseUtilsDebug(mode)



##### **Output paths**
TODO: elaborate

In [45]:
ROOT = 'fuse_examples' # TODO: fill path here
PATHS = {'model_dir': os.path.join(ROOT, 'mnist/model_dir'),
         'force_reset_model_dir': True,  # If True will reset model dir automatically - otherwise will prompt 'are you sure' message.
         'cache_dir': os.path.join(ROOT, 'mnist/cache_dir'),
         'inference_dir': os.path.join(ROOT, 'mnist/infer_dir'),
         'analyze_dir': os.path.join(ROOT, 'mnist/analyze_dir')}

paths = PATHS

##### **Define Train Common Params**
TODO: elaborate: what is the use of those params? common options? how it serves the user?

In [46]:

# ============
# Data
# ============
TRAIN_COMMON_PARAMS = {}
TRAIN_COMMON_PARAMS['data.batch_size'] = 30
TRAIN_COMMON_PARAMS['data.train_num_workers'] = 8
TRAIN_COMMON_PARAMS['data.validation_num_workers'] = 8

# ===============
# Manager - Train
# ===============
TRAIN_COMMON_PARAMS['manager.train_params'] = {
    'device': 'cuda', 
    'num_epochs': 5,
    'virtual_batch_size': 1,  # number of batches in one virtual batch
    'start_saving_epochs': 10,  # first epoch to start saving checkpoints from
    'gap_between_saving_epochs': 5,  # number of epochs between saved checkpoint
}
TRAIN_COMMON_PARAMS['manager.best_epoch_source'] = {
    'source': 'metrics.accuracy',  # can be any key from 'epoch_results'
    'optimization': 'max',  # can be either min/max
    'on_equal_values': 'better',
    # can be either better/worse - whether to consider best epoch when values are equal
}
TRAIN_COMMON_PARAMS['manager.learning_rate'] = 1e-4
TRAIN_COMMON_PARAMS['manager.weight_decay'] = 0.001
TRAIN_COMMON_PARAMS['manager.resume_checkpoint_filename'] = None  # if not None, will try to load the checkpoint

train_params = TRAIN_COMMON_PARAMS

##### **Helper function**
TOOD: elaborate? delete?

In [47]:
def perform_softmax(output):
    if isinstance(output, torch.Tensor):  # validation
        logits = output
    else:  # train
        logits = output.logits
    cls_preds = F.softmax(logits, dim=1)
    return logits, cls_preds

### **Set logger**

In [48]:
fuse_logger_start(output_path=paths['model_dir'], console_verbose_level=logging.INFO)
lgr = logging.getLogger('Fuse')
lgr.info('Fuse Train', {'attrs': ['bold', 'underline']})
lgr.info(f'model_dir={paths["model_dir"]}', {'color': 'magenta'})
lgr.info(f'cache_dir={paths["cache_dir"]}', {'color': 'magenta'})

[4m[1mFuse Train[0m
[35mmodel_dir=fuse_examples/mnist/model_dir[0m
[35mcache_dir=fuse_examples/mnist/cache_dir[0m


## **Training the model**

##### **Define the logger**
TODO: elaborate, what is the use of it? common options?

In [49]:
fuse_logger_start(output_path=paths['model_dir'], console_verbose_level=logging.INFO)
lgr = logging.getLogger('Fuse')
lgr.info('Fuse Train', {'attrs': ['bold', 'underline']})

lgr.info(f'model_dir={paths["model_dir"]}', {'color': 'magenta'})
lgr.info(f'cache_dir={paths["cache_dir"]}', {'color': 'magenta'})

[4m[1mFuse Train[0m
[35mmodel_dir=fuse_examples/mnist/model_dir[0m
[35mcache_dir=fuse_examples/mnist/cache_dir[0m


##### **Data**

In [50]:
# ==============================================================================
# Data
# ==============================================================================
# Train Data
lgr.info(f'Train Data:', {'attrs': 'bold'})

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])
# Create dataset
torch_train_dataset = torchvision.datasets.MNIST(paths['cache_dir'], download=True, train=True, transform=transform)
# wrapping torch dataset
# FIXME: support also using torch dataset directly
train_dataset = FuseDatasetWrapper(name='train', dataset=torch_train_dataset, mapping=('image', 'label'))
train_dataset.create()
lgr.info(f'- Create sampler:')
sampler = FuseSamplerBalancedBatch(dataset=train_dataset,
                                balanced_class_name='data.label',
                                num_balanced_classes=10,
                                batch_size=train_params['data.batch_size'],
                                balanced_class_weights=None)
lgr.info(f'- Create sampler: Done')

# Create dataloader
train_dataloader = DataLoader(dataset=train_dataset, batch_sampler=sampler, num_workers=train_params['data.train_num_workers'])
lgr.info(f'Train Data: Done', {'attrs': 'bold'})

## Validation data
lgr.info(f'Validation Data:', {'attrs': 'bold'})
# Create dataset
torch_validation_dataset = torchvision.datasets.MNIST(paths['cache_dir'], download=True, train=False, transform=transform)
# wrapping torch dataset
validation_dataset = FuseDatasetWrapper(name='validation', dataset=torch_validation_dataset, mapping=('image', 'label'))
validation_dataset.create()

# dataloader
validation_dataloader = DataLoader(dataset=validation_dataset, batch_size=train_params['data.batch_size'],
                                num_workers=train_params['data.validation_num_workers'])
lgr.info(f'Validation Data: Done', {'attrs': 'bold'})

# ==============================================================================
# Model
# ==============================================================================
lgr.info('Model:', {'attrs': 'bold'})

torch_model = models.resnet18(num_classes=10)
# modify conv1 to support single channel image
torch_model.conv1 = torch.nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
# use adaptive avg pooling to support mnist low resolution images
torch_model.avgpool = torch.nn.AdaptiveAvgPool2d(1)

model = FuseModelWrapper(model=torch_model,
                        model_inputs=['data.image'],
                        post_forward_processing_function=perform_softmax,
                        model_outputs=['logits.classification', 'output.classification']
                        )

lgr.info('Model: Done', {'attrs': 'bold'})

# ====================================================================================
#  Loss
# ====================================================================================
losses = {
    'cls_loss': FuseLossDefault(pred_name='model.logits.classification', target_name='data.label', callable=F.cross_entropy, weight=1.0),
}

# ====================================================================================
# Metrics
# ====================================================================================
metrics = {
    'accuracy': FuseMetricAccuracy(pred_name='model.output.classification', target_name='data.label')
}

# =====================================================================================
#  Callbacks
# =====================================================================================
callbacks = [
    # default callbacks
    FuseTensorboardCallback(model_dir=paths['model_dir']),  # save statistics for tensorboard
    FuseMetricStatisticsCallback(output_path=paths['model_dir'] + "/metrics.csv"),  # save statistics a csv file
    FuseTimeStatisticsCallback(num_epochs=train_params['manager.train_params']['num_epochs'], load_expected_part=0.1)  # time profiler
]

# =====================================================================================
#  Manager - Train
# =====================================================================================
lgr.info('Train:', {'attrs': 'bold'})

# create optimizer
optimizer = optim.Adam(model.parameters(), lr=train_params['manager.learning_rate'], weight_decay=train_params['manager.weight_decay'])

# create learning scheduler
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer)

# train from scratch
manager = FuseManagerDefault(output_model_dir=paths['model_dir'], force_reset=paths['force_reset_model_dir'])
# Providing the objects required for the training process.
manager.set_objects(net=model,
                    optimizer=optimizer,
                    losses=losses,
                    metrics=metrics,
                    best_epoch_source=train_params['manager.best_epoch_source'],
                    lr_scheduler=scheduler,
                    callbacks=callbacks,
                    train_params=train_params['manager.train_params'])

## Continue training
if train_params['manager.resume_checkpoint_filename'] is not None:
    # Loading the checkpoint including model weights, learning rate, and epoch_index.
    manager.load_checkpoint(checkpoint=train_params['manager.resume_checkpoint_filename'], mode='train')

# Start training
manager.train(train_dataloader=train_dataloader, validation_dataloader=validation_dataloader)

lgr.info('Train: Done', {'attrs': 'bold'})


[1mTrain Data:[0m
- Create sampler:
- Create sampler: Done
[1mTrain Data: Done[0m
[1mValidation Data:[0m
[1mValidation Data: Done[0m
[1mModel:[0m
[1mModel: Done[0m
[1mTrain:[0m
[33mKey lr_sch_target not found in config parameter, setting value to default (train.losses.total_loss)[0m
[1m[31mTotal number of parameters in model:11,175,370, trainable parameters:11,175,370[0m
Train Dataset Summary:
Class = <class 'fuse.data.dataset.dataset_wrapper.FuseDatasetWrapper'>
Processors:
------------------------
<fuse.data.dataset.dataset_wrapper.DatasetProcessor object at 0x7ff06c0c8d10>
Cache destination:
------------------
None
Augmentor:
----------
None
Data source:
------------
FuseDataSourceFromList - 60000 samples

Sample keys:
------------
['data.descriptor', 'data.image', 'data.label']
Basic Data Statistic:
-------------------

Validation Dataset Summary:
Class = <class 'fuse.data.dataset.dataset_wrapper.FuseDatasetWrapper'>
Processors:
------------------------
<fuse.dat

100%|██████████| 334/334 [00:14<00:00, 23.52it/s]


Stats for Pre-Training:
losses.cls_loss      = 2.464972653788721
losses.total_loss    = 2.464972653788721
metrics.accuracy     = 0.0951
Start training on epoch 1


100%|██████████| 2248/2248 [02:48<00:00, 13.35it/s]


Start validation on epoch 1


100%|██████████| 334/334 [00:13<00:00, 23.94it/s]


[1m[32mThis is the best epoch ever (metrics.accuracy = 0.8209)[0m
[1m[4mStats for epoch: 1 (Currently the best epoch for source metrics.accuracy!)[0m

------------------------------------------------------------------------------------------------------------
|                          | Best Epoch Value         | Current Epoch Validation | Current Epoch Train      |
------------------------------------------------------------------------------------------------------------
| losses.cls_loss          | 0.6196                   | 0.6196                   | 1.0614                   |
------------------------------------------------------------------------------------------------------------
| losses.total_loss        | 0.6196                   | 0.6196                   | 1.0614                   |
------------------------------------------------------------------------------------------------------------
| metrics.accuracy         | 0.8209                   | 0.8209               

100%|██████████| 2248/2248 [02:48<00:00, 13.33it/s]


Start validation on epoch 2


100%|██████████| 334/334 [00:14<00:00, 23.36it/s]


[1m[32mThis is the best epoch ever (metrics.accuracy = 0.845)[0m
[1m[4mStats for epoch: 2 (Currently the best epoch for source metrics.accuracy!)[0m

------------------------------------------------------------------------------------------------------------
|                          | Best Epoch Value         | Current Epoch Validation | Current Epoch Train      |
------------------------------------------------------------------------------------------------------------
| losses.cls_loss          | 0.5730                   | 0.5730                   | 0.8927                   |
------------------------------------------------------------------------------------------------------------
| losses.total_loss        | 0.5730                   | 0.5730                   | 0.8927                   |
------------------------------------------------------------------------------------------------------------
| metrics.accuracy         | 0.8450                   | 0.8450                

100%|██████████| 2248/2248 [02:49<00:00, 13.26it/s]


Start validation on epoch 3


100%|██████████| 334/334 [00:14<00:00, 23.26it/s]


[1m[32mThis is the best epoch ever (metrics.accuracy = 0.8553)[0m
[1m[4mStats for epoch: 3 (Currently the best epoch for source metrics.accuracy!)[0m

------------------------------------------------------------------------------------------------------------
|                          | Best Epoch Value         | Current Epoch Validation | Current Epoch Train      |
------------------------------------------------------------------------------------------------------------
| losses.cls_loss          | 0.5509                   | 0.5509                   | 0.9024                   |
------------------------------------------------------------------------------------------------------------
| losses.total_loss        | 0.5509                   | 0.5509                   | 0.9024                   |
------------------------------------------------------------------------------------------------------------
| metrics.accuracy         | 0.8553                   | 0.8553               

100%|██████████| 2248/2248 [02:50<00:00, 13.16it/s]


Start validation on epoch 4


100%|██████████| 334/334 [00:14<00:00, 22.95it/s]


Stats for epoch: 4 (Best epoch is 3 for source metrics.accuracy)

------------------------------------------------------------------------------------------------------------
|                          | Best Epoch Value         | Current Epoch Validation | Current Epoch Train      |
------------------------------------------------------------------------------------------------------------
| losses.cls_loss          | 0.5509                   | 0.5529                   | 0.8478                   |
------------------------------------------------------------------------------------------------------------
| losses.total_loss        | 0.5509                   | 0.5529                   | 0.8478                   |
------------------------------------------------------------------------------------------------------------
| metrics.accuracy         | 0.8553                   | 0.8398                   | 0.7307                   |
----------------------------------------------------------

## **Infer**

##### **Define Infer Common Params**


In [51]:
INFER_COMMON_PARAMS = {}
INFER_COMMON_PARAMS['infer_filename'] = 'validation_set_infer.gz'
INFER_COMMON_PARAMS['checkpoint'] = 'best'  # Fuse TIP: possible values are 'best', 'last' or epoch_index.

infer_common_params = INFER_COMMON_PARAMS

##### **Define logger**

In [52]:
fuse_logger_start(output_path=paths['inference_dir'], console_verbose_level=logging.INFO)
lgr = logging.getLogger('Fuse')
lgr.info('Fuse Inference', {'attrs': ['bold', 'underline']})
lgr.info(f'infer_filename={os.path.join(paths["inference_dir"], infer_common_params["infer_filename"])}', {'color': 'magenta'})

[4m[1mFuse Inference[0m
[35minfer_filename=fuse_examples/mnist/infer_dir/validation_set_infer.gz[0m


In [53]:
## Data
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])
# Create dataset
torch_validation_dataset = torchvision.datasets.MNIST(paths['cache_dir'], download=True, train=False, transform=transform)
# wrapping torch dataset
validation_dataset = FuseDatasetWrapper(name='validation', dataset=torch_validation_dataset, mapping=('image', 'label'))
validation_dataset.create()
# dataloader
validation_dataloader = DataLoader(dataset=validation_dataset, collate_fn=validation_dataset.collate_fn, batch_size=2, num_workers=2)

## Manager for inference
manager = FuseManagerDefault()
output_columns = ['model.output.classification', 'data.label']
manager.infer(data_loader=validation_dataloader,
                input_model_dir=paths['model_dir'],
                checkpoint=infer_common_params['checkpoint'],
                output_columns=output_columns,
                output_file_name=os.path.join(paths["inference_dir"], infer_common_params["infer_filename"]))


Loading model from: fuse_examples/mnist/model_dir
[33mLoading checkpoint file: fuse_examples/mnist/model_dir/checkpoint_best_0_epoch.pth. values_to_resume all[0m
[33mKey device not found in config parameter, setting value to default (cuda)[0m
[33mKey virtual_batch_size not found in config parameter, setting value to default (1)[0m


100%|██████████| 5000/5000 [01:01<00:00, 81.63it/s]


Save inference results into fuse_examples/mnist/infer_dir/validation_set_infer.gz


Unnamed: 0,descriptor,model.output.classification,data.label
0,"(validation, 0)","[0.0016879884, 0.0015029196, 0.0051113814, 0.0...",7
1,"(validation, 1)","[0.048130117, 0.024796639, 0.58028364, 0.14518...",2
2,"(validation, 2)","[0.0019267538, 0.981097, 0.00044614665, 0.0015...",1
3,"(validation, 3)","[0.99099356, 0.00012841726, 0.00017717913, 1.5...",0
4,"(validation, 4)","[0.052714493, 0.026158107, 0.021874268, 0.0045...",4
...,...,...,...
9995,"(validation, 9995)","[0.00034772445, 0.0002892727, 0.9494368, 0.010...",2
9996,"(validation, 9996)","[0.054082025, 0.012338727, 0.115864225, 0.6180...",3
9997,"(validation, 9997)","[0.00045667685, 0.010034586, 0.0019736052, 0.0...",4
9998,"(validation, 9998)","[0.010676819, 0.033213064, 0.0027186736, 0.005...",5


## **Analyze**

##### **Analyze Infer Common Params**


In [54]:
ANALYZE_COMMON_PARAMS = {}
ANALYZE_COMMON_PARAMS['infer_filename'] = INFER_COMMON_PARAMS['infer_filename']
ANALYZE_COMMON_PARAMS['output_filename'] = os.path.join(PATHS['analyze_dir'], 'all_metrics')

analyze_common_params = ANALYZE_COMMON_PARAMS

##### **Define logger**

In [55]:
fuse_logger_start(output_path=None, console_verbose_level=logging.INFO)
lgr = logging.getLogger('Fuse')
lgr.info('Fuse Analyze', {'attrs': ['bold', 'underline']})

[4m[1mFuse Analyze[0m


In [56]:

# metrics
metrics = {
    'accuracy': FuseMetricAccuracy(pred_name='model.output.classification', target_name='data.label'),
    'roc': FuseMetricROCCurve(pred_name='model.output.classification', target_name='data.label', output_filename=os.path.join(paths['inference_dir'], 'roc_curve.png')),
    'auc': FuseMetricAUC(pred_name='model.output.classification', target_name='data.label')
}

# create analyzer
analyzer = FuseAnalyzerDefault()

# run
# FIXME: simplify analyze interface for this case
results = analyzer.analyze(gt_processors={},
                data_pickle_filename=os.path.join(paths["inference_dir"], analyze_common_params["infer_filename"]),
                metrics=metrics,
                output_filename=analyze_common_params['output_filename'])


100%|██████████| 5000/5000 [00:00<00:00, 5117.20it/s]


[4m[1mResults[0m
[1m
Metric accuracy:[0m
0.8553
[1m
Metric roc:[0m
[1m
Metric auc:[0m
class_0: 0.9964835512919138
class_1: 0.9974063224431078
class_2: 0.9866720345898249
class_3: 0.9845814381215652
class_4: 0.9909174635567064
class_5: 0.9815306752453372
class_6: 0.987555407172005
class_7: 0.9798445574732284
class_8: 0.9744527672964846
class_9: 0.979112908746209
macro_avg: 0.9858557125936382
[1m[35m
Analyzer done. Results saved in fuse_examples/mnist/analyze_dir/all_metrics  [0m
[1m[35m
Analyzer done.[0m
