In [None]:
#| default_exp detector
#| export
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import List, Dict, Tuple

from qct_3d_nod_detect.backbone import Backbone
from qct_3d_nod_detect.rpn import RPN3D
from qct_3d_nod_detect.roi_heads import RoIHead3D
from qct_3d_nod_detect.structures import Instances3D, Boxes3D, ImageList3D
from qct_3d_nod_detect.matcher import Matcher
from qct_3d_nod_detect.box_regression import BoxTransform3D
from qct_3d_nod_detect.poolers import ROIPooler3D

class Faster_RCNN3D(nn.Module):

    def __init__(
        self,
        backbone: Backbone,
        rpn: RPN3D,
        pooler: ROIPooler3D,
        roi_head: RoIHead3D,
        matcher: Matcher,
        box_transform: BoxTransform3D,
        num_classes: int = 2 
    ):

        super().__init__()
        self.backbone_fpn = backbone
        self.rpn = rpn
        self.pooler = pooler
        self.roi_head = roi_head

        self.matcher = matcher
        self.box_transform = box_transform
        self.num_classes = num_classes

    def forward(
        self, 
        images: ImageList3D,
        features: Dict[str, torch.Tensor],
        gt_instances: List[Instances3D] = None,
    ) -> Tuple[Dict[str, torch.Tensor], List[Instances3D]]:
        
        """
        Training mode:
            returns losses dict + empty proposals (or proposals with gt for viz)
        Inference mode:
            returns empty losses + list of predicted Instances3D
        """

        if features is None:
            features = self.backbone_fpn(images.tensors)

        proposals, rpn_losses = self.rpn(
            images=images,
            features=features,
            gt_instances=gt_instances if self.training else None,
        )

        if not self.training:
            return {}, proposals # inference: return detections

        proposal_boxes = [inst.proposal_boxes for inst in proposals]

        pooled_features = self.pooler(
            x = [features[k] for k in self.pooler.scales_order or ["p3", "p4", "p5"]],
            box_lists = proposal_boxes
        )

        cls_logits, box_deltas = self.roi_head(pooled_features)
        roi_losses = self._compute_roi_losses(
            proposal_boxes,
            gt_instances,
            cls_logits,
            box_deltas
        )

        losses = {**rpn_losses, **roi_losses}
        return losses

    def _compute_roi_losses(
            self,
            proposal_boxes: List[Boxes3D],
            gt_instances: List[Instances3D],
            cls_logits: torch.Tensor,
            box_deltas: torch.Tensor,
    ) -> Dict[str, torch.Tensor]:

        all_gt_boxes = [inst.gt_boxes for inst in gt_instances]
        all_gt_classes = [inst.gt_classes for inst in gt_instances]

        match_results = self.matcher.match()





ModuleNotFoundError: No module named 'qct_3d_nod_detect.backbone'