In [1]:
import numpy as np
import matplotlib.pyplot as plt
import os
from os.path import join
from tqdm import tqdm
import torch
from torch.utils.data import Dataset, DataLoader
from segment_anything import SamPredictor, sam_model_registry
from segment_anything.utils.transforms import ResizeLongestSide
from utils.SurfaceDice import compute_dice_coefficient
from sklearn.model_selection import train_test_split
import radvis as rv
import monai

torch.manual_seed(2023)
np.random.seed(2023)

In [2]:
DATA_LOCATION = "C:/Users/matth/Data/brainseg/nfbs_dataset_preprocessed"
IMAGES_FOLDER = join(DATA_LOCATION, "images")
MASKS_FOLDER = join(DATA_LOCATION, "masks")
EMBEDDINGS_FOLDER = join(DATA_LOCATION, "embeddings")
DATA_SPLIT = 0.2
MODEL_VERSION = "0.0.1"
MODEL_TASK = "brainstrip"
MODEL_SAVE_PATH = f"models/{MODEL_TASK}/{MODEL_VERSION}"

# Create folder for save path
os.makedirs(MODEL_SAVE_PATH, exist_ok=True)

In [3]:
embedding_ids = [i for i in os.listdir(EMBEDDINGS_FOLDER) if i.endswith(".npz")]
# Train test split
train_ids, test_ids = train_test_split(embedding_ids, test_size=DATA_SPLIT, random_state=2023)
train_filenames = [join(EMBEDDINGS_FOLDER, i) for i in train_ids]
test_filenames = [join(EMBEDDINGS_FOLDER, i) for i in test_ids]

In [4]:
print(f"{len(train_ids)} training samples")
print(f"{len(test_ids)} testing samples")

100 training samples
25 testing samples


In [5]:
import time


class NpzDataset(Dataset):
    def __init__(self, data_filenames):
        self.npz_files = data_filenames
        self.cumulative_sizes = self.compute_cumulative_sizes()
        self.last_loaded_file = None
        self.embedding_cache = None
        self.gts_cache = None

    def compute_cumulative_sizes(self):
        cumulative_sizes = []
        total_size = 0
        for f in self.npz_files:
            data = np.load(f)
            total_size += data['gts'].shape[0]
            cumulative_sizes.append(total_size)
        return cumulative_sizes

    def find_file_index(self, index):
        for i, size in enumerate(self.cumulative_sizes):
            if index < size:
                return i, index if i == 0 else index - self.cumulative_sizes[i-1]
        raise IndexError("index out of range")

    def load_data(self, file_index):
        if self.last_loaded_file == file_index:
            return self.gts_cache, self.embedding_cache
        else:
            # Print the files completed
            # Time the process
            start = time.time()
            data = np.load(self.npz_files[file_index])
            self.embedding_cache = data['img_embeddings']
            self.gts_cache = data['gts']
            self.last_loaded_file = file_index
            end = time.time()
            #print(f"Time to load: {end-start}")
            return self.gts_cache, self.embedding_cache

    def __len__(self):
        return self.cumulative_sizes[-1]

    def __getitem__(self, index):
        file_index, file_item_index = self.find_file_index(index)
        # Time the process
        
        ori_data, embedding_data = self.load_data(file_index)
        start = time.time()
        
        ori_gt = ori_data[file_item_index]
        img_embedding = embedding_data[file_item_index]

        y_indices, x_indices = np.where(ori_gt > 0)
        x_min, x_max = np.min(x_indices), np.max(x_indices)
        y_min, y_max = np.min(y_indices), np.max(y_indices)


        # add perturbation to bounding box coordinates
        H, W = ori_gt.shape
        x_min = max(0, x_min - np.random.randint(0, 20))
        x_max = min(W, x_max + np.random.randint(0, 20))
        y_min = max(0, y_min - np.random.randint(0, 20))
        y_max = min(H, y_max + np.random.randint(0, 20))

        bboxes = np.array([x_min, y_min, x_max, y_max])

        # Print time before the tensor converts
        # convert img embedding, mask, bounding box to torch tensor
        result = (torch.tensor(img_embedding).float(), torch.tensor(ori_gt[None, :,:]).long(), torch.tensor(bboxes).float())
        end = time.time()
        #print(f"Time to get item: {end-start}")
        return result


In [6]:

