In [None]:
import cv2
import torch
from segment_anything import SamPredictor, sam_model_registry
from torch.nn.functional import threshold, normalize
from segment_anything.utils.transforms import ResizeLongestSide
import numpy as np
import matplotlib.pyplot as plt
from  tqdm import tqdm
from statistics import mean


Helper functions

In [None]:
def show_box(box, ax):
    x0, y0 = box[0], box[1]
    w, h = box[2] - box[0], box[3] - box[1]
    ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0,0,0,0), lw=2)) 

def show_mask(mask, ax, random_color=False):
    if random_color:
        color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
    else:
        color = np.array([30/255, 144/255, 255/255, 0.6])
    h, w = mask.shape[-2:]
    mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
    ax.imshow(mask_image)

In [None]:
# load model
device = "cpu"
sam_model = sam_model_registry["vit_h"](checkpoint="sam_vit_h_4b8939.pth")
sam_model.to(device=device)
sam_model.train()

In [None]:
loss_fn = torch.nn.MSELoss()
lr = 1e-4
wd = 0
optimizer = torch.optim.Adam(sam_model.mask_decoder.parameters(), lr=lr, weight_decay=wd)

In [None]:
# from ../ImageSets load the image_ids from the train.txt file
with open("/Users/ioanna/Downloads/FoodSeg103/ImageSets/train.txt", "r") as f:
    image_ids = f.read().splitlines()
    # remove .jpg from the image_ids
    image_ids = [image_id[:-4].strip() for image_id in image_ids]

# shuffle the image_ids
np.random.shuffle(image_ids)
# subset the image_ids to 1000
image_ids = image_ids[:1000]
# split the image_ids into train and val sets 90% train, 10% val
train_ids = image_ids[:int(len(image_ids)*0.9)]
val_ids = image_ids[int(len(image_ids)*0.9):]

In [None]:
num_epochs = 100
losses = []
val_losses = []
mean_val_losses = []
patience = 5
early_stopping_counter = 0


