### Modules

#### Losses:

##### accuracy.py

In [1]:
def accuracy(pred, target, topk=1, thresh=None):
    """Calculate accuracy according to the prediction and target.

    Args:
        pred (torch.Tensor): The model prediction, shape (N, num_class, ...)
        target (torch.Tensor): The target of each prediction, shape (N, , ...)
        topk (int | tuple[int], optional): If the predictions in ``topk``
            matches the target, the predictions will be regarded as
            correct ones. Defaults to 1.
        thresh (float, optional): If not None, predictions with scores under
            this threshold are considered incorrect. Default to None.

    Returns:
        float | tuple[float]: If the input ``topk`` is a single integer,
            the function will return a single float as accuracy. If
            ``topk`` is a tuple containing multiple integers, the
            function will return a tuple containing accuracies of
            each ``topk`` number.
    """
    assert isinstance(topk, (int, tuple)) # topk should be int or tuple 
    if isinstance(topk, int):
        topk = (topk, )
        return_single = True
    else:
        return_single = False # convert topk to tuple if int, track how many values user passed

    maxk = max(topk) # max number of top predictions we'll evaluate
    if pred.size(0) == 0:
        accu = [pred.new_tensor(0.) for i in range(len(topk))]
        return accu[0] if return_single else accu # check if pred batch is empty
    assert pred.ndim == target.ndim + 1 # checks that pred has one more dimension than target
    assert pred.size(0) == target.size(0) # same size
    assert maxk <= pred.size(1), \
        f'maxk {maxk} exceeds pred dimension {pred.size(1)}'
    pred_value, pred_label = pred.topk(maxk, dim=1) # selects topk predictions and their indices
    # transpose to shape (maxk, N, ...)
    pred_label = pred_label.transpose(0, 1)
    correct = pred_label.eq(target.unsqueeze(0).expand_as(pred_label)) # makes correct a boolean matrix (whether top-k predictions match the target)
    if thresh is not None:
        # Only prediction values larger than thresh are counted as correct
        correct = correct & (pred_value > thresh).t() # masks out prediction below threshold with top-k scores
    res = []
    for k in topk:
        correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
        res.append(correct_k.mul_(100.0 / target.numel()))
    return res[0] if return_single else res


##### cross_entropy_loss

In [2]:
# import libraries

import torch
import torch.nn as nn
import torch.nn.functional as F

In [4]:
def reduce_loss(loss, reduction):
    """Reduce loss as specified.

    Args:
        loss (Tensor): Elementwise loss tensor.
        reduction (str): Options are "none", "mean" and "sum".

    Return:
        Tensor: Reduced loss tensor.
    """
    reduction_enum = F._Reduction.get_enum(reduction)
    # none: 0, elementwise_mean:1, sum: 2
    if reduction_enum == 0:
        return loss
    elif reduction_enum == 1:
        return loss.mean()
    elif reduction_enum == 2:
        return loss.sum()

In [7]:
def weight_reduce_loss(loss, weight=None, reduction='mean', avg_factor=None):
    """Apply element-wise weight and reduce loss.

    Args:
        loss (Tensor): Element-wise loss.
        weight (Tensor): Element-wise weights.
        reduction (str): Same as built-in losses of PyTorch.
        avg_factor (float): Avarage factor when computing the mean of losses.

    Returns:
        Tensor: Processed loss values.
    """
    # if weight is specified, apply element-wise weight
    if weight is not None:
        assert weight.dim() == loss.dim()
        if weight.dim() > 1:
            assert weight.size(1) == 1 or weight.size(1) == loss.size(1)
        loss = loss * weight

    # if avg_factor is not specified, just reduce the loss
    if avg_factor is None:
        loss = reduce_loss(loss, reduction)
    else:
        # if reduction is mean, then average the loss by avg_factor
        if reduction == 'mean':
            loss = loss.sum() / avg_factor
        # if reduction is 'none', then do nothing, otherwise raise an error
        elif reduction != 'none':
            raise ValueError('avg_factor can not be used with reduction="sum"')
    return loss


In [8]:
def cross_entropy(pred,
                  label,
                  weight=None,
                  class_weight=None,
                  reduction='mean',
                  avg_factor=None,
                  ignore_index=-100):
    """The wrapper function for :func:`F.cross_entropy`"""
    # class_weight is a manual rescaling weight given to each class.
    # If given, has to be a Tensor of size C element-wise losses
    loss = F.cross_entropy(
        pred,
        label,
        weight=class_weight,
        reduction='none',
        ignore_index=ignore_index)

    # apply weights and do the reduction
    if weight is not None:
        weight = weight.float()
    loss = weight_reduce_loss(
        loss, weight=weight, reduction=reduction, avg_factor=avg_factor)

    return loss

