In [6]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision.models.video import r3d_18  # Import the 3D ResNet-18 model
from torch.utils.data import DataLoader
import torch
import torch.nn.functional as F
# Import your components from the modules you created
from def_detr_transformer import DeformableTransformerEncoder, DeformableTransformerDecoder, DeformableTransformer, DeformableTransformerEncoderLayer, DeformableTransformerDecoderLayer
from position_encoding import PositionEmbeddingSine3D, PositionEmbeddingLearned3D
from ms_deform_attn import MSDeformAttn

import numpy as np

In [7]:
class BasicBlock3D(nn.Module):
    def __init__(self, in_planes, planes, stride=1):
        super(BasicBlock3D, self).__init__()
        self.conv1 = nn.Conv3d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm3d(planes)
        self.conv2 = nn.Conv3d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm3d(planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != planes:
            self.shortcut = nn.Sequential(
                nn.Conv3d(in_planes, planes, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm3d(planes)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out
class ResNet3D(nn.Module):
    def __init__(self, block, num_blocks):
        super(ResNet3D, self).__init__()
        self.in_planes = 64

        self.conv1 = nn.Conv3d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
        self.bn1 = nn.BatchNorm3d(64)
        self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1] * (num_blocks - 1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes
        return nn.Sequential(*layers)

    def forward(self, x):
        # Forward through the ResNet3D layers to extract multi-scale features
        x = F.relu(self.bn1(self.conv1(x)))

        # Extract multi-scale features
        x1 = self.layer1(x)  # First scale feature map
        x2 = self.layer2(x1)  # Second scale feature map
        x3 = self.layer3(x2)  # Third scale feature map
        x4 = self.layer4(x3)  # Fourth scale feature map

        return [x1, x2, x3, x4]  # Return multi-scale features

In [None]:
class DeformableDETR3D(nn.Module):
    def __init__(self, num_classes, hidden_dim=360, num_queries=100):
        super(DeformableDETR3D, self).__init__()

        # Use the ResNet3D backbone with multi-scale output
        self.backbone = ResNet3D(BasicBlock3D, [2, 2, 2, 2])
        
        # Reduce the output of each backbone feature map to the hidden_dim feature space
        self.input_proj = nn.ModuleList([
            nn.Conv3d(64, hidden_dim, kernel_size=1),
            nn.Conv3d(128, hidden_dim, kernel_size=1),
            nn.Conv3d(256, hidden_dim, kernel_size=1),
            nn.Conv3d(512, hidden_dim, kernel_size=1)
        ])

        # Positional encodings for the 3D feature maps with correct hidden_dim size
        self.position_embedding = PositionEmbeddingSine3D(channels=hidden_dim)

        # Set n_heads to 6
        n_heads = 6

        # Deformable Transformer encoder and decoder
        encoder_layer = DeformableTransformerEncoderLayer(
            d_model=hidden_dim, d_ffn=1024, dropout=0.1, activation="relu",
            n_levels=4, n_heads=n_heads, n_points=4, use_cuda=torch.cuda.is_available()
        )
        self.encoder = DeformableTransformerEncoder(encoder_layer, num_layers=6)

        decoder_layer = DeformableTransformerDecoderLayer(
            d_model=hidden_dim, d_ffn=1024, dropout=0.1, activation="relu",
            n_levels=4, n_heads=n_heads, n_points=4, use_cuda=torch.cuda.is_available()
        )
        self.decoder = DeformableTransformerDecoder(decoder_layer, num_layers=6, return_intermediate=True)

        # Prediction heads for class and bounding box
        self.linear_class = nn.Linear(hidden_dim, num_classes + 1)
        self.linear_bbox = nn.Linear(hidden_dim, 6)  # 6 coordinates for 3D bounding boxes

        # Learnable object queries
        self.query_pos = nn.Parameter(torch.rand(num_queries, hidden_dim))

        self._reset_parameters()

    def _reset_parameters(self):
        nn.init.uniform_(self.query_pos)

    def forward(self, inputs):
        # Pass inputs through the 3D ResNet backbone to get multi-scale features
        multi_scale_features = self.backbone(inputs)

        # Apply input projection to match hidden_dim for each scale
        features = [proj(feat) for proj, feat in zip(self.input_proj, multi_scale_features)]

        # Compute positional embeddings for each scale
        pos_embeds = [self.position_embedding(feat) for feat in features]

        # Calculate spatial shapes, keeping (D, H, W) format
        spatial_shapes = torch.as_tensor(
            [(feat.shape[-3], feat.shape[-2], feat.shape[-1]) for feat in features],  # Shapes per level (D, H, W)
            dtype=torch.long, 
            device=features[0].device
        )  # Shape is [n_levels, 3]

        # Correctly flatten the spatial dimensions and prepare for the transformer
        srcs = [feat.flatten(2).permute(0, 2, 1) for feat in features]  # Flatten and permute to [batch_size, D*H*W, hidden_dim]
        pos_embeds = [pos.flatten(2).permute(0, 2, 1) for pos in pos_embeds]

        # Concatenate all levels properly along the patch dimension
        src_flatten = torch.cat(srcs, dim=1)  # Concatenate across levels along the patch dimension
        pos_flatten = torch.cat(pos_embeds, dim=1)

        # Calculate level start indices, ensuring dimensions match
        level_start_index = torch.cat((spatial_shapes.new_zeros((1,)), (spatial_shapes[:, 0] * spatial_shapes[:, 1] * spatial_shapes[:, 2]).cumsum(0)[:-1]))

        # Validate that total elements across all levels match
        total_elements = (spatial_shapes[:, 0] * spatial_shapes[:, 1] * spatial_shapes[:, 2]).sum().item()
        assert total_elements == src_flatten.shape[1], f"Mismatch between spatial shapes total {total_elements} and src_flatten {src_flatten.shape[1]}"

        # Calculate valid ratios for each level
        valid_ratios = torch.ones((inputs.size(0), len(features), 3), device=src_flatten.device)  # Adjusted for 3D data

        # Debugging output to ensure all inputs are correctly shaped
     #   print(f"src_flatten shape: {src_flatten.shape}")
     #   print(f"spatial_shapes: {spatial_shapes}")
     #   print(f"level_start_index: {level_start_index}")
     #   print(f"valid_ratios shape: {valid_ratios.shape}")
     #   print(f"pos_flatten shape: {pos_flatten.shape}")

        # Pass through the Deformable Transformer encoder
        memory = self.encoder(src_flatten, spatial_shapes, level_start_index, valid_ratios, pos=pos_flatten)

        # Prepare object queries for the decoder
        query_embed = self.query_pos.unsqueeze(1).repeat(1, memory.shape[1], 1)  # [num_queries, batch_size, hidden_dim]

        # Pass through the Deformable Transformer decoder
        hs, _ = self.decoder(query_embed, memory, spatial_shapes, level_start_index, valid_ratios)

        # Transpose and pass through prediction heads
        hs = hs.transpose(1, 2)  # [batch_size, num_queries, hidden_dim]
        pred_logits = self.linear_class(hs)  # [batch_size, num_queries, num_classes + 1]
        pred_boxes = self.linear_bbox(hs).sigmoid()  # [batch_size, num_queries, 6]

        return {'pred_logits': pred_logits, 'pred_boxes': pred_boxes}

# Usage Example
num_classes = 5  # Number of object classes
model = DeformableDETR3D(num_classes, hidden_dim=360)
inputs = torch.randn(1, 1, 128, 128, 128)  # Example 3D input [batch_size, channels, depth, height, width]
outputs = model(inputs)

print(outputs['pred_logits'].shape)  # [1, num_queries, num_classes + 1]
print(outputs['pred_boxes'].shape)



src_flatten shape: torch.Size([1, 299520, 360])
spatial_shapes: tensor([[64, 64, 64],
        [32, 32, 32],
        [16, 16, 16],
        [ 8,  8,  8]])
level_start_index: tensor([     0, 262144, 294912, 299008])
valid_ratios shape: torch.Size([1, 4, 3])
pos_flatten shape: torch.Size([1, 299520, 360])
