In [None]:
from TinySAM import *
import numpy as np
import matplotlib.pyplot as plt
import torch
from tqdm.auto import tqdm, trange
from torch.utils.data import DataLoader

# Setting Up Models

In [None]:
GroundingModel = GDino("IDEA-Research/grounding-dino-tiny")
SAMModel = EdgeSAM()

# Load Data to RAM

In [None]:
data = ZeroShotObjectDetectionDataset('./Data/cityscapes/', do_preprocess=False, processor=GroundingModel.processor)

In [None]:
# visualize a random sample
rnd_idx = np.random.randint(0,len(data))
data.visualize(rnd_idx)

# Run Grounding Dino

In [None]:
batch_size = 8
dataloader = DataLoader(data, batch_size=batch_size, shuffle=False)

# get the text prompts
input_ids = data.input_prompt_ins.input_ids.to(GroundingModel.device)
target_image_size = data.image_size

boxes, labels, scores = GroundingModel.run_loader(dataloader, input_ids, data.text_prompts, target_image_size)

# Run SAM2

In [None]:
masks = SAMModel(data.images, boxes)

In [None]:
# Save Results (if needed)
np.save('masks.npy', np.array(masks, dtype=object))
np.save('boxes.npy', np.array(boxes, dtype=object))
np.save('labels.npy', np.array(labels, dtype=object))
np.save('scores.npy', np.array(scores, dtype=object))

In [None]:
# Load Results (if needed)
masks = np.load('masks.npy', allow_pickle=True)
boxes = np.load('boxes.npy', allow_pickle=True)
labels = np.load('labels.npy', allow_pickle=True)
scores = np.load('scores.npy', allow_pickle=True)

# Visualize A Prediction

In [None]:
rnd_idx = np.random.randint(0,len(data))
# visualize ground truth
data.visualize(rnd_idx)
plt.title('Ground Truth')

# visualize the predicted masks
data.visualize_prediction(rnd_idx, boxes[rnd_idx], masks[rnd_idx], labels[rnd_idx])
plt.title('Predicted Instances Raw')

# visualize the predicted masks after post-processing
data.visualize_prediction(rnd_idx, boxes[rnd_idx], masks[rnd_idx], labels[rnd_idx], unify=True)

In [None]:
mIoU, mAP, overall_iou, processed_boxes, processed_labels, processed_masks, processed_scores, unified_masks = data.evaluate_precitions(boxes, labels, masks, scores, return_processed=True)

In [None]:
mIoU.mean()