## Create toy dataset

In [45]:
import torch
from pathlib import Path
import random

In [46]:
def generate_sphere_volume(
    shape=(64, 64, 64),
    radius_range=(4, 8),
):
    """
    Returns:
        volume: (1, D, H, W)
        box: (6,) -> (cx, cy, cz, dx, dy, dz)
    """
    D, H, W = shape
    volume = torch.zeros(1, D, H, W)

    r = random.randint(*radius_range)
    cx = random.randint(r, W - r - 1)
    cy = random.randint(r, H - r - 1)
    cz = random.randint(r, D - r - 1)

    z, y, x = torch.meshgrid(
        torch.arange(D),
        torch.arange(H),
        torch.arange(W),
        indexing="ij",
    )

    mask = (x - cx) ** 2 + (y - cy) ** 2 + (z - cz) ** 2 <= r ** 2
    volume[0][mask] = 1.0

    # Bounding box in (cx, cy, cz, dx, dy, dz)
    x1, x2 = cx - r, cx + r
    y1, y2 = cy - r, cy + r
    z1, z2 = cz - r, cz + r

    box = torch.tensor(
        [
            (x1 + x2) / 2,
            (y1 + y2) / 2,
            (z1 + z2) / 2,
            x2 - x1,
            y2 - y1,
            z2 - z1,
        ],
        dtype=torch.float32,
    )

    return volume, box

def create_toy_detection_dataset(
    root_dir,
    num_train=500,
    num_val=100,
    shape=(64, 64, 64),
):
    root = Path(root_dir)

    for split, n_samples in [("train", num_train), ("val", num_val)]:
        vol_dir = root / split / "volumes"
        tgt_dir = root / split / "targets"
        vol_dir.mkdir(parents=True, exist_ok=True)
        tgt_dir.mkdir(parents=True, exist_ok=True)

        for i in range(n_samples):
            volume, box = generate_sphere_volume(shape)

            torch.save(
                volume,
                vol_dir / f"sample_{i:04d}.pt",
            )

            torch.save(
                {
                    "boxes": box.unsqueeze(0),      # (1, 6)
                    "labels": torch.tensor([0]),    # single class
                },
                tgt_dir / f"sample_{i:04d}.pt",
            )

    print(f"Dataset written to: {root}")



In [47]:
create_toy_detection_dataset(
    root_dir="/home/users/ishan.tiwari/Ishan_Nodseg/qct_3d_nod_detect/toy_dataset",
    num_train=1000,
    num_val=200,
    shape=(128,)*3
)

Dataset written to: /home/users/ishan.tiwari/Ishan_Nodseg/qct_3d_nod_detect/toy_dataset


In [33]:
import torch
import matplotlib.pyplot as plt
from pathlib import Path
import ipywidgets as widgets
from IPython.display import display