demo_dataset = NpzDataset(train_filenames)
demo_dataloader = DataLoader(demo_dataset, batch_size=8, shuffle=False)
for img_embed, gt2D, bboxes in demo_dataloader:
    # img_embed: (B, 256, 64, 64), gt2D: (B, 1, 256, 256), bboxes: (B, 4)
    print(f"{img_embed.shape=}, {gt2D.shape=}, {bboxes.shape=}")
    break

img_embed.shape=torch.Size([8, 256, 64, 64]), gt2D.shape=torch.Size([8, 1, 256, 256]), bboxes.shape=torch.Size([8, 4])


In [7]:
# prepare SAM model
model_type = 'vit_b'
checkpoint = 'work_dir/SAM/sam_vit_b_01ec64.pth'
device = 'cuda:0'
sam_model = sam_model_registry[model_type](checkpoint=checkpoint).to(device)
sam_model.train()

# Set up the optimizer, hyperparameter tuning will improve performance here
optimizer = torch.optim.Adam(sam_model.mask_decoder.parameters(), lr=1e-5, weight_decay=0)
seg_loss = monai.losses.DiceCELoss(sigmoid=True, squared_pred=True, reduction='mean')

In [8]:

from tqdm.notebook import tqdm as tqdm_notebook

num_epochs = 100
losses = []
best_loss = 1e10
train_dataset = NpzDataset(train_filenames)
train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=False)
for epoch in range(num_epochs):
    print(f"Epoch {epoch+1}/{num_epochs}")
    epoch_loss = 0
    step = 0
    # train
    for image_embedding, gt2D, boxes in tqdm_notebook(train_dataloader):
        # do not compute gradients for image encoder and prompt encoder
        with torch.no_grad():
            # convert box to 1024x1024 grid
            box_np = boxes.numpy()
            sam_trans = ResizeLongestSide(sam_model.image_encoder.img_size)
            box = sam_trans.apply_boxes(box_np, (gt2D.shape[-2], gt2D.shape[-1]))
            box_torch = torch.as_tensor(box, dtype=torch.float, device=device)
            if len(box_torch.shape) == 2:
                box_torch = box_torch[:, None, :] # (B, 1, 4)
            # get prompt embeddings 
            sparse_embeddings, dense_embeddings = sam_model.prompt_encoder(
                points=None,
                boxes=box_torch,
                masks=None,
            )
        # predicted masks
        mask_predictions, _ = sam_model.mask_decoder(
            image_embeddings=image_embedding.to(device), # (B, 256, 64, 64)
            image_pe=sam_model.prompt_encoder.get_dense_pe(), # (1, 256, 64, 64)
            sparse_prompt_embeddings=sparse_embeddings, # (B, 2, 256)
            dense_prompt_embeddings=dense_embeddings, # (B, 256, 64, 64)
            multimask_output=False,
          )

        loss = seg_loss(mask_predictions, gt2D.to(device))
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()
        step += 1

    losses.append(epoch_loss)
    print(f'EPOCH: {epoch}, Loss: {epoch_loss}')
    # save the latest model checkpoint
    torch.save(sam_model.state_dict(), join(MODEL_SAVE_PATH, 'sam_model_latest.pth'))
    # save the best model
    if epoch_loss < best_loss:
        best_loss = epoch_loss
        torch.save(sam_model.state_dict(), join(MODEL_SAVE_PATH, 'sam_model_best.pth'))

Epoch 1/100


  0%|          | 0/518 [00:00<?, ?it/s]

EPOCH: 0, Loss: 20.098882243037224
Epoch 2/100


  0%|          | 0/518 [00:00<?, ?it/s]

EPOCH: 1, Loss: 14.701295489445329
Epoch 3/100


  0%|          | 0/518 [00:00<?, ?it/s]

EPOCH: 2, Loss: 13.55590415559709
Epoch 4/100


  0%|          | 0/518 [00:00<?, ?it/s]

EPOCH: 3, Loss: 12.82196581736207
Epoch 5/100


  0%|          | 0/518 [00:00<?, ?it/s]

EPOCH: 4, Loss: 12.394416432827711
Epoch 6/100


  0%|          | 0/518 [00:00<?, ?it/s]

EPOCH: 5, Loss: 12.044345108792186
Epoch 7/100


  0%|          | 0/518 [00:00<?, ?it/s]

EPOCH: 6, Loss: 11.726233242079616
Epoch 8/100


  0%|          | 0/518 [00:00<?, ?it/s]

