In [1]:
import numpy as np
import torch
import torchvision
from torch.nn.parallel import DataParallel
from torchvision.transforms import ToTensor, Compose
from torch.utils.data import DataLoader, TensorDataset, random_split, Subset
import matplotlib.pyplot as plt
import tqdm
import os
from torchmetrics.classification import BinaryJaccardIndex

from segment_anything.utils.transforms import ResizeLongestSide

from datasets import Embedding_Dataset, Custom_Dataset, Cutout_Dataset
from utils import SAMPreprocess, PILToNumpy, NumpyToTensor, SamplePoint, embedding_collate, is_valid_file
from utils import create_cutouts
from models import SAM_Baseline

jaccard = BinaryJaccardIndex()

In [2]:
batch_size = 8
folder_paths = [
    '/pfs/work7/workspace/scratch/ul_xto11-FSSAM/Liebherr/dataset',
    '/pfs/work7/workspace/scratch/ul_xto11-FSSAM/FSSAM Datasets/EgoHOS/dataset',
    '/pfs/work7/workspace/scratch/ul_xto11-FSSAM/FSSAM Datasets/GTEA/dataset',
    '/pfs/work7/workspace/scratch/ul_xto11-FSSAM/FSSAM Datasets/LVIS/dataset',
    '/pfs/work7/workspace/scratch/ul_xto11-FSSAM/FSSAM Datasets/NDIS Park/dataset',
    '/pfs/work7/workspace/scratch/ul_xto11-FSSAM/FSSAM Datasets/TrashCan/dataset',
    '/pfs/work7/workspace/scratch/ul_xto11-FSSAM/FSSAM Datasets/ZeroWaste-f/dataset',
]
multi_output = True
select_mode = 'max_sim' # must be one of: 'random', 'first', 'highest_pred', 'max_sim'
embed_model_type = 'dino'
visual_prompt_engineering = False

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = SAM_Baseline(multi_output)

In [4]:
if embed_model_type == 'clip':
    import clip
    embed_model, _ = clip.load("ViT-L/14@336px", device=device) # Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
    cutout_size = embed_model.visual.input_resolution
    embed_model.to(device)
    print("CLIP loaded", cutout_size)
elif embed_model_type == 'dino':
    class DINO:
        def __init__(self, device):
            self.model = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitl14')
            self.model.to(device)
        def encode_image(self, image):
            with torch.no_grad():
                return self.model(image)
    embed_model = DINO(device)
    cutout_size = 336
    print("DINO loaded", cutout_size)

Using cache found in /home/ul/ul_student/ul_xto11/.cache/torch/hub/facebookresearch_dinov2_main


DINO loaded 336


In [5]:
PIXEL_MEAN = (0.48145466, 0.4578275, 0.40821073)
PIXEL_STD = (0.26862954, 0.26130258, 0.27577711)
pixel_mean = torch.Tensor(PIXEL_MEAN).reshape(3, 1, 1)
pixel_std = torch.Tensor(PIXEL_STD).reshape(3, 1, 1)

In [6]:
sam_transform = ResizeLongestSide(model.img_size)
target_transform = Compose([
    sam_transform.apply_image_torch, # rescale
    SAMPreprocess(model.img_size, normalize=False), # padding
    SamplePoint(),
])
transform = Compose([
    PILToNumpy(),
    sam_transform.apply_image, # rescale
    NumpyToTensor(),
    SAMPreprocess(model.img_size) # padding
])
example_transform = Compose([
    ToTensor(),
])

def create_dataloader(folder_path):
    dataset = Embedding_Dataset(root=folder_path, transform=transform, target_transform=target_transform, is_valid_file=is_valid_file)
    example_set = Custom_Dataset(root=folder_path, transform=example_transform, is_valid_file=is_valid_file)
    test_loader = DataLoader(dataset, batch_size=batch_size, collate_fn=embedding_collate)
    return test_loader, example_set

In [7]:
dataloaders = {}
example_sets = {}

for folder_path in folder_paths:
    dataset_name = folder_path.split('/')[-2]  # Extract dataset name from the folder path
    dataloaders[dataset_name], example_sets[dataset_name] = create_dataloader(folder_path)
    print(dataset_name, "loaded")

Liebherr loaded
EgoHOS loaded
GTEA loaded
LVIS loaded
NDIS Park loaded
TrashCan loaded
ZeroWaste-f loaded


In [8]:
from torchvision.transforms import GaussianBlur

class CreateCutouts(object):
    def __init__(self, cutout_size, padding, background_transform, background_intensity):
        self.cutout_size = cutout_size
        self.padding = padding
        self.background_transform = background_transform
        self.background_intensity = background_intensity

    def __call__(self, image, masks):
        #image = (image - pixel_mean) / pixel_std # better performance without normilization
        return create_cutouts(image, masks, self.cutout_size, self.padding, self.background_transform, self.background_intensity)
        
