diff --git a/monai/networks/nets/milmodel.py b/monai/networks/nets/milmodel.py index 75cd9cf53e..d7b8a7ec19 100644 --- a/monai/networks/nets/milmodel.py +++ b/monai/networks/nets/milmodel.py @@ -11,23 +11,24 @@ class MILModel(nn.Module): """ Multiple Instance Learning (MIL) model, with a backbone classification model. - Currently, it only works for 2D images, typical use case is for classification of the + Currently, it only works for 2D images, a typical use case is for classification of the digital pathology whole slide images. The expected shape of input data is `[B, N, C, H, W]`, - where `B` is the batch_size of PyTorch Dataloader and `N` is the number of the instances + where `B` is the batch_size of PyTorch Dataloader and `N` is the number of instances extracted from every original image in the batch. A tutorial example is available at: https://github.com/Project-MONAI/tutorials/tree/master/pathology/multiple_instance_learning. Args: num_classes: number of output classes. - mil_mode: MIL algorithm, available values: - "mean" - average features from all instances, equivalent to pure CNN (non MIL). - "max - retain only the instance with the max probability for loss calculation. - "att" - attention based MIL https://arxiv.org/abs/1802.04712. - "att_trans" - transformer MIL https://arxiv.org/abs/2111.01556. - "att_trans_pyramid" - transformer pyramid MIL https://arxiv.org/abs/2111.01556. - Defaults to ``att``. + mil_mode: MIL algorithm, available values (Defaults to ``"att"``): + + - ``"mean"`` - average features from all instances, equivalent to pure CNN (non MIL). + - ``"max"`` - retain only the instance with the max probability for loss calculation. + - ``"att"`` - attention based MIL https://arxiv.org/abs/1802.04712. + - ``"att_trans"`` - transformer MIL https://arxiv.org/abs/2111.01556. + - ``"att_trans_pyramid"`` - transformer pyramid MIL https://arxiv.org/abs/2111.01556. + pretrained: init backbone with pretrained weights, defaults to ``True``. - backbone: Backbone classifier CNN (either None, nn.Module that returns features, + backbone: Backbone classifier CNN (either ``None``, a ``nn.Module`` that returns features, or a string name of a torchvision model). Defaults to ``None``, in which case ResNet50 is used. backbone_num_features: Number of output features of the backbone CNN