In [1]:
import numpy as np
import torch
import torchvision
from torch.nn.parallel import DataParallel
from torchvision.transforms import ToTensor, Compose
from torch.utils.data import DataLoader, TensorDataset, random_split, Subset
import matplotlib.pyplot as plt
import tqdm
import os
from torchmetrics.classification import BinaryJaccardIndex

from segment_anything.utils.transforms import ResizeLongestSide

from datasets import Embedding_Dataset, Custom_Dataset, Cutout_Split_Dataset
from utils import SAMPreprocess, PILToNumpy, NumpyToTensor, SamplePoint, embedding_collate, is_valid_file
from utils import create_cutout_split
from models import SAM_Baseline

jaccard = BinaryJaccardIndex()

In [2]:
batch_size = 8
folder_paths = [
    #'/pfs/work7/workspace/scratch/ul_xto11-FSSAM/Liebherr/dataset',
    '/pfs/work7/workspace/scratch/ul_xto11-FSSAM/FSSAM Datasets/EgoHOS/dataset',
    '/pfs/work7/workspace/scratch/ul_xto11-FSSAM/FSSAM Datasets/GTEA/dataset',
    '/pfs/work7/workspace/scratch/ul_xto11-FSSAM/FSSAM Datasets/LVIS/dataset',
    '/pfs/work7/workspace/scratch/ul_xto11-FSSAM/FSSAM Datasets/NDIS Park/dataset',
    '/pfs/work7/workspace/scratch/ul_xto11-FSSAM/FSSAM Datasets/TrashCan/dataset',
    '/pfs/work7/workspace/scratch/ul_xto11-FSSAM/FSSAM Datasets/ZeroWaste-f/dataset',
]
select_mode = 'max_sim' # must be one of: 'random', 'first', 'highest_pred', 'max_sim'

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = SAM_Baseline()

In [4]:
from maskCLIP.clip import clip # modified version of CLIP from ov-seg, adds mask-embedding to image
clip_model, clip_preprocess = clip.load('ViT-L/14', mask_prompt_depth=3, device=device)
ov_seg = torch.load('ovseg_swinbase_vitL14_ft_mpt.pth')
state_dict = ov_seg["model"]
state_dict = {
    k.replace("clip_adapter.clip_model.", ""): v
    for k, v in state_dict.items()
    if k.startswith("clip_adapter.clip_model.")
}
clip_model.load_state_dict(state_dict)
cutout_size = clip_model.visual.input_resolution
print("CLIP loaded", cutout_size)

CLIP loaded 224


In [5]:
PIXEL_MEAN = (0.48145466, 0.4578275, 0.40821073)
PIXEL_STD = (0.26862954, 0.26130258, 0.27577711)
pixel_mean = torch.Tensor(PIXEL_MEAN).reshape(3, 1, 1)
pixel_std = torch.Tensor(PIXEL_STD).reshape(3, 1, 1)

In [6]:
sam_transform = ResizeLongestSide(model.img_size)
target_transform = Compose([
    sam_transform.apply_image_torch, # rescale
    SAMPreprocess(model.img_size, normalize=False), # padding
    SamplePoint(),
])
transform = Compose([
    PILToNumpy(),
    sam_transform.apply_image, # rescale
    NumpyToTensor(),
    SAMPreprocess(model.img_size) # padding
])
example_transform = Compose([
    ToTensor(),
])

def create_dataloader(folder_path):
    dataset = Embedding_Dataset(root=folder_path, transform=transform, target_transform=target_transform, is_valid_file=is_valid_file)
    example_set = Custom_Dataset(root=folder_path, transform=example_transform, is_valid_file=is_valid_file)

    dataset_size = len(dataset)
    example_size = int(0.25 * dataset_size)
    test_size = dataset_size - example_size
    generator = torch.Generator().manual_seed(42)
    test_indices, example_indices = random_split(range(len(dataset)), [test_size, example_size], generator=generator)

    test_set = Subset(dataset, test_indices)
    example_set = Subset(example_set, example_indices)
    test_loader = DataLoader(test_set, batch_size=batch_size, collate_fn=embedding_collate)
    return test_loader, example_set

