## Class for Candidate Location

In [1]:
import numpy as np
import torch
import matplotlib.pyplot as plt

class Boundary_coordinate:
    def __init__(self, grid_size=16, image_size=256):
        """
        Initialize the BoundaryMaskCreator.

        Parameters:
        - grid_size: The size of the grid for boundary point extraction
        - image_size: The dimensions of the image (assumed square)
        """
        self.grid_size = grid_size
        self.image_size = image_size
        self.gap = image_size // grid_size
        self.boundary_points = self._generate_boundary_points()
        self.boundary_points_tensor = torch.tensor(self.boundary_points, dtype=torch.float32)

    def _generate_boundary_points(self):
        """
         
        
        Returns:
        - boundary_points: Array of boundary points
        """
        boundary_points = []
        for i in range(self.grid_size):
            # Horizontal lines
            y = i * self.gap
            for x in range(0, self.image_size, self.gap):
                boundary_points.append((x, y))
            
            # Vertical lines
            x = i * self.gap
            for y in range(0, self.image_size, self.gap):
                boundary_points.append((x, y))
        
        boundary_points = np.array(boundary_points)
        boundary_points = np.unique(boundary_points, axis=0)
        return boundary_points

In [2]:
creator = Boundary_coordinate(grid_size=32)

In [3]:
boundary_points = creator._generate_boundary_points()
print(f"boundary_points {boundary_points.shape}")

boundary_points (1024, 2)


In [4]:
boundary_points

array([[  0,   0],
       [  0,   8],
       [  0,  16],
       ...,
       [248, 232],
       [248, 240],
       [248, 248]])

In [5]:
import torch
import torch.nn as nn
import numpy as np

class SimplePromptEncoder(nn.Module):
    def __init__(self, embed_dim=384, num_pos_feats=128, input_image_size=(256, 256), num_boxes=2):
        """
        A simpler version of PromptEncoder for encoding bounding box coordinates.

        Arguments:
        embed_dim -- Dimension of the embedding (e.g., 384)
        num_pos_feats -- Number of positional features (e.g., 128)
        input_image_size -- Size of the input image (height, width)
        num_boxes -- Number of key points used for encoding (default: 2 for top-left and bottom-right corners)
        """
        super().__init__()
        self.embed_dim = embed_dim
        self.input_image_size = input_image_size
        self.num_boxes = num_boxes  # Usually 2 for two corners

        # Positional embedding matrix
        self.register_buffer("positional_embedding", torch.randn((2, num_pos_feats)) * embed_dim // 2)

        # Learnable embeddings for each box corner
        self.box_embeddings = nn.ModuleList([nn.Embedding(1, num_pos_feats * 2) for _ in range(num_boxes)])

    def forward(self, boxes: torch.Tensor) -> torch.Tensor:
        """
        Embeds bounding box coordinates.

        Arguments:
        boxes -- Tensor of shape (batch_size, num_boxes, 2) containing (x, y) coordinates.

        Returns:
        Tensor of shape (batch_size, num_boxes, embed_dim).
        """
        batch_size = boxes.shape[0]

        # Compute positional embeddings
        box_embedding = self.compute_positional_embedding(boxes)

        # Add learnable embeddings
        for i in range(self.num_boxes):
            box_embedding[:, i, :] += self.box_embeddings[i].weight

        return box_embedding.view(batch_size, -1, box_embedding.shape[-1])  # Reshape to match expected output

    def compute_positional_embedding(self, coords: torch.Tensor) -> torch.Tensor:
        """
        Compute positional embedding for input coordinates.
    
        Arguments:
        coords -- Tensor of shape (batch_size, num_boxes, 2)
    
        Returns:
        Tensor with encoded positional information.
        """
        coords = coords.clone().to(torch.float32)  # Convert to float before division
    
        # Normalize coordinates to [0, 1] range
        height, width = self.input_image_size
        coords[:, :, 0] /= width
        coords[:, :, 1] /= height
    
        # Scale to [-1, 1] range
        coords = 2 * coords - 1
        coords = coords @ self.positional_embedding  # Apply embedding matrix
    
        # Convert to sinusoidal embeddings
        coords = 2 * np.pi * coords
        return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1)  # Concatenate sin and cos



In [6]:
point_coordinate = torch.tensor(boundary_points)
print("point_coordinate", point_coordinate.shape)

input_boxes = torch.zeros((point_coordinate.size(0), 4), dtype=point_coordinate.dtype)

# Fill the new tensor
input_boxes[:, 0] = point_coordinate[:, 0]  # First column
input_boxes[:, 1] = point_coordinate[:, 1]  # Second column
input_boxes[:, 2] = point_coordinate[:, 0]  # Repeat first column
input_boxes[:, 3] = point_coordinate[:, 1]  # Repeat second column

coords = input_boxes.reshape(-1, 2, 2)

# coords = coords.unsqueeze(1)
print("coords", coords.shape)

point_coordinate torch.Size([1024, 2])
coords torch.Size([1024, 2, 2])


In [7]:
# Instantiate the PromptEncoder model
prompt_encoder = SimplePromptEncoder(embed_dim=256, input_image_size=(256, 256))


# Forward pass through the model
sparse_embeddings = prompt_encoder(boxes=coords)

# Print the output shape
print("Sparse embeddings shape:", sparse_embeddings.shape)


Sparse embeddings shape: torch.Size([1024, 2, 256])


In [8]:
sparse_embeddings

tensor([[[ 0.0149, -1.4758, -0.9741,  ...,  0.9460,  1.2408,  0.9130],
         [-2.0411,  0.2696,  0.9735,  ..., -1.8496,  1.4023, -0.4489]],

        [[-0.3678, -1.8584, -0.9742,  ..., -0.4367,  0.9479, -0.7941],
         [-2.4238, -0.1131,  0.9735,  ..., -3.2323,  1.1094, -2.1560]],

        [[ 0.7220, -0.7687, -0.9742,  ..., -0.7611,  0.2408, -0.0870],
         [-1.3340,  0.9767,  0.9735,  ..., -3.5567,  0.4023, -1.4489]],

        ...,

        [[ 0.9388, -0.4758, -1.8981,  ..., -0.4367,  0.6235, -0.7941],
         [-1.1172,  1.2696,  0.0496,  ..., -3.2323,  0.7850, -2.1560]],

        [[-0.6922, -2.3997, -0.0503,  ..., -0.7611,  1.1647, -0.0870],
         [-2.7482, -0.6543,  1.8974,  ..., -3.5567,  1.3262, -1.4489]],

        [[ 0.3975, -0.7687, -1.8981,  ...,  0.8698,  1.1647,  0.6201],
         [-1.6584,  0.9767,  0.0496,  ..., -1.9257,  1.3262, -0.7418]]],
       grad_fn=<ViewBackward0>)

In [9]:
sparse_emd = sparse_embeddings[:, 0, :]
print("sparse_emd:::::", sparse_emd.shape)

sparse_emd::::: torch.Size([1024, 256])


## Saving this location embedding

In [10]:
i = 1

In [11]:
torch.save(sparse_emd, './Candidate_Prompt_Embedding' + str(i) + '.pt')