Skip to content

Commit

Permalink
Add wandb logging
Browse files Browse the repository at this point in the history
  • Loading branch information
quajak committed Mar 3, 2024
1 parent 3b3ce8f commit f46b0fa
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 2 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -160,4 +160,5 @@ cython_debug/
#.idea/
models/sam_*
outputs/
.vscode/
.vscode/
wandb/
3 changes: 3 additions & 0 deletions run_data_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,11 @@
from hydra.utils import instantiate
from omegaconf import DictConfig

from src.utils.logging import start_logging

@hydra.main(version_base=None, config_path='./configs/data_collection', config_name='config')
def main(cfg: DictConfig) -> None:
start_logging(cfg, name=f"data-collection-{cfg.collector.game}-{cfg.collector.num_samples}")
data_collector = instantiate(cfg.collector)
data_collector.collect_data()

Expand Down
6 changes: 5 additions & 1 deletion src/data_collection/data_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from ocatari.utils import load_agent
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator # type: ignore
import tqdm
import wandb

class DataCollector:
def __init__(self, game: str, num_samples: int) -> None:
Expand Down Expand Up @@ -46,6 +47,7 @@ def collect_data(self) -> None:
self.episode_actions.append(action)
masks = self.generator.generate(obs)
self.episode_detected_masks.append([mask["segmentation"] for mask in masks])
wandb.log({"data_collected": self.collected_data})
for obj in self.env.objects:
self.episode_object_types[-1].append(obj.category)
self.episode_object_bounding_boxes[-1].append(obj.xywh)
Expand All @@ -68,7 +70,9 @@ def store_episode(self) -> None:
episode_detected_masks=np.array(self.episode_detected_masks),
episode_actions=np.array(self.episode_actions))
self.curr_episode_id += 1
self.collected_data += len(self.episode_frames)
episode_length = len(self.episode_frames)
self.collected_data += episode_length
wandb.log({"episode_length": episode_length})
self.episode_frames = []
self.episode_object_types = []
self.episode_object_bounding_boxes = []
Expand Down
8 changes: 8 additions & 0 deletions src/utils/logging.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from omegaconf import DictConfig
import wandb

def start_logging(config: DictConfig, name: str) -> None:
"""
Start logging to wandb
"""
wandb.init(project="oc-data-collection", entity='atari-obj-pred', name=name, config=dict(config))

0 comments on commit f46b0fa

Please sign in to comment.