In [7]:
dataloaders = {}
example_sets = {}

for folder_path in folder_paths:
    dataset_name = folder_path.split('/')[-2]  # Extract dataset name from the folder path
    dataloaders[dataset_name], example_sets[dataset_name] = create_dataloader(folder_path)
    print(dataset_name, "loaded")

EgoHOS loaded
GTEA loaded
LVIS loaded
NDIS Park loaded
TrashCan loaded
ZeroWaste-f loaded


In [8]:
from torchvision.transforms import GaussianBlur

class CreateCutouts(object):
    def __init__(self, cutout_size, padding):
        self.cutout_size = cutout_size
        self.padding = padding

    def __call__(self, image, masks):
        #image = (image - pixel_mean) / pixel_std # better performance without normilization even in this setting
        return create_cutout_split(image, masks, self.cutout_size, self.padding)

create_cutouts_f = CreateCutouts(cutout_size, 30)

#image, masks = example_set[1]
#image_s, mask_s = create_cutouts_f(image, masks)
#fig, axs = plt.subplots(1, 2)
#axs[0].imshow(image_s[0].permute(1, 2, 0))
#axs[1].imshow(mask_s[0].permute(1, 2, 0))

In [9]:
from  torch.cuda.amp import autocast

def get_cutouts_dataset(dataset):
    images_set = []
    masks_set = []
    for image, masks in tqdm.tqdm(dataset):
        i, m = create_cutouts_f(image, masks)
        images_set.append(i)
        masks_set.append(m)
    cutouts_set = Cutout_Split_Dataset(torch.cat(images_set), torch.cat(masks_set))
    return cutouts_set

def get_cutout_embeddings(dataset, model, device, batch_size):
    cutouts_loader = DataLoader(dataset, batch_size=batch_size)
    examples = []
    for images, masks in tqdm.tqdm(cutouts_loader):
        images, masks = images.to(device), masks.to(device)
        with autocast():
            cutout_embeddings = model.visual(images, masks)
        cutout_embeddings = cutout_embeddings[~torch.any(cutout_embeddings.isnan(), dim=1)] # remove nan embeddings
        examples.append(cutout_embeddings.detach().cpu())
    return torch.cat(examples)

In [10]:
def get_example_embeddings(example_set):
    file = 'example_embeddings/' + name + '_mpt.pt'

    if os.path.exists(file):
        example_embeddings = torch.load(file)
    else:
        cutouts_set = get_cutouts_dataset(example_set)
        example_embeddings = get_cutout_embeddings(cutouts_set, clip_model, device, batch_size=16)
        torch.save(example_embeddings, file)
        
    print(example_embeddings.shape)
    return example_embeddings

In [11]:
def get_closest(images, masks, examples, create_cutouts_f, model, device):
    embeddings = []
    for image, mask_suggestions in zip(images, masks):
        i, m = create_cutouts_f(image, mask_suggestions)
        i, m = i.to(device), m.to(device)
        with autocast():
            cutout_embeddings = model.visual(i, m)
        embeddings.append(cutout_embeddings.detach())
    embeddings = torch.cat(embeddings)
    examples = examples.to(device)
    
    embeddings /= torch.norm(embeddings, dim=1, keepdim=True)
    examples /= torch.norm(examples, dim=1, keepdim=True)
    similarity_matrix = embeddings @ examples.T
    
    similarity_matrix = similarity_matrix.reshape(len(images), 3, len(examples))
    similarity_matrix, _ = torch.max(similarity_matrix, dim=-1)
    max_i = torch.argmax(similarity_matrix, dim=-1)
    return max_i.cpu().tolist()

In [12]:
if torch.cuda.device_count() > 1:
    print("Let's use", torch.cuda.device_count(), "GPUs!")
    model = DataParallel(model)
model.to(device)
print("SAM loaded")

SAM loaded


