In [2]:
import math

import torch
import torch.nn as nn
import torchvision

In [4]:
class HR2O_NL(nn.Module):
    def __init__(self, hidden_dim=512, kernel_size=3, mlp_1x1=False):
        super(HR2O_NL, self).__init__()

        self.hidden_dim = hidden_dim

        padding = kernel_size // 2
        self.conv_q = nn.Conv2d(hidden_dim, hidden_dim, kernel_size, padding=padding, bias=False)
        self.conv_k = nn.Conv2d(hidden_dim, hidden_dim, kernel_size, padding=padding, bias=False)
        self.conv_v = nn.Conv2d(hidden_dim, hidden_dim, kernel_size, padding=padding, bias=False)
        self.lrelu = nn.LeakyReLU(negative_slope=0.3)

        self.conv = nn.Conv2d(
            hidden_dim, hidden_dim,
            1 if mlp_1x1 else kernel_size,
            padding=0 if mlp_1x1 else padding,
            bias=False
        )
        self.norm = nn.GroupNorm(1, hidden_dim, affine=True)
        self.dp = nn.Dropout(0.2)

    def forward(self, x):
        query = self.conv_q(x).unsqueeze(1)
        key = self.conv_k(x).unsqueeze(0)
        att = (query * key).sum(2) / (self.hidden_dim ** 0.5)
        att = nn.Softmax(dim=1)(att)
        value = self.conv_v(x)
        virt_feats = (att.unsqueeze(2) * value).sum(1)

        virt_feats = self.norm(virt_feats)
        virt_feats = self.lrelu(virt_feats)
        virt_feats = self.conv(virt_feats)
        virt_feats = self.dp(virt_feats)

        x = x + virt_feats
        return x


In [6]:
batch_size = 4
hidden_dim = 512
height = 32
width = 32

dummy_input = torch.randn(batch_size, hidden_dim, height, width)

# Initialize the model
model = HR2O_NL(hidden_dim=hidden_dim)

# Pass the dummy data through the model
output = model(dummy_input)

# Print the shape of the output
print(f"Output shape: {output.shape}")

Output shape: torch.Size([4, 512, 32, 32])


In [28]:
import torch
import torch.nn as nn