EPOCH: 7, Loss: 11.53623440489173
Epoch 9/100


  0%|          | 0/518 [00:00<?, ?it/s]

EPOCH: 8, Loss: 11.380700625479221
Epoch 10/100


  0%|          | 0/518 [00:00<?, ?it/s]

EPOCH: 9, Loss: 11.24341774918139
Epoch 11/100


  0%|          | 0/518 [00:00<?, ?it/s]

EPOCH: 10, Loss: 11.04649905115366
Epoch 12/100


  0%|          | 0/518 [00:00<?, ?it/s]

EPOCH: 11, Loss: 10.957893604412675
Epoch 13/100


  0%|          | 0/518 [00:00<?, ?it/s]

EPOCH: 12, Loss: 10.797164564952254
Epoch 14/100


  0%|          | 0/518 [00:00<?, ?it/s]

EPOCH: 13, Loss: 10.703793110325933
Epoch 15/100


  0%|          | 0/518 [00:00<?, ?it/s]

EPOCH: 14, Loss: 10.570884838700294
Epoch 16/100


  0%|          | 0/518 [00:00<?, ?it/s]

EPOCH: 15, Loss: 10.562100175768137
Epoch 17/100


  0%|          | 0/518 [00:00<?, ?it/s]

EPOCH: 16, Loss: 10.402351746335626
Epoch 18/100


  0%|          | 0/518 [00:00<?, ?it/s]

EPOCH: 17, Loss: 10.393060468137264
Epoch 19/100


  0%|          | 0/518 [00:00<?, ?it/s]

EPOCH: 18, Loss: 10.280767489224672
Epoch 20/100


  0%|          | 0/518 [00:00<?, ?it/s]

EPOCH: 19, Loss: 10.239581817761064
Epoch 21/100


  0%|          | 0/518 [00:00<?, ?it/s]

EPOCH: 20, Loss: 10.13981183245778
Epoch 22/100


  0%|          | 0/518 [00:00<?, ?it/s]

EPOCH: 21, Loss: 10.079304862767458
Epoch 23/100


  0%|          | 0/518 [00:00<?, ?it/s]

EPOCH: 22, Loss: 10.026009660214186
Epoch 24/100


  0%|          | 0/518 [00:00<?, ?it/s]

EPOCH: 23, Loss: 9.991848032921553
Epoch 25/100


  0%|          | 0/518 [00:00<?, ?it/s]

EPOCH: 24, Loss: 9.891213169321418
Epoch 26/100


  0%|          | 0/518 [00:00<?, ?it/s]

EPOCH: 25, Loss: 9.821096589788795
Epoch 27/100


  0%|          | 0/518 [00:00<?, ?it/s]

EPOCH: 26, Loss: 9.761318273842335
Epoch 28/100


  0%|          | 0/518 [00:00<?, ?it/s]

EPOCH: 27, Loss: 9.708616264164448
Epoch 29/100


  0%|          | 0/518 [00:00<?, ?it/s]

EPOCH: 28, Loss: 9.67256435751915
Epoch 30/100


  0%|          | 0/518 [00:00<?, ?it/s]

EPOCH: 29, Loss: 9.672079250216484
Epoch 31/100


  0%|          | 0/518 [00:00<?, ?it/s]

EPOCH: 30, Loss: 9.61734189465642
Epoch 32/100


  0%|          | 0/518 [00:00<?, ?it/s]

EPOCH: 31, Loss: 9.56654567271471
Epoch 33/100


  0%|          | 0/518 [00:00<?, ?it/s]

EPOCH: 32, Loss: 9.505803855136037
Epoch 34/100


  0%|          | 0/518 [00:00<?, ?it/s]

EPOCH: 33, Loss: 9.553964968770742
Epoch 35/100


  0%|          | 0/518 [00:00<?, ?it/s]

EPOCH: 34, Loss: 9.412881530821323
Epoch 36/100


  0%|          | 0/518 [00:00<?, ?it/s]

EPOCH: 35, Loss: 9.418887186795473
Epoch 37/100


  0%|          | 0/518 [00:00<?, ?it/s]

EPOCH: 36, Loss: 9.31468565762043
Epoch 38/100


  0%|          | 0/518 [00:00<?, ?it/s]

EPOCH: 37, Loss: 9.309133676812053
Epoch 39/100


  0%|          | 0/518 [00:00<?, ?it/s]