for epoch in tqdm(range(num_epochs)):
    sam_model.train()
    epoch_losses = []
    epoch_val_losses = []

    for image_id in tqdm(train_ids):
        # load mask 
        mask_path = "/Users/ioanna/Downloads/FoodSeg103/Images/ann_dir/train/{}.png".format(image_id)
        mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
        # load image
        image = cv2.imread("/Users/ioanna/Downloads/FoodSeg103/Images/img_dir/train/{}.jpg".format(image_id))
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

        mask_temp = mask.copy()


        for unique_value in np.unique(mask):
            if unique_value == 0:
                continue
            # set all other values to 0
            mask_temp[mask != unique_value] = 0
            # set all other values to 1
            mask_temp[mask == unique_value] = 1

            bbox = np.argwhere(mask_temp)
            
            # get bounding box coordinates
            (y1, x1), (y2, x2) = bbox.min(0), bbox.max(0) + 1
            # to nnumpy
            bbox_coords = np.array([x1, y1, x2, y2])

             
            # plt.figure(figsize=(10,10))
            # plt.imshow(image)
            # show_box(bbox, plt.gca())
            # show_mask(mask_temp, plt.gca())
            # plt.axis('off')
            # plt.show()

            transform = ResizeLongestSide(target_length= sam_model.image_encoder.img_size)
            input_image = transform.apply_image(image)
            input_image_torch = torch.as_tensor(input_image, device=device)
            transformed_image = input_image_torch.permute(2, 0, 1).contiguous()[None, :, :, :]
            input_image = sam_model.preprocess(transformed_image)

            original_image_size = image.shape[:2]
            input_size = tuple(transformed_image.shape[-2:])

            with torch.no_grad():
                image_embedding = sam_model.image_encoder(input_image)
            
            
            box = transform.apply_boxes(bbox_coords, original_image_size)
            box_torch = torch.as_tensor(box, dtype=torch.float, device=device)
            box_torch = box_torch[None, :]
            
            sparse_embeddings, dense_embeddings = sam_model.prompt_encoder(
                points=None,
                boxes=box_torch,
                masks=None,
            )
            low_res_masks, iou_predictions = sam_model.mask_decoder(
            image_embeddings=image_embedding,
            image_pe=sam_model.prompt_encoder.get_dense_pe(),
            sparse_prompt_embeddings=sparse_embeddings,
            dense_prompt_embeddings=dense_embeddings,
            multimask_output=False,
            )

            upscaled_masks = sam_model.postprocess_masks(low_res_masks, input_size, original_image_size).to(device)
            binary_mask = normalize(threshold(upscaled_masks, 0.0, 0)).to(device)

            gt_mask_resized = torch.from_numpy(np.resize(mask_temp, (1, 1, mask_temp.shape[0], mask_temp.shape[1]))).to(device)
            gt_binary_mask = torch.as_tensor(gt_mask_resized > 0, dtype=torch.float32)
            
            loss = loss_fn(binary_mask, gt_binary_mask)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            epoch_losses.append(loss.item())
            print(f'Loss: {loss.item()}')
            
    losses.append(mean(epoch_losses))
    print(f'EPOCH: {epoch}')
    print(f'Mean loss: {mean(epoch_losses)}')
    # save checkpoint of model
    torch.save(sam_model.state_dict(), 'trained_foodseg103_checkpoint.pth')    

    sam_model.eval()
    with torch.no_grad():
        # val set
        for val_image_id in tqdm(val_ids):

            # load mask 
            mask_path = "/Users/ioanna/Downloads/FoodSeg103/Images/ann_dir/train/{}.png".format(val_image_id)
            mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
            # load image
            image = cv2.imread("/Users/ioanna/Downloads/FoodSeg103/Images/img_dir/train/{}.jpg".format(val_image_id))
            image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

            mask_temp = mask.copy()

            for unique_value in np.unique(mask):
                if unique_value == 0:
                    continue
                # set all other values to 0
                mask_temp[mask != unique_value] = 0
                # set all other values to 1
                mask_temp[mask == unique_value] = 1

                bbox = np.argwhere(mask_temp)
                
                # get bounding box coordinates
                (y1, x1), (y2, x2) = bbox.min(0), bbox.max(0) + 1
                # to nnumpy
                bbox_coords = np.array([x1, y1, x2, y2])

                
                # plt.figure(figsize=(10,10))
                # plt.imshow(image)
                # show_box(bbox, plt.gca())
                # show_mask(mask_temp, plt.gca())
                # plt.axis('off')
                # plt.show()

                transform = ResizeLongestSide(target_length= sam_model.image_encoder.img_size)
                input_image = transform.apply_image(image)
                input_image_torch = torch.as_tensor(input_image, device=device)
                transformed_image = input_image_torch.permute(2, 0, 1).contiguous()[None, :, :, :]
                input_image = sam_model.preprocess(transformed_image)

                original_image_size = image.shape[:2]
                input_size = tuple(transformed_image.shape[-2:])

                with torch.no_grad():
                    image_embedding = sam_model.image_encoder(input_image)
                
                
                box = transform.apply_boxes(bbox_coords, original_image_size)
                box_torch = torch.as_tensor(box, dtype=torch.float, device=device)
                box_torch = box_torch[None, :]
                
                sparse_embeddings, dense_embeddings = sam_model.prompt_encoder(
                    points=None,
                    boxes=box_torch,
                    masks=None,
                )
                low_res_masks, iou_predictions = sam_model.mask_decoder(
                image_embeddings=image_embedding,
                image_pe=sam_model.prompt_encoder.get_dense_pe(),
                sparse_prompt_embeddings=sparse_embeddings,
                dense_prompt_embeddings=dense_embeddings,
                multimask_output=False,
                )

                upscaled_masks = sam_model.postprocess_masks(low_res_masks, input_size, original_image_size).to(device)
                binary_mask = normalize(threshold(upscaled_masks, 0.0, 0)).to(device)
                gt_mask_resized = torch.from_numpy(np.resize(mask_temp, (1, 1, mask_temp.shape[0], mask_temp.shape[1]))).to(device)
                gt_binary_mask = torch.as_tensor(gt_mask_resized > 0, dtype=torch.float32)
                val_loss = loss_fn(binary_mask, gt_binary_mask)
                epoch_val_losses.append(val_loss.item())



    val_losses.append(mean(epoch_val_losses))
    
    # if validation error is increasing, stop training
    if mean(epoch_val_losses) > min(val_losses):
        print('Validation loss is increasing')
        early_stopping_counter += 1
        if early_stopping_counter == patience:
            print('Early stopping')
            break
    else:
        early_stopping_counter = 0
    
    


# save epoch losses
with open('losses103.txt', 'w') as f:
    for item in losses:
        f.write("%s\n" % item)

# save validation losses
with open('val_losses103.txt', 'w') as f:
    for item in val_losses:
        f.write("%s\n" % item)

# save model
torch.save(sam_model.state_dict(), 'trained_foodseg103.pth')
    