In [2]:
import sys
sys.path.append("..")
import json
import PIL

from lib.models.llava import LLaVACaptioner
from lib.models.detic import DeticModel
from lib.models.chat_gpt import ChatGPTGraphBuilder
from lib.data.models_registry import LocalModelsRegistry
from lib.data.detection_pipeline import DataPipeLine, TrajectoryImageCollector, SceneCaptionStep, DetectionStep, DetectionReduceStep, GPTGraphGenerateStep

In [None]:
registry = LocalModelsRegistry(
    llava_captioner=LLaVACaptioner(),
    detic=DeticModel(ignore_classes=("person", "lightbulb")),
    gpt_graph_builder=GPTGraphGenerateStep()
)

pipeline = DataPipeLine(
    [
        TrajectoryImageCollector(threshold_distance=7.),
        SceneCaptionStep(models=registry, batch_size=2),
        DetectionStep(registry),
        DetectionReduceStep(),
        GPTGraphGenerateStep(registry)
    ]
)

In [None]:
result = pipeline("/mnt/vol0/datasets/rosbag2_navigation_for_graph_21_feb_2024")
with open("dataset.json", "w") as f:
    json.dump(result, f)
