This notebook is an explanation of the `notebooks/automatic_mask_generator_example.ipynb` notebook.

# **Algorithm Overview**

1. Initialize `_build_sam` using `build_sam_vit_h` wrapper.

    def build_sam_vit_h(checkpoint=None):
        return _build_sam(
            encoder_embed_dim=1280,
            encoder_depth=32,
            encoder_num_heads=16,
            encoder_global_attn_indexes=[7, 15, 23, 31],
            checkpoint=checkpoint,
        )


2. The `build_sam` function initializes a `SAM` object. 

The SAM object takes in

- an `image_encoder`, which is a Vision Transformer, 
- a `prompt_encoder`, which encodes the given prompt
- a `mask_decoder`, which Maps image embedding, prompt embedding, and an output token to a mask (Modification of a Transformer decoder block followed by a dynamic mask prediction head).

3. Pass the `SAM` object to the `SamAutomaticMaskGenerator` and call the `generate` function for it.

# **Forward Propagation in SAM**

<img src = "https://github.com/PragyanSubedi/Segment-Anything-Model-Breakdown/blob/main/images/segment_anything_model.PNG?raw=true" >

The code for this can be found in `segment_anything.modeling.sam`

#### Step 1: Image pre-processing

```input_images = torch.stack([self.preprocess(x["image"]) for x in batched_input], dim=0)```

1. The pre-process normalizes the color (Standardization) by subtracting each pixel of an image by the batch's mean and dividing the resultant by the batch's standard deviation. 
2. The image is also padded in order to set as the input image size of the image encoder.

#### Step 2: Create image embeddings using Vision Transformer (ViT) Image Encoder

```image_embeddings = self.image_encoder(input_images)```

Original paper of ViT: https://arxiv.org/pdf/2010.11929v2.pdf

**Key point:**

```
Inspired by the Transformer scaling successes in NLP, we experiment with applying a standard Transformer directly to images, with the fewest possible modifications. To do so, we split an image into patches and provide the sequence of linear embeddings of these patches as an input to a Transformer. Image patches are treated the same way as tokens (words) in an NLP application. We train the model on image classification in supervised fashion.
```

<img src ="https://github.com/PragyanSubedi/Segment-Anything-Model-Breakdown/blob/main/images/vit_architecture.PNG?raw=true" width=80%>

**For each image, x:**

1. Create a patch embedding of image x with pre-defined patch size of 16 x 16. The stride is set as 16 x 16 as well. The number of input image channels are 3 since we are working with RGB images and the embedding dimension is 1280 (vit_h).
2. Add a positional embedding to x if positional embedding is not None.
3. Create 32 blocks of x. 


```
def forward(self, x: torch.Tensor) -> torch.Tensor:
        shortcut = x
        x = self.norm1(x)
        # Window partition
        if self.window_size > 0:
            H, W = x.shape[1], x.shape[2]
            x, pad_hw = window_partition(x, self.window_size)

        x = self.attn(x)
        # Reverse window partition
        if self.window_size > 0:
            x = window_unpartition(x, self.window_size, pad_hw, (H, W))

        x = shortcut + x
        x = x + self.mlp(self.norm2(x))

        return x
```

In each block, normalize x. If windows_size is > 0, create a window partition else do not. Pass through a multiheaded attention block. Reverse window partition. Add it with the original value of x as it came in the block. Pass it through the MLP layer after normalization and then add it with the value of x before normalization.

This gives the `image_embeddings`.


4. Pass it through the neck: conv2d->layernorm->conv2d->layernorm. 

#### **Step 3: Use a prompt encoder to get `sparse_embeddings` and `dense_embeddings`**

```
for image_record, curr_embedding in zip(batched_input, image_embeddings):
            if "point_coords" in image_record:
                points = (image_record["point_coords"], image_record["point_labels"])
            else:
                points = None
            sparse_embeddings, dense_embeddings = self.prompt_encoder(
                points=points,
                boxes=image_record.get("boxes", None),
                masks=image_record.get("mask_inputs", None),
            )
```

Loop across each single image record and its image embeddings.

- If the image record contains point coordinates `point_coords`, set points as a tuple of `point_coords` and `point_labels`. Else, set `points` as None.

- Then, create a sparse (for points and boxes) and dense embeddings (for masks) using the `prompt_encoder`.

#### **Step 4: Generate `low_res_masks` and `iou_predictions`**

```
low_res_masks, iou_predictions = self.mask_decoder(
                image_embeddings=curr_embedding.unsqueeze(0),
                image_pe=self.prompt_encoder.get_dense_pe(),
                sparse_prompt_embeddings=sparse_embeddings,
                dense_prompt_embeddings=dense_embeddings,
                multimask_output=multimask_output,
            )
```

Use the mask decoder that creates predictions for masks and Intersection over Union (IoU).


#### **Step 5: Postprocess masks and provide output predictions**

```
masks = self.postprocess_masks(
                low_res_masks,
                input_size=image_record["image"].shape[-2:],
                original_size=image_record["original_size"],
            )
masks = masks > self.mask_threshold
outputs.append(
    {
        "masks": masks,
        "iou_predictions": iou_predictions,
        "low_res_logits": low_res_masks,
    }
)         
```

- Remove padding and upscale masks to original image size.
- Check if mask is greater than mask threshold.
- Create outputs.