------------
## Installation Details - Google Colab

### **Enable GPU Support**

To use GPU through Google Colab, change the runtime mode to GPU:

From the "Runtime" menu select "Change Runtime Type", choose "GPU" from the drop-down menu and click "SAVE"
When asked, reboot the system.

### **Install FuseMedML**

In [None]:
# !git clone https://github.com/IBM/fuse-med-ml.git
%cd fuse-med-ml
# !pip install -e .
%cd fuse_examples/tutorials/multimodality_image_clinical

### **Setup imports**

In [None]:
import logging
import os

In [None]:
###########################################################################################################
# Fuse
###########################################################################################################
##########################################
# Debug modes
##########################################
mode = 'default'  # Options: 'default', 'fast', 'debug', 'verbose', 'user'. See details in FuseUtilsDebug
debug = FuseUtilsDebug(mode)

##########################################
# Output Paths
##########################################
ROOT = 'examples' # TODO: fill path here
PATHS = {'model_dir': os.path.join(ROOT, 'mnist/model_dir'),
         'force_reset_model_dir': True,  # If True will reset model dir automatically - otherwise will prompt 'are you sure' message.
         'cache_dir': os.path.join(ROOT, 'mnist/cache_dir'),
         'inference_dir': os.path.join(ROOT, 'mnist/infer_dir'),
         'analyze_dir': os.path.join(ROOT, 'mnist/analyze_dir')}

##########################################
# Train Common Params
##########################################
# ============
# Data
# ============
TRAIN_COMMON_PARAMS = {}
TRAIN_COMMON_PARAMS['data.batch_size'] = 30
TRAIN_COMMON_PARAMS['data.train_num_workers'] = 8
TRAIN_COMMON_PARAMS['data.validation_num_workers'] = 8

# ===============
# Manager - Train
# ===============
TRAIN_COMMON_PARAMS['manager.train_params'] = {
    'device': 'cuda', 
    'num_epochs': 5,
    'virtual_batch_size': 1,  # number of batches in one virtual batch
    'start_saving_epochs': 10,  # first epoch to start saving checkpoints from
    'gap_between_saving_epochs': 5,  # number of epochs between saved checkpoint
}
TRAIN_COMMON_PARAMS['manager.best_epoch_source'] = {
    'source': 'metrics.accuracy',  # can be any key from 'epoch_results'
    'optimization': 'max',  # can be either min/max
    'on_equal_values': 'better',
    # can be either better/worse - whether to consider best epoch when values are equal
}
TRAIN_COMMON_PARAMS['manager.learning_rate'] = 1e-4
TRAIN_COMMON_PARAMS['manager.weight_decay'] = 0.001
TRAIN_COMMON_PARAMS['manager.resume_checkpoint_filename'] = None  # if not None, will try to load the checkpoint


def perform_softmax(output):
    if isinstance(output, torch.Tensor):  # validation
        logits = output
    else:  # train
        logits = output.logits
    cls_preds = F.softmax(logits, dim=1)
    return logits, cls_preds


### **Set logger**

In [None]:
fuse_logger_start(output_path=paths['model_dir'], console_verbose_level=logging.INFO)
lgr = logging.getLogger('Fuse')
lgr.info('Fuse Train', {'attrs': ['bold', 'underline']})
lgr.info(f'model_dir={paths["model_dir"]}', {'color': 'magenta'})
lgr.info(f'cache_dir={paths["cache_dir"]}', {'color': 'magenta'})

### **Train**

In [None]:
# ==============================================================================
# Logger
# ==============================================================================
fuse_logger_start(output_path=paths['model_dir'], console_verbose_level=logging.INFO)
lgr = logging.getLogger('Fuse')
lgr.info('Fuse Train', {'attrs': ['bold', 'underline']})

lgr.info(f'model_dir={paths["model_dir"]}', {'color': 'magenta'})
lgr.info(f'cache_dir={paths["cache_dir"]}', {'color': 'magenta'})

# ==============================================================================
# Data
# ==============================================================================
# Train Data
lgr.info(f'Train Data:', {'attrs': 'bold'})

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])
# Create dataset
torch_train_dataset = torchvision.datasets.MNIST(paths['cache_dir'], download=True, train=True, transform=transform)
# wrapping torch dataset
# FIXME: support also using torch dataset directly
train_dataset = FuseDatasetWrapper(name='train', dataset=torch_train_dataset, mapping=('image', 'label'))
train_dataset.create()
lgr.info(f'- Create sampler:')
sampler = FuseSamplerBalancedBatch(dataset=train_dataset,
                                balanced_class_name='data.label',
                                num_balanced_classes=10,
                                batch_size=train_params['data.batch_size'],
                                balanced_class_weights=None)
lgr.info(f'- Create sampler: Done')

# Create dataloader
train_dataloader = DataLoader(dataset=train_dataset, batch_sampler=sampler, num_workers=train_params['data.train_num_workers'])
lgr.info(f'Train Data: Done', {'attrs': 'bold'})

## Validation data
lgr.info(f'Validation Data:', {'attrs': 'bold'})
# Create dataset
torch_validation_dataset = torchvision.datasets.MNIST(paths['cache_dir'], download=True, train=False, transform=transform)
# wrapping torch dataset
validation_dataset = FuseDatasetWrapper(name='validation', dataset=torch_validation_dataset, mapping=('image', 'label'))
validation_dataset.create()

# dataloader
validation_dataloader = DataLoader(dataset=validation_dataset, batch_size=train_params['data.batch_size'],
                                num_workers=train_params['data.validation_num_workers'])
