In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
cd /content/drive/MyDrive/

/content/drive/MyDrive


In [None]:
"""
(C) Copyright 2021 IBM Corp.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
   http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
Created on June 30, 2021
"""

from typing import Tuple, Any

import torch.nn as nn
from torch import Tensor
from torch.hub import load_state_dict_from_url
from torchvision.models.video.resnet import VideoResNet, BasicBlock, Conv3DSimple, BasicStem, model_urls


class FuseBackboneResnet3D(VideoResNet):
    """
    3D model classifier (ResNet architecture"
    """

    def __init__(self, pretrained: bool = False, in_channels: int = 2, name: str = "r3d_18") -> None:
        """
        Create 3D ResNet model
        :param pretrained: Use pretrained weights
        :param in_channels: number of input channels
        :param name: model name. currently only 'r3d_18' is supported
        """
        # init parameters per required backbone
        init_parameters = {
            'r3d_18': {'block': BasicBlock,
                       'conv_makers': [Conv3DSimple] * 4,
                       'layers': [2, 2, 2, 2],
                       'stem': BasicStem},
        }[name]
        # init original model
        super().__init__(**init_parameters)

        # load pretrained parameters if required
        if pretrained:
            state_dict = load_state_dict_from_url(model_urls[name])
            self.load_state_dict(state_dict)

        # save input parameters
        self.pretrained = pretrained
        self.in_channels = in_channels
        # override the first convolution layer to support any number of input channels
        self.stem = nn.Sequential(
            nn.Conv3d(self.in_channels, 64, kernel_size=(3, 7, 7), stride=(1, 2, 2),
                      padding=(1, 3, 3), bias=False),
            nn.BatchNorm3d(64),
            nn.ReLU(inplace=True)
        )

    def features(self, x: Tensor) -> Any:
        """
        Extract spatial features - given a 3D tensor
        :param x: Input tensor - shape: [batch_size, channels, z, y, x]
        :return: spatial features - shape [batch_size, n_features, z', y', x']
        """
        x = self.stem(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        return x

    def forward(self, x: Tensor) -> Tuple[Tensor, None, None, None]:  # type: ignore
        """
        Forward pass. 3D global classification given a volume
        :param x: Input volume. shape: [batch_size, channels, z, y, x]
        :return: logits for global classification. shape: [batch_size, n_classes].
        """
        x = self.features(x)
        return x

In [None]:
model=FuseBackboneResnet3D(pretrained=True)
import torch
inp=torch.rand(2,2,10,128,128)
out=model(inp)
print(out.shape)

torch.Size([2, 512, 2, 8, 8])


In [None]:
gmp = nn.AdaptiveMaxPool3d(output_size=1)
outm=gmp(out)
print(outm.shape)

torch.Size([2, 512, 1, 1, 1])


In [None]:
from contextlib import redirect_stderr
import torch.nn as nn
gmp = nn.AdaptiveMaxPool3d(output_size=1)
outm=gmp(out)
outms=torch.squeeze(outm,axis=2)
print(outms.shape)
rr=torch.rand(1,512,1,1,1)
c1=nn.Conv3d(512, 256, kernel_size=1)
outc=c1(rr)
print(outc.shape)
conv_classifier_3d = nn.Sequential(nn.Conv3d(512, 256, kernel_size=1),
            nn.ReLU(),nn.Dropout3d(p=0.5), nn.Conv3d(256, 3, kernel_size=1),
            )
outcm=conv_classifier_3d(outm)
#outclass=outcm(outm)
print(outcm.shape)
do = nn.Dropout3d(p=0.5)
logits=outcm
logits = logits.squeeze(dim=4)
logits = logits.squeeze(dim=3)
logits = logits.squeeze(dim=2)
print(logits.shape)

torch.Size([1, 512, 1, 1])
torch.Size([1, 256, 1, 1, 1])
torch.Size([1, 3, 1, 1, 1])
torch.Size([1, 3])


In [None]:
print(outm.shape)

torch.Size([1, 512, 1, 1, 1])


In [None]:
from typing import Optional, Sequence
import torch.nn as nn
class ClassifierMLP(nn.Module):
    def __init__(self, in_ch: int, num_classes: Optional[int], layers_description: Sequence[int]=(256,128), dropout_rate: float = 0.1):
        super().__init__()
        layer_list = []
        layer_list.append(nn.Linear(in_ch, layers_description[0]))
        layer_list.append(nn.ReLU())
        if dropout_rate is not None and dropout_rate > 0:
            layer_list.append(nn.Dropout(p=dropout_rate))
        last_layer_size = layers_description[0]
        for curr_layer_size in layers_description[1:]:
            layer_list.append(nn.Linear(last_layer_size, curr_layer_size))
            layer_list.append(nn.ReLU())
            if dropout_rate is not None and dropout_rate > 0:
                layer_list.append(nn.Dropout(p=dropout_rate))
            last_layer_size = curr_layer_size
        
        if num_classes is not None:
            layer_list.append(nn.Linear(last_layer_size, num_classes))
        
        self.classifier = nn.Sequential(*layer_list)

    def forward(self, x):
        x = self.classifier(x)
        return x
modelnlpnew=ClassifierMLP(11,num_classes=None)
#print(modelnlpnew)

In [None]:
modelnlpnew=ClassifierMLP(11,num_classes=None)
inp=torch.rand(2,11)
outnew=modelnlpnew(inp)
#outnew=model(inp)
print(outnew.shape)
features = outnew.reshape(outnew.shape + (1,1,1))
print(features.shape)
outm
global_features = torch.cat((outm, features), dim=1)
print(global_features.shape)

torch.Size([2, 128])
torch.Size([2, 128, 1, 1, 1])
torch.Size([2, 640, 1, 1, 1])


In [None]:
class ClassifierMLP(nn.Module):
    def __init__(self, in_ch: 512, num_classes: 3, layers_description=[256], dropout_rate: float = 0.1):
        super().__init__()
        layer_list = []
        layer_list.append(nn.Linear(in_ch, layers_description[0]))
        layer_list.append(nn.ReLU())
        if dropout_rate is not None and dropout_rate > 0:
            layer_list.append(nn.Dropout(p=dropout_rate))
        last_layer_size = layers_description[0]
        for curr_layer_size in layers_description[1:]:
            layer_list.append(nn.Linear(last_layer_size, curr_layer_size))
            layer_list.append(nn.ReLU())
            if dropout_rate is not None and dropout_rate > 0:
                layer_list.append(nn.Dropout(p=dropout_rate))
            last_layer_size = curr_layer_size
        
        if num_classes is not None:
            layer_list.append(nn.Linear(last_layer_size, num_classes))
        
        self.classifier = nn.Sequential(*layer_list)

    def forward(self, x):
        x = self.classifier(x)
        return x

modelnlp=ClassifierMLP(3,num_classes=3)


In [None]:
print(modelnlp)

ClassifierMLP(
  (classifier): Sequential(
    (0): Linear(in_features=3, out_features=256, bias=True)
    (1): ReLU()
    (2): Dropout(p=0.1, inplace=False)
    (3): Linear(in_features=256, out_features=3, bias=True)
  )
)


In [None]:
"""
(C) Copyright 2021 IBM Corp.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
   http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
Created on June 30, 2021
"""

from typing import Dict, Tuple, Sequence, Optional

import torch
import torch.nn as nn
import torch.nn.functional as F

from fuse.utils.utils_hierarchical_dict import FuseUtilsHierarchicalDict
from fuse.models.heads.common import ClassifierMLP


class FuseHead3dClassifier(nn.Module):
    """
    Model that capture slice feature including the 3D context given the local feature about a slice.
    """

    def __init__(self, head_name: str = 'head_0',
                 conv_inputs: Sequence[Tuple[str, int]] = (('model.backbone_features', 512),),
                 dropout_rate: float = 0.1,
                 num_classes: int = 3,
                 append_features: Optional[Tuple[str, int]] = None,
                 layers_description: Sequence[int] = (256,), 
                 append_layers_description: Sequence[int] = tuple(),
                 append_dropout_rate: float = 0.0,
                 fused_dropout_rate: float = 0.0,
                 ) -> None:
        """
        Create simple 3D context model
        :param head_name: string representing the head name
        :param conv_inputs: Sequence of tuples, each indication features name in batch_dict and size of features (channels)
        :param dropout_rate: dropout fraction
        :param num_classes: number of output classes
        :param append_features: Sequence of tuples, each indication features name in batch_dict and size of features (channels).
                                Those are global features that appended after the global max pooling operation
        :param layers_description:          Layers description for the classifier module - sequence of hidden layers sizes (Not used currently)
        :param append_layers_description: Layers description for the tabular data, before the concatination with the features extracted from the image - sequence of hidden layers sizes
        :param append_dropout_rate: Dropout rate for tabular layers
        """
        super().__init__()
        # save input params
        self.head_name = head_name
        self.conv_inputs = conv_inputs
        self.dropout_rate = dropout_rate
        self.num_classes = num_classes
        self.append_features = append_features
        self.gmp = nn.AdaptiveMaxPool3d(output_size=1)
        self.features_size = sum([features[1] for features in self.conv_inputs]) if self.conv_inputs is not None else 0

        # calc appended feature size if used
        if self.append_features is not None:
            if len(append_layers_description) == 0:
                self.features_size += sum([post_concat_input[1] for post_concat_input in append_features])
                self.append_features_module = nn.Identity()
            else:
                self.features_size += append_layers_description[-1]
                self.append_features_module = ClassifierMLP(in_ch=sum([post_concat_input[1] for post_concat_input in append_features]),
                                                    num_classes=None,
                                                    layers_description=append_layers_description,
                                                    dropout_rate=append_dropout_rate)                

        self.conv_classifier_3d = nn.Sequential(
            nn.Conv3d(self.features_size, 256, kernel_size=1),
            nn.ReLU(),
            nn.Dropout3d(p=fused_dropout_rate), 
            nn.Conv3d(256, self.num_classes, kernel_size=1),
        )

        self.do = nn.Dropout3d(p=self.dropout_rate)
    
    def forward(self, batch_dict: Dict) -> Dict:
        """
        Forward pass
        :param batch_dict: dictionary containing an input tensor representing spatial features with 3D context. shape: [batch_size, in_features, z, y, x]
        :return: batch dict with fields model.outputs and model.logits
        """
        if self.conv_inputs is not None:
            conv_input = torch.cat(
                [FuseUtilsHierarchicalDict.get(batch_dict, conv_input[0]) for conv_input in self.conv_inputs], dim=1)
            global_features = self.gmp(conv_input)
            # save global max pooling features in case needed (mostly to analyze)
            FuseUtilsHierarchicalDict.set(batch_dict, 'model.' + self.head_name +'.gmp_features', global_features.squeeze(dim=4).squeeze(dim=3).squeeze(dim=2))
            # backward compatibility
            if hasattr(self, 'do'):
                global_features = self.do(global_features)
        # append global features if are used
        if self.append_features is not None:
            features = torch.cat(
                [FuseUtilsHierarchicalDict.get(batch_dict, features[0]).reshape(-1, features[1]) for features in self.append_features], dim=1)
            features = self.append_features_module(features)
            features = features.reshape(features.shape + (1,1,1))
            if self.conv_inputs is not None:
                global_features = torch.cat((global_features, features), dim=1)
            else:
                global_features = features

        logits = self.conv_classifier_3d(global_features)
        logits = logits.squeeze(dim=4)
        logits = logits.squeeze(dim=3)
        logits = logits.squeeze(dim=2)  # squeeze will change the shape to  [batch_size, channels']

        cls_preds = F.softmax(logits, dim=1)
        FuseUtilsHierarchicalDict.set(batch_dict, 'model.logits.' + self.head_name, logits)
        FuseUtilsHierarchicalDict.set(batch_dict, 'model.output.' + self.head_name, cls_preds)

        return batch_dict

In [None]:
import torch
batch_dict = {'pred': torch.randn(3, 5, requires_grad=True),
              'gt': torch.empty(3, dtype=torch.long).random_(5),
                'batch_loss_kwargs': {'reduction': 'mean', 'ignore_index': 0}}

In [None]:
batch_dict['batch_loss_kwargs']

{'ignore_index': 0, 'reduction': 'mean'}

In [None]:
from typing import Callable, Dict, Optional

import torch

#from fuse.losses.loss_base import FuseLossBase
#from fuse.utils.utils_hierarchical_dict import FuseUtilsHierarchicalDict
import torch
from typing import Any, Set, Callable, Optional, List, Sequence, Union

import numpy
import torch


class FuseUtilsHierarchicalDict:
    @classmethod
    def get(cls, hierarchical_dict: dict, key: str):
        """
        get(dict, 'x.y.z') <==> dict['x']['y']['z']
        """
        # split according to '.'
        hierarchical_key = key.split('.')

        # go over the the dictionary towards the requested value
        try:
            value = hierarchical_dict[hierarchical_key[0]]
            for sub_key in hierarchical_key[1:]:
                value = value[sub_key]
            return value
        except:
            flat_dict = FuseUtilsHierarchicalDict.flatten(hierarchical_dict)
            if key in flat_dict:
                return flat_dict[key]
            else:
                raise KeyError(f'key {key} does not exist\n. Possible keys are: {str(list(flat_dict.keys()))}')

    @classmethod
    def set(cls, hierarchical_dict: dict, key: str, value: Any) -> None:
        """
        set(dict, 'x.y.z', value) <==> dict['x']['y']['z'] = value
        If either 'x', 'y' or 'z' nodes do not exist, this function will create them
        """
        # split according to '.'
        hierarchical_key = key.split('.')

        # go over the the dictionary according to the path, create the nodes that does not exist
        element = hierarchical_dict
        for key in hierarchical_key[:-1]:
            if key not in element:
                element[key] = {}
            element = element[key]

        # set the value
        element[hierarchical_key[-1]] = value

    @classmethod
    def get_all_keys(cls, hierarchical_dict: dict, include_values: bool = False) -> Union[List[str], dict]:
        """
        Get all hierarchical keys in  hierarchical_dict
        """
        all_keys = {}
        for key in hierarchical_dict:
            if isinstance(hierarchical_dict[key], dict):
                all_sub_keys = FuseUtilsHierarchicalDict.get_all_keys(hierarchical_dict[key], include_values=True)
                keys_to_add = {f'{key}.{sub_key}':all_sub_keys[sub_key] for sub_key in all_sub_keys}
                all_keys.update(keys_to_add)
            else:
                all_keys[key] = hierarchical_dict[key]
        if include_values:
            return all_keys
        else:
            return list(all_keys.keys())

    @classmethod
    def subkey(cls, key: str, start: int, end: Optional[int]) -> Optional[str]:
        """
        Sub string of hierarchical key.
        Example: subkey('a.b.c.d.f', 1, 3) -> 'b.c'
        :param key: the original key
        :param start: start index
        :param end: end index, not including
        :return: str
        """
        key_parts = key.split('.')

        # if end not specified set to max.
        if end is None:
            end = len(key_parts)

        if len(key_parts) < start or len(key_parts) < end:
            return None

        res = '.'.join(key_parts[start:end])
        return res

    @classmethod
    def apply_on_all(cls, hierarchical_dict: dict, apply_func: Callable, *args: Any) -> None:
        all_keys = cls.get_all_keys(hierarchical_dict)
        for key in all_keys:
            new_value = apply_func(cls.get(hierarchical_dict, key), *args)
            cls.set(hierarchical_dict, key, new_value)
        pass

    @classmethod
    def flatten(cls, hierarchical_dict: dict) -> dict:
        """
        Flatten the dict
        @param hierarchical_dict: dict to flatten
        @return: dict where keys are the hierarchical_dict keys separated by periods.
        """
        flat_dict = {}
        return cls.get_all_keys(hierarchical_dict, include_values=True)

    @classmethod
    def indices(cls, hierarchical_dict: dict, indices: List[int]) -> dict:
        """
        Extract the specified indices from each element in the dictionary (if possible)
        :param hierarchical_dict: input dict
        :param indices: indices to extract
        :return: dict with only the required indices
        """
        new_dict = {}
        all_keys = cls.get_all_keys(hierarchical_dict)
        for key in all_keys:
            value = cls.get(hierarchical_dict, key)
            if isinstance(value, numpy.ndarray) or isinstance(value, torch.Tensor):
                new_value = value[indices]
            elif isinstance(value, Sequence):
                new_value =[item for i, item in enumerate(value) if indices[i]]
            else:
                new_value = value
            cls.set(new_dict, key, new_value)
        return new_dict

    @classmethod
    def to_string(cls, hierarchical_dict: dict) -> str:
        """
        Get flat string including thr content of the dictionary
        :param hierarchical_dict: input dict
        :return: string
        """
        keys = cls.get_all_keys(hierarchical_dict)
        keys = sorted(keys)
        res = ''
        for key in keys:
            res += f'{key} = {FuseUtilsHierarchicalDict.get(hierarchical_dict, key)}\n'

        return res

    @classmethod
    def pop(cls, hierarchical_dict: dict, key:str):
        """
        return the value hierarchical_dict[key] and remove the key from the dict.
        :param hierarchical_dict: the dictionary
        :param key: the key to return and remove
        """
        # split according to '.'
        hierarchical_key = key.split('.')
        # go over the the dictionary towards the requested value
        try:
            key_idx = len(hierarchical_key) - 1
            value = hierarchical_dict[hierarchical_key[0]] if key_idx > 0 else hierarchical_dict
            for sub_key in hierarchical_key[1:-1]:
                value = value[sub_key]
            return value.pop(hierarchical_key[key_idx])
        except:
            flat_dict = FuseUtilsHierarchicalDict.flatten(hierarchical_dict)
            if key in flat_dict:
                return flat_dict[key]
            else:
                raise KeyError(f'key {key} does not exist\n. Possible keys are: {str(list(flat_dict.keys()))}')

    @classmethod
    def is_in(cls, hierarchical_dict: dict, key:str) -> bool:
        """
        Returns True if the full key is in dict, False otherwise.
        e.g., for dict = {'a':1, 'b.c':2} is_in(dict, 'b.c') returns True, but is_in(dict, 'c') returns False.
        :param hierarchical_dict: dict to check
        :param key: key to search
        :return: key in hierarchical_dict
        """
        return key in cls.get_all_keys(hierarchical_dict)


class FuseLossBase(torch.nn.Module):
    """
    Base class for Fuse loss functions
    """

    def __init__(self,
                 pred_name: str = None,
                 target_name: str = None,
                 weight: float = 1.0, ) -> None:
        super().__init__()
        self.pred_name = pred_name
        self.target_name = target_name
        self.weight = weight

class FuseLossDefault(FuseLossBase):
    """
    Default Fuse loss function
    """

    def __init__(self,
                 pred_name: str = None,
                 target_name: str = None,
                 batch_kwargs_name: str = None,
                 callable: Callable = None,
                 sample_weight_name: Optional[str] = None,
                 weight: float = 1.0,
                 filter_func: Optional[Callable] = None,
                 **kwargs
                 ) -> None:
        """
        This class wraps a PyTorch loss function with a Fuse api.
        :param pred_name:               batch_dict key for prediction (e.g., network output)
        :param target_name:             batch_dict key for target (e.g., ground truth label)
        :param batch_kwargs_name:       batch_dict key for additional, ad-hoc kwargs for loss function
                                        Note: batch_kwargs will be merged into other loss function kwargs
        :param sample_weight_name       batch_dict key that holds the sample weight for loss summation
        :param callable:                PyTorch loss function handle (e.g., torch.nn.functional.cross_entropy)
        :param weight:                  Weight multiplier for final loss value
        :param filter_func:             function that filters batch_dict/ The function gets ans input batch_dict and returns filtered batch_dict
        :param kwargs:                  kwargs for PyTorch loss function
        """
        super().__init__()
        self.pred_name = pred_name
        self.target_name = target_name
        self.batch_kwargs_name = batch_kwargs_name
        self.callable = callable
        self.sample_weight_name = sample_weight_name
        self.weight = weight
        self.filter_func = filter_func
        self.kwargs = kwargs

    def __call__(self, batch_dict: Dict) -> torch.Tensor:
        # filter batch_dict if required
        if self.filter_func is not None:
            batch_dict = self.filter_func(batch_dict)
        preds = FuseUtilsHierarchicalDict.get(batch_dict, self.pred_name)
        targets = FuseUtilsHierarchicalDict.get(batch_dict, self.target_name)
        batch_kwargs = FuseUtilsHierarchicalDict.get(batch_dict, self.batch_kwargs_name) if self.batch_kwargs_name is not None else {}
        kwargs_copy = self.kwargs.copy()
        kwargs_copy.update(batch_kwargs)
        if self.sample_weight_name is not None:
            assert 'reduction' not in kwargs_copy.keys(), 'reduction is forced to none when applying sample weight'
            kwargs_copy.update({'reduction': 'none'})
        loss_obj = self.callable(preds, targets, **kwargs_copy) * self.weight
        if self.sample_weight_name is not None:
            sample_weight = FuseUtilsHierarchicalDict.get(batch_dict, self.sample_weight_name)
            weighted_loss = loss_obj*sample_weight
            loss_obj = torch.mean(weighted_loss)

        return loss_obj

loss = FuseLossDefault(pred_name='pred',
                           target_name='gt',
                           batch_kwargs_name='batch_loss_kwargs',
                           callable=torch.nn.functional.cross_entropy,
                          weight=1.0,
                           reduction='sum')

res = loss(batch_dict)
print('Loss output = ' + str(res))

Loss output = tensor(1.8709, grad_fn=<MulBackward0>)


In [None]:
import torch

batch_dict = {'pred': torch.randn(3, 5, requires_grad=True),
                  'gt': torch.empty(3, dtype=torch.long).random_(5),
                  'batch_loss_kwargs': {'reduction': 'mean', 'ignore_index': 0}}

loss = FuseLossDefault(pred_name='pred',
                           target_name='gt',
                           batch_kwargs_name='batch_loss_kwargs',
                           callable=torch.nn.functional.cross_entropy,
                           weight=1.0,
                           reduction='sum')

res = loss(batch_dict)
print('Loss output = ' + str(res))