diff --git a/.gitignore b/.gitignore index ba49e8c..683c052 100644 --- a/.gitignore +++ b/.gitignore @@ -160,4 +160,5 @@ cython_debug/ #.idea/ models/sam_* outputs/ -.vscode/ \ No newline at end of file +.vscode/ +wandb/ \ No newline at end of file diff --git a/run_data_collection.py b/run_data_collection.py index 6b9f2e4..198611e 100644 --- a/run_data_collection.py +++ b/run_data_collection.py @@ -2,8 +2,11 @@ from hydra.utils import instantiate from omegaconf import DictConfig +from src.utils.logging import start_logging + @hydra.main(version_base=None, config_path='./configs/data_collection', config_name='config') def main(cfg: DictConfig) -> None: + start_logging(cfg, name=f"data-collection-{cfg.collector.game}-{cfg.collector.num_samples}") data_collector = instantiate(cfg.collector) data_collector.collect_data() diff --git a/src/data_collection/data_collector.py b/src/data_collection/data_collector.py index 0975bd4..73bc7a8 100644 --- a/src/data_collection/data_collector.py +++ b/src/data_collection/data_collector.py @@ -6,6 +6,7 @@ from ocatari.utils import load_agent from segment_anything import sam_model_registry, SamAutomaticMaskGenerator # type: ignore import tqdm +import wandb class DataCollector: def __init__(self, game: str, num_samples: int) -> None: @@ -46,6 +47,7 @@ def collect_data(self) -> None: self.episode_actions.append(action) masks = self.generator.generate(obs) self.episode_detected_masks.append([mask["segmentation"] for mask in masks]) + wandb.log({"data_collected": self.collected_data}) for obj in self.env.objects: self.episode_object_types[-1].append(obj.category) self.episode_object_bounding_boxes[-1].append(obj.xywh) @@ -68,7 +70,9 @@ def store_episode(self) -> None: episode_detected_masks=np.array(self.episode_detected_masks), episode_actions=np.array(self.episode_actions)) self.curr_episode_id += 1 - self.collected_data += len(self.episode_frames) + episode_length = len(self.episode_frames) + self.collected_data += episode_length + wandb.log({"episode_length": episode_length}) self.episode_frames = [] self.episode_object_types = [] self.episode_object_bounding_boxes = [] diff --git a/src/utils/logging.py b/src/utils/logging.py new file mode 100644 index 0000000..3fd361a --- /dev/null +++ b/src/utils/logging.py @@ -0,0 +1,8 @@ +from omegaconf import DictConfig +import wandb + +def start_logging(config: DictConfig, name: str) -> None: + """ + Start logging to wandb + """ + wandb.init(project="oc-data-collection", entity='atari-obj-pred', name=name, config=dict(config)) \ No newline at end of file