In [15]:
import torch
import torch.nn as nn
import torch.nn.functional as F

# ChArUcoNet Architecture
class ChArUcoNet(nn.Module):
    def __init__(self):
        super(ChArUcoNet, self).__init__()

        # (based on SuperPoint)
        c1, c2, c3, c4, c5, d1 = 64, 64, 128, 128, 256, 256

        self.encoder = nn.Sequential(
            # Block 1
            nn.Conv2d(1, c1, kernel_size=3, stride=1, padding=1),  # Input: 1x240x320
            nn.ReLU(inplace=True),
            nn.Conv2d(c1, c1, kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),  # 64x120x160
            # Block 2
            nn.Conv2d(c1, c2, kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(c2, c2, kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),  # 64x60x80
            # Block 3
            nn.Conv2d(c2, c3, kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(c3, c3, kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),  # 128x30x40
            # Block 4
            nn.Conv2d(c3, c4, kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(c4, c4, kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True),  # 128x30x40
        )

        # Point Detection Head (65 classes: 64 locations + "no-point")
        self.det_head_a = torch.nn.Conv2d(c4, c5, kernel_size=3, stride=1, padding=1)
        self.det_head_b = torch.nn.Conv2d(c5, 65, kernel_size=1, stride=1, padding=0)

        # ID Classification Head (17 classes: 16 IDs + "none")
        self.id_head_a = nn.Conv2d(c4, c5, kernel_size=3, padding=1)
        self.id_head_b = nn.Conv2d(c5, 17, kernel_size=1)

    def forward(self, x):
        # Input: [batch_size, 1, 320, 240]
        features = self.encoder(x)

        # Point detection: [batch_size, 65, 40, 30]
        point_features = F.relu(self.det_head_a(features))
        point_logits = self.det_head_b(point_features)
        point_probs = F.softmax(point_logits, dim=1)

        # ID classification: [batch_size, 17, 40, 30]
        id_features = F.relu(self.id_head_a(features))
        id_logits = self.id_head_b(id_features)
        id_probs = F.softmax(id_logits, dim=1)

        return point_probs, id_probs


# Example Usage
def main():
    charuco_net = ChArUcoNet()

    input_image = torch.randn(1, 1, 320, 240)  # Dummy input

    # Forward pass
    point_probs, id_probs = charuco_net(input_image)
    print(f"Point Probabilities Shape: {point_probs.shape}")    # [1, 65, 80, 60]
    print(f"ID Probabilities Shape: {id_probs.shape}")          # [1, 17, 80, 60]

if __name__ == "__main__":
    main()

Point Probabilities Shape: torch.Size([1, 65, 40, 30])
ID Probabilities Shape: torch.Size([1, 17, 40, 30])


In [16]:
from torchsummary import summary

model_charuco = ChArUcoNet()
summary(model_charuco, input_size=(1, 240, 320))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 64, 240, 320]             640
              ReLU-2         [-1, 64, 240, 320]               0
            Conv2d-3         [-1, 64, 240, 320]          36,928
              ReLU-4         [-1, 64, 240, 320]               0
         MaxPool2d-5         [-1, 64, 120, 160]               0
            Conv2d-6         [-1, 64, 120, 160]          36,928
              ReLU-7         [-1, 64, 120, 160]               0
            Conv2d-8         [-1, 64, 120, 160]          36,928
              ReLU-9         [-1, 64, 120, 160]               0
        MaxPool2d-10           [-1, 64, 60, 80]               0
           Conv2d-11          [-1, 128, 60, 80]          73,856
             ReLU-12          [-1, 128, 60, 80]               0
           Conv2d-13          [-1, 128, 60, 80]         147,584
             ReLU-14          [-1, 128,