In [1]:
from test import create_dataloaders, test_model, plot_iou_scores
from segment_anything import SamPredictor, sam_model_registry
from segment_anything.utils.transforms import ResizeLongestSide

import matplotlib.pyplot as plt
import numpy as np
import torch
import torchvision
from torch.nn.parallel import DataParallel
from torchvision.transforms import ToTensor, Compose

from utils import Embedding_Dataset, SAMPreprocess, PILToNumpy, NumpyToTensor, sample_point, SAMPostprocess

In [2]:
folder_paths = [
    '/pfs/work7/workspace/scratch/ul_xto11-FSSAM/Liebherr/dataset',
    '/pfs/work7/workspace/scratch/ul_xto11-FSSAM/FSSAM Datasets/EgoHOS/dataset',
    '/pfs/work7/workspace/scratch/ul_xto11-FSSAM/FSSAM Datasets/GTEA/dataset',
    '/pfs/work7/workspace/scratch/ul_xto11-FSSAM/FSSAM Datasets/LVIS/dataset',
    '/pfs/work7/workspace/scratch/ul_xto11-FSSAM/FSSAM Datasets/NDIS Park/dataset',
    '/pfs/work7/workspace/scratch/ul_xto11-FSSAM/FSSAM Datasets/TrashCan/dataset',
    '/pfs/work7/workspace/scratch/ul_xto11-FSSAM/FSSAM Datasets/ZeroWaste-f/dataset',
]
only_liebherr = True
load_weights = 'Fine-Tune-SAM.pth'
only_test = True

In [3]:
class SAM_Baseline(torch.nn.Module):
    def __init__(self):
        super(SAM_Baseline, self).__init__()
        self.sam_model = sam_model_registry['vit_b'](checkpoint='sam_vit_b_01ec64.pth')
        self.img_size = self.sam_model.image_encoder.img_size
        self.postprocess_masks = SAMPostprocess(self.img_size)

    def forward(self, embeddings, points):
        labels = torch.ones(embeddings.shape[0], 1)
        labels.to(points.device)
        sparse_embeddings, dense_embeddings = self.sam_model.prompt_encoder(
          points=(points.unsqueeze(1), labels),
          boxes=None,
          masks=None
        )
        masks, iou_predictions = self.sam_model.mask_decoder(
          image_embeddings=embeddings,
          image_pe=self.sam_model.prompt_encoder.get_dense_pe(),
          sparse_prompt_embeddings=sparse_embeddings,
          dense_prompt_embeddings=dense_embeddings,
          multimask_output=True,
        )
        masks = self.postprocess_masks(masks)
        return masks, iou_predictions # (B, 3, 1024, 1024), (B, 3)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = SAM_Baseline()
if load_weights is not None:
    model.load_state_dict(torch.load(load_weights))
if torch.cuda.device_count() > 1:
    print("Let's use", torch.cuda.device_count(), "GPUs!")
    model = DataParallel(model)
model.to(device)
print("Let's go")

Let's go


In [4]:
sam_transform = ResizeLongestSide(model.img_size)
target_transform = Compose([
    sam_transform.apply_image_torch, # rescale
    SAMPreprocess(model.img_size, normalize=False), # padding
    sample_point,
])
transform = Compose([
    PILToNumpy(),
    sam_transform.apply_image, # rescale
    NumpyToTensor(),
    SAMPreprocess(model.img_size) # padding
])
def custom_collate(batch):
    images, targets, embeddings = zip(*batch)
    masks, points = zip(*targets)
    return torch.stack(images), torch.stack(masks), torch.stack(points), torch.stack(embeddings)

if only_liebherr:
    folder_paths = [folder_paths[0]]
    
dataloaders = create_dataloaders(folder_paths, transform=transform, target_transform=target_transform, collate_fn=custom_collate,
                                 batch_size=8, only_test=only_test)

In [7]:
iou_scores, oracle_iou_scores, dataset_names = test_model(model, dataloaders, device)

Testing Liebherr


100%|██████████| 16/16 [01:02<00:00,  3.93s/it, IoU=0.608, oracle IoU=0.699]

Average IoU tensor(0.6082)
Average Oracle IoU tensor(0.6988)





In [None]:
plot_iou_scores(iou_scores[1:], oracle_iou_scores[1:], dataset_names[1:])
print(np.mean(iou_scores[1:]), np.mean(oracle_iou_scores[1:]))