if visual_prompt_engineering:
    background_intensity = .1
    background_transform = Compose([
        GaussianBlur(11, 10)
    ])
else:
    background_intensity = 0
    background_transform = None

create_cutouts_f = CreateCutouts(cutout_size, 30, background_transform, background_intensity)

In [9]:
def get_cutouts_dataset(dataset):
    cutouts_set = []
    for image, masks in tqdm.tqdm(dataset):
        cutouts = create_cutouts_f(image, masks)
        cutouts_set.append(cutouts)
    cutouts_set = Cutout_Dataset(torch.cat(cutouts_set))
    return cutouts_set

def get_cutout_embeddings(dataset, model, device, batch_size):
    cutouts_loader = DataLoader(dataset, batch_size=batch_size)
    examples = []
    for cutouts in tqdm.tqdm(cutouts_loader):
        cutouts = cutouts.to(device)
        cutout_embeddings = model.encode_image(cutouts)
        cutout_embeddings = cutout_embeddings[~torch.any(cutout_embeddings.isnan(), dim=1)] # remove nan embeddings
        examples.append(cutout_embeddings.detach().cpu())
    return torch.cat(examples)

In [10]:
def get_example_embeddings(example_set):
    if visual_prompt_engineering:
        file = 'example_embeddings/' + name + '_vpe_' + embed_model_type + '_full.pt'
    else:
        file = 'example_embeddings/' + name + '_' + embed_model_type + '_full.pt'

    if os.path.exists(file):
        example_embeddings = torch.load(file)
    else:
        cutouts_set = get_cutouts_dataset(example_set)
        example_embeddings = get_cutout_embeddings(cutouts_set, embed_model, device, batch_size=16)
        torch.save(example_embeddings, file)
        
    print(example_embeddings.shape)
    return example_embeddings

In [11]:
def get_closest(images, masks, examples, create_cutouts_f, model, device):
    embeddings = []
    for image, mask_suggestions in zip(images, masks):
        cutouts = create_cutouts_f(image, mask_suggestions).to(device)
        cutout_embeddings = model.encode_image(cutouts)
        embeddings.append(cutout_embeddings.detach())
    embeddings = torch.cat(embeddings)
    examples = examples.to(device)
    
    embeddings /= torch.norm(embeddings, dim=1, keepdim=True)
    examples /= torch.norm(examples, dim=1, keepdim=True)
    similarity_matrix = embeddings @ examples.T
    
    similarity_matrix = similarity_matrix.reshape(len(images), 3, len(examples))
    topk_values, _ = torch.topk(similarity_matrix, k=2, dim=-1)
    similarity_matrix = topk_values[:, :, 1]
    second_max_i = torch.argmax(similarity_matrix, dim=-1)
    return second_max_i.cpu().tolist()

In [12]:
if torch.cuda.device_count() > 1:
    print("Let's use", torch.cuda.device_count(), "GPUs!")
    model = DataParallel(model)
model.to(device)
print("SAM loaded")

SAM loaded


In [13]:
def evaluate_model(model, dataloader, example_embeddings, device): # select_mode, create_cutouts_f, embed_model
    np.random.seed(42)
    model.eval()
    total_iou = 0.0
    total_samples = 0

    with torch.no_grad():
        pbar = tqdm.tqdm(dataloader)
        for images, masks, points, embeddings in pbar:
            embeddings = embeddings.to(device)
            points = points.to(device)
            
            if multi_output:
                pred_masks, pred_iou = model(embeddings, points)
                pred_masks, pred_iou = pred_masks.cpu(), pred_iou.cpu()
            else:
                pred_masks = model(embeddings, points)
                pred_masks = pred_masks.cpu()
            pred_masks = pred_masks > 0
            
            if select_mode == 'random':
                max_i = torch.randint(3, (len(images),))
            elif select_mode == 'first':
                max_i = torch.zeros((len(images),), dtype=int)
            elif select_mode == 'highest_pred':
                max_i = torch.argmax(pred_iou, dim=1)
            elif select_mode == 'max_sim':
                max_i = get_closest(images, pred_masks, example_embeddings, create_cutouts_f, embed_model, device)

            pred_masks = pred_masks[range(len(pred_masks)), max_i].unsqueeze(1)
            iou = jaccard(masks, pred_masks)

            total_iou += iou * images.size(0)
            total_samples += images.size(0)
            
            pbar.set_postfix({'IoU': (total_iou / total_samples).item()})

    average_iou = total_iou / total_samples
    return average_iou

