Skip to content

Commit

Permalink
Split train, val, test in dataloader (by episode)
Browse files Browse the repository at this point in the history
  • Loading branch information
Cubevoid committed Apr 9, 2024
1 parent 1ce89f2 commit f5279d6
Showing 1 changed file with 29 additions and 9 deletions.
38 changes: 29 additions & 9 deletions src/data_collection/data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,16 @@


class DataLoader:
def __init__(self, game: str, num_obj: int):
def __init__(self, game: str, num_obj: int, train_pct: float = 0.7, val_pct: float = 0.15, test_pct: float = 0.15):
assert train_pct + val_pct + test_pct == 1, "Train, validation and test percentages should sum to 1"
self.dataset_path = get_data_directory(game)
self.load_data()
self.history_len = 4
self.num_obj = num_obj
self.num_train = int((train_pct * len(self.episode_data)) // self.history_len) * self.history_len
self.num_train = max(self.num_train, 1)
self.num_val = int((val_pct * len(self.episode_data)) // self.history_len) * self.history_len
self.num_test = len(self.episode_data) - self.num_train - self.num_val

def load_data(self) -> None:
"""
Expand All @@ -29,24 +34,39 @@ def load_data(self) -> None:
episode_lengths.append(get_length_from_episode_name(episode))
episode_id = get_id_from_episode_name(episode)
assert episode_id not in episode_data, f"Episode {episode_id} already exists in the dataset"
episode_data[episode_id] = (data["episode_frames"], data["episode_object_types"], data["episode_object_bounding_boxes"],
data["episode_detected_masks"], data["episode_actions"], data["episode_last_idx"])
episode_data[episode_id] = (
data["episode_frames"],
data["episode_object_types"],
data["episode_object_bounding_boxes"],
data["episode_detected_masks"],
data["episode_actions"],
data["episode_last_idx"],
)
episode_counts += 1

self.episode_data = episode_data
self.episode_weights = np.array(episode_lengths) / sum(episode_lengths)

def sample(self, batch_size: int, time_steps: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
def sample(self, batch_size: int, time_steps: int, data_type: str = "train") -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Sample a batch of states from episodes
Args:
batch_size: Number of states to sample
time_steps: Number of time steps to sample for bboxes
data_type: Type of data to sample from. Can be "train", "val" or "test"
Returns:
Tuple containing stacked states [batch_size, stacked_frames=4, channels=3, H=128, W=128], Object bounding boxes [batch_size, t, num_objects, 4],
Tuple containing stacked states [batch_size, time_steps=4, channels=3, H=128, W=128], Object bounding boxes [batch_size, t, num_objects, 4],
Masks [batch_size, num_obj, H=128, W=128], Actions [batch_size, t]
"""
episodes = np.random.choice(len(self.episode_data), size=batch_size, p=self.episode_weights)
if data_type == "train":
start, end = 0, self.num_train
elif data_type == "val":
start, end = self.num_train, self.num_train + self.num_val
elif data_type == "test":
start, end = self.num_train + self.num_val, len(self.episode_data)
weights = self.episode_weights[start:end]
weights = weights / weights.sum() # renormalize probabilities to sum to 1
episodes = np.random.choice(np.arange(start, end), size=batch_size, p=weights)
states: List[npt.NDArray] = []
object_bounding_boxes_list: List[npt.NDArray] = []
masks: List[npt.NDArray] = []
Expand All @@ -56,7 +76,7 @@ def sample(self, batch_size: int, time_steps: int) -> Tuple[torch.Tensor, torch.
start = np.random.randint(0, len(frames) - self.history_len - time_steps)
base = start + self.history_len
states.append(frames[start:base])
obj_bbxs = object_bounding_boxes[base:base + time_steps] # [T, O, 4]
obj_bbxs = object_bounding_boxes[base : base + time_steps] # [T, O, 4]
objs = obj_bbxs[0].sum(-1) != 0 # [O]
orderd_bbxs = np.zeros_like(obj_bbxs) # [T, O, 4] ordered by the initial object they are tracking
order = np.arange(objs.sum()) # [o]
Expand All @@ -65,7 +85,7 @@ def sample(self, batch_size: int, time_steps: int) -> Tuple[torch.Tensor, torch.
order = last_idxs[base + t, order]
object_bounding_boxes_list.append(orderd_bbxs)
masks.append(detected_masks[base])
actions.append(episode_actions[base:base + time_steps])
actions.append(episode_actions[base : base + time_steps])

states_tensor = torch.from_numpy(np.array(states))
states_tensor = states_tensor / 255
Expand All @@ -75,7 +95,7 @@ def sample(self, batch_size: int, time_steps: int) -> Tuple[torch.Tensor, torch.
w = states_tensor.shape[-2]
h = states_tensor.shape[-1]
object_bounding_boxes_tensor /= torch.Tensor([h, w, h, w]).float()
object_bounding_boxes_tensor = object_bounding_boxes_tensor[:, :, :self.num_obj]
object_bounding_boxes_tensor = object_bounding_boxes_tensor[:, :, : self.num_obj]

states_tensor = states_tensor.reshape(*states_tensor.shape[:1], -1, *states_tensor.shape[3:])
states_tensor = F.interpolate(states_tensor, (128, 128))
Expand Down

0 comments on commit f5279d6

Please sign in to comment.