In [13]:
def evaluate_model(model, dataloader, example_embeddings, device): # select_mode, create_cutouts_f, clip_model
    np.random.seed(42)
    model.eval()
    total_iou = 0.0
    total_samples = 0

    with torch.no_grad():
        pbar = tqdm.tqdm(dataloader)
        for images, masks, points, embeddings in pbar:
            embeddings = embeddings.to(device)
            points = points.to(device)

            pred_masks, pred_iou = model(embeddings, points)
            pred_masks, pred_iou = pred_masks.cpu(), pred_iou.cpu()
            pred_masks = pred_masks > 0
            
            if select_mode == 'random':
                max_i = torch.randint(3, (len(images),))
            elif select_mode == 'first':
                max_i = torch.zeros((len(images),), dtype=int)
            elif select_mode == 'highest_pred':
                max_i = torch.argmax(pred_iou, dim=1)
            elif select_mode == 'max_sim':
                max_i = get_closest(images, pred_masks, example_embeddings, create_cutouts_f, clip_model, device)

            pred_masks = pred_masks[range(len(pred_masks)), max_i].unsqueeze(1)
            iou = jaccard(masks, pred_masks)

            total_iou += iou * images.size(0)
            total_samples += images.size(0)
            
            pbar.set_postfix({'IoU': (total_iou / total_samples).item()})

    average_iou = total_iou / total_samples
    return average_iou

In [14]:
for name, dataloader in dataloaders.items():
    print("Loading example embeddings for", name)
    example_embeddings = get_example_embeddings(example_sets[name])
    print("Testing", name)
    iou_score = evaluate_model(model, dataloader, example_embeddings, device)
    print('IoU: {:.4f}'.format(iou_score), flush=True)

Loading example embeddings for EgoHOS


100%|██████████| 250/250 [00:24<00:00, 10.27it/s]
100%|██████████| 53/53 [00:04<00:00, 10.90it/s]


torch.Size([842, 768])
Testing EgoHOS


100%|██████████| 94/94 [03:08<00:00,  2.01s/it, IoU=0.511]

IoU: 0.5113





Loading example embeddings for GTEA


100%|██████████| 163/163 [00:02<00:00, 74.16it/s]
100%|██████████| 19/19 [00:01<00:00, 15.92it/s]


torch.Size([300, 768])
Testing GTEA


100%|██████████| 62/62 [01:34<00:00,  1.52s/it, IoU=0.786]

IoU: 0.7859





Loading example embeddings for LVIS


100%|██████████| 250/250 [00:10<00:00, 24.15it/s]
100%|██████████| 194/194 [00:12<00:00, 15.88it/s]


torch.Size([3088, 768])
Testing LVIS


100%|██████████| 94/94 [02:33<00:00,  1.64s/it, IoU=0.446]

IoU: 0.4462





Loading example embeddings for NDIS Park


100%|██████████| 28/28 [00:13<00:00,  2.01it/s]
100%|██████████| 59/59 [00:03<00:00, 15.98it/s]


torch.Size([915, 768])
Testing NDIS Park


100%|██████████| 11/11 [00:55<00:00,  5.05s/it, IoU=0.283]

IoU: 0.2826





Loading example embeddings for TrashCan


100%|██████████| 250/250 [00:02<00:00, 92.15it/s] 
100%|██████████| 44/44 [00:02<00:00, 15.87it/s]


torch.Size([699, 768])
Testing TrashCan


100%|██████████| 94/94 [02:10<00:00,  1.38s/it, IoU=0.28] 

IoU: 0.2800





Loading example embeddings for ZeroWaste-f


100%|██████████| 250/250 [00:35<00:00,  6.96it/s]
100%|██████████| 35/35 [00:02<00:00, 16.07it/s]


torch.Size([548, 768])
Testing ZeroWaste-f


100%|██████████| 94/94 [03:04<00:00,  1.96s/it, IoU=0.237]

IoU: 0.2371





In [15]:
MPT = [0.5113, 0.7859, 0.4462, 0.2826, 0.2800, 0.2371]
print(np.mean(MPT))

0.42385