EPOCH: 38, Loss: 9.284208105877042
Epoch 40/100


  0%|          | 0/518 [00:00<?, ?it/s]

EPOCH: 39, Loss: 9.250009674578905
Epoch 41/100


  0%|          | 0/518 [00:00<?, ?it/s]

EPOCH: 40, Loss: 9.210006438195705
Epoch 42/100


  0%|          | 0/518 [00:00<?, ?it/s]

EPOCH: 41, Loss: 9.162329766899347
Epoch 43/100


  0%|          | 0/518 [00:00<?, ?it/s]

EPOCH: 42, Loss: 9.115827925503254
Epoch 44/100


  0%|          | 0/518 [00:00<?, ?it/s]

EPOCH: 43, Loss: 9.067664373666048
Epoch 45/100


  0%|          | 0/518 [00:00<?, ?it/s]

EPOCH: 44, Loss: 9.055212551727891
Epoch 46/100


  0%|          | 0/518 [00:00<?, ?it/s]

EPOCH: 45, Loss: 8.9704261533916
Epoch 47/100


  0%|          | 0/518 [00:00<?, ?it/s]

EPOCH: 46, Loss: 8.957732524722815
Epoch 48/100


  0%|          | 0/518 [00:00<?, ?it/s]

EPOCH: 47, Loss: 8.949120063334703
Epoch 49/100


  0%|          | 0/518 [00:00<?, ?it/s]

EPOCH: 48, Loss: 8.91446471400559
Epoch 50/100


  0%|          | 0/518 [00:00<?, ?it/s]

EPOCH: 49, Loss: 8.856655159965158
Epoch 51/100


  0%|          | 0/518 [00:00<?, ?it/s]

EPOCH: 50, Loss: 8.85141378082335
Epoch 52/100


  0%|          | 0/518 [00:00<?, ?it/s]

EPOCH: 51, Loss: 8.804186107590795
Epoch 53/100


  0%|          | 0/518 [00:00<?, ?it/s]

EPOCH: 52, Loss: 8.863667020574212
Epoch 54/100


  0%|          | 0/518 [00:00<?, ?it/s]

EPOCH: 53, Loss: 8.76688970439136
Epoch 55/100


  0%|          | 0/518 [00:00<?, ?it/s]

EPOCH: 54, Loss: 8.718034159392118
Epoch 56/100


  0%|          | 0/518 [00:00<?, ?it/s]

EPOCH: 55, Loss: 8.701479567214847
Epoch 57/100


  0%|          | 0/518 [00:00<?, ?it/s]

EPOCH: 56, Loss: 8.689270757138729
Epoch 58/100


  0%|          | 0/518 [00:00<?, ?it/s]

EPOCH: 57, Loss: 8.673077503219247
Epoch 59/100


  0%|          | 0/518 [00:00<?, ?it/s]

EPOCH: 58, Loss: 8.601378146559
Epoch 60/100


  0%|          | 0/518 [00:00<?, ?it/s]

EPOCH: 59, Loss: 8.6558634955436
Epoch 61/100


  0%|          | 0/518 [00:00<?, ?it/s]

EPOCH: 60, Loss: 8.569406466558576
Epoch 62/100


  0%|          | 0/518 [00:00<?, ?it/s]

EPOCH: 61, Loss: 8.534704145044088
Epoch 63/100


  0%|          | 0/518 [00:00<?, ?it/s]

EPOCH: 62, Loss: 8.529423579573631
Epoch 64/100


  0%|          | 0/518 [00:00<?, ?it/s]

EPOCH: 63, Loss: 8.498558090999722
Epoch 65/100


  0%|          | 0/518 [00:00<?, ?it/s]

EPOCH: 64, Loss: 8.475382335484028
Epoch 66/100


  0%|          | 0/518 [00:00<?, ?it/s]

EPOCH: 65, Loss: 8.449982013553381
Epoch 67/100


  0%|          | 0/518 [00:00<?, ?it/s]

EPOCH: 66, Loss: 8.431822530925274
Epoch 68/100


  0%|          | 0/518 [00:00<?, ?it/s]

EPOCH: 67, Loss: 8.41218700632453
Epoch 69/100


  0%|          | 0/518 [00:00<?, ?it/s]

