Skip to content

Commit

Permalink
Merge branch 'main' into ground_truth_testing
Browse files Browse the repository at this point in the history
  • Loading branch information
quajak authored Apr 9, 2024
2 parents 88e6121 + 44759be commit f98f678
Show file tree
Hide file tree
Showing 2 changed files with 97 additions and 100 deletions.
107 changes: 61 additions & 46 deletions src/data_collection/data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,64 +10,103 @@


class DataLoader:
def __init__(self, game: str, num_obj: int):
def __init__(self, game: str, num_obj: int, train_pct: float = 0.7, val_pct: float = 0.15, test_pct: float = 0.15):
assert train_pct + val_pct + test_pct == 1, "Train, validation and test percentages should sum to 1"
self.dataset_path = get_data_directory(game)
self.load_data()
self.history_len = 4
self.num_obj = num_obj
self.num_train = int(((train_pct * len(self.frames))) // self.history_len) * self.history_len
self.num_val = int(((val_pct * len(self.frames))) // self.history_len) * self.history_len
self.num_test = len(self.frames) - self.num_train - self.num_val

def load_data(self) -> None:
"""
Load all the data from the disk
"""
episodes_paths = [f for f in os.listdir(to_absolute_path(self.dataset_path)) if f.endswith(".npz")]
episode_counts = 0
episode_lengths = []
episode_data = {}
self.episode_count = 0
self.episode_lengths = []
episode_ids = set()
frames = []
object_types = []
object_bounding_boxes = []
detected_masks = []
actions = []
last_idx = []
for episode in episodes_paths:
data = np.load(to_absolute_path(self.dataset_path + episode))
episode_lengths.append(get_length_from_episode_name(episode))
self.episode_lengths.append(get_length_from_episode_name(episode))
episode_id = get_id_from_episode_name(episode)
assert episode_id not in episode_data, f"Episode {episode_id} already exists in the dataset"
episode_data[episode_id] = (data["episode_frames"], data["episode_object_types"], data["episode_object_bounding_boxes"],
data["episode_detected_masks"], data["episode_actions"], data["episode_last_idx"])
episode_counts += 1

self.episode_data = episode_data
self.episode_weights = np.array(episode_lengths) / sum(episode_lengths)
assert episode_id not in episode_ids, f"Episode {episode_id} already exists in the dataset"
episode_ids.add(episode_id)
frames.append(data["episode_frames"])
object_types.append(data["episode_object_types"])
object_bounding_boxes.append(data["episode_object_bounding_boxes"])
detected_masks.append(data["episode_detected_masks"])
actions.append(data["episode_actions"])
last_idx.append(data["episode_last_idx"])
self.episode_count += 1
self.frames = np.concatenate(frames, axis=0)
self.object_types = np.concatenate(object_types, axis=0)
self.object_bounding_boxes = np.concatenate(object_bounding_boxes, axis=0)
self.detected_masks = np.concatenate(detected_masks, axis=0)
self.actions = np.concatenate(actions, axis=0)
self.last_idx = np.concatenate(last_idx, axis=0)

def sample(self, batch_size: int, time_steps: int, device: str) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
def sample(self, batch_size: int, time_steps: int, device: str, data_type: str = "train") -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Sample a batch of states from episodes
Args:
batch_size: Number of states to sample
time_steps: Number of time steps to sample for bboxes
data_type: Type of data to sample from. Can be "train", "val" or "test"
Returns:
Tuple containing stacked states [batch_size, stacked_frames=4, channels=3, H=128, W=128], Object bounding boxes [batch_size, t, num_objects, 4],
Tuple containing stacked states [batch_size, time_steps=4, channels=3, H=128, W=128], Object bounding boxes [batch_size, t, num_objects, 4],
Masks [batch_size, num_obj, H=128, W=128], Actions [batch_size, t]
- States are in the past
- Masks are in the present
- Bounding boxes, actions are in the future
"""
episodes = np.random.choice(len(self.episode_data), size=batch_size, p=self.episode_weights)
if data_type == "train":
start, end = 0, self.num_train
elif data_type == "val":
start, end = self.num_train, self.num_train + self.num_val
elif data_type == "test":
start, end = self.num_train + self.num_val, len(self.frames)
frames = np.random.choice(np.arange(start, end), size=batch_size)
states: List[npt.NDArray] = []
object_bounding_boxes_list: List[npt.NDArray] = []
masks: List[npt.NDArray] = []
actions = []
for episode in episodes:
stacked_state, orderd_bbxs, cur_mask, action = self.get_step(episode, time_steps)

states.append(stacked_state)
for idx in frames:
idx = min(idx, len(self.frames) - time_steps) # Ensure we don't go out of bounds
idx = max(idx, self.history_len)
frame = self.frames[idx - self.history_len : idx]
object_bounding_boxes = self.object_bounding_boxes[idx : idx + time_steps] # [T, O, 4] future bboxes
mask = self.detected_masks[idx] # [O, H, W]
action = self.actions[idx : idx + time_steps] # [T]
last_idxs = self.last_idx[idx : idx + time_steps] # [T, O]
states.append(frame)
objs = object_bounding_boxes[0].sum(-1) != 0 # [O]
orderd_bbxs = np.zeros_like(object_bounding_boxes) # [T, O, 4] ordered by the initial object they are tracking
order = np.arange(objs.sum()) # [o]
for t in range(time_steps):
orderd_bbxs[t, order] = object_bounding_boxes[t, objs]
order = last_idxs[t, order]
object_bounding_boxes_list.append(orderd_bbxs)
masks.append(cur_mask)
masks.append(mask)
actions.append(action)

states_tensor = torch.from_numpy(np.array(states)).to(device)
states_tensor = states_tensor / 255
states_tensor = states_tensor.permute(0, 4, 1, 2, 3)
states_tensor = states_tensor.permute(0, 1, 4, 2, 3) # [B, T, C, H, W]
object_bounding_boxes_tensor = torch.from_numpy(np.array(object_bounding_boxes_list)).to(device)
object_bounding_boxes_tensor = object_bounding_boxes_tensor.float()
w = states_tensor.shape[-2]
h = states_tensor.shape[-1]
object_bounding_boxes_tensor /= torch.Tensor([h, w, h, w]).to(device).float()
object_bounding_boxes_tensor = object_bounding_boxes_tensor[:, :, :self.num_obj]
object_bounding_boxes_tensor = object_bounding_boxes_tensor[:, :, : self.num_obj]

states_tensor = states_tensor.reshape(*states_tensor.shape[:1], -1, *states_tensor.shape[3:])
states_tensor = F.interpolate(states_tensor, (128, 128))
Expand All @@ -79,27 +118,3 @@ def sample(self, batch_size: int, time_steps: int, device: str) -> Tuple[torch.T
masks_tensor = F.interpolate(masks_tensor, (128, 128))

return states_tensor, object_bounding_boxes_tensor, masks_tensor, torch.from_numpy(np.array(actions))


def get_step(self, episode: int, time_steps: int) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
"""
Get the frame from the episode
Args:
episode: Episode number
Returns:
Frame as numpy array, bounding boxes, current FastSAM masks, Action
"""
frames, _, object_bounding_boxes, detected_masks, episode_actions, last_idxs = self.episode_data[episode]
start = np.random.randint(0, len(frames) - self.history_len - time_steps)
base = start + self.history_len
obj_bbxs = object_bounding_boxes[base:base + time_steps] # [T, O, 4]
objs = obj_bbxs[0].sum(-1) != 0 # [O]
orderd_bbxs = np.zeros_like(obj_bbxs) # [T, O, 4] ordered by the initial object they are tracking
order = np.arange(objs.sum()) # [o]
for t in range(time_steps):
orderd_bbxs[t, order] = obj_bbxs[t, objs]
order = last_idxs[base + t, order]
stacked_state = frames[start:base]
cur_mask = detected_masks[base]
action = episode_actions[base:base + time_steps]
return stacked_state, orderd_bbxs, cur_mask, action
90 changes: 36 additions & 54 deletions src/data_visualization/visualizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,20 @@
from src.model.feat_extractor import FeatureExtractor
from src.model.predictor import Predictor


# generate a list of 32 distinct colors for matplotlib
def get_distinct_colors(n: int) -> List[Tuple[float, float, float]]:
colors = []

for i in np.arange(0., 360., 360. / n):
h = i / 360.
l = (50 + np.random.rand() * 10) / 100.
s = (90 + np.random.rand() * 10) / 100.
for i in np.arange(0.0, 360.0, 360.0 / n):
h = i / 360.0
l = (50 + np.random.rand() * 10) / 100.0
s = (90 + np.random.rand() * 10) / 100.0
colors.append(hls_to_rgb(h, l, s))

return colors


# color_map = [plt.cm.Set1(i) for i in np.linspace(0, 1, 33)] # type: ignore[attr-defined] # pylint: disable=no-member
color_map = get_distinct_colors(8)

Expand All @@ -45,70 +48,49 @@ def __init__(self, cfg: DictConfig) -> None:
self.root.geometry("1200x800+200x200")
self.root.title("Data Visualizer")
self.root.update()
self.frame = ctk.CTkFrame(master=self.root,
height= self.root.winfo_height()*0.66,
width = self.root.winfo_width()*0.66,
fg_color="darkblue")
self.frame = ctk.CTkFrame(master=self.root, height=self.root.winfo_height() * 0.66, width=self.root.winfo_width() * 0.66, fg_color="darkblue")
self.frame.place(relx=0.33, rely=0.025)
num_episodes = len(self.data_loader.episode_data)
self.episode_slider = ctk.CTkSlider(master=self.root,
width=300,
height=20,
from_=1,
to=num_episodes,
number_of_steps=num_episodes-1,
command=self.update_surface)
self.episode_slider.place(relx= 0.025,rely=0.5)
self.data_slider = ctk.CTkSlider(master=self.root,
width=300,
height=20,
from_=1,
to=num_episodes,
number_of_steps=num_episodes-1,
command=self.update_surface)
self.data_slider = ctk.CTkSlider(
master=self.root,
width=300,
height=20,
from_=1,
to=len(self.data_loader.frames),
number_of_steps=len(self.data_loader.frames) - 1,
command=self.update_surface,
)

self.radio_var = tkinter.IntVar(value=1)
radiobutton_1 = ctk.CTkRadioButton(self.root, text="Image",
command=self.set_display_mode, variable= self.radio_var, value=1)
radiobutton_1.place(relx= 0.025,rely=0.1)
radiobutton_2 = ctk.CTkRadioButton(self.root, text="SAM Masks",
command=self.set_display_mode, variable= self.radio_var, value=2)
radiobutton_2.place(relx= 0.025,rely=0.15)
radiobutton_3 = ctk.CTkRadioButton(self.root, text="SAM Masks + Image",
command=self.set_display_mode, variable= self.radio_var, value=3)
radiobutton_3.place(relx= 0.025,rely=0.2)
radiobutton_4 = ctk.CTkRadioButton(self.root, text="Groundtruth",
command=self.set_display_mode, variable= self.radio_var, value=4)
radiobutton_4.place(relx= 0.025,rely=0.25)
radiobutton_5 = ctk.CTkRadioButton(self.root, text="SAM Mask + Groundtruth",
command=self.set_display_mode, variable= self.radio_var, value=5)
radiobutton_5.place(relx= 0.025,rely=0.3)
radiobutton_1 = ctk.CTkRadioButton(self.root, text="Image", command=self.set_display_mode, variable=self.radio_var, value=1)
radiobutton_1.place(relx=0.025, rely=0.1)
radiobutton_2 = ctk.CTkRadioButton(self.root, text="SAM Masks", command=self.set_display_mode, variable=self.radio_var, value=2)
radiobutton_2.place(relx=0.025, rely=0.15)
radiobutton_3 = ctk.CTkRadioButton(self.root, text="SAM Masks + Image", command=self.set_display_mode, variable=self.radio_var, value=3)
radiobutton_3.place(relx=0.025, rely=0.2)
radiobutton_4 = ctk.CTkRadioButton(self.root, text="Groundtruth", command=self.set_display_mode, variable=self.radio_var, value=4)
radiobutton_4.place(relx=0.025, rely=0.25)
radiobutton_5 = ctk.CTkRadioButton(self.root, text="SAM Mask + Groundtruth", command=self.set_display_mode, variable=self.radio_var, value=5)
radiobutton_5.place(relx=0.025, rely=0.3)

self.fig, self.ax = plt.subplots()
self.fig.set_size_inches(6,6)
self.fig.set_size_inches(6, 6)
self.canvas = FigureCanvasTkAgg(self.fig, master=self.root)
self.canvas.get_tk_widget().place(relx=0.33, rely=0.025)

self.data_slider.place(relx= 0.025,rely=0.75)
self.update_data_slider(None)
# add event handler to episode slider to update max number of data slider
self.episode_slider.bind("<ButtonRelease-1>", self.update_data_slider)
self.data_slider.place(relx=0.025, rely=0.75)
self.update_surface(None)
self.root.mainloop()

def set_display_mode(self) -> None:
self.update_surface(None)

def update_data_slider(self, _: Any) -> None:
episode = self.data_loader.episode_data[int(self.episode_slider.get()) - 1][0]
self.data_slider.configure(number_of_steps=len(episode)-1, to=len(episode))

def update_surface(self, _: Any) -> None:
episode_idx = int(self.data_slider.get()) - 1
frame, types, boxes, masks, actions, _ = self.data_loader.episode_data[int(self.episode_slider.get()) - 1] # type: ignore
frame, types, boxes, masks, actions = frame[episode_idx], types[episode_idx], boxes[episode_idx], masks[episode_idx], actions[episode_idx]
masks = F.one_hot(torch.from_numpy(masks).long()).movedim(-1,0).numpy()[1:] # the 0 mask is the background [O, W, H]
frame = frame.astype(np.float32) / 255.
frame_idx = int(self.data_slider.get()) - 1
frame = self.data_loader.frames[frame_idx]
boxes = self.data_loader.object_bounding_boxes[frame_idx]
masks = self.data_loader.detected_masks[frame_idx]
masks = F.one_hot(torch.from_numpy(masks).long()).movedim(-1, 0).numpy()[1:] # the 0 mask is the background [O, W, H]
frame = frame.astype(np.float32) / 255.0
orig_img = np.array(frame)
mode = self.radio_var.get()
if mode in [2, 3, 5]:
Expand All @@ -122,7 +104,7 @@ def update_surface(self, _: Any) -> None:
if mode == 5:
mask_ys, mask_xs = np.nonzero(mask == 1)
if mask_xs.size > 0:
frame = cv2.arrowedLine(frame, (int(mask_xs.mean()), int(mask_ys.mean())), (x,y), color_map[i], 1) # pylint: disable=no-member
frame = cv2.arrowedLine(frame, (int(mask_xs.mean()), int(mask_ys.mean())), (x, y), color_map[i], 1) # pylint: disable=no-member
frame = frame.clip(0, 1)
if mode in [4, 5]:
for i, box in enumerate(boxes):
Expand Down

0 comments on commit f98f678

Please sign in to comment.