## Class for Candidate Location

In [None]:
import numpy as np
import torch
import matplotlib.pyplot as plt

class Boundary_coordinate:
    def __init__(self, grid_size=16, image_size=256):
        """
        Initialize the BoundaryMaskCreator.

        Parameters:
        - grid_size: The size of the grid for boundary point extraction
        - image_size: The dimensions of the image (assumed square)
        """
        self.grid_size = grid_size
        self.image_size = image_size
        self.gap = image_size // grid_size
        self.boundary_points = self._generate_boundary_points()
        self.boundary_points_tensor = torch.tensor(self.boundary_points, dtype=torch.float32)

    def _generate_boundary_points(self):
        """
         
        
        Returns:
        - boundary_points: Array of boundary points
        """
        boundary_points = []
        for i in range(self.grid_size):
            # Horizontal lines
            y = i * self.gap
            for x in range(0, self.image_size, self.gap):
                boundary_points.append((x, y))
            
            # Vertical lines
            x = i * self.gap
            for y in range(0, self.image_size, self.gap):
                boundary_points.append((x, y))
        
        boundary_points = np.array(boundary_points)
        boundary_points = np.unique(boundary_points, axis=0)
        return boundary_points

In [None]:
creator = Boundary_coordinate(grid_size=32)

In [None]:
boundary_points = creator._generate_boundary_points()
print(f"boundary_points {boundary_points.shape}")

In [None]:
boundary_points

## SAM Prompt Encoder Function with Initialized Weights (Note: Pretrained weights are not used; instead, the point-based prompt embedding class functionality is utilized.)

In [None]:
import torch
from transformers import SamModel, SamConfig

# Set device (use GPU if available, otherwise fallback to CPU)
device = "cuda:1" if torch.cuda.is_available() else "cpu"

# Initialize the model from scratch with a configuration
config = SamConfig()  # This will use the default configuration. You can modify it if needed.
SAM = SamModel(config).to(device)

# Freeze the prompt encoder's parameters
for param in SAM.prompt_encoder.parameters():
    param.requires_grad = False 

In [None]:
point_coordinate = torch.tensor(boundary_points)
print("point_coordinate", point_coordinate.shape)

input_boxes = torch.zeros((point_coordinate.size(0), 4), dtype=point_coordinate.dtype).to(device)

# Fill the new tensor
input_boxes[:, 0] = point_coordinate[:, 0]  # First column
input_boxes[:, 1] = point_coordinate[:, 1]  # Second column
input_boxes[:, 2] = point_coordinate[:, 0]  # Repeat first column
input_boxes[:, 3] = point_coordinate[:, 1]  # Repeat second column

input_boxes = input_boxes.unsqueeze(1)
input_boxes = torch.round(input_boxes * 4)

print("point_coordinate", point_coordinate)
print("input_boxes", input_boxes.shape)
print("input_boxes", input_boxes)

In [None]:
sparse_embeddings, dense_embeddings = SAM.prompt_encoder(
    input_points = None,
    input_boxes = input_boxes,
    input_masks = None,
    input_labels = None
)

In [None]:
print("sparse_embeddings:::::", sparse_embeddings.shape)
print("dense_embeddings:::::", dense_embeddings.shape)

In [None]:
sparse_emd = sparse_embeddings[:, 0, 0, :]
print("sparse_emd:::::", sparse_emd.shape)

## Saving this location embedding

In [None]:
i = 1

In [None]:
torch.save(sparse_emd, './Candidate_Prompt_Embedding' + str(i) + '.pt')