In [4]:
import sys
from os import mkdir

import torch.nn.functional as F

sys.path.append('.')
from config import cfg

from engine.sam_trainer import do_train

from data.build import SAM_Dataloader
from modeling.segment_anything import prepare_sam
from solver.optimizer import build_optimizer
from solver.lr_schedular import build_lrSchedular
from solver.loss import FocalLoss, DiceLoss

from utils.logger import setup_logger

from config import cfg  # Import the default config file

In [5]:
def train(cfg):
    # Get the model and download the checkpoint if needed
    model = prepare_sam(checkpoint=cfg.MODEL.CHECKPOINT)
    device = cfg.MODEL.DEVICE

    #Set the portion of the model to be trained (We will train only the mask_decoder part)
    for name, param in model.named_parameters():
        if name.startswith('image_encoder') or name.startswith('prompt_encoder'):
            param.requires_grad = False
    
    # Get the dataloader and prepare the train, validation and test dataloader
    sam_dataloader = SAM_Dataloader(cfg)
    train_loader, val_loader, _ = sam_dataloader.build_dataloader()

    # Get the optimizer and scheduler
    optimizer = build_optimizer(cfg, model)
    scheduler = build_lrSchedular(cfg=None, optimizer=optimizer)

    loss_fn = FocalLoss()

    do_train( 
        cfg,
        model,
        train_loader,
        val_loader,
        optimizer,
        scheduler,
        loss_fn,
    )

In [6]:
cfg.freeze()
print(cfg)

DATALOADER:
  BATCH_SIZE: 4
  NUM_WORKERS: 0
  TEST_DATA: 0.2
  TRAIN_DATA: 0.8
  VALID_DATA: 0.2
DATASETS:
  ROOT_DIR: /home/legion-ubuntu/Research/Kerner-Lab/datasets/images2/
INPUT:
  
MODEL:
  CHECKPOINT: /home/legion-ubuntu/Research/Kerner-Lab/SAM-FineTuning/modeling/model_checkpoints/sam_vit_b_01ec64.pth
  DEVICE: cuda
OUTPUT_DIR: 
SOLVER:
  BASE_LR: 0.01
  ITEMS_PER_BATCH: 4
  MAX_EPOCHS: 5
  WEIGHT_DECAY: 0.001
TEST:
  ITEMS_PER_BATCH: 2
VALID:
  ITEMS_PER_BATCH: 2


In [7]:
train(cfg)

Downloading SAM ViT-B checkpoint...
sam_vit_b_01ec64.pth  is downloaded!


  0%|          | 0/111 [00:00<?, ?it/s]


TypeError: list indices must be integers or slices, not str

In [1]:
import numpy as np

In [2]:
def _build_point_grid(n_per_side: int) -> np.ndarray:
        """Generates a 2D grid of points evenly spaced in [0,1]x[0,1]."""
        offset = 1 / (2 * n_per_side)
        points_one_side = np.linspace(offset, 1 - offset, n_per_side)
        points_x = np.tile(points_one_side[None, :], (n_per_side, 1))
        points_y = np.tile(points_one_side[:, None], (1, n_per_side))
        points = np.stack([points_x, points_y], axis=-1).reshape(-1, 2)
        return points

In [4]:
bbox = [193, 301, 387, 448]
x_min, y_min, x_max, y_max = bbox

In [5]:
# Get grid points within the bounding box
grid_points = _build_point_grid(10)

In [9]:
#This forms a 2D grid points and we need to normalize it to the bbox size
grid_points[:, 0] = grid_points[:, 0] * (x_max - x_min) + x_min
grid_points[:, 1] = grid_points[:, 1] * (y_max - y_min) + y_min

TypeError: only size-1 arrays can be converted to Python scalars

In [8]:
grid_points

array([[202.7 , 308.35],
       [222.1 , 308.35],
       [241.5 , 308.35],
       [260.9 , 308.35],
       [280.3 , 308.35],
       [299.7 , 308.35],
       [319.1 , 308.35],
       [338.5 , 308.35],
       [357.9 , 308.35],
       [377.3 , 308.35],
       [202.7 , 323.05],
       [222.1 , 323.05],
       [241.5 , 323.05],
       [260.9 , 323.05],
       [280.3 , 323.05],
       [299.7 , 323.05],
       [319.1 , 323.05],
       [338.5 , 323.05],
       [357.9 , 323.05],
       [377.3 , 323.05],
       [202.7 , 337.75],
       [222.1 , 337.75],
       [241.5 , 337.75],
       [260.9 , 337.75],
       [280.3 , 337.75],
       [299.7 , 337.75],
       [319.1 , 337.75],
       [338.5 , 337.75],
       [357.9 , 337.75],
       [377.3 , 337.75],
       [202.7 , 352.45],
       [222.1 , 352.45],
       [241.5 , 352.45],
       [260.9 , 352.45],
       [280.3 , 352.45],
       [299.7 , 352.45],
       [319.1 , 352.45],
       [338.5 , 352.45],
       [357.9 , 352.45],
       [377.3 , 352.45],
