In [1]:
import numpy as np
import torch
import torchvision
from torch.nn.parallel import DataParallel
from torchvision.transforms import ToTensor, Compose
from torch.utils.data import DataLoader, TensorDataset, random_split, Subset
import matplotlib.pyplot as plt
import tqdm
import os
from torchmetrics.classification import BinaryJaccardIndex

from segment_anything.utils.transforms import ResizeLongestSide

from datasets import Embedding_Dataset, Custom_Dataset, Cutout_Dataset
from utils import SAMPreprocess, PILToNumpy, NumpyToTensor, SamplePoint, embedding_collate, is_valid_file
from utils import create_cutouts
from models import SAM_Concat

jaccard = BinaryJaccardIndex()

In [6]:
batch_size = 8
folder_paths = [
    '/pfs/work7/workspace/scratch/ul_xto11-FSSAM/Liebherr/dataset',
    '/pfs/work7/workspace/scratch/ul_xto11-FSSAM/FSSAM Datasets/EgoHOS/dataset',
    '/pfs/work7/workspace/scratch/ul_xto11-FSSAM/FSSAM Datasets/GTEA/dataset',
    '/pfs/work7/workspace/scratch/ul_xto11-FSSAM/FSSAM Datasets/LVIS/dataset',
    '/pfs/work7/workspace/scratch/ul_xto11-FSSAM/FSSAM Datasets/NDIS Park/dataset',
    '/pfs/work7/workspace/scratch/ul_xto11-FSSAM/FSSAM Datasets/TrashCan/dataset',
    '/pfs/work7/workspace/scratch/ul_xto11-FSSAM/FSSAM Datasets/ZeroWaste-f/dataset',
]
embed_model_type = 'dino' # clip, dino
guidance_method = 'concat' # concat, attn
example_method = 'ground_truth' # ground_truth, noisy_ground_truth, retrieval
multi_output = False
visual_prompt_engineering = False

load_weights = 'SAM_single_concat_dino_ground_truth.pt'

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [4]:
if embed_model_type == 'clip':
    import clip
    class CLIP:
        def __init__(self, device):
            self.model, _ = clip.load("ViT-L/14@336px", device=device)
            self.model.eval()
            self.model.to(device)
        def __call__(self, image):
            with torch.no_grad():
                return self.model.encode_image(image)
    embed_model = CLIP(device)
    cutout_size = 336
    example_dim = 768
    print("CLIP loaded", cutout_size)
elif embed_model_type == 'dino':
    class DINO:
        def __init__(self, device):
            self.model = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitl14')
            self.model.eval()
            self.model.to(device)
        def __call__(self, image):
            with torch.no_grad():
                return self.model(image)
    embed_model = DINO(device)
    cutout_size = 336
    example_dim = 1024
    print("DINO loaded", cutout_size)

CLIP loaded 336


In [7]:
if guidance_method == 'concat':
    model = SAM_Concat(example_dim, multi_output)
else:
    pass
pretrained_dict = torch.load(load_weights)
pretrained_dict = {key.replace("module.", ""): value for key, value in pretrained_dict.items()}
model.load_state_dict(pretrained_dict)

<All keys matched successfully>

In [8]:
sam_transform = ResizeLongestSide(model.img_size)
target_transform = Compose([
    sam_transform.apply_image_torch, # rescale
    SAMPreprocess(model.img_size, normalize=False), # padding
    SamplePoint(),
])
transform = Compose([
    PILToNumpy(),
    sam_transform.apply_image, # rescale
    NumpyToTensor(),
    SAMPreprocess(model.img_size) # padding
])
example_transform = Compose([
    ToTensor(),
])

def create_dataloader(folder_path):
    dataset = Embedding_Dataset(root=folder_path, transform=transform, target_transform=target_transform, is_valid_file=is_valid_file)
    example_set = Custom_Dataset(root=folder_path, transform=example_transform, is_valid_file=is_valid_file)

    dataset_size = len(dataset)
    example_size = int(0.25 * dataset_size)
    test_size = dataset_size - example_size
    generator = torch.Generator().manual_seed(42)
    test_indices, example_indices = random_split(range(len(dataset)), [test_size, example_size], generator=generator)

    test_set = Subset(dataset, test_indices)
    example_set = Subset(example_set, example_indices)
    test_loader = DataLoader(test_set, batch_size=batch_size, collate_fn=embedding_collate)
    return test_loader, example_set

In [9]:
dataloaders = {}
example_sets = {}

for folder_path in folder_paths:
    dataset_name = folder_path.split('/')[-2]  # Extract dataset name from the folder path
    dataloaders[dataset_name], example_sets[dataset_name] = create_dataloader(folder_path)
    print(dataset_name, "loaded")

Liebherr loaded
EgoHOS loaded
GTEA loaded
LVIS loaded
NDIS Park loaded
TrashCan loaded
ZeroWaste-f loaded