lgr.info(f'Validation Data: Done', {'attrs': 'bold'})

# ==============================================================================
# Model
# ==============================================================================
lgr.info('Model:', {'attrs': 'bold'})

torch_model = models.resnet18(num_classes=10)
# modify conv1 to support single channel image
torch_model.conv1 = torch.nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
# use adaptive avg pooling to support mnist low resolution images
torch_model.avgpool = torch.nn.AdaptiveAvgPool2d(1)

model = FuseModelWrapper(model=torch_model,
                        model_inputs=['data.image'],
                        post_forward_processing_function=perform_softmax,
                        model_outputs=['logits.classification', 'output.classification']
                        )

lgr.info('Model: Done', {'attrs': 'bold'})

# ====================================================================================
#  Loss
# ====================================================================================
losses = {
    'cls_loss': FuseLossDefault(pred_name='model.logits.classification', target_name='data.label', callable=F.cross_entropy, weight=1.0),
}

# ====================================================================================
# Metrics
# ====================================================================================
metrics = {
    'accuracy': FuseMetricAccuracy(pred_name='model.output.classification', target_name='data.label')
}

# =====================================================================================
#  Callbacks
# =====================================================================================
callbacks = [
    # default callbacks
    FuseTensorboardCallback(model_dir=paths['model_dir']),  # save statistics for tensorboard
    FuseMetricStatisticsCallback(output_path=paths['model_dir'] + "/metrics.csv"),  # save statistics a csv file
    FuseTimeStatisticsCallback(num_epochs=train_params['manager.train_params']['num_epochs'], load_expected_part=0.1)  # time profiler
]

# =====================================================================================
#  Manager - Train
# =====================================================================================
lgr.info('Train:', {'attrs': 'bold'})

# create optimizer
optimizer = optim.Adam(model.parameters(), lr=train_params['manager.learning_rate'], weight_decay=train_params['manager.weight_decay'])

# create learning scheduler
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer)

# train from scratch
manager = FuseManagerDefault(output_model_dir=paths['model_dir'], force_reset=paths['force_reset_model_dir'])
# Providing the objects required for the training process.
manager.set_objects(net=model,
                    optimizer=optimizer,
                    losses=losses,
                    metrics=metrics,
                    best_epoch_source=train_params['manager.best_epoch_source'],
                    lr_scheduler=scheduler,
                    callbacks=callbacks,
                    train_params=train_params['manager.train_params'])

## Continue training
if train_params['manager.resume_checkpoint_filename'] is not None:
    # Loading the checkpoint including model weights, learning rate, and epoch_index.
    manager.load_checkpoint(checkpoint=train_params['manager.resume_checkpoint_filename'], mode='train')

# Start training
manager.train(train_dataloader=train_dataloader, validation_dataloader=validation_dataloader)

lgr.info('Train: Done', {'attrs': 'bold'})


### **Infer**

In [None]:
######################################
# Inference Common Params
######################################
INFER_COMMON_PARAMS = {}
INFER_COMMON_PARAMS['infer_filename'] = 'validation_set_infer.gz'
INFER_COMMON_PARAMS['checkpoint'] = 'best'  # Fuse TIP: possible values are 'best', 'last' or epoch_index.


# ---
#### Logger
fuse_logger_start(output_path=paths['inference_dir'], console_verbose_level=logging.INFO)
lgr = logging.getLogger('Fuse')
lgr.info('Fuse Inference', {'attrs': ['bold', 'underline']})
lgr.info(f'infer_filename={os.path.join(paths["inference_dir"], infer_common_params["infer_filename"])}', {'color': 'magenta'})

## Data
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])
# Create dataset
torch_validation_dataset = torchvision.datasets.MNIST(paths['cache_dir'], download=True, train=False, transform=transform)
# wrapping torch dataset
validation_dataset = FuseDatasetWrapper(name='validation', dataset=torch_validation_dataset, mapping=('image', 'label'))
validation_dataset.create()
# dataloader
validation_dataloader = DataLoader(dataset=validation_dataset, collate_fn=validation_dataset.collate_fn, batch_size=2, num_workers=2)

## Manager for inference
manager = FuseManagerDefault()
output_columns = ['model.output.classification', 'data.label']
manager.infer(data_loader=validation_dataloader,
                input_model_dir=paths['model_dir'],
                checkpoint=infer_common_params['checkpoint'],
                output_columns=output_columns,
                output_file_name=os.path.join(paths["inference_dir"], infer_common_params["infer_filename"]))


### **Analyze**

In [None]:
fuse_logger_start(output_path=None, console_verbose_level=logging.INFO)
lgr = logging.getLogger('Fuse')
lgr.info('Fuse Analyze', {'attrs': ['bold', 'underline']})

# metrics
metrics = {
    'accuracy': FuseMetricAccuracy(pred_name='model.output.classification', target_name='data.label'),
    'roc': FuseMetricROCCurve(pred_name='model.output.classification', target_name='data.label', output_filename=os.path.join(paths['inference_dir'], 'roc_curve.png')),
    'auc': FuseMetricAUC(pred_name='model.output.classification', target_name='data.label')
}

# create analyzer
analyzer = FuseAnalyzerDefault()

# run
# FIXME: simplify analyze interface for this case
results = analyzer.analyze(gt_processors={},
                data_pickle_filename=os.path.join(paths["inference_dir"], analyze_common_params["infer_filename"]),
                metrics=metrics,
                output_filename=analyze_common_params['output_filename'])