def visualize_3d_bbox(
    root_dir,
    split="train",
    sample_name=None,
):
    """
    Interactive visualization for a 3D volume with 3D bounding box.
    
    Expected structure:
    root_dir/
        train/
            volumes/sample_xxxx.pt
            targets/sample_xxxx.pt
        val/
            volumes/sample_xxxx.pt
            targets/sample_xxxx.pt

    target format:
    {
        'boxes': tensor([[x, y, z, dx, dy, dz]]),
        'labels': tensor([class_id])
    }
    """

    root_dir = Path(root_dir)
    vol_dir = root_dir / split / "volumes"
    tgt_dir = root_dir / split / "targets"

    samples = sorted([p.stem for p in vol_dir.glob("*.pt")])
    if len(samples) == 0:
        raise RuntimeError("No samples found")

    if sample_name is None:
        sample_name = samples[0]

    vol = torch.load(vol_dir / f"{sample_name}.pt")
    tgt = torch.load(tgt_dir / f"{sample_name}.pt")

    vol = vol.squeeze().cpu()
    box = tgt["boxes"][0].cpu()  # [x, y, z, dx, dy, dz]

    x, y, z, dx, dy, dz = box
    x0, x1 = int(x - dx / 2), int(x + dx / 2)
    y0, y1 = int(y - dy / 2), int(y + dy / 2)
    z0, z1 = int(z - dz / 2), int(z + dz / 2)

    x0, y0, z0 = max(x0, 0), max(y0, 0), max(z0, 0)
    x1, y1, z1 = min(x1, vol.shape[2]-1), min(y1, vol.shape[1]-1), min(z1, vol.shape[0]-1)

    def plot_slice(axis, idx):
        plt.figure(figsize=(5, 5))

        if axis == "z":
            img = vol[idx]
            if z0 <= idx <= z1:
                plt.gca().add_patch(
                    plt.Rectangle(
                        (x0, y0),
                        x1 - x0,
                        y1 - y0,
                        fill=False,
                        edgecolor="red",
                        linewidth=2,
                    )
                )
            plt.imshow(img, cmap="gray")

        elif axis == "y":
            img = vol[:, idx, :]
            if y0 <= idx <= y1:
                plt.gca().add_patch(
                    plt.Rectangle(
                        (x0, z0),
                        x1 - x0,
                        z1 - z0,
                        fill=False,
                        edgecolor="red",
                        linewidth=2,
                    )
                )
            plt.imshow(img, cmap="gray")

        elif axis == "x":
            img = vol[:, :, idx]
            if x0 <= idx <= x1:
                plt.gca().add_patch(
                    plt.Rectangle(
                        (y0, z0),
                        y1 - y0,
                        z1 - z0,
                        fill=False,
                        edgecolor="red",
                        linewidth=2,
                    )
                )
            plt.imshow(img, cmap="gray")

        plt.title(f"{sample_name} | axis={axis} | slice={idx}")
        plt.axis("off")
        plt.show()

    axis_dd = widgets.Dropdown(
        options=["z", "y", "x"],
        value="z",
        description="Axis:",
    )

    slice_slider = widgets.IntSlider(
        min=0,
        max=vol.shape[0] - 1,
        step=1,
        value=int(z),
        description="Slice:",
        continuous_update=False,
    )

    def update_slider(*args):
        axis = axis_dd.value
        if axis == "z":
            slice_slider.max = vol.shape[0] - 1
            slice_slider.value = int(z)
        elif axis == "y":
            slice_slider.max = vol.shape[1] - 1
            slice_slider.value = int(y)
        elif axis == "x":
            slice_slider.max = vol.shape[2] - 1
            slice_slider.value = int(x)

    axis_dd.observe(update_slider, names="value")

    ui = widgets.VBox([axis_dd, slice_slider])
    out = widgets.interactive_output(
        plot_slice,
        {"axis": axis_dd, "idx": slice_slider},
    )

    display(ui, out)


# Example usage in notebook:
# visualize_3d_bbox(
#     root_dir="/path/to/dataset",
#     split="train",
#     sample_name="sample_0001"
# )


In [36]:
visualize_3d_bbox(root_dir="/home/users/ishan.tiwari/Ishan_Nodseg/qct_3d_nod_detect/toy_dataset", split="train", sample_name="sample_0011")