In [10]:
from torchvision.transforms import GaussianBlur

class CreateCutouts(object):
    def __init__(self, cutout_size, padding, background_transform, background_intensity):
        self.cutout_size = cutout_size
        self.padding = padding
        self.background_transform = background_transform
        self.background_intensity = background_intensity

    def __call__(self, image, masks):
        #image = (image - pixel_mean) / pixel_std # better performance without normilization
        return create_cutouts(image, masks, self.cutout_size, self.padding, self.background_transform, self.background_intensity)
        
if visual_prompt_engineering:
    background_intensity = .1
    background_transform = Compose([
        GaussianBlur(11, 10)
    ])
else:
    background_intensity = 0
    background_transform = None

create_cutouts_f = CreateCutouts(cutout_size, 30, background_transform, background_intensity)

#image, masks = example_set[1]
#cutouts = create_cutouts_f(image, masks)
#plt.imshow(cutouts[0].permute(1, 2, 0))

In [11]:
def get_cutouts_dataset(dataset):
    cutouts_set = []
    for image, masks in tqdm.tqdm(dataset):
        cutouts = create_cutouts_f(image, masks)
        cutouts_set.append(cutouts)
    cutouts_set = Cutout_Dataset(torch.cat(cutouts_set))
    return cutouts_set

def get_cutout_embeddings(dataset, model, device, batch_size):
    cutouts_loader = DataLoader(dataset, batch_size=batch_size)
    examples = []
    for cutouts in tqdm.tqdm(cutouts_loader):
        cutouts = cutouts.to(device)
        cutout_embeddings = model.encode_image(cutouts)
        cutout_embeddings = cutout_embeddings[~torch.any(cutout_embeddings.isnan(), dim=1)] # remove nan embeddings
        examples.append(cutout_embeddings.detach().cpu())
    return torch.cat(examples)

In [12]:
def get_example_embeddings(example_set):
    if visual_prompt_engineering:
        file = 'example_embeddings/' + name + '_vpe_' + embed_model_type + '.pt'
    else:
        file = 'example_embeddings/' + name + '_' + embed_model_type + '.pt'

    if os.path.exists(file):
        example_embeddings = torch.load(file)
    else:
        cutouts_set = get_cutouts_dataset(example_set)
        example_embeddings = get_cutout_embeddings(cutouts_set, embed_model, device, batch_size=16)
        torch.save(example_embeddings, file)
        
    print(example_embeddings.shape)
    return example_embeddings

In [14]:
if torch.cuda.device_count() > 1:
    print("Let's use", torch.cuda.device_count(), "GPUs!")
    model = DataParallel(model)
model.to(device)
print("SAM loaded")

SAM loaded


In [18]:
def evaluate_model(model, dataloader, example_embeddings, device): # select_mode, create_cutouts_f, embed_model
    np.random.seed(42)
    model.eval()
    total_iou = 0.0
    total_samples = 0

    with torch.no_grad():
        pbar = tqdm.tqdm(dataloader)
        for images, masks, points, embeddings in pbar:
            embeddings = embeddings.to(device)
            points = points.to(device)
            
            # Get cutout embeddings
            cutouts = [create_cutouts_f(i, m) for i, m in zip(images, masks)]
            cutouts = torch.cat(cutouts).to(device)
            cutout_embeddings = embed_model(cutouts)
            
            examples = cutout_embeddings.unsqueeze(1)
            #examples = example_embeddings.repeat(len(images), 1, 1)[:, 0, :].unsqueeze(1)
            #examples = torch.zeros((len(images), 1, 768), dtype=torch.float16).to(device)
            
            # Forward Pass
            if multi_output:
                outputs, iou_pred = model(embeddings, points, examples)
                outputs = outputs[range(len(outputs)), torch.argmax(iou_pred, dim=1)].unsqueeze(1)
            else:
                outputs = model(embeddings, points, examples)

            total_iou += jaccard(masks > 0, outputs.cpu() > 0) * embeddings.size(0)
            total_samples += images.size(0)
            
            pbar.set_postfix({'IoU': (total_iou / total_samples).item()})

    average_iou = total_iou / total_samples
    return average_iou

In [None]:
iou_scores = []
for name, dataloader in dataloaders.items():
    print("Loading example embeddings for", name)
    example_embeddings = get_example_embeddings(example_sets[name]).to(device)
    print("Testing", name)
    iou_score = evaluate_model(model, dataloader, example_embeddings, device)
    print('IoU: {:.4f}'.format(iou_score), flush=True)
    iou_scores.append(iou_score)
print(iou_scores, np.mean(iou_scores[1:]))

Loading example embeddings for Liebherr
torch.Size([3993, 768])
Testing Liebherr


 79%|███████▉  | 63/80 [03:22<00:55,  3.28s/it, IoU=0.502]