Skip to content

Commit

Permalink
Merge branch 'ground_truth_testing' of https://github.com/Cubevoid/at…
Browse files Browse the repository at this point in the history
…ari-obj-pred into ground_truth_testing
  • Loading branch information
Cubevoid committed Apr 9, 2024
2 parents f89c13f + 15b462d commit 8583c32
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 17 deletions.
25 changes: 18 additions & 7 deletions src/data_collection/data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,26 +74,38 @@ def sample(self, batch_size: int, time_steps: int, device: str, data_type: str =
start, end = self.num_train, self.num_train + self.num_val
elif data_type == "test":
start, end = self.num_train + self.num_val, len(self.frames)
frames = np.random.choice(np.arange(start, end), size=batch_size)
frames: np.ndarray[os.Any, os.Any] = np.random.choice(np.arange(start + time_steps, end - self.history_len), size=batch_size)
states_tensor, object_bounding_boxes_tensor, masks_tensor, actions = self.sample_idxes(time_steps, device, frames)

return states_tensor, object_bounding_boxes_tensor, masks_tensor, torch.from_numpy(np.array(actions))

def sample_idxes(self, time_steps, device, frames) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, List[int]]:
"""
Sample a given array of indexes (frames)
Args:
time_steps: Number of time steps to sample for bboxes
device: Device to move tensors to
frames: List of indexes to sample
Returns:
States, Object bounding boxes, Masks, Actions
"""
states: List[npt.NDArray] = []
object_bounding_boxes_list: List[npt.NDArray] = []
masks: List[npt.NDArray] = []
actions = []
for idx in frames:
idx = min(idx, len(self.frames) - time_steps) # Ensure we don't go out of bounds
idx = max(idx, self.history_len)
frame = self.frames[idx - self.history_len : idx]
object_bounding_boxes = self.object_bounding_boxes[idx : idx + time_steps] # [T, O, 4] future bboxes
mask = self.detected_masks[idx] # [O, H, W]
action = self.actions[idx : idx + time_steps] # [T]
last_idxs = self.last_idx[idx : idx + time_steps] # [T, O]
states.append(frame)
objs = object_bounding_boxes[0].sum(-1) != 0 # [O]
orderd_bbxs = np.zeros_like(object_bounding_boxes) # [T, O, 4] ordered by the initial object they are tracking
order = np.arange(objs.sum()) # [o]
for t in range(time_steps):
orderd_bbxs[t, order] = object_bounding_boxes[t, objs]
order = last_idxs[t, order]
states.append(frame)
object_bounding_boxes_list.append(orderd_bbxs)
masks.append(mask)
actions.append(action)
Expand All @@ -105,7 +117,7 @@ def sample(self, batch_size: int, time_steps: int, device: str, data_type: str =
object_bounding_boxes_tensor = object_bounding_boxes_tensor.float()
w = states_tensor.shape[-2]
h = states_tensor.shape[-1]
object_bounding_boxes_tensor /= torch.Tensor([h, w, h, w]).to(device).float()
object_bounding_boxes_tensor /= torch.Tensor([w, h, w, h]).to(device).float()
object_bounding_boxes_tensor = object_bounding_boxes_tensor[:, :, : self.num_obj]

states_tensor = states_tensor.reshape(*states_tensor.shape[:1], -1, *states_tensor.shape[3:])
Expand All @@ -116,5 +128,4 @@ def sample(self, batch_size: int, time_steps: int, device: str, data_type: str =
masks_tensor = F.one_hot(masks_tensor.long(), num_classes=self.num_obj + 1).float()[:, :, :, 1:]
masks_tensor = masks_tensor.permute(0, 3, 1, 2)
masks_tensor = F.interpolate(masks_tensor, (128, 128))

return states_tensor, object_bounding_boxes_tensor, masks_tensor, torch.from_numpy(np.array(actions))
return states_tensor, object_bounding_boxes_tensor, masks_tensor, actions
20 changes: 12 additions & 8 deletions src/data_visualization/visualizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def __init__(self, cfg: DictConfig) -> None:
self.feature_extractor = FeatureExtractor(num_objects=cfg.num_objects)
self.feature_extractor.load_state_dict(feature_extractor_state)
predictor_state = torch.load("models/trained/Pong/1711831906_transformer_predictor.pth", map_location='cpu')
self.predictor = Predictor(num_layers=1)
self.predictor = Predictor(num_layers=1, log=False)
self.predictor.load_state_dict(predictor_state)

ctk.set_appearance_mode("dark")
Expand Down Expand Up @@ -115,15 +115,19 @@ def update_surface(self, _: Any) -> None:
# visualize predictions
if self.predictor is not None:
frame = frame * 0.5
m_frame, m_bbxs, m_masks, _= self.data_loader.sample_idxes(5, "cpu", [frame_idx])
m_bbxs = m_bbxs[:, :, :, :2]
with torch.no_grad():
frame_tensor = torch.from_numpy(orig_img).permute(2, 0, 1).unsqueeze(0).float()
masks_tensor = torch.from_numpy(masks).unsqueeze(0).float()
features = self.feature_extractor(frame_tensor, masks_tensor)
features = self.feature_extractor(m_frame, m_masks)
predictions = self.predictor(features)
for i, prediction in enumerate(predictions[0]):
x, y = prediction
if x != 0 or y != 0:
frame = cv2.circle(frame, (int(x), int(y)), 2, color_map[i], 2)
for t_pred in predictions[0]:
for i, prediction in enumerate(t_pred):
x, y = prediction[0] * 210, prediction[1] * 160
frame = cv2.circle(frame, (int(x), int(y)), 1, color_map[i], 1)
for t_pred in m_bbxs[0]:
for i, prediction in enumerate(t_pred):
x, y = prediction[0] * 210, prediction[1] * 160
frame = cv2.circle(frame, (int(x), int(y)), 1, color_map[i], 1)

self.ax.imshow(frame)
self.ax.axis("off")
Expand Down
6 changes: 4 additions & 2 deletions src/model/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,9 @@

class Predictor(nn.Module):
def __init__(self, input_size: int = 128, hidden_size: int = 32, output_size: int = 120, num_layers: int = 2,
hidden_dim: int = 120, nhead: int = 2, time_steps: int = 5) -> None:
hidden_dim: int = 120, nhead: int = 2, time_steps: int = 5, log: bool = True) -> None:
super().__init__()
self.log = log
self.time_steps = time_steps
self.fc1 = nn.Linear(input_size, hidden_size)
encoder_layers = nn.TransformerEncoderLayer(d_model=output_size, nhead=nhead, dim_feedforward=hidden_dim)
Expand Down Expand Up @@ -36,5 +37,6 @@ def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None, src_key_
x = torch.stack(predictions, 1) # [B, time_steps, num_objects, output_size]
x = F.relu(self.fc3(x)) # [B, time_steps, num_objects, output_size]
x = self.fc4(x) # [B, time_steps, num_objects, 2]
wandb.log(debug_stats)
if self.log:
wandb.log(debug_stats)
return x

0 comments on commit 8583c32

Please sign in to comment.