In [1]:
from matplotlib import pyplot as plt
from models.utils import set_seed
from models.generation import gen_stable_diffusion
from models.cls import classify_clip, plot_classification
from models.det import detect_groundingdino
from models.seg import segment_sam1, plot_segmentation_detection
from models.caption import caption_vqa


set_seed(42)

GREEN = "\033[92m"
RESET = "\033[0m"

# 1. generate an image
print(f"{GREEN}Generating an image...{RESET}")
img = gen_stable_diffusion("an apple and a banana on a table")
plt.imshow(img)
plt.axis('off')
plt.show()

# 2. caption it
print(f"{GREEN}Captioning the image...{RESET}")
text_queries = caption_vqa(img)
print("text_queries:", text_queries)

# 3. classify, detect, segment it
print(f"{GREEN}Classifying the image...{RESET}")
probs = classify_clip(img, text_queries)
plot_classification(img, text_queries, probs)

print(f"{GREEN}Detecting objects in the image...{RESET}")
threshold = 0.1
boxes, scores, labels = detect_groundingdino(img, text_queries, threshold)

print(f"{GREEN}Segmenting objects in the image...{RESET}")
masks = segment_sam1(img, boxes)
plot_segmentation_detection(img, boxes, scores, text_queries, masks)


ModuleNotFoundError: No module named 'utils'