VBox(children=(Dropdown(description='Axis:', options=('z', 'y', 'x'), value='z'), IntSlider(value=16, continuoâ€¦

Output()

## Load the toy sphere dataset

In [17]:
from torch import nn
from qct_3d_nod_detect.structures import Instances3D, Boxes3D
import lightning.pytorch as pl
import torch

class GeneralizedRCNN3D(nn.Module):
    def __init__(
        self,
        backbone_fpn,
        rpn,
        roi_pooler,
        roi_head,
    ):
        super().__init__()
        self.backbone_fpn = backbone_fpn
        self.rpn = rpn
        self.roi_pooler = roi_pooler
        self.roi_head = roi_head

    def forward(self, images, gt_instances=None):
        """
        Args:
            images: Tensor[B, 1, D, H, W]
            gt_instances: Optional[List[Instances3D]]
        """

        # ----------------------------------
        # Wrap images as ImageList3d
        # ----------------------------------
        image_list = build_image_list_3d(images)

        # ----------------------------------
        # Backbone + FPN
        # ----------------------------------
        features = self.backbone_fpn(images)
        # Dict[str, Tensor]

        # ----------------------------------
        # RPN (uses ImageList3d + Instances3D)
        # ----------------------------------
        proposals, rpn_losses = self.rpn(
            image_list,
            features,
            gt_instances if self.training else None,
        )

        proposal_boxes = [x.proposal_boxes for x in proposals]

        # ----------------------------------
        # ROI Pooling
        # ----------------------------------
        x = list(features.values())
        roi_features = self.roi_pooler(
            x,
            proposal_boxes,
        )

        # ----------------------------------
        # ROI Head
        # ----------------------------------
        predictions = self.roi_head(roi_features)

        if self.training:
            
            roi_losses = self.roi_head.losses(predictions, proposals)

            losses = {}
            losses.update(rpn_losses)
            losses.update(roi_losses)
            return losses

        else:
            detections, _ = self.roi_head.inference(predictions, proposals)
            return detections

from qct_3d_nod_detect.structures import ImageList3D

def build_image_list_3d(images: torch.Tensor) -> ImageList3D:
    """
    Args:
        images: Tensor[B, C, D, H, W]
    """
    image_sizes = [tuple(images.shape[-3:]) for _ in range(images.shape[0])]
    return ImageList3D(image_sizes)

def build_instances_3d(batch):
    instances = []

    for boxes, classes in zip(batch["gt_boxes"], batch["gt_classes"]):
        inst = Instances3D(image_size=batch["image"].shape[-3:])
        inst.gt_boxes = Boxes3D(boxes)
        inst.gt_classes = classes
        instances.append(inst)

    return instances

# Lightning module
import lightning.pytorch as pl

class FasterRCNN3DLightning(pl.LightningModule):
    def __init__(self, model, lr=1e-4):
        super().__init__()
        self.model = model
        self.lr = lr

    def training_step(self, batch, batch_idx):
        images = batch["image"]                # Tensor[B, 1, D, H, W]
        gt_instances = build_instances_3d(batch)

        losses = self.model(images, gt_instances)

        total_loss = sum(losses.values())

        self.log_dict(
            {k: v.detach() for k, v in losses.items()},
            prog_bar=True,
        )
        self.log("loss_total", total_loss, prog_bar=True)

        return total_loss


    def configure_optimizers(self):
        return torch.optim.AdamW(self.parameters(), lr=self.lr)


In [18]:
# Dataset
from torch.utils.data import Dataset, DataLoader
from pathlib import Path
import torch

class ToySphereDetectionDataset(Dataset):
    def __init__(self, root_dir, split="train"):
        self.root = Path(root_dir) / split
        self.vol_dir = self.root / "volumes"
        self.tgt_dir = self.root / "targets"

        self.ids = sorted(p.stem for p in self.vol_dir.glob("*.pt"))

    def __len__(self):
        return len(self.ids)

    def __getitem__(self, idx):
        sid = self.ids[idx]

        volume = torch.load(self.vol_dir / f"{sid}.pt")   # (1, D, H, W)
        target = torch.load(self.tgt_dir / f"{sid}.pt")

        return {
            "image": volume,                  # Tensor[1, D, H, W]
            "gt_boxes": target["boxes"],      # Tensor[N, 6]
            "gt_classes": target["labels"],   # Tensor[N]
        }
    
class ToySphereDetectionDataModule(pl.LightningDataModule):
    def __init__(
        self,
        root_dir: str,
        batch_size: int = 2,
        num_workers: int = 4,
        pin_memory: bool = True,
    ):
        super().__init__()
        self.root_dir = root_dir
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.pin_memory = pin_memory

    def setup(self, stage=None):
        # Called once per process
        self.train_dataset = ToySphereDetectionDataset(
            root_dir=self.root_dir,
            split="train",
        )

        self.val_dataset = ToySphereDetectionDataset(
            root_dir=self.root_dir,
            split="val",
        )

    def train_dataloader(self):
        return DataLoader(
            self.train_dataset,
            batch_size=self.batch_size,
            shuffle=True,
            num_workers=self.num_workers,
            pin_memory=self.pin_memory,
            collate_fn=detection_collate,
        )

    def val_dataloader(self):
        return DataLoader(
            self.val_dataset,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=self.num_workers,
            pin_memory=self.pin_memory,
            collate_fn=detection_collate,
        )

def detection_collate(batch):
    return {
        "image": torch.stack([b["image"] for b in batch], dim=0),
        "gt_boxes": [b["gt_boxes"] for b in batch],
        "gt_classes": [b["gt_classes"] for b in batch],
    }

train_dataset = ToySphereDetectionDataset(
    root_dir="/home/users/ishan.tiwari/Ishan_Nodseg/qct_3d_nod_detect/toy_dataset",
    split="train",
)

val_dataset = ToySphereDetectionDataset(
    root_dir="/home/users/ishan.tiwari/Ishan_Nodseg/qct_3d_nod_detect/toy_dataset",
    split="val",
)

datamodule = ToySphereDetectionDataModule(root_dir='/home/users/ishan.tiwari/Ishan_Nodseg/qct_3d_nod_detect/toy_dataset',
                                          batch_size=2)

In [19]:
train_dataloader = DataLoader(
    train_dataset,
    batch_size=2,          # start small
    shuffle=True,
    num_workers=4,
    pin_memory=True,
    collate_fn=detection_collate,
)

val_dataloader = DataLoader(
    val_dataset,
    batch_size=2,
    shuffle=False,
    num_workers=4,
    pin_memory=True,
    collate_fn=detection_collate,
)


In [20]:
batch = next(iter(train_dataloader))

print(batch["image"].shape)       # (B, 1, D, H, W)
print(len(batch["gt_boxes"]))     # B
print(batch["gt_boxes"][0].shape) # (N, 6)


torch.Size([2, 1, 128, 128, 128])
2
torch.Size([1, 6])


In [21]:
from qct_3d_nod_detect.rpn import RPN3D, StandardRPNHead3d
from qct_3d_nod_detect.roi_head import FasterRCNNOutputLayers3D
from qct_3d_nod_detect.poolers import ROIPooler3D
from qct_3d_nod_detect.anchor_generator_3d import DefaultAnchorGenerator3D
from qct_3d_nod_detect.box_regression import Box3DTransform
from qct_3d_nod_detect.matcher import Matcher
from qct_3d_nod_detect.backbones import build_vit_backbone_with_fpn
import math

In [22]:
anchor_generator_3d = DefaultAnchorGenerator3D(
    sizes=[[1], [2], [4], [8]],
    aspect_ratios_3d=[[(1.0, 1.0)], [(1.0, 1.0)], [(1.0, 1.0)], [(1.0, 1.0)]],
    strides=[4, 8, 16, 32],
    offset=0.5,
)

backbone_fpn = build_vit_backbone_with_fpn(
    variant="L",
    ckpt_path="/raid15/utkarsh.singh/ctssl/experiments/mae_vit3d_l_grid384x512x512_cs128_ps8_randmask09/checkpoints/best.ckpt",
    scales=[1, 2, 0.5, 0.25],
    out_channels=256
)

box3d2box3d_transform = Box3DTransform(
    weights=(1.0, 1.0, 1.0, 1.0, 1.0, 1.0),
    scale_clamp=math.log(1000.0),
)

rpn_head_3d = StandardRPNHead3d(
    in_channels=256,
    num_anchors=anchor_generator_3d.num_cell_anchors[0],
    box_dim=6
)

anchor_matcher = Matcher(
    thresholds=[0.3, 0.7],
    labels=[0, -1, 1],
    allow_low_quality_matches=True,
)

roi_pooler = ROIPooler3D(
    output_size=(7, 7, 7),
    canonical_level=4,
    canonical_box_size=224,
    pooler_type="ROIALign3DV2",
    scales=[1, 2, 0.5, 0.25]
)

rpn = RPN3D(
    in_features=["p2", "p3", "p4", "p5"],
    head=rpn_head_3d,
    anchor_generator=anchor_generator_3d,
    anchor_matcher=anchor_matcher,
    box3d_transform=box3d2box3d_transform,
    batch_size_per_image=256,
    positive_fraction=0.5,
    pre_nms_topk=(200, 100),
    post_nms_topk=(100, 50),
    nms_thresh=0.5,
    min_box_size=2.0,
    box_reg_loss_type="smooth_l1",
    smooth_l1_beta=0.0,
    is_training=True
)

roi_head = FasterRCNNOutputLayers3D(
    input_shape=(32, 256, 7, 7, 7),
    num_classes=1,
    box2box_transform=box3d2box3d_transform,
    cls_agnostic_bbox_reg=False,
)

rcnn = GeneralizedRCNN3D(backbone_fpn=backbone_fpn,
                         rpn=rpn,
                         roi_pooler=roi_pooler,
                         roi_head=roi_head,)

lit_model = FasterRCNN3DLightning(
    model=rcnn,
    lr=1e-4
)

[32m2026-01-21 09:43:42.463[0m | [1mINFO    [0m | [36mqct_3d_nod_detect.backbones[0m:[36mload_state_dict_into_vit[0m:[36m351[0m - [1mMissing MAE keys during load: %s[0m


Missing keys: ['pos_embed'], Unexpected keys: []


In [23]:
trainer = pl.Trainer(
    max_epochs=5,
    accelerator="gpu",        # or "cpu"
    devices=1,
    precision=32,             # start without AMP
    log_every_n_steps=5,
    enable_checkpointing=False,
    enable_model_summary=True,
)

trainer.fit(
    lit_model,
    datamodule=datamodule
)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
/home/users/ishan.tiwari/miniconda3/envs/qct_nod_seg/lib/python3.11/site-packages/lightning/pytorch/trainer/configuration_validator.py:68: You passed in a `val_dataloader` but have no `validation_step`. Skipping val loop.
You are using a CUDA device ('NVIDIA L40S') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]

  | Name  | Type              | Params | Mode 
----------------------------------------------------
0 | model | GeneralizedRCNN3D | 351 M  | train
----------------------------------------------------
347 M     Trainable params
4.2 M     Non-trainable params
351 M     Total params
1,404.989 Total estimated model params size (MB)
436       Modules in train mode
0         Modules in eval mode


Training: |          | 0/? [00:00<?, ?it/s]

AttributeError: Cannot find field 'gt_classes' in the given Instances!

In [25]:
proposals[0]

Instances3D(num_instances=32, image_height=128, image_width=128, fields=[proposal_boxes: Boxes3D(tensor([[ 81.6441,  57.4957, 113.2783,  82.4525,  58.4657, 114.3582],
        [ 89.6444,  57.5030, 113.2813,  90.4595,  58.4642, 114.3541],
        [ 81.6471,  65.5006, 113.2797,  82.4592,  66.4578, 114.3580],
        [ 89.6488,  65.5047, 113.2798,  90.4632,  66.4579, 114.3544],
        [ 33.6491,   9.5046,  65.2855,  34.4605,  10.4533,  66.3566],
        [ 89.6460,  49.5050, 113.2894,  90.4539,  50.4666, 114.3578],
        [ 81.6465,  49.4965, 113.2828,  82.4484,  50.4705, 114.3632],
        [ 81.6515,  57.5023, 105.2821,  82.4588,  58.4588, 106.3612],
        [ 73.6546,  57.4946, 113.2885,  74.4594,  58.4582, 114.3676],
        [ 89.6516,  57.5062, 105.2839,  90.4614,  58.4574, 106.3589],
        [ 11.8512,  11.9104,  42.9540,  19.2663,  18.9172,  51.5887],
        [ 74.7293,  43.4593, 107.3351,  83.1400,  51.1693, 114.9526],
        [ 75.2848,  75.7828, 107.5316,  83.4733,  83.9337, 115.

In [8]:
images = batch["image"]                # Tensor[B, 1, D, H, W]
gt_instances = build_instances_3d(batch)
features = backbone_fpn(images)

In [9]:
image_list = build_image_list_3d(images)
proposals, rpn_losses = rpn(
                            image_list,
                            features,
                            gt_instances,
                        )

proposal_boxes = [x.proposal_boxes for x in proposals]
# ----------------------------------
# ROI Pooling
# ----------------------------------
x = list(features.values())
roi_features = roi_pooler(
    x,
    proposal_boxes,
)

In [10]:
roi_features.shape

torch.Size([62, 256, 7, 7, 7])

In [11]:
predictions = roi_head(roi_features)

if True:
    roi_losses = roi_head.losses(predictions, proposals)

    losses = {}
    losses.update(rpn_losses)
    losses.update(roi_losses)

else:
    detections, _ = self.roi_head.inference(predictions, proposals)

AttributeError: Cannot find field 'gt_classes' in the given Instances!

In [14]:
predictions[0].shape

torch.Size([62, 2])

tensor([[0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0.],
        [0., 0