In [None]:
# | default_exp nets/detr_3d

# Imports

In [None]:
# | export


from functools import wraps

import torch
from einops import rearrange, repeat
from huggingface_hub import PyTorchModelHubMixin
from scipy.optimize import linear_sum_assignment
from torch import nn
from torch.nn import functional as F

from vision_architectures.blocks.transformer import Attention1DWithMLPConfig, TransformerDecoderBlock1D
from vision_architectures.docstrings import populate_docstring
from vision_architectures.layers.embeddings import AbsolutePositionEmbeddings3D, AbsolutePositionEmbeddings3DConfig
from vision_architectures.utils.activation_checkpointing import ActivationCheckpointing
from vision_architectures.utils.custom_base_model import CustomBaseModel, Field
from vision_architectures.utils.rearrange import rearrange_channels

# Config

In [None]:
# | export


class DETRDecoderConfig(Attention1DWithMLPConfig):
    num_layers: int = Field(..., description="Number of transformer decoder layers.")


class DETRBBoxMLPConfig(CustomBaseModel):
    dim: int = Field(..., description="Dimension of the input features.")
    num_classes: int = Field(..., description="Number of classes for the bounding box predictions.")


class DETR3DConfig(DETRDecoderConfig, DETRBBoxMLPConfig, AbsolutePositionEmbeddings3DConfig):
    num_objects: int = Field(..., description="Maximum number of objects to detect.")
    drop_prob: float = Field(0.0, description="Dropout probability for input embeddings.")

# Architecture

### Decoder

In [None]:
# | export


class DETRDecoder(nn.Module, PyTorchModelHubMixin):
    """DETR Transformer decoder."""

    @populate_docstring
    def __init__(self, config: DETRDecoderConfig = {}, checkpointing_level: int = 0, **kwargs):
        """Initialize the DETRDecoder. Activation checkpointing level 4.

        Args:
            config: {CONFIG_INSTANCE_DOC}
            checkpointing_level: {CHECKPOINTING_LEVEL_DOC}
            **kwargs: {CONFIG_KWARGS_DOC}
        """
        super().__init__()

        self.config = DETRDecoderConfig.model_validate(config | kwargs)

        self.layers = nn.ModuleList(
            [TransformerDecoderBlock1D(config, checkpointing_level) for _ in range(self.config.num_layers)]
        )

        self.checkpointing_level4 = ActivationCheckpointing(4, checkpointing_level)

    @populate_docstring
    def _forward(
        self, object_queries: torch.Tensor, embeddings: torch.Tensor, return_intermediates: bool = False
    ) -> torch.Tensor | tuple[torch.Tensor, list[torch.Tensor]]:
        """Forward pass of the DETR3D decoder.

        Args:
            object_queries: Tokens that represent object queries. {INPUT_1D_DOC}
            embeddings: Actual embeddings of the input. {INPUT_1D_DOC}
            return_intermediates: If True, also returns the outputs of all layers. Defaults to False.

        Returns:
            If return_intermediates is True, returns the final object embeddings and a list of outputs from all layers.
            Otherwise, returns only the final object embeddings.
        """
        # object_queries: (b, num_possible_objects, dim)
        # embeddings: (b, num_embed_tokens, dim)

        object_embeddings = object_queries

        layer_outputs = []
        for layer in self.layers:
            object_embeddings = layer(object_embeddings, embeddings)
            layer_outputs.append(object_embeddings)

        if return_intermediates:
            return object_embeddings, layer_outputs
        return object_embeddings

    @wraps(_forward)
    def forward(self, *args, **kwargs):
        return self.checkpointing_level4(self._forward, *args, **kwargs)

In [None]:
test_config = {
    "attn_drop_prob": 0.2,
    "dim": 54,
    "drop_prob": 0.2,
    "embed_spacing_info": False,
    "in_channels": 1,
    "mlp_ratio": 2,
    "layer_norm_eps": 1e-6,
    "learnable_absolute_position_embeddings": False,
    "mlp_drop_prob": 0.2,
    "num_heads": 6,
    "patch_size": (8, 16, 16),
    "proj_drop_prob": 0.2,
    "num_layers": 4,
}

