# Imports

In [1]:
import os
import cv2
from pathlib import Path

os.chdir(Path().resolve().parent)

from src.data_processing.loader import SAMDataset, create_dataloader
from src.modeling.trainer import SAMTrainer

# Prepare the data

In [None]:
train_dataset = SAMDataset(
    image_dir="data/train_images",
    mask_dir="data/train_labels",
    spacing_metadata_dir="data/metadata/spacing_mm.txt",
    processor="facebook/sam-vit-base",
)

print(f"Number of records: {len(train_dataset)}")
print(f"Example of one record:")
for k, v in train_dataset[0].items():
    print(f"{k}: {v.shape}")

In [3]:
test_dataset = SAMDataset(
    image_dir="data/test_images",
    bbox_file_dir="data/metadata/test1_bbox.txt",
    spacing_metadata_dir="data/metadata/spacing_mm.txt",
    processor="facebook/sam-vit-base",
)

print(f"Number of records: {len(test_dataset)}")
print(f"Example of one record:")
for k, v in test_dataset[0].items():
    print(f"{k}: {v.shape}")

Number of records: 6543
Example of one record:
pixel_values: torch.Size([3, 1024, 1024])
original_sizes: torch.Size([2])
reshaped_input_sizes: torch.Size([2])
input_boxes: torch.Size([1, 4])


# Use dataloader

In [4]:
train_dataloader = create_dataloader(
    train_dataset,
    batch_size=36,
    train_ratio=0.8,
    shuffle=True,
    num_workers=2,
)

In [5]:
batch = next(iter(train_dataloader))

print(f"Example of one batch:")
for k, v in batch.items():
    print(f"{k}: {v.shape}")

Example of one batch:
pixel_values: torch.Size([36, 3, 1024, 1024])
original_sizes: torch.Size([36, 2])
reshaped_input_sizes: torch.Size([36, 2])
input_boxes: torch.Size([36, 1, 4])
ground_truth_mask: torch.Size([36, 512, 512])


# Train Model

In [6]:
trainer = SAMTrainer(
    model_name="facebook/sam-vit-base",  # Default model name
    device="cpu",  # Use GPU (cuda) if available
    learning_rate=1e-5,  # Default learning rate
    weight_decay=0,  # Default weight decay
)
print(trainer.model)

SamModel(
  (shared_image_embedding): SamPositionalEmbedding()
  (vision_encoder): SamVisionEncoder(
    (patch_embed): SamPatchEmbeddings(
      (projection): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
    )
    (layers): ModuleList(
      (0-11): 12 x SamVisionLayer(
        (layer_norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (attn): SamVisionAttention(
          (qkv): Linear(in_features=768, out_features=2304, bias=True)
          (proj): Linear(in_features=768, out_features=768, bias=True)
        )
        (layer_norm2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (mlp): SamMLPBlock(
          (lin1): Linear(in_features=768, out_features=3072, bias=True)
          (lin2): Linear(in_features=3072, out_features=768, bias=True)
          (act): GELUActivation()
        )
      )
    )
    (neck): SamVisionNeck(
      (conv1): Conv2d(768, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (layer_norm1): SamLayerNorm()
     

In [7]:
# model.k_fold_cross_validation(
#     dataloader=train_dataloader,  # SAM DataLoader object 
#     k_folds=5,  # Default: 5 folds
#     num_epochs=10,  # Default: 10 epochs per fold
# )

# Running Inference

In [10]:
len(test_dataset)

6543

In [29]:
# import torch
# from transformers import SamModel, SamProcessor

# # device = "cuda" if torch.cuda.is_available() else "cpu"
# device = "cpu"
# model = trainer.model
# processor = SamProcessor.from_pretrained("facebook/sam-vit-huge")

# record = test_dataset[0]
# pixel_values = record["pixel_values"].unsqueeze(0).to(device)
# original_sizes = record["original_sizes"].unsqueeze(0).to(device)
# reshaped_input_sizes = record["reshaped_input_sizes"].unsqueeze(0).to(device)
# input_boxes = record["input_boxes"].unsqueeze(0).to(device)

# inputs = {
#     "pixel_values": pixel_values,
#     "input_boxes": input_boxes,
# }

# with torch.no_grad():
#     outputs = model(**inputs)

# outputs


SamImageSegmentationOutput(iou_scores=tensor([[[0.9527, 0.9629, 0.9341]]]), pred_masks=tensor([[[[[-16.9070, -16.5040, -16.4959,  ..., -16.9824, -17.2585,
            -17.1890],
           [-16.8632, -16.9447, -16.6422,  ..., -17.2348, -17.5858,
            -17.5795],
           [-16.2116, -15.6282, -17.1129,  ..., -16.0656, -17.4824,
            -17.0776],
           ...,
           [-17.4280, -17.9174, -17.2977,  ..., -17.8660, -18.1844,
            -18.3659],
           [-17.0276, -17.0331, -17.9977,  ..., -17.3907, -18.2660,
            -17.6402],
           [-17.2971, -17.0041, -17.6774,  ..., -17.4256, -18.0767,
            -18.7077]],

          [[-14.4527, -14.5858, -14.0422,  ..., -14.1928, -14.9696,
            -14.9991],
           [-14.3798, -15.0212, -14.0075,  ..., -14.1970, -15.1347,
            -15.2938],
           [-13.7320, -13.6956, -14.5075,  ..., -13.4944, -14.5137,
            -14.5743],
           ...,
           [-15.2405, -16.3668, -14.9343,  ..., -15.1885, -1

In [38]:
# outputs.pred_masks.shape

torch.Size([1, 1, 3, 256, 256])

In [39]:
# # apply sigmoid
# medsam_seg_prob = torch.sigmoid(outputs.pred_masks.squeeze(1))
# # convert soft mask to hard mask
# medsam_seg_prob = medsam_seg_prob.cpu().numpy().squeeze()
# medsam_seg = (medsam_seg_prob > 0.5).astype(np.uint8)

# Evaluation