![](../images/segment_anything.png)

### Prompt Encoder

In [1]:
from typing import Optional, Tuple, List
import torch
from torch import nn

In [2]:
LayerNorm2d = None

In [None]:
class PromptEncoder:
    def __init__(self):
        pass
    
    def _embed_points(
            self,
            points: torch.Tensor, # Indicates the coordinates of the points
            labels: torch.Tensor, # Indicate if th epoint is foreground or background
            pad: bool
    ) -> torch.Tensor:
        """Embeds point prompts."""
        points = points + 0.5 # Shift to center of pixel
        if pad: # Add padding if needed (to keep the segment length constant)
            padding_point = torch.zeros((points.shape[0], 1, 2), device=points.device)
            padding_label = -torch.ones((labels.shape[0], 1), device=labels.device) # The -1 label indicates padding
            points = torch.cat([points, padding_point], dim=1) # Append the padding point
            labels = torch.cat([labels, padding_label], dim=1) # Append the padding label
        
        point_embedding = self.pe_layer.forward_with_coords(points, self.input_image_size) # Obtain the positional encodings
        point_embedding[labels == -1] = 0.0 # Zero out the padding points
        point_embedding[labels == -1] += self.not_a_point_embed.weight # Add special embedding to indicate padding
        point_embedding[labels == 0] += self.point_embeddings[0].weight # Add embedding for background points
        point_embedding[labels == 1] += self.point_embeddings[1].weight # Add embedding for foreground points
        return point_embedding
    
    def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor:
        """Embeds box prompts."""
        boxes = boxes + 0.5 # Shift to center of pixel
        coords = boxes.reshape(-1, 2, 2)
        corner_embedding = self.pe_layer.forward_with_coords(coords, self.input_image_size) # Obtain the positional encodings
        corner_embedding[:, 0, :] += self.point_embeddings[2].weight # Special embedding for top-left corner
        corner_embedding[:, 1, :] += self.point_embeddings[3].weight # Special embedding for bottom-right corner
        return corner_embedding

### Mask

In [None]:
class Mask:
    def __init__(self, mask_in_chans, embed_dim, activation):
        self._init_mask(mask_in_chans, embed_dim, activation)
    
    def __expand_image_to_mask(self, image_embeddings, tokens, dense_prompt_embeddings, image_pe):
        # Expand per-image data in batch direction to be per-mask
        src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0)
        src = src + dense_prompt_embeddings # Add mask embeddings to the image
        pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0)
        b, c, h, w = src.shape 

    def __get_dense_embeddings(self, masks: Optional[torch.Tensor], bs):
        if masks is not None:
            dense_embeddings = self._embed_masks(masks)
        else: # If no mask is specified, use a special "no mask" embedding
            dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand(
                bs, -1, self.image_embedding_size[0], self.image_embedding_size[1]
            )
        return dense_embeddings

    def _init_mask(self, mask_in_chans, embed_dim, activation):
        self.mask_downscaling = nn.Sequential(
            nn.Conv2d(1, mask_in_chans // 4, kernel_size=2, stride=2),
            LayerNorm2d(mask_in_chans // 4),
            activation(),
            nn.Conv2d(mask_in_chans // 4, mask_in_chans, kernel_size=2, stride=2),
            LayerNorm2d(mask_in_chans),
            activation(),
            nn.Conv2d(mask_in_chans, embed_dim, kernel_size=1),
        )
        self.no_mask_embed == nn.Embedding(1, embed_dim)

    def _embed_masks(self, masks: torch.Tensor) -> torch.Tensor:
        """Embeds mask prompts."""
        masks_embedding = self.mask_downscaling(masks)
        return masks_embedding

### Mask Decoder

<img src="../images/segment_anything_mask_decoder.png" alt="drawing" width="70%"/>

In [None]:
output_token = torch.cat([self.iou_token.weight, self.mask_tokens.weight], dim=0)
output_token = output_token.unsqueeze(0).expand(sparse_prompt_embedding.size(0), -1, -1)
tokens = torch.cat([output_token, sparse_prompt_embeddings], dim=1)

In [None]:
class MaskDecoder:
    def __init__(self):
        pass

    def _tmp(self, point_embedding, image_embedding, image_pe):
        # Prepare queries
        queries = point_embedding
        keys = image_embedding

        # Apply transformer blocks and final layernorm
        for layer in self.layers:
            queries, keys = layer(
                queries=queries, 
                keys=keys,
                query_pe=point_embedding,
                key_pe=image_pe,
            )

    def _prepare_tokens(self, sparse_prompt_embeddings):
        output_token = torch.cat([self.iou_token.weight, self.mask_tokens.weight], dim=0)
        output_token = output_token.unsqueeze(0).expand(sparse_prompt_embeddings.size(0), -1, -1)
        tokens = torch.cat([output_token, sparse_prompt_embeddings], dim=1)

In [None]:
class Transformer:
    def __init__(self):
        pass

    def forward(
            self,
            queries: torch.Tensor,
            keys: torch.Tensor,
            query_pe: torch.Tensor,
            key_pe: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        # Self attention block
        if self.skip_first_layer_pe:
            queries = self.self_attn(q=queries, k=queries, v=queries)
        else:
            q = queries + query_pe # Add positional encoding to the prompt tokens
            attn_out = self.self_attn(q=q, k=q, v=queries) # Run self-attention on the prompt
            queries = queries + attn_out
        queries = self.norm1(queries)

        # Cross attention block, tokens attending to image embedding
        q = queries + query_pe # The queries are the prompt tokens
        k = keys + key_pe # The keys are the image embedding + positional encoding
        # the vaules are the image embedding (without positional encoding)
        attn_out = self.cross_attn_token_to_image(q=q, k=k, v=keys)
        queries = queries + attn_out
        queries = self.norm2(queries)

        # MLP block
        mlp_out = self.mlp(queries)
        queries = queries + mlp_out
        queries = self.norm3(queries)

        # Cross attention block, image embedding attending to tokens
        q = queries + query_pe # q = prompt tokens + positional encoding
        k = keys + key_pe # k = image embedding + positional encoding
        # Bad variable naming practices from META!
        attn_out = self.cross_attn_image_to_token(q=k, k=q, v=queries)
        keys = keys + attn_out
        keys = self.norm4(keys)

        return queries, keys

In [None]:
class Output:
    def __init__(self):
        pass
    
    def forward(self, src, pos_src, tokens, hs, b, c, h, w):
        # Run the transformer
        # hs is the transformer output for the prompt
        # src is the tranformer output for the image

        hs, src = self.tranformer(src, pos_src, tokens)
        iou_token_out = hs[:, 0, :] # Output token for the IoU prediction
        mask_tokens_out = hs [:, 1 : (1 + self.num_mask_tokens), :] # Output tokens for the mask

        # Upscale mask embedding and predict masks using the mask tokens
        src = src.transpose(1,2).view(b, c, h, w)
        upscale_embedding = self.output_upscaling(src)
        hyper_in_list: List[torch.Tensor] = []
        for i in range(self.num_mask_tokens): # Run each token through its MLP 
            hyper_in_list.append(self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :]))
        hyper_in = torch.stack(hyper_in_list, dim=1)
        b, c, h, w = upscale_embedding.shape
        # Dot product of the MLP output for each "output token" and the upscale image
        # (each output token represents a mask)
        masks = (hyper_in @ upscale_embedding.view(b, c, h * w)).view(b, -1, h, w)

        # Generate mask quality predictions
        iou_pred = self.iou_prediction_head(iou_token_out)

        return masks, iou_pred