## The Game Plan

1. Load the SAM model and inspect its weights
2. Assemble SAM + "linear" layer to form a new model
3. Learn how to design mask output and loss function
4. Learn to write dataloader
5. Setup PyTorch training loop
6. Visualize output from new model
7. Evaluate output from new model

In [1]:
import numpy as np
import torch
import matplotlib.pyplot as plt
import cv2

In [2]:
import sys
sys.path.append("..")
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor

sam_checkpoint = "../sam_vit_h_4b8939.pth"
model_type = "vit_h"

device = "cuda"

sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)

In [4]:
sam

Sam(
  (image_encoder): ImageEncoderViT(
    (patch_embed): PatchEmbed(
      (proj): Conv2d(3, 1280, kernel_size=(16, 16), stride=(16, 16))
    )
    (blocks): ModuleList(
      (0): Block(
        (norm1): LayerNorm((1280,), eps=1e-06, elementwise_affine=True)
        (attn): Attention(
          (qkv): Linear(in_features=1280, out_features=3840, bias=True)
          (proj): Linear(in_features=1280, out_features=1280, bias=True)
        )
        (norm2): LayerNorm((1280,), eps=1e-06, elementwise_affine=True)
        (mlp): MLPBlock(
          (lin1): Linear(in_features=1280, out_features=5120, bias=True)
          (lin2): Linear(in_features=5120, out_features=1280, bias=True)
          (act): GELU(approximate=none)
        )
      )
      (1): Block(
        (norm1): LayerNorm((1280,), eps=1e-06, elementwise_affine=True)
        (attn): Attention(
          (qkv): Linear(in_features=1280, out_features=3840, bias=True)
          (proj): Linear(in_features=1280, out_features=1280, bia

In [5]:
sam.image_encoder

ImageEncoderViT(
  (patch_embed): PatchEmbed(
    (proj): Conv2d(3, 1280, kernel_size=(16, 16), stride=(16, 16))
  )
  (blocks): ModuleList(
    (0): Block(
      (norm1): LayerNorm((1280,), eps=1e-06, elementwise_affine=True)
      (attn): Attention(
        (qkv): Linear(in_features=1280, out_features=3840, bias=True)
        (proj): Linear(in_features=1280, out_features=1280, bias=True)
      )
      (norm2): LayerNorm((1280,), eps=1e-06, elementwise_affine=True)
      (mlp): MLPBlock(
        (lin1): Linear(in_features=1280, out_features=5120, bias=True)
        (lin2): Linear(in_features=5120, out_features=1280, bias=True)
        (act): GELU(approximate=none)
      )
    )
    (1): Block(
      (norm1): LayerNorm((1280,), eps=1e-06, elementwise_affine=True)
      (attn): Attention(
        (qkv): Linear(in_features=1280, out_features=3840, bias=True)
        (proj): Linear(in_features=1280, out_features=1280, bias=True)
      )
      (norm2): LayerNorm((1280,), eps=1e-06, element

### Conclusion So Far

SAM contains ImageEncoder, PromptEncoder and MaskDecoder; for the time being, we only care about ImageEncoder
- We want to figure out what is the input and output dimension for ImageEncoder; read source code for more

In [3]:
!nvidia-smi

Sun Oct 22 19:59:02 2023       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 529.04       Driver Version: 529.04       CUDA Version: 12.0     |
|-------------------------------+----------------------+----------------------+
| GPU  Name            TCC/WDDM | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  NVIDIA GeForce ... WDDM  | 00000000:01:00.0 Off |                  N/A |
| N/A   41C    P4    16W /  40W |      0MiB /  8188MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [4]:
# ResizeLongestSide
sam.image_encoder.img_size

1024

In [3]:
# input to SAM is 1024x1024
img = torch.zeros(size=(1, 3, 1024, 1024))
img = img.to(device)
sam = sam.to(device)

with torch.no_grad():
    img_embed = sam.image_encoder(img)
img_embed.shape

torch.Size([1, 256, 64, 64])

In [5]:
# test ConvTranspose2D
x = torch.zeros(size=(1, 256, 64, 64))
conv1 = torch.nn.ConvTranspose2d(256, 128, kernel_size=(2, 2), stride=(2, 2))
x = conv1(x)
conv2 = torch.nn.ConvTranspose2d(128, 64, kernel_size=(2, 2), stride=(2, 2))
x = conv2(x)
conv3 = torch.nn.ConvTranspose2d(64, 32, kernel_size=(2, 2), stride=(2, 2))
x = conv3(x)
conv4 = torch.nn.ConvTranspose2d(32, 20, kernel_size=(2, 2), stride=(2, 2))
x = conv4(x)
x.shape

torch.Size([1, 20, 1024, 1024])

In [8]:
# test Preprocess
import sys
sys.path.append("..")
from segment_anything.utils.transforms import ResizeLongestSide

In [14]:
img = np.zeros(shape=(414, 372, 3), dtype=np.uint8)
resize = ResizeLongestSide(1024)
img = resize.apply_image(img)
img.shape

(1024, 920, 3)