In [14]:
iou_scores = []
for name, dataloader in dataloaders.items():
    print("Loading example embeddings for", name)
    example_embeddings = get_example_embeddings(example_sets[name])
    print("Testing", name)
    iou_score = evaluate_model(model, dataloader, example_embeddings, device)
    print('IoU: {:.4f}'.format(iou_score), flush=True)
    iou_scores.append(iou_score)
print(iou_scores)

Loading example embeddings for Liebherr


100%|██████████| 848/848 [26:24<00:00,  1.87s/it]
100%|██████████| 991/991 [09:13<00:00,  1.79it/s]


torch.Size([15847, 1024])
Testing Liebherr


100%|██████████| 106/106 [07:56<00:00,  4.50s/it, IoU=0.503]

IoU: 0.5027





Loading example embeddings for EgoHOS


100%|██████████| 1000/1000 [07:47<00:00,  2.14it/s]
100%|██████████| 216/216 [01:57<00:00,  1.83it/s]


torch.Size([3451, 1024])
Testing EgoHOS


100%|██████████| 125/125 [07:04<00:00,  3.40s/it, IoU=0.572]

IoU: 0.5716





Loading example embeddings for GTEA


100%|██████████| 652/652 [00:44<00:00, 14.77it/s]
100%|██████████| 76/76 [00:41<00:00,  1.85it/s]


torch.Size([1208, 1024])
Testing GTEA


100%|██████████| 82/82 [03:59<00:00,  2.91s/it, IoU=0.738]

IoU: 0.7384





Loading example embeddings for LVIS


100%|██████████| 1000/1000 [03:11<00:00,  5.22it/s]
100%|██████████| 652/652 [05:56<00:00,  1.83it/s]


torch.Size([10421, 1024])
Testing LVIS


100%|██████████| 125/125 [06:18<00:00,  3.03s/it, IoU=0.558]

IoU: 0.5579





Loading example embeddings for NDIS Park


100%|██████████| 112/112 [07:30<00:00,  4.02s/it]
100%|██████████| 157/157 [01:25<00:00,  1.84it/s]


torch.Size([2508, 1024])
Testing NDIS Park


100%|██████████| 14/14 [01:27<00:00,  6.22s/it, IoU=0.721]

IoU: 0.7210





Loading example embeddings for TrashCan


100%|██████████| 1000/1000 [00:55<00:00, 18.06it/s]
100%|██████████| 135/135 [01:13<00:00,  1.85it/s]


torch.Size([2151, 1024])
Testing TrashCan


100%|██████████| 125/125 [05:39<00:00,  2.71s/it, IoU=0.42] 

IoU: 0.4202





Loading example embeddings for ZeroWaste-f


100%|██████████| 1000/1000 [05:39<00:00,  2.94it/s]
100%|██████████| 133/133 [01:12<00:00,  1.82it/s]


torch.Size([2117, 1024])
Testing ZeroWaste-f


100%|██████████| 125/125 [07:10<00:00,  3.45s/it, IoU=0.309]

IoU: 0.3088





[tensor(0.5027), tensor(0.5716), tensor(0.7384), tensor(0.5579), tensor(0.7210), tensor(0.4202), tensor(0.3088)]


In [2]:
RANDOM = [0.4695, 0.5637, 0.3954, 0.4686, 0.1120, 0.2694]
print(np.mean(RANDOM))

FIRST_SINGLE = [0.5494, 0.8276, 0.4359, 0.3154, 0.3054, 0.2320]
print(np.mean(FIRST_SINGLE))

HIGHEST = [0.5349, 0.7512, 0.4374, 0.4321, 0.1903, 0.2587]
print(np.mean(HIGHEST))

##### CLIP #####
MAX_SIM = [0.4979, 0.6253, 0.4101, 0.3702, 0.2812, 0.2767]
print(np.mean(MAX_SIM))

VPE = [0.5338, 0.8117, 0.4168, 0.3844, 0.2712, 0.2645]
print(np.mean(VPE))

##### DINO #####
MAX_SIM = [0.5388, 0.7259, 0.4657, 0.6028, 0.3367, 0.2933]
print(np.mean(MAX_SIM))

VPE = [0.5079, 0.7455, 0.4574, 0.5149, 0.3457, 0.3003]
print(np.mean(VPE))

FULL = [0.5027, 0.5716, 0.7384, 0.5579, 0.7210, 0.4202, 0.3088]
print(np.mean(FULL[1:]))

0.37976666666666664
0.44428333333333336
0.4341000000000001
0.4102333333333334
0.44706666666666667
0.4938666666666666
0.4786166666666667
0.5529833333333334