EPOCH: 68, Loss: 8.373852703720331
Epoch 70/100


  0%|          | 0/518 [00:00<?, ?it/s]

EPOCH: 69, Loss: 8.367378601804376
Epoch 71/100


  0%|          | 0/518 [00:00<?, ?it/s]

EPOCH: 70, Loss: 8.322012322023511
Epoch 72/100


  0%|          | 0/518 [00:00<?, ?it/s]

EPOCH: 71, Loss: 8.325243452563882
Epoch 73/100


  0%|          | 0/518 [00:00<?, ?it/s]

EPOCH: 72, Loss: 8.314975013956428
Epoch 74/100


  0%|          | 0/518 [00:00<?, ?it/s]

EPOCH: 73, Loss: 8.27413105033338
Epoch 75/100


  0%|          | 0/518 [00:00<?, ?it/s]

EPOCH: 74, Loss: 8.239232487976551
Epoch 76/100


  0%|          | 0/518 [00:00<?, ?it/s]

EPOCH: 75, Loss: 8.230965215712786
Epoch 77/100


  0%|          | 0/518 [00:00<?, ?it/s]

EPOCH: 76, Loss: 8.223687667399645
Epoch 78/100


  0%|          | 0/518 [00:00<?, ?it/s]

EPOCH: 77, Loss: 8.198343738913536
Epoch 79/100


  0%|          | 0/518 [00:00<?, ?it/s]

EPOCH: 78, Loss: 8.207768980413675
Epoch 80/100


  0%|          | 0/518 [00:00<?, ?it/s]

EPOCH: 79, Loss: 8.174477316439152
Epoch 81/100


  0%|          | 0/518 [00:00<?, ?it/s]

EPOCH: 80, Loss: 8.1194330137223
Epoch 82/100


  0%|          | 0/518 [00:00<?, ?it/s]

EPOCH: 81, Loss: 8.125475890934467
Epoch 83/100


  0%|          | 0/518 [00:00<?, ?it/s]

EPOCH: 82, Loss: 8.115643862634897
Epoch 84/100


  0%|          | 0/518 [00:00<?, ?it/s]

EPOCH: 83, Loss: 8.031849764287472
Epoch 85/100


  0%|          | 0/518 [00:00<?, ?it/s]

EPOCH: 84, Loss: 8.091156153008342
Epoch 86/100


  0%|          | 0/518 [00:00<?, ?it/s]

EPOCH: 85, Loss: 8.03753830678761
Epoch 87/100


  0%|          | 0/518 [00:00<?, ?it/s]

EPOCH: 86, Loss: 8.02778321877122
Epoch 88/100


  0%|          | 0/518 [00:00<?, ?it/s]

EPOCH: 87, Loss: 8.027365611866117
Epoch 89/100


  0%|          | 0/518 [00:00<?, ?it/s]

EPOCH: 88, Loss: 7.987261172384024
Epoch 90/100


  0%|          | 0/518 [00:00<?, ?it/s]

EPOCH: 89, Loss: 7.964181121438742
Epoch 91/100


  0%|          | 0/518 [00:00<?, ?it/s]

EPOCH: 90, Loss: 7.990426167845726
Epoch 92/100


  0%|          | 0/518 [00:00<?, ?it/s]

EPOCH: 91, Loss: 7.983821509405971
Epoch 93/100


  0%|          | 0/518 [00:00<?, ?it/s]

EPOCH: 92, Loss: 7.940541362389922
Epoch 94/100


  0%|          | 0/518 [00:00<?, ?it/s]

EPOCH: 93, Loss: 7.924335276708007
Epoch 95/100


  0%|          | 0/518 [00:00<?, ?it/s]

EPOCH: 94, Loss: 7.884878281503916
Epoch 96/100


  0%|          | 0/518 [00:00<?, ?it/s]

EPOCH: 95, Loss: 7.820279326289892
Epoch 97/100


  0%|          | 0/518 [00:00<?, ?it/s]

EPOCH: 96, Loss: 7.819154340773821
Epoch 98/100


  0%|          | 0/518 [00:00<?, ?it/s]

EPOCH: 97, Loss: 7.818513125181198
Epoch 99/100


  0%|          | 0/518 [00:00<?, ?it/s]

EPOCH: 98, Loss: 7.843869941309094
Epoch 100/100


  0%|          | 0/518 [00:00<?, ?it/s]

EPOCH: 99, Loss: 7.7748399041593075