class SceneUnderstandor(nn.Module):
    def __init__(self, hidden_dim=512, kernel_size=3, mlp_1x1=False):
        super(SceneUnderstandor, self).__init__()

        # The HR2O_NL layer (High Order Reasoning Operator)
        self.hr2o_nl = HR2O_NL(hidden_dim=hidden_dim, kernel_size=kernel_size, mlp_1x1=mlp_1x1)

        # Convolution layer to process RGB image (to get the same hidden_dim)
        self.conv_rgb = nn.Conv2d(3, hidden_dim, kernel_size=kernel_size, padding=kernel_size // 2, bias=False)

        # Linear layers for ROI bounding boxes (8 points) and keypoints (84 points)
        self.fc_roi = nn.Linear(8, hidden_dim)
        self.fc_keypoints = nn.Linear(84, hidden_dim)
        self.fc_keypoints
        # A fully connected layer to integrate all features (RGB + ROI + keypoints)
        self.fc_integrated = nn.Linear(3 * hidden_dim, hidden_dim)

        # Additional layers for processing the integrated feature
        self.fc_out = nn.Linear(hidden_dim, hidden_dim)
        self.relu = nn.ReLU()

    def forward(self, rgb_image, roi_bbox, keypoints):
        # Step 1: Process the RGB image through convolution
        rgb_features = self.conv_rgb(rgb_image)  # Shape: (batch_size, hidden_dim, height, width)
        rgb_features = self.hr2o_nl(rgb_features)  # Apply the HR2O_NL layer

        print(rgb_features.shape)
        # Flatten the rgb_features tensor: (batch_size, hidden_dim, height, width) -> (batch_size, hidden_dim * height * width)
        rgb_features = rgb_features.view(rgb_features.size(0), -1)

        # Step 2: Process the ROI bounding box (8 points) and keypoints (84 points) through fully connected layers
        roi_features = self.fc_roi(roi_bbox)  # Shape: (batch_size, hidden_dim)
        keypoints_features = self.fc_keypoints(keypoints)  # Shape: (batch_size, hidden_dim)

        print(rgb_features.shape, keypoints_features.shape, roi_features.shape)
        # Step 3: Integrate the features (RGB, ROI, and keypoints)
        integrated_features = torch.cat((rgb_features, roi_features, keypoints_features), dim=1)

        print(integrated_features.shape)

        # Step 4: Pass through fully connected layers to obtain the final output features
        integrated_features = self.fc_integrated(integrated_features)
        integrated_features = self.relu(integrated_features)

        # Output layer (you can adjust the output size as needed)
        output = self.fc_out(integrated_features)

        return output


In [29]:
# Create dummy data
batch_size = 4
hidden_dim = 512
height = 32
width = 32

# Dummy RGB image (batch_size, 3, height, width)
rgb_image = torch.randn(batch_size, 3, height, width)

# Dummy ROI bounding boxes (batch_size, 8) with two sets of 4 points each
roi_bbox = torch.randn(batch_size, 8)

# Dummy keypoints (batch_size, 84) representing 42 keypoints in x, y sequence
keypoints = torch.randn(batch_size, 84)

# Initialize the model
scene_understandor = SceneUnderstandor(hidden_dim=hidden_dim)

# Pass the dummy data through the model
output = scene_understandor(rgb_image, roi_bbox, keypoints)

# Print the output shape
print(f"Output shape: {output.shape}")


torch.Size([4, 512, 32, 32])
torch.Size([4, 524288]) torch.Size([4, 512]) torch.Size([4, 512])
torch.Size([4, 525312])


RuntimeError: mat1 and mat2 shapes cannot be multiplied (4x525312 and 1536x512)

In [27]:
512*32*32

524288

In [30]:
import torch
import torch.nn as nn

class HR2O_NL(nn.Module):
    def __init__(self, hidden_dim=512, kernel_size=3, mlp_1x1=False):
        super(HR2O_NL, self).__init__()

        self.hidden_dim = hidden_dim

        padding = kernel_size // 2
        self.conv_q = nn.Conv2d(hidden_dim, hidden_dim, kernel_size, padding=padding, bias=False)
        self.conv_k = nn.Conv2d(hidden_dim, hidden_dim, kernel_size, padding=padding, bias=False)
        self.conv_v = nn.Conv2d(hidden_dim, hidden_dim, kernel_size, padding=padding, bias=False)
        self.lrelu = nn.LeakyReLU(negative_slope=0.3)

        self.conv = nn.Conv2d(
            hidden_dim, hidden_dim,
            1 if mlp_1x1 else kernel_size,
            padding=0 if mlp_1x1 else padding,
            bias=False
        )
        self.norm = nn.GroupNorm(1, hidden_dim, affine=True)
        self.dp = nn.Dropout(0.2)

    def forward(self, x):
        query = self.conv_q(x).unsqueeze(1)
        key = self.conv_k(x).unsqueeze(0)
        att = (query * key).sum(2) / (self.hidden_dim ** 0.5)
        att = nn.Softmax(dim=1)(att)
        value = self.conv_v(x)
        virt_feats = (att.unsqueeze(2) * value).sum(1)

        virt_feats = self.norm(virt_feats)
        virt_feats = self.lrelu(virt_feats)
        virt_feats = self.conv(virt_feats)
        virt_feats = self.dp(virt_feats)

        x = x + virt_feats
        return x




In [39]:
class SceneUnderstandor(nn.Module):
    def __init__(self, hidden_dim=512, kernel_size=3, mlp_1x1=False):
        super(SceneUnderstandor, self).__init__()

        # High Order Reasoning Operator
        self.hr2o_nl = HR2O_NL(hidden_dim=hidden_dim, kernel_size=kernel_size, mlp_1x1=mlp_1x1)

        # Convolutional layer to process RGB image (to get the same hidden_dim)
        self.conv_rgb = nn.Conv2d(3, hidden_dim, kernel_size=kernel_size, padding=kernel_size // 2, bias=False)

        # Linear layers for ROI bounding boxes (8 points) and keypoints (84 points)
        self.fc_roi = nn.Linear(8, hidden_dim)  # ROI bbox to hidden_dim
        self.fc_keypoints = nn.Linear(84, hidden_dim)  # Keypoints to hidden_dim

        # A fully connected layer to integrate all features (RGB + ROI + keypoints)
        self.fc_integrated = nn.Linear(3 * hidden_dim, hidden_dim)

        # Additional layers for processing the integrated feature
        self.fc_out = nn.Linear(hidden_dim, hidden_dim)
        self.relu = nn.ReLU()

    def extract_roi_features(self, image, roi_bbox):
        """
        Extract features from the image using the ROI bounding boxes.
        The bbox is not normalized and represents raw pixel values.
        """
        batch_size, channels, height, width = image.shape
        # Assuming roi_bbox is of shape (batch_size, 8), representing a rectangular box (x1, y1, x2, y2, ...)
        roi_features = []
    
        for i in range(batch_size):
            x1, y1, x2, y2 = roi_bbox[i].tolist()[:4]  # Get the first 4 points as the bounding box
    
            # Convert indices to integers to ensure proper slicing
            x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2)
    
            # Pooling over ROI (mean pooling over the ROI region)
            roi_features.append(image[i, :, y1:y2, x1:x2].mean(dim=[1, 2]))  # Pooling over ROI
    
        roi_features = torch.stack(roi_features, dim=0)
        return roi_features  # Shape: (batch_size, hidden_dim)


    def extract_keypoints_features(self, image, keypoints):
        """
        Extract features from the image using the keypoints.
        The keypoints are not normalized.
        """
        batch_size, channels, height, width = image.shape
        keypoints_features = []

        for i in range(batch_size):
            # Assuming keypoints is a flattened list of (x, y) coordinates
            keypoints_list = keypoints[i].reshape(-1, 2).int()
            keypoint_feature_map = torch.zeros_like(image[i, 0, :, :])  # Create a blank feature map for each keypoint

            for (x, y) in keypoints_list:
                if 0 <= x < width and 0 <= y < height:
                    keypoint_feature_map[y, x] = 1  # Set pixel values at keypoint locations

            keypoints_features.append(keypoint_feature_map.unsqueeze(0))  # Keep as a single channel feature map

        keypoints_features = torch.stack(keypoints_features, dim=0)
        return keypoints_features  # Shape: (batch_size, 1, height, width)

    def forward(self, rgb_image, roi_bbox, keypoints):
        # Step 1: Process the RGB image through convolution
        rgb_features = self.conv_rgb(rgb_image)  # Shape: (batch_size, hidden_dim, height, width)

        # Step 2: Extract ROI features from the image
        roi_features = self.extract_roi_features(rgb_image, roi_bbox)  # Shape: (batch_size, hidden_dim)

        # Step 3: Extract keypoints features from the image
        keypoints_features = self.extract_keypoints_features(rgb_image, keypoints)  # Shape: (batch_size, 1, height, width)

        # Step 4: Convert roi_features to the same size as the RGB image (height, width)
        roi_features = roi_features.unsqueeze(2).unsqueeze(3).expand(-1, -1, rgb_features.size(2), rgb_features.size(3))

        # Step 5: Concatenate the RGB image with the ROI and keypoints features
        combined_features = torch.cat((rgb_features, roi_features, keypoints_features), dim=1)

        # Step 6: Apply High Order Reasoning Operator (HR2O)
        final_features = self.hr2o_nl(combined_features)

        # Step 7: Flatten the features for fully connected layers
        final_features = final_features.view(final_features.size(0), -1)

        # Step 8: Pass through fully connected layers to obtain the final output features
        integrated_features = self.fc_integrated(final_features)
        integrated_features = self.relu(integrated_features)

        # Output layer (you can adjust the output size as needed)
        output = self.fc_out(integrated_features)

        return output

In [40]:
# Create dummy data
batch_size = 4
hidden_dim = 512
height = 32
width = 32

# Dummy RGB image (batch_size, 3, height, width)
rgb_image = torch.randn(batch_size, 3, height, width)

# Dummy ROI bounding boxes (batch_size, 8) with two sets of 4 points each
roi_bbox = torch.randn(batch_size, 8)

# Dummy keypoints (batch_size, 84) representing 42 keypoints in x, y sequence
keypoints = torch.randn(batch_size, 84)

# Initialize the model
scene_understandor = SceneUnderstandor(hidden_dim=hidden_dim)

# Pass the dummy data through the model
output = scene_understandor(rgb_image, roi_bbox, keypoints)

# Print the output shape
print(f"Output shape: {output.shape}")


RuntimeError: Given groups=1, weight of size [512, 512, 3, 3], expected input[4, 516, 32, 32] to have 512 channels, but got 516 channels instead

In [41]:
scene_understandor.extract_roi_features(rgb_image, roi_bbox)

tensor([[nan, nan, nan],
        [nan, nan, nan],
        [nan, nan, nan],
        [nan, nan, nan]])