test = DETRDecoder(test_config)
display(test)
o = test(
    torch.randn(2, 10, 54),
    torch.randn(2, 64, 54),
    True,
)
display((o[0].shape, [x.shape for x in o[1]]))


[1;35mDETRDecoder[0m[1m([0m
  [1m([0mlayers[1m)[0m: [1;35mModuleList[0m[1m([0m
    [1m([0m[1;36m0[0m-[1;36m3[0m[1m)[0m: [1;36m4[0m x [1;35mTransformerDecoderBlock1D[0m[1m([0m
      [1m([0mattn1[1m)[0m: [1;35mAttention1D[0m[1m([0m
        [1m([0mW_q[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m54[0m, [33mout_features[0m=[1;36m54[0m, [33mbias[0m=[3;92mTrue[0m[1m)[0m
        [1m([0mW_k[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m54[0m, [33mout_features[0m=[1;36m54[0m, [33mbias[0m=[3;92mTrue[0m[1m)[0m
        [1m([0mW_v[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m54[0m, [33mout_features[0m=[1;36m54[0m, [33mbias[0m=[3;92mTrue[0m[1m)[0m
        [1m([0mproj[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m54[0m, [33mout_features[0m=[1;36m54[0m, [33mbias[0m=[3;92mTrue[0m[1m)[0m
        [1m([0mproj_drop[1m)[0m: [1;35mDropout[0m[1


[1m([0m
    [1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m10[0m, [1;36m54[0m[1m][0m[1m)[0m,
    [1m[[0m
        [1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m10[0m, [1;36m54[0m[1m][0m[1m)[0m,
        [1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m10[0m, [1;36m54[0m[1m][0m[1m)[0m,
        [1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m10[0m, [1;36m54[0m[1m][0m[1m)[0m,
        [1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m10[0m, [1;36m54[0m[1m][0m[1m)[0m
    [1m][0m
[1m)[0m

In [None]:
# | export


class DETRBBoxMLP(nn.Module):
    """DETR Bounding Box MLP. This module predicts bounding boxes and class scores from object query embeddings."""

    @populate_docstring
    def __init__(self, config: DETRBBoxMLPConfig = {}, **kwargs):
        """Initialize the DETRBBoxMLP.

        Args:
            config: {CONFIG_INSTANCE_DOC}
            **kwargs: {CONFIG_KWARGS_DOC}
        """
        super().__init__()

        self.config = DETRBBoxMLPConfig.model_validate(config | kwargs)

        self.linear = nn.Linear(self.config.dim, 6 + 1 + self.config.num_classes)

    @populate_docstring
    def forward(
        self,
        object_embeddings: torch.Tensor,
    ) -> torch.Tensor:
        """Forward pass of the DETRBBoxMLP.

        Args:
            object_embeddings: Object embeddings from the DETR decoder. {INPUT_1D_DOC}

        Returns:
            A tensor of shape (b, num_possible_objects, 1 objectness class + 6 bounding box parameters + num_classes)
            containing the predicted bounding boxes and class scores.
        """
        # object_embeddings: (b, num_possible_objects, dim)

        bboxes = self.linear(object_embeddings)
        # (b, num_possible_objects, 6 + 1 + num_classes)

        # Sigmoid the bounding box parameters
        bboxes[:, :, :6] = bboxes[:, :, :6].sigmoid()

        return bboxes

In [None]:
test_config = {
    "dim": 54,
    "num_classes": 10,
}

test = DETRBBoxMLP(test_config)
display(test)
o = test(
    torch.randn(2, 10, 54),
)
display((o[0].shape), o[0][0])


[1;35mDETRBBoxMLP[0m[1m([0m
  [1m([0mlinear[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m54[0m, [33mout_features[0m=[1;36m17[0m, [33mbias[0m=[3;92mTrue[0m[1m)[0m
[1m)[0m

[1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m10[0m, [1;36m17[0m[1m][0m[1m)[0m


[1;35mtensor[0m[1m([0m[1m[[0m [1;36m0.6007[0m,  [1;36m0.5349[0m,  [1;36m0.3729[0m,  [1;36m0.6909[0m,  [1;36m0.5232[0m,  [1;36m0.5380[0m, [1;36m-0.2862[0m,  [1;36m0.1478[0m,
        [1;36m-0.9619[0m,  [1;36m0.6182[0m,  [1;36m0.1466[0m,  [1;36m0.9278[0m,  [1;36m0.2595[0m, [1;36m-0.1345[0m, [1;36m-0.6035[0m, [1;36m-0.0612[0m,
        [1;36m-0.0189[0m[1m][0m, [33mgrad_fn[0m=[1m<[0m[1;95mSelectBackward0[0m[1m>[0m[1m)[0m

# Models

In [None]:
# | export


class DETR3D(nn.Module, PyTorchModelHubMixin):
    """DETR 3D model. Also implements bipartite matching loss which is essential for DETR training."""

    @populate_docstring
    def __init__(self, config: DETR3DConfig = {}, checkpointing_level: int = 0, **kwargs):
        """Initialize the DETR3D. Activation checkpointing level 4.

        Args:
            config: {CONFIG_INSTANCE_DOC}
            checkpointing_level: {CHECKPOINTING_LEVEL_DOC}
            **kwargs: {CONFIG_KWARGS_DOC}
        """
        super().__init__()

        self.config = DETR3DConfig.model_validate(config | kwargs)

        self.embeddings = AbsolutePositionEmbeddings3D(config)
        self.pos_drop = nn.Dropout(self.config.drop_prob)
        self.num_possible_objects = self.config.num_objects
        self.object_queries = nn.Parameter(torch.randn(1, self.num_possible_objects, self.config.dim))
        self.decoder = DETRDecoder(config, checkpointing_level)
        self.bbox_mlp = DETRBBoxMLP(config)

        self.checkpointing_level4 = ActivationCheckpointing(4, checkpointing_level)

    @populate_docstring
    def _forward(
        self,
        embeddings: torch.Tensor,
        spacings: torch.Tensor | None = None,
        channels_first: bool = True,
        return_intermediates: bool = False,
    ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor, list[torch.Tensor]]:
        """Forward pass of the DETR3D.

        Args:
            embeddings: Encoded input features. {INPUT_3D_DOC}
            spacings: {SPACINGS_DOC}
            channels_first: {CHANNELS_FIRST_DOC}
            return_intermediates: If True, also returns the outputs of all layers. Defaults to False.

        Returns:
            A tuple containing bounding boxes, object embeddings, and layer outputs if return_intermediates is True.
            Else, returns only the bounding boxes.
        """
        # embeddings: (b, [dim], num_tokens_z, num_tokens_y, num_tokens_x, [dim])
        # spacings: (b, 3)

        embeddings = rearrange_channels(embeddings, channels_first, True)
        # (b, dim, num_tokens_z, num_tokens_y, num_tokens_x)

        embeddings = self.embeddings(embeddings, spacings=spacings)
        embeddings = self.pos_drop(embeddings)
        # (b, dim, num_tokens_z, num_tokens_y, num_tokens_x)

        embeddings = rearrange(embeddings, "b d z y x -> b (z y x) d")
        # (b, num_embed_tokens, dim)

        object_queries = repeat(self.object_queries, "1 n d -> b n d", b=embeddings.shape[0])
        # (b, num_possible_objects, dim)

        object_embeddings, layer_outputs = self.decoder(object_queries, embeddings, return_intermediates=True)
        # object_embeddings: (b, num_possible_objects, dim)
        # layer_outputs: list of (b, num_possible_objects, dim)

        bboxes = self.bbox_mlp(object_embeddings)
        # (b, num_possible_objects, 6 + 1 + num_classes)

        if return_intermediates:
            return bboxes, object_embeddings, layer_outputs

        return bboxes

    @wraps(_forward)
    def forward(self, *args, **kwargs):
        return self.checkpointing_level4(self._forward, *args, **kwargs)

    @staticmethod
    def bipartite_matching_loss(
        pred: torch.Tensor,
        target: torch.Tensor | list[torch.Tensor],
        classification_cost_weight: float = 1.0,
        bbox_l1_cost_weight: float = 1.0,
        bbox_iou_cost_weight: float = 1.0,
        reduction: str = "mean",
    ) -> torch.Tensor:
        """Bipartite matching loss for DETR. The classes are expected to optimize for a multi-class classification
        problem. Expects raw logits in class predictions, not probabilities. Use ``logits_to_scores_fn=None`` in the
        ``forward`` function to avoid applying any transformation.

        Args:
            pred: Predicted bounding boxes and class scores. It should be of shape
                `(B, num_objects, 6 + 1 + num_classes)`. Number of objects and number of classes will be inferred from
                here.
            target: Target bounding boxes and class scores. If provided as a list, each element should be a tensor for
                the corresponding batch element in ``pred`` and therefore should have a length of `B`. Each tensor
                should have less than or equal to the number of objects in `pred`. The number of classes can either be
                the exact same as in `pred`, or it should be 1 argmax (one-cold) decoding.
            classification_cost_weight: Weight for the classification cost in hungarian matching.
            bbox_l1_cost_weight: Weight for the bounding box L1 loss cost in hungarian matching.
            bbox_iou_cost_weight: Weight for the bounding box IoU cost in hungarian matching.
            reduction: Specifies the reduction to apply to the output.

        Returns:
            A tensor containing the bipartite matching loss with the shape depending on the `reduction` argument.
        """
        B = pred.shape[0]

        # Convert target to a list of tensors if not already
        if isinstance(target, torch.Tensor):
            target = list(target)

        # argmax encode the class labels if they are not already
        for i in range(len(target)):
            if target[i].shape[-1] > 7:  # 6 bbox + 1 class
                target[i] = torch.cat([target[i][:, :6], target[i][:, 6:].argmax(-1, keepdims=True)], dim=-1)

        # Perform hungarian matching
        matched_indices = DETR3D.hungarian_matching(
            pred, target, classification_cost_weight, bbox_l1_cost_weight, bbox_iou_cost_weight
        )

        losses = []
        for i in range(B):
            pred_indices, target_indices = matched_indices[i]

            matched_pred = pred[i][pred_indices]
            matched_target = target[i][target_indices]

            pred_bboxes = matched_pred[:, :6]
            target_bboxes = matched_target[:, :6]

            pred_classes = matched_pred[:, 6:]
            target_class_labels = matched_target[:, 6].long()

            # Compute losses for matched pairs
            # BBox L1 loss
            bbox_l1_loss = F.l1_loss(pred_bboxes, target_bboxes)

            # BBox IOU loss
            bbox_iou_loss = 1 - DETR3D._generalized_bbox_iou(pred_bboxes, target_bboxes)

            # Classification loss
            class_loss = F.cross_entropy(pred_classes, target_class_labels)

            # Total loss for this batch element
            total_loss = (
                classification_cost_weight * class_loss
                + bbox_l1_cost_weight * bbox_l1_loss
                + bbox_iou_cost_weight * bbox_iou_loss
            )
            losses.append(total_loss)

        # Stack batch losses and apply reduction
        loss = torch.stack(losses)

        if reduction == "mean":
            loss = loss.mean()
        elif reduction == "sum":
            loss = loss.sum()
        elif reduction == "none":
            pass
        else:
            raise ValueError(f"Invalid reduction mode: {reduction}")

        return loss

    @torch.no_grad()
    @staticmethod
    def hungarian_matching(
        pred: torch.Tensor,
        target: list[torch.Tensor],
        classification_cost_weight: float = 1.0,
        bbox_l1_cost_weight: float = 1.0,
        bbox_iou_cost_weight: float = 1.0,
    ) -> list[tuple[list[int], list[int]]]:
        """Hungarian matching between predictions and targets.

        Args:
            pred: Predicted bounding boxes and class scores. It should be of shape
                `(B, num_objects, 6 + 1 + num_classes)`. Number of objects and number of classes will be inferred from
                here.
            target: Target bounding boxes and class scores. This is in argmax encoding.
            classification_cost_weight: Weight for the classification cost.
            bbox_l1_cost_weight: Weight for the bounding box L1 loss cost.
            bbox_iou_cost_weight: Weight for the bounding box IoU cost.

        Returns:
            A list of tuples containing matched indices for predictions and targets. Each tuple is of the form
            `(pred_indices, target_indices)`, where `pred_indices` and `target_indices` are lists of indices for the
            matched predictions and targets, respectively.
        """
        B = pred.shape[0]

        matched_indices = []
        for i in range(B):
            pred_bboxes = pred[i, :, :6]  # (num_objects, 6)
            target_bboxes = target[i][:, :6]  # (<=num_objects, 6)

            pred_class_logits = pred[i, :, 6:]  # (num_objects, num_classes)
            target_class_labels = target[i][:, 6].long()  # (<=num_objects,) this is in argmax encoding

            # ----- Cost matrix calculation -----

            # Classification cost
            pred_class_probabilities = F.softmax(pred_class_logits, dim=-1)
            # (num_objects, num_classes)

            classification_cost = -pred_class_probabilities[:, target_class_labels]
            # (num_objects, <=num_objects)

            # L1 loss for bounding boxes
            bbox_l1_cost = torch.cdist(pred_bboxes, target_bboxes, p=1)
            # (num_objects, <=num_objects)

            # IOU cost for bounding boxes
            bbox_iou_cost = 1 - DETR3D._generalized_pairwise_bbox_iou(pred_bboxes, target_bboxes)
            # (num_objects, <=num_objects)

            # Total cost matrix
            cost_matrix = (
                classification_cost_weight * classification_cost
                + bbox_l1_cost_weight * bbox_l1_cost
                + bbox_iou_cost_weight * bbox_iou_cost
            )
            # (num_objects, <=num_objects)

            # Hungarian matching
            pred_indices_element, target_indices_element = linear_sum_assignment(cost_matrix.detach().cpu().numpy())

            matched_indices.append((list(pred_indices_element), list(target_indices_element)))

        return matched_indices

    @staticmethod
    def _generalized_bbox_iou(
        pred_bboxes: torch.Tensor,
        target_bboxes: torch.Tensor,
    ) -> torch.Tensor:
        """Compute the IoU loss between two matched sets of bounding boxes.

        Args:
            pred_bbox: Predicted bounding box of shape `(num_boxes, 6)`.
            target_bbox: Target bounding box of shape `(num_boxes, 6)`.

        Returns:
            A tensor containing the IoU loss.
        """

        # Convert bboxes from center format (z, y, x, d, h, w) to corner format
        def center_to_corners(bboxes):
            centers = bboxes[:, :3]
            sizes = bboxes[:, 3:] / 2
            min_coords = centers - sizes
            max_coords = centers + sizes
            return torch.cat([min_coords, max_coords], dim=1)  # shape (N, 6)

        pred_corners = center_to_corners(pred_bboxes)
        target_corners = center_to_corners(target_bboxes)

        # Intersection corners
        max_min = torch.max(pred_corners[:, :3], target_corners[:, :3])
        min_max = torch.min(pred_corners[:, 3:], target_corners[:, 3:])
        inter_dims = (min_max - max_min).clamp(min=0)
        inter_vol = inter_dims.prod(dim=1)

        # Volumes
        pred_dims = pred_corners[:, 3:] - pred_corners[:, :3]
        target_dims = target_corners[:, 3:] - target_corners[:, :3]
        pred_vol = pred_dims.prod(dim=1)
        target_vol = target_dims.prod(dim=1)
        union_vol = pred_vol + target_vol - inter_vol

        # Enclosing box corners
        enc_min = torch.min(pred_corners[:, :3], target_corners[:, :3])
        enc_max = torch.max(pred_corners[:, 3:], target_corners[:, 3:])
        enc_dims = (enc_max - enc_min).clamp(min=0)
        enc_vol = enc_dims.prod(dim=1)

        iou = inter_vol / union_vol.clamp(min=1e-7)
        giou = iou - (enc_vol - union_vol) / enc_vol.clamp(min=1e-7)

        return giou.mean()

    @staticmethod
    def _generalized_pairwise_bbox_iou(
        pred_bboxes: torch.Tensor,
        target_bboxes: torch.Tensor,
    ) -> torch.Tensor:
        """Compute the IoU loss between all combinations of predicted and target bounding boxes.

        Args:
            pred_bboxes: Predicted bounding boxes of shape `(num_objects, 6)`.
            target_bboxes: Target bounding boxes of shape `(<=num_objects, 6)`.

        Returns:
            A tensor containing the IoU losses of all combinations.
        """
        # Compute pairwise IoU
        gious = []
        for i in range(pred_bboxes.shape[0]):
            row_ious = []
            for j in range(target_bboxes.shape[0]):
                giou = DETR3D._generalized_bbox_iou(pred_bboxes[i : i + 1], target_bboxes[j : j + 1])
                row_ious.append(giou)

            gious.append(torch.stack(row_ious))

        return torch.stack(gious)

In [None]:
test_config = {
    "patch_size": (8, 16, 16),
    "in_channels": 1,
    "dim": 54,
    "num_heads": 6,
    "mlp_ratio": 2,
    "layer_norm_eps": 1e-6,
    "attn_drop_prob": 0.2,
    "proj_drop_prob": 0.2,
    "mlp_drop_prob": 0.2,
    "learnable_absolute_position_embeddings": True,
    "embed_spacing_info": False,
    "image_size": (32, 512, 512),
    "num_objects": 10,
    "num_classes": 5,
    "num_layers": 4,
}

test = DETR3D(test_config)
display(test)
o = test(
    torch.randn(2, 1, 4, 32, 32),
    torch.randn(2, 3),
    return_intermediates=True,
)
display((o[0].shape, o[1].shape, [x.shape for x in o[2]]))

for gt_bboxes in [
    [
        torch.cat([torch.rand(10, 7), torch.randint(0, 5, (10, 1))], dim=-1),
        torch.cat([torch.rand(10, 7), torch.randint(0, 5, (10, 1))], dim=-1),
    ],  # Regular testing
    [torch.rand(10, 12), torch.rand(2, 12)],  # Requiring argmax encoding
]:
    display(DETR3D.bipartite_matching_loss(o[0], gt_bboxes, reduction="none"))


[1;35mDETR3D[0m[1m([0m
  [1m([0membeddings[1m)[0m: [1;35mAbsolutePositionEmbeddings3D[0m[1m([0m[1m)[0m
  [1m([0mpos_drop[1m)[0m: [1;35mDropout[0m[1m([0m[33mp[0m=[1;36m0[0m[1;36m.0[0m, [33minplace[0m=[3;91mFalse[0m[1m)[0m
  [1m([0mdecoder[1m)[0m: [1;35mDETRDecoder[0m[1m([0m
    [1m([0mlayers[1m)[0m: [1;35mModuleList[0m[1m([0m
      [1m([0m[1;36m0[0m-[1;36m3[0m[1m)[0m: [1;36m4[0m x [1;35mTransformerDecoderBlock1D[0m[1m([0m
        [1m([0mattn1[1m)[0m: [1;35mAttention1D[0m[1m([0m
          [1m([0mW_q[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m54[0m, [33mout_features[0m=[1;36m54[0m, [33mbias[0m=[3;92mTrue[0m[1m)[0m
          [1m([0mW_k[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m54[0m, [33mout_features[0m=[1;36m54[0m, [33mbias[0m=[3;92mTrue[0m[1m)[0m
          [1m([0mW_v[1m)[0m: [1;35mLinear[0m[1m([0m[33min_features[0m=[1;36m54[0m, [33mout


[1m([0m
    [1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m10[0m, [1;36m12[0m[1m][0m[1m)[0m,
    [1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m10[0m, [1;36m54[0m[1m][0m[1m)[0m,
    [1m[[0m
        [1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m10[0m, [1;36m54[0m[1m][0m[1m)[0m,
        [1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m10[0m, [1;36m54[0m[1m][0m[1m)[0m,
        [1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m10[0m, [1;36m54[0m[1m][0m[1m)[0m,
        [1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m10[0m, [1;36m54[0m[1m][0m[1m)[0m
    [1m][0m
[1m)[0m

[1;35mtensor[0m[1m([0m[1m[[0m[1;36m6.0385[0m, [1;36m5.0299[0m[1m][0m, [33mgrad_fn[0m=[1m<[0m[1;95mStackBackward0[0m[1m>[0m[1m)[0m

[1;35mtensor[0m[1m([0m[1m[[0m[1;36m5.8381[0m, [1;36m3.9212[0m[1m][0m, [33mgrad_fn[0m=[1m<[0m[1;95mStackBackward0[0m[1m>[0m[1m)[0m

# nbdev

In [None]:
!nbdev_export