In [9]:
def _expand_onehot_labels(labels, label_weights, target_shape, ignore_index):
    """Expand onehot labels to match the size of prediction."""
    bin_labels = labels.new_zeros(target_shape)
    valid_mask = (labels >= 0) & (labels != ignore_index)
    inds = torch.nonzero(valid_mask, as_tuple=True)

    if inds[0].numel() > 0:
        if labels.dim() == 3:
            bin_labels[inds[0], labels[valid_mask], inds[1], inds[2]] = 1
        else:
            bin_labels[inds[0], labels[valid_mask]] = 1

    valid_mask = valid_mask.unsqueeze(1).expand(target_shape).float()
    if label_weights is None:
        bin_label_weights = valid_mask
    else:
        bin_label_weights = label_weights.unsqueeze(1).expand(target_shape)
        bin_label_weights *= valid_mask

    return bin_labels, bin_label_weights

In [10]:
def binary_cross_entropy(pred,
                         label,
                         weight=None,
                         reduction='mean',
                         avg_factor=None,
                         class_weight=None,
                         ignore_index=255):
    """Calculate the binary CrossEntropy loss.

    Args:
        pred (torch.Tensor): The prediction with shape (N, 1).
        label (torch.Tensor): The learning label of the prediction.
        weight (torch.Tensor, optional): Sample-wise loss weight.
        reduction (str, optional): The method used to reduce the loss.
            Options are "none", "mean" and "sum".
        avg_factor (int, optional): Average factor that is used to average
            the loss. Defaults to None.
        class_weight (list[float], optional): The weight for each class.
        ignore_index (int | None): The label index to be ignored. Default: 255

    Returns:
        torch.Tensor: The calculated loss
    """
    if pred.dim() != label.dim():
        assert (pred.dim() == 2 and label.dim() == 1) or (
                pred.dim() == 4 and label.dim() == 3), \
            'Only pred shape [N, C], label shape [N] or pred shape [N, C, ' \
            'H, W], label shape [N, H, W] are supported'
        label, weight = _expand_onehot_labels(label, weight, pred.shape,
                                              ignore_index)

    # weighted element-wise losses
    if weight is not None:
        weight = weight.float()
    loss = F.binary_cross_entropy_with_logits(
        pred, label.float(), pos_weight=class_weight, reduction='none')
    # do the reduction for the weighted loss
    loss = weight_reduce_loss(
        loss, weight, reduction=reduction, avg_factor=avg_factor)

    return loss

In [11]:
def mask_cross_entropy(pred,
                       target,
                       label,
                       reduction='mean',
                       avg_factor=None,
                       class_weight=None,
                       ignore_index=None):
    """Calculate the CrossEntropy loss for masks.

    Args:
        pred (torch.Tensor): The prediction with shape (N, C), C is the number
            of classes.
        target (torch.Tensor): The learning label of the prediction.
        label (torch.Tensor): ``label`` indicates the class label of the mask'
            corresponding object. This will be used to select the mask in the
            of the class which the object belongs to when the mask prediction
            if not class-agnostic.
        reduction (str, optional): The method used to reduce the loss.
            Options are "none", "mean" and "sum".
        avg_factor (int, optional): Average factor that is used to average
            the loss. Defaults to None.
        class_weight (list[float], optional): The weight for each class.
        ignore_index (None): Placeholder, to be consistent with other loss.
            Default: None.

    Returns:
        torch.Tensor: The calculated loss
    """
    assert ignore_index is None, 'BCE loss does not support ignore_index'
    # TODO: handle these two reserved arguments
    assert reduction == 'mean' and avg_factor is None
    num_rois = pred.size()[0]
    inds = torch.arange(0, num_rois, dtype=torch.long, device=pred.device)
    pred_slice = pred[inds, label].squeeze(1)
    return F.binary_cross_entropy_with_logits(
        pred_slice, target, weight=class_weight, reduction='mean')[None]

In [12]:
def compute_cross_entropy_loss(cls_score,
                               label,
                               *,
                               use_sigmoid=False,
                               use_mask=False,
                               class_weight=None,
                               loss_weight=1.0,
                               weight=None,
                               avg_factor=None,
                               reduction='mean',
                               reduction_override=None,
                               ignore_index=-100,
                               **kwargs):
    """
    Functional equivalent of CrossEntropyLoss class forward method.

    Args:
        cls_score (Tensor): Prediction logits.
        label (Tensor): Ground-truth labels.
        use_sigmoid (bool): Use sigmoid + BCE.
        use_mask (bool): Use mask cross-entropy.
        class_weight (list[float] | Tensor | None): Per-class weight.
        loss_weight (float): Scalar multiplier on the loss.
        weight (Tensor | None): Sample-wise weighting.
        avg_factor (float | None): Averaging factor for mean reduction.
        reduction (str): 'none' | 'mean' | 'sum'.
        reduction_override (str | None): Overrides the reduction method.
        ignore_index (int): Label index to ignore.
        kwargs: Passed to the specific loss function.

    Returns:
        Tensor: Computed loss.
    """
    assert not (use_sigmoid and use_mask), \
        "Cannot use both sigmoid and mask mode."

    reduction = reduction_override if reduction_override else reduction

    # Select the appropriate core loss function
    if use_sigmoid:
        loss_fn = binary_cross_entropy
    elif use_mask:
        loss_fn = mask_cross_entropy
    else:
        loss_fn = cross_entropy

    # Convert class weights to tensor if needed
    if class_weight is not None and not torch.is_tensor(class_weight):
        class_weight = cls_score.new_tensor(class_weight)

    return loss_weight * loss_fn(
        cls_score,
        label,
        weight=weight,
        class_weight=class_weight,
        reduction=reduction,
        avg_factor=avg_factor,
        ignore_index=ignore_index,
        **kwargs
    )


#### Archs

##### Fcn_arch

In [15]:
import torch
import torch.nn as nn
from mmcv.cnn import ConvModule #, normal_init
#from mmseg.ops import resize

  from torch.distributed.optim import \


In [16]:
import torch
import torch.nn as nn
import torch.nn.functional as F

def init_decode_head(in_channels,
                     channels,
                     num_classes,
                     dropout_ratio=0.1,
                     conv_cfg=None,
                     norm_cfg=dict(type='BN'),
                     act_cfg=dict(type='ReLU'),
                     in_index=-1,
                     input_transform=None,
                     ignore_index=255,
                     align_corners=False):
    """
    Initializes the decode head layers.

    Returns:
        dict containing initialized layers and parameters.
    """
    in_channels = process_inputs(in_channels, in_index, input_transform)

    conv_seg = nn.Conv2d(channels, num_classes, kernel_size=1)
    dropout = nn.Dropout2d(dropout_ratio) if dropout_ratio > 0 else None

    return {
        'in_channels': in_channels,
        'channels': channels,
        'num_classes': num_classes,
        'dropout_ratio': dropout_ratio,
        'conv_cfg': conv_cfg,
        'norm_cfg': norm_cfg,
        'act_cfg': act_cfg,
        'in_index': in_index,
        'input_transform': input_transform,
        'ignore_index': ignore_index,
        'align_corners': align_corners,
        'conv_seg': conv_seg,
        'dropout': dropout
    }

def process_inputs(in_channels, in_index, input_transform):
    """
    Processes input channel transformations.
    """
    if input_transform is not None:
        assert input_transform in ['resize_concat', 'multiple_select']
        assert isinstance(in_channels, (list, tuple))
        assert isinstance(in_index, (list, tuple))
        assert len(in_channels) == len(in_index)
        return sum(in_channels) if input_transform == 'resize_concat' else in_channels

    assert isinstance(in_channels, int)
    assert isinstance(in_index, int)
    return in_channels

def transform_inputs(inputs, config):
    """
    Transforms inputs based on decode head settings.
    """
    if config['input_transform'] == 'resize_concat':
        inputs = [inputs[i] for i in config['in_index']]
        upsampled_inputs = [
            F.interpolate(
                x,
                size=inputs[0].shape[2:],
                mode='bilinear',
                align_corners=config['align_corners']) for x in inputs
        ]
        return torch.cat(upsampled_inputs, dim=1)

    elif config['input_transform'] == 'multiple_select':
        return [inputs[i] for i in config['in_index']]

    return inputs[config['in_index']]

def classify_pixels(feat, config):
    """
    Applies dropout and convolutional classification layer.
    """
    if config['dropout']:
        feat = config['dropout'](feat)
    return config['conv_seg'](feat)


In [17]:
def init_fcn_head(num_convs=2, kernel_size=3, concat_input=True, base_config=None):
    """
    Initializes FCNHead parameters and layers.

    Args:
        num_convs (int): Number of convolutions in the head.
        kernel_size (int): Kernel size for convolutions.
        concat_input (bool): Whether to concatenate input with output before classification.
        base_config (dict): Configuration from BaseDecodeHead.

    Returns:
        dict containing initialized layers and parameters.
    """
    assert num_convs >= 0, "Number of convolutions must be non-negative"
    
    if num_convs == 0:
        assert base_config['in_channels'] == base_config['channels'], "in_channels must match channels when num_convs=0"
    
    convs = build_convs(num_convs, kernel_size, base_config)
    
    conv_cat = None
    if concat_input:
        conv_cat = nn.Conv2d(
            base_config['in_channels'] + base_config['channels'],
            base_config['channels'],
            kernel_size=kernel_size,
            padding=kernel_size // 2
        )

    return {
        'num_convs': num_convs,
        'kernel_size': kernel_size,
        'concat_input': concat_input,
        'base_config': base_config,
        'convs': convs,
        'conv_cat': conv_cat
    }

def build_convs(num_convs, kernel_size, base_config):
    """
    Builds a sequence of convolutional layers.
    """
    if num_convs == 0:
        return nn.Identity()

    layers = []
    layers.append(nn.Conv2d(base_config['in_channels'], base_config['channels'], kernel_size, padding=kernel_size//2))
    
    for _ in range(num_convs - 1):
        layers.append(nn.Conv2d(base_config['channels'], base_config['channels'], kernel_size, padding=kernel_size//2))
    
    return nn.Sequential(*layers)

def fcn_head_forward(inputs, config):
    """
    Forward pass for FCNHead.

    Args:
        inputs (torch.Tensor): Input tensor.
        config (dict): Configuration dictionary.

    Returns:
        torch.Tensor: Output segmentation logits.
    """
    x = transform_inputs(inputs, config['base_config'])
    output = config['convs'](x)

    if config['concat_input']:
        output = torch.cat([x, output], dim=1)
        output = config['conv_cat'](output)

    output = classify_pixels(output, config['base_config'])
    return output


In [18]:
import torch
import torch.nn as nn

def init_multihead_fcn(in_channels, channels, num_classes, num_convs=2, kernel_size=3, 
                       concat_input=True, num_heads=18, dropout_ratio=0.1, base_config=None):
    """
    Initializes MultiHeadFCN parameters and layers.

    Args:
        in_channels (int): Input channels.
        channels (int): Internal feature channels.
        num_classes (int): Number of segmentation classes.
        num_convs (int): Number of convolutions per head.
        kernel_size (int): Convolution kernel size.
        concat_input (bool): Whether to concatenate input with output.
        num_heads (int): Number of segmentation heads.
        dropout_ratio (float): Dropout probability.
        base_config (dict): Configuration from BaseDecodeHead.

    Returns:
        dict containing initialized layers and parameters.
    """
    assert num_convs >= 0, "Number of convolutions must be non-negative"
    
    dropout = nn.Dropout2d(dropout_ratio) if dropout_ratio > 0 else None

    conv_heads = build_multihead_classifiers(num_heads, channels, num_classes)
    convs, conv_cats = build_multihead_convs(num_heads, num_convs, kernel_size, in_channels, channels, concat_input)

    return {
        'in_channels': in_channels,
        'channels': channels,
        'num_classes': num_classes,
        'num_convs': num_convs,
        'kernel_size': kernel_size,
        'concat_input': concat_input,
        'num_heads': num_heads,
        'dropout': dropout,
        'conv_heads': conv_heads,
        'convs': convs,
        'conv_cats': conv_cats,
        'base_config': base_config
    }

def build_multihead_classifiers(num_heads, channels, num_classes):
    """Creates multiple classification heads."""
    return nn.ModuleList([nn.Conv2d(channels, num_classes, kernel_size=1) for _ in range(num_heads)])

def build_multihead_convs(num_heads, num_convs, kernel_size, in_channels, channels, concat_input):
    """Builds multiple convolutional branches."""
    convs_list = []
    conv_cat_list = []

    for _ in range(num_heads):
        convs = []
        convs.append(nn.Conv2d(in_channels, channels, kernel_size, padding=kernel_size//2))
        
        for _ in range(num_convs - 1):
            convs.append(nn.Conv2d(channels, channels, kernel_size, padding=kernel_size//2))

        convs_list.append(nn.Sequential(*convs) if num_convs > 0 else nn.Identity())

        if concat_input:
            conv_cat_list.append(nn.Conv2d(in_channels + channels, channels, kernel_size, padding=kernel_size//2))

    return nn.ModuleList(convs_list), nn.ModuleList(conv_cat_list)

def multihead_fcn_forward(inputs, config):
    """
    Forward pass for MultiHeadFCN.

    Args:
        inputs (torch.Tensor): Input tensor.
        config (dict): Configuration dictionary.

    Returns:
        List[torch.Tensor]: Output segmentation logits for each head.
    """
    x = transform_inputs(inputs, config['base_config'])
    output_list = []

    for head_idx in range(config['num_heads']):
        output = config['convs'][head_idx](x)

        if config['concat_input']:
            output = config['conv_cats'][head_idx](torch.cat([x, output], dim=1))

        if config['dropout']:
            output = config['dropout'](output)

        output = config['conv_heads'][head_idx](output)
        output_list.append(output)

    return output_list


##### unet arch

In [50]:
import torch.utils.checkpoint as cp
from mmcv.cnn import (ConvModule, build_activation_layer,
                      build_norm_layer, build_upsample_layer)
#from mmcv.runner import load_checkpoint
#from mmseg.utils import get_root_logger

In [51]:
import torch
from mmcv.cnn import ConvModule, build_upsample_layer

def build_upconv_block(
    conv_block,
    in_channels,
    skip_channels,
    out_channels,
    num_convs=2,
    stride=1,
    dilation=1,
    with_cp=False,
    conv_cfg=None,
    norm_cfg=dict(type='BN'),
    act_cfg=dict(type='ReLU'),
    upsample_cfg=dict(type='InterpConv'),
    dcn=None,
    plugins=None
):
    """Builds the upsample and conv blocks used in UNet decoder."""

    assert dcn is None, 'Not implemented yet.'
    assert plugins is None, 'Not implemented yet.'

    conv_block_inst = conv_block(
        in_channels=2 * skip_channels,
        out_channels=out_channels,
        num_convs=num_convs,
        stride=stride,
        dilation=dilation,
        with_cp=with_cp,
        conv_cfg=conv_cfg,
        norm_cfg=norm_cfg,
        act_cfg=act_cfg,
        dcn=None,
        plugins=None
    )

    if upsample_cfg is not None:
        upsample_inst = build_upsample_layer(
            cfg=upsample_cfg,
            in_channels=in_channels,
            out_channels=skip_channels,
            with_cp=with_cp,
            norm_cfg=norm_cfg,
            act_cfg=act_cfg
        )
    else:
        upsample_inst = ConvModule(
            in_channels,
            skip_channels,
            kernel_size=1,
            stride=1,
            padding=0,
            conv_cfg=conv_cfg,
            norm_cfg=norm_cfg,
            act_cfg=act_cfg
        )

    return conv_block_inst, upsample_inst

def upconv_block_forward(skip, x, conv_block, upsample):
    """Forward function for upsample + conv block."""
    x = upsample(x)
    out = torch.cat([skip, x], dim=1)
    out = conv_block(out)
    return out


In [52]:
import torch
import torch.nn as nn
import torch.utils.checkpoint as cp
from mmcv.cnn import ConvModule

def build_basic_conv_block(
    in_channels,
    out_channels,
    num_convs=2,
    stride=1,
    dilation=1,
    with_cp=False,
    conv_cfg=None,
    norm_cfg=dict(type='BN'),
    act_cfg=dict(type='ReLU'),
    dcn=None,
    plugins=None
):
    """
    Builds a basic convolutional block for UNet.

    This block consists of several plain convolutional layers (Conv + Norm + Activation).
    
    Args:
        in_channels (int): Number of input channels.
        out_channels (int): Number of output channels.
        num_convs (int): Number of convolutional layers. Default: 2.
        stride (int): If stride=2, applies stride convolution in the first layer. Default: 1.
        dilation (int): Dilation rate for all conv layers except the first. Default: 1.
        with_cp (bool): If True, enables checkpointing for memory savings. Default: False.
        conv_cfg (dict | None): Configuration for convolution layer. Default: None.
        norm_cfg (dict | None): Configuration for normalization layer. Default: dict(type='BN').
        act_cfg (dict | None): Configuration for activation function. Default: dict(type='ReLU').
        dcn (bool): Deformable convolution support. Not implemented. Default: None.
        plugins (dict): Plugins for conv layers. Not implemented. Default: None.

    Returns:
        nn.Sequential: A sequential module containing the convolutional layers.
        bool: Whether checkpointing is enabled.
    """
    assert dcn is None, 'Not implemented yet.'
    assert plugins is None, 'Not implemented yet.'

    convs = []
    for i in range(num_convs):
        convs.append(
            ConvModule(
                in_channels=in_channels if i == 0 else out_channels,
                out_channels=out_channels,
                kernel_size=3,
                stride=stride if i == 0 else 1,
                dilation=1 if i == 0 else dilation,
                padding=1 if i == 0 else dilation,
                conv_cfg=conv_cfg,
                norm_cfg=norm_cfg,
                act_cfg=act_cfg
            )
        )

    return nn.Sequential(*convs), with_cp


def basic_conv_block_forward(x, convs, with_cp):
    """
    Forward function for basic convolutional block.

    Args:
        x (Tensor): Input tensor.
        convs (nn.Sequential): Convolutional layers.
        with_cp (bool): Whether checkpointing is enabled.

    Returns:
        Tensor: Output tensor after applying convs.
    """
    if with_cp and x.requires_grad:
        return cp.checkpoint(convs, x)
    else:
        return convs(x)


In [53]:
import torch
import torch.nn as nn
import torch.utils.checkpoint as cp
from mmcv.cnn import build_norm_layer, build_activation_layer

def build_deconv_module(
    in_channels,
    out_channels,
    with_cp=False,
    norm_cfg=dict(type='BN'),
    act_cfg=dict(type='ReLU'),
    *,
    kernel_size=4,
    scale_factor=2
):
    """
    Builds a deconvolution upsample module for UNet decoder (2x upsample).

    Args:
        in_channels (int): Number of input channels.
        out_channels (int): Number of output channels.
        with_cp (bool): Whether to use checkpointing. Default: False.
        norm_cfg (dict | None): Config dict for normalization layer. Default: dict(type='BN').
        act_cfg (dict | None): Config dict for activation function. Default: dict(type='ReLU').
        kernel_size (int): Kernel size of the transposed convolution. Default: 4.
        scale_factor (int): Upsampling factor (stride). Default: 2.

    Returns:
        nn.Sequential: A sequential module for deconv -> norm -> activation.
        bool: Whether checkpointing is enabled.
    """
    assert (kernel_size - scale_factor >= 0) and \
           (kernel_size - scale_factor) % 2 == 0, (
        f'Invalid kernel/scale config: kernel_size={kernel_size}, scale_factor={scale_factor}. '
        'kernel_size must be >= scale_factor and their difference must be even.')

    stride = scale_factor
    padding = (kernel_size - scale_factor) // 2

    deconv = nn.ConvTranspose2d(
        in_channels,
        out_channels,
        kernel_size=kernel_size,
        stride=stride,
        padding=padding
    )

    _, norm = build_norm_layer(norm_cfg, out_channels)
    activate = build_activation_layer(act_cfg)

    module = nn.Sequential(deconv, norm, activate)
    return module, with_cp


def deconv_module_forward(x, module, with_cp):
    """
    Forward function for the deconvolution upsampling module.

    Args:
        x (Tensor): Input tensor.
        module (nn.Sequential): Deconv -> Norm -> Activation module.
        with_cp (bool): Whether checkpointing is enabled.

    Returns:
        Tensor: Output tensor.
    """
    if with_cp and x.requires_grad:
        return cp.checkpoint(module, x)
    else:
        return module(x)


In [54]:
import torch
import torch.nn as nn
import torch.utils.checkpoint as cp
from mmcv.cnn import ConvModule

def build_interp_conv(
    in_channels,
    out_channels,
    with_cp=False,
    norm_cfg=dict(type='BN'),
    act_cfg=dict(type='ReLU'),
    *,
    conv_cfg=None,
    conv_first=False,
    kernel_size=1,
    stride=1,
    padding=0,
    upsample_cfg=dict(scale_factor=2, mode='bilinear', align_corners=False)
):
    """
    Builds an interpolation-based upsample module for the UNet decoder.

    This module performs interpolation upsampling followed by a convolutional
    block, or vice versa depending on `conv_first`.

    Args:
        in_channels (int): Number of input channels.
        out_channels (int): Number of output channels.
        with_cp (bool): Use checkpointing to reduce memory. Default: False.
        norm_cfg (dict): Normalization layer config. Default: dict(type='BN').
        act_cfg (dict): Activation layer config. Default: dict(type='ReLU').
        conv_cfg (dict | None): Convolution config. Default: None.
        conv_first (bool): Whether to apply convolution before upsampling. Default: False.
        kernel_size (int): Kernel size of the convolution. Default: 1.
        stride (int): Stride of the convolution. Default: 1.
        padding (int): Padding for the convolution. Default: 0.
        upsample_cfg (dict): Config for `nn.Upsample`. Default: bilinear 2x.

    Returns:
        nn.Sequential: A sequential module (Upsample + Conv or Conv + Upsample).
        bool: Whether checkpointing is enabled.
    """
    conv = ConvModule(
        in_channels,
        out_channels,
        kernel_size=kernel_size,
        stride=stride,
        padding=padding,
        conv_cfg=conv_cfg,
        norm_cfg=norm_cfg,
        act_cfg=act_cfg
    )

    upsample = nn.Upsample(**upsample_cfg)

    if conv_first:
        module = nn.Sequential(conv, upsample)
    else:
        module = nn.Sequential(upsample, conv)

    return module, with_cp


def interp_conv_forward(x, module, with_cp):
    """
    Forward function for the interpolation-based upsampling module.

    Args:
        x (Tensor): Input tensor.
        module (nn.Sequential): Upsample + Conv or Conv + Upsample.
        with_cp (bool): Whether checkpointing is enabled.

    Returns:
        Tensor: Output tensor after upsampling and convolution.
    """
    if with_cp and x.requires_grad:
        return cp.checkpoint(module, x)
    else:
        return module(x)


In [57]:
def build_unet(in_channels=3,
               base_channels=64,
               num_stages=5,
               strides=(1, 1, 1, 1, 1),
               enc_num_convs=(2, 2, 2, 2, 2),
               dec_num_convs=(2, 2, 2, 2),
               downsamples=(True, True, True, True),
               enc_dilations=(1, 1, 1, 1, 1),
               dec_dilations=(1, 1, 1, 1),
               with_cp=False,
               conv_cfg=None,
               norm_cfg=dict(type='BN'),
               act_cfg=dict(type='ReLU'),
               upsample_cfg=dict(type='InterpConv'),
               norm_eval=False,
               dcn=None,
               plugins=None):
    """
    Builds a U-Net encoder-decoder backbone.

    Args:
        in_channels (int): Number of input image channels.
        base_channels (int): Base number of feature channels.
        num_stages (int): Number of encoder stages.
        strides (tuple[int]): Stride of each encoder stage.
        enc_num_convs (tuple[int]): Number of convs in each encoder stage.
        dec_num_convs (tuple[int]): Number of convs in each decoder stage.
        downsamples (tuple[bool]): Whether to downsample with MaxPool per stage.
        enc_dilations (tuple[int]): Dilation rates for encoder stages.
        dec_dilations (tuple[int]): Dilation rates for decoder stages.
        with_cp (bool): Whether to use checkpointing.
        conv_cfg (dict): Convolution config.
        norm_cfg (dict): Normalization config.
        act_cfg (dict): Activation config.
        upsample_cfg (dict): Upsample config for decoder.
        norm_eval (bool): Whether to freeze norm stats in eval.
        dcn (any): Placeholder for deformable convs (not supported).
        plugins (any): Placeholder for plugin layers (not supported).

    Returns:
        nn.Module: A U-Net model with encoder and decoder stages.
    """
    assert dcn is None, 'DCN not implemented'
    assert plugins is None, 'Plugins not implemented'
    assert len(strides) == num_stages
    assert len(enc_num_convs) == num_stages
    assert len(dec_num_convs) == (num_stages - 1)
    assert len(downsamples) == (num_stages - 1)
    assert len(enc_dilations) == num_stages
    assert len(dec_dilations) == (num_stages - 1)

    encoder = nn.ModuleList()
    decoder = nn.ModuleList()
    curr_in_channels = in_channels

    for i in range(num_stages):
        enc_layers = []
        if i != 0:
            if strides[i] == 1 and downsamples[i - 1]:
                enc_layers.append(nn.MaxPool2d(kernel_size=2))

            upsample = (strides[i] != 1 or downsamples[i - 1])
            decoder.append(
                build_upconv_block(
                    in_channels=base_channels * 2**i,
                    skip_channels=base_channels * 2**(i - 1),
                    out_channels=base_channels * 2**(i - 1),
                    num_convs=dec_num_convs[i - 1],
                    stride=1,
                    dilation=dec_dilations[i - 1],
                    with_cp=with_cp,
                    conv_cfg=conv_cfg,
                    norm_cfg=norm_cfg,
                    act_cfg=act_cfg,
                    upsample_cfg=upsample_cfg if upsample else None
                )
            )

        enc_layers.append(
            build_basic_conv_block(
                in_channels=curr_in_channels,
                out_channels=base_channels * 2**i,
                num_convs=enc_num_convs[i],
                stride=strides[i],
                dilation=enc_dilations[i],
                with_cp=with_cp,
                conv_cfg=conv_cfg,
                norm_cfg=norm_cfg,
                act_cfg=act_cfg
            )
        )
        encoder.append(nn.Sequential(*enc_layers))
        curr_in_channels = base_channels * 2**i

    def forward(x):
        """
        Forward pass of U-Net.

        Args:
            x (torch.Tensor): Input tensor.

        Returns:
            list[torch.Tensor]: List of outputs at each decoder level.
        """
        enc_outs = []
        for enc in encoder:
            x = enc(x)
            enc_outs.append(x)

        dec_outs = [x]
        for i in reversed(range(len(decoder))):
            x = decoder[i](enc_outs[i], x)
            dec_outs.append(x)
        return dec_outs

    def init_weights(module, pretrained=None):
        """
        Initializes model weights.

        Args:
            module (nn.Module): The UNet model.
            pretrained (str or None): Path to weights or None for random init.
        """
        if isinstance(pretrained, str):
            state_dict = torch.load(pretrained, map_location='cpu')
            if 'state_dict' in state_dict:
                state_dict = state_dict['state_dict']
            module.load_state_dict(state_dict, strict=False)
        elif pretrained is None:
            for m in module.modules():
                if isinstance(m, nn.Conv2d):
                    nn.init.kaiming_normal_(m.weight, mode='fan_out')
                    if m.bias is not None:
                        nn.init.constant_(m.bias, 0)
                elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
                    nn.init.constant_(m.weight, 1)
                    nn.init.constant_(m.bias, 0)
        else:
            raise TypeError('pretrained must be a str or None')

    # Package everything into a functional-like module
    class UNetWrapper(nn.Module):
        def __init__(self):
            super().__init__()
            self.encoder = encoder
            self.decoder = decoder
            self.forward = forward.__get__(self, UNetWrapper)

        def init_weights(self, pretrained=None):
            init_weights(self, pretrained)

        def train(self, mode=True):
            super().train(mode)
            if norm_eval:
                for m in self.modules():
                    if isinstance(m, nn.BatchNorm2d):
                        m.eval()

    return UNetWrapper()


In [58]:
def build_attr_unet(
    in_channels=3,
    base_channels=64,
    num_stages=5,
    attr_embedding=128,
    strides=(1, 1, 1, 1, 1),
    enc_num_convs=(2, 2, 2, 2, 2),
    dec_num_convs=(2, 2, 2, 2),
    downsamples=(True, True, True, True),
    enc_dilations=(1, 1, 1, 1, 1),
    dec_dilations=(1, 1, 1, 1),
    with_cp=False,
    conv_cfg=None,
    norm_cfg=dict(type='BN'),
    act_cfg=dict(type='ReLU'),
    upsample_cfg=dict(type='InterpConv'),
    norm_eval=False,
    dcn=None,
    plugins=None
):
    assert dcn is None, 'DCN not implemented'
    assert plugins is None, 'Plugins not implemented'
    assert len(strides) == num_stages
    assert len(enc_num_convs) == num_stages
    assert len(dec_num_convs) == (num_stages - 1)
    assert len(downsamples) == (num_stages - 1)
    assert len(enc_dilations) == num_stages
    assert len(dec_dilations) == (num_stages - 1)

    encoder = nn.ModuleList()
    decoder = nn.ModuleList()
    curr_in_channels = in_channels + attr_embedding

    for i in range(num_stages):
        enc_block = []
        if i != 0:
            if strides[i] == 1 and downsamples[i - 1]:
                enc_block.append(nn.MaxPool2d(kernel_size=2))
            upsample = (strides[i] != 1 or downsamples[i - 1])
            decoder.append(
                build_upconv_block(
                    in_channels=base_channels * 2**i,
                    skip_channels=base_channels * 2**(i - 1),
                    out_channels=base_channels * 2**(i - 1),
                    num_convs=dec_num_convs[i - 1],
                    stride=1,
                    dilation=dec_dilations[i - 1],
                    with_cp=with_cp,
                    conv_cfg=conv_cfg,
                    norm_cfg=norm_cfg,
                    act_cfg=act_cfg,
                    upsample_cfg=upsample_cfg if upsample else None
                )
            )

        enc_block.append(
            build_basic_conv_block(
                in_channels=curr_in_channels,
                out_channels=base_channels * 2**i,
                num_convs=enc_num_convs[i],
                stride=strides[i],
                dilation=enc_dilations[i],
                with_cp=with_cp,
                conv_cfg=conv_cfg,
                norm_cfg=norm_cfg,
                act_cfg=act_cfg
            )
        )
        encoder.append(nn.Sequential(*enc_block))
        curr_in_channels = base_channels * 2**i

    def forward(x, attr_emb):
        """
        Forward pass of AttrUNet.

        Args:
            x (torch.Tensor): Input image tensor.
            attr_emb (torch.Tensor): Attribute embedding (B, C_attr).

        Returns:
            list[torch.Tensor]: Decoder outputs (multi-level).
        """
        enc_outs = []
        B, C_attr = attr_emb.size()
        for enc in encoder:
            _, _, H, W = x.shape
            attr = attr_emb.view(B, C_attr, 1, 1).expand(-1, -1, H, W)
            x = enc(torch.cat([x, attr], dim=1))
            enc_outs.append(x)

        dec_outs = [x]
        for i in reversed(range(len(decoder))):
            x = decoder[i](enc_outs[i], x)
            dec_outs.append(x)

        return dec_outs

    def init_weights(module, pretrained=None):
        """
        Initialize weights.

        Args:
            module (nn.Module): The UNet model.
            pretrained (str or None): Optional path to checkpoint.
        """
        if isinstance(pretrained, str):
            state_dict = torch.load(pretrained, map_location='cpu')
            if 'state_dict' in state_dict:
                state_dict = state_dict['state_dict']
            state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()}
            module.load_state_dict(state_dict, strict=False)
        elif pretrained is None:
            for m in module.modules():
                if isinstance(m, nn.Conv2d):
                    nn.init.kaiming_normal_(m.weight, mode='fan_out')
                    if m.bias is not None:
                        nn.init.constant_(m.bias, 0)
                elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
                    nn.init.constant_(m.weight, 1)
                    nn.init.constant_(m.bias, 0)
        else:
            raise TypeError('pretrained must be a str or None')

    class AttrUNetWrapper(nn.Module):
        def __init__(self):
            super().__init__()
            self.encoder = encoder
            self.decoder = decoder
            self.forward = forward.__get__(self, AttrUNetWrapper)

        def init_weights(self, pretrained=None):
            init_weights(self, pretrained)

        def train(self, mode=True):
            super().train(mode)
            if norm_eval:
                for m in self.modules():
                    if isinstance(m, nn.BatchNorm2d):
                        m.eval()

    return AttrUNetWrapper()
