Skip to content

Latest commit



231 lines (174 loc) · 7.83 KB

File metadata and controls

231 lines (174 loc) · 7.83 KB

Integration guides

Aim integrates seamlessly with your favorite ML frameworks - Pytorch Ignite, Pytorch Lightning, Hugging Face and others. Basic integration guides can be found at Quick Start section.

In this section we're going to deep-dive into the ways we can extend the basic loggers, manipulate them to track a lot more. The basic loggers can track specific metrics and hyper-params only.

There are two ways Aim callbacks/adapters/loggers can be extended:

  • by deriving and overriding the main methods that are responsible for logging.
  • by using public property called experiment which gives access to underlying aim.Run object to easily track new metrics, params and other metadata that would benefit your project.

Pytorch Ignite

Both callback extension mechanisms are available with Pytorch Ignite. In the example below you'll see how to use the experiment property to track confusion matrix as an image using aim.Image after the training is completed.

Here is an example colab notebook.

from aim import Image
from aim.pytorch_ignite import AimLogger

import matplotlib.pyplot as plt
import seaborn as sns

# Create a logger
aim_logger = AimLogger()
def log_confusion_matrix(trainer):
    metrics = val_evaluator.state.metrics
    cm = metrics['cm']
    cm = cm.numpy()
    cm = cm.astype(int)
    classes = ['T-shirt/top','Trouser','Pullover','Dress','Coat','Sandal','Shirt','Sneaker','Bag','Ankle Boot']
    fig, ax = plt.subplots(figsize=(10,10))  
    ax= plt.subplot()
    sns.heatmap(cm, annot=True, ax = ax,fmt="d")
    # labels, title and ticks
    ax.set_xlabel('Predicted labels')
    ax.set_ylabel('True labels') 
    ax.set_title('Confusion Matrix') 
    aim_logger.experiment.track(Image(fig), name='cm_training_end')

With Pytorch Ignite there's also a 3rd approach to extend the integration. For example Pytorch Ignite's Tensorboard logger provides a possibility to track model's gradients and weights as histograms. Same can be achieved with Aim

from typing import Optional, Union

import torch.nn as nn
from ignite.contrib.handlers.base_logger import BaseWeightsHistHandler
from ignite.engine import Engine, Events

from aim.pytorch_ignite import AimLogger
from aim import Distribution

class AimGradsHistHandler(BaseWeightsHistHandler):
    def __init__(self, model: nn.Module, tag: Optional[str] = None):
        super(GradsHistHandler, self).__init__(model, tag=tag)

    def __call__(self, engine: Engine, logger: AimLogger, event_name: Union[str, Events]) -> None:
        global_step = engine.state.get_event_attrib_value(event_name)
        context = {'subset': self.tag} if self.tag else {}
        for name, p in self.model.named_parameters():
            if p.grad is None:
            name = name.replace(".", "/")

# Create a logger
aim_logger = AimLogger()

# Attach the logger to the trainer to log model's weights norm after each iteration

Pytorch Lightning

In the example provided in the Aim GitHub repo using PL + Aim there's already a reference how to customize an integration.

    def test_step(self, batch, batch_idx):
        # Track metrics manually
        self.logger.experiment.track(1, name='manually_tracked_metric')

So you can track lots of metadata at each iteration of test step: images, texts, whatever is needed by you and supported by Aim.

Hugging Face

Here is how to extend the basic Hugging Face logger. Below is an example of a CustomCallback that's derived from the AimCallback. The main HF method here is the on_log() that's overriden.

This allows us to track any str object that is passed to on_log() method as aim.Text.

from aim.hugging_face import AimCallback
from aim import Text

class CustomCallback(AimCallback):
    def on_log(self, args, state, control,
               model=None, logs=None, **kwargs):
        super().on_log(args, state, control, model, logs, **kwargs)

        context = {
            'subset': self._current_shift,
        for log_name, log_value in logs.items():
            if isinstance(log_value, str):
                self.experiment.track(Text(log_value), name=log_name, context=context)


Here is how to track confusion matrices with Aim while extending the default callback provided for tf.keras. We have taken and adapted this example. to Aim. Here is how it looks:

from aim.tensorflow import AimCallback

class CustomImageTrackingCallback(AimCallback):
    def __init__(self, data):
        super().__init__() = data

    def on_epoch_end(self, epoch, logs=None):
        super().on_epoch_end(epoch, logs)
        from aim import Image
        # Use the model to predict the values from the validation dataset.
        test_pred_raw = self.model.predict(test_images)
        test_pred = np.argmax(test_pred_raw, axis=1)

        # Calculate the confusion matrix.
        cm = sklearn.metrics.confusion_matrix(test_labels, test_pred)
        # Log the confusion matrix as an image summary.
        figure = plot_confusion_matrix(cm, class_names=class_names)
        cm_image = Image(figure)

        # Log the confusion matrix as an Aim image.
        self.experiment.track(cm_image,"Confusion Matrix", step=epoch)

aim_callback = CustomImageTrackingCallback()
    verbose=0, # Suppress chatty output
    validation_data=(test_images, test_labels),


Here is how to override the AimCallback for XGBoost.

from aim import Text
from aim.xgboost import AimCallback

class CustomCallback(AimCallback):

    def after_iteration(self, model, epoch, evals_log):
        for data, metric in evals_log.items():
            for metric_name, log in metric.items():
                self.experiment.track(Text(log), name=metric_name)
        return super().after_iteration(model, epoch, evals_log)


Catboost's .fit() has log_cout parameter which can be used to redirect log output into a custom object which has write attribute. Our logger is an object which implements write method to parse log string according to its content. Thus, most of the log output will be ignored by our parser logic, but you still can code up your own logic on top of ours to fill the gap for your needs.

from aim.catboost import AimLogger

class CustomLogger(AimLogger):

    def write(self, log):
        # Process the log string through our parser

        # Do your own parsing
        log = log.strip().split()
        if log[1] == 'bin:':
            value_bin = log[1][4:]
            value_score = self._to_number(log[3])
            self.experiment.track(value_score, name='score')


Here is how to override the AimCallback for LightGBM.

from aim.lightgbm import AimCallback

class CustomCallback(AimCallback):

    def before_tracking(self, env):
        for item in env.evaluation_result_list:
            # manipulate item here

    def after_tracking(self, env):
        # do any other action if necessary after tracking value