Skip to content

Commit

Permalink
cuda code fixes, action num and some trained models
Browse files Browse the repository at this point in the history
  • Loading branch information
quajak committed Apr 19, 2024
1 parent b500fb6 commit a61cd1d
Show file tree
Hide file tree
Showing 11 changed files with 24 additions and 20 deletions.
1 change: 1 addition & 0 deletions configs/training/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,5 +13,6 @@ game: Pong
model: FastSAM-x
num_iterations: 1000
ground_truth_masks: false
t: 0

save_models: false
3 changes: 2 additions & 1 deletion configs/training/predictor/baseline.yaml
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
_target_: src.model.predictor_baseline.PredictorBaseline
_target_: src.model.predictor_baseline.PredictorBaseline
num_actions: 18
3 changes: 2 additions & 1 deletion configs/training/predictor/residual.yaml
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
_target_: src.model.residual_predictor.ResidualPredictor
num_layers: 1
num_layers: 1
num_actions: 18
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
2 changes: 1 addition & 1 deletion src/data_collection/data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,4 +135,4 @@ def sample_idxes(self, time_steps: int, device: str, frames: Iterable[int]) -> T
masks_tensor = F.one_hot(masks_tensor.long(), num_classes=self.num_obj + 1).float()[:, :, :, 1:]
masks_tensor = masks_tensor.permute(0, 3, 1, 2)
masks_tensor = F.interpolate(masks_tensor, (128, 128))
return states_tensor, object_bounding_boxes_tensor, masks_tensor, torch.from_numpy(np.array(actions))
return states_tensor, object_bounding_boxes_tensor, masks_tensor, torch.from_numpy(np.array(actions)).to(device)
4 changes: 2 additions & 2 deletions src/model/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

class Predictor(nn.Module):
def __init__(self, input_size: int = 128, hidden_size: int = 32, output_size: int = 120, num_layers: int = 2,
hidden_dim: int = 120, embed_dim: int = 8, nhead: int = 2, time_steps: int = 5, log: bool = True) -> None:
hidden_dim: int = 120, embed_dim: int = 8, nhead: int = 2, time_steps: int = 5, log: bool = True, num_actions: int = 18) -> None:
super().__init__()
self.log = log
self.time_steps = time_steps
Expand All @@ -17,7 +17,7 @@ def __init__(self, input_size: int = 128, hidden_size: int = 32, output_size: in
self.transformer_encoder = nn.TransformerEncoder(encoder_layers, num_layers)
self.time_mlp = nn.Sequential(nn.Linear(output_size, output_size))
self.pred_mlp = nn.Sequential(nn.Linear(output_size, output_size), nn.ReLU(), nn.Linear(output_size, 2))
self.action_embedding = nn.Embedding(18, embed_dim)
self.action_embedding = nn.Embedding(num_actions, embed_dim)
self.embedding = nn.Sequential(nn.Linear(output_size+embed_dim, output_size), nn.ReLU())

def forward(self, x: torch.Tensor, curr_pos: torch.Tensor, actions: torch.Tensor) -> torch.Tensor: # pylint: disable = unused-argument
Expand Down
4 changes: 2 additions & 2 deletions src/model/predictor_baseline.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@
from torch import nn

class PredictorBaseline(nn.Module):
def __init__(self, input_size: int = 128, time_steps: int = 5, embed_dim: int = 8):
def __init__(self, input_size: int = 128, time_steps: int = 5, embed_dim: int = 8, num_actions: int = 18, log: bool = False):
super().__init__()
self.time_steps = time_steps
self.encoder = nn.Sequential(nn.Linear(input_size, input_size), nn.ReLU(), nn.Linear(input_size, input_size))
self.next_state = nn.Sequential(nn.Linear(input_size, input_size), nn.ReLU(), nn.Linear(input_size, input_size))
self.output = nn.Sequential(nn.Linear(input_size, input_size), nn.ReLU(), nn.Linear(input_size, 2))
self.action_embedding = nn.Embedding(18, embed_dim)
self.action_embedding = nn.Embedding(num_actions, embed_dim)
self.embedding = nn.Sequential(nn.Linear(input_size+embed_dim, input_size), nn.ReLU())

def forward(self, x: torch.Tensor, curr_pos: torch.Tensor, actions: torch.Tensor) -> torch.Tensor:
Expand Down
27 changes: 14 additions & 13 deletions src/scripts/eval_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,13 @@
def eval(cfg: DictConfig) -> None:
device = "cuda" if torch.cuda.is_available() else "cpu"

data_loader = instantiate(cfg.data_loader, model=cfg.model, game=cfg.game, num_obj=cfg.num_objects, val_pct=0, test_pct=0.3)
data_loader = instantiate(cfg.data_loader, model=cfg.model, game=cfg.game, num_obj=cfg.num_objects, val_pct=0, test_pct=0.3, max_data=10000)

t = 1713391908
t = cfg.t
feature_extractor_state = torch.load(f"models/trained/{cfg.game}/{t}_feat_extract.pth", map_location=device)
feature_extractor = instantiate(cfg.feature_extractor, num_objects=cfg.num_objects, history_len=cfg.data_loader.history_len)
feature_extractor = instantiate(cfg.feature_extractor, num_objects=cfg.num_objects, history_len=cfg.data_loader.history_len).to(device)
feature_extractor.load_state_dict(feature_extractor_state)
predictor = instantiate(cfg.predictor, time_steps=cfg.time_steps, log=False)
predictor = instantiate(cfg.predictor, time_steps=cfg.time_steps, log=False).to(device)
predictor_state = torch.load(f"models/trained/{cfg.game}/{t}_{type(predictor).__name__}.pth", map_location=device)
predictor.load_state_dict(predictor_state)

Expand All @@ -45,14 +45,15 @@ def eval(cfg: DictConfig) -> None:
gt_positions = positions[:, : cfg.data_loader.history_len, :, :] # [B, H, O, 2]

# Run models
features: torch.Tensor = feature_extractor(images, masks, gt_positions)
output: torch.Tensor = predictor(features, target[:, 0], actions)
loss: torch.Tensor = criterion(output, target)
with torch.no_grad():
features: torch.Tensor = feature_extractor(images, masks, gt_positions)
output: torch.Tensor = predictor(features, target[:, 0], actions)
loss: torch.Tensor = criterion(output, target)

log_dict = eval_metrics(cfg, features, target, output, loss, prefix="test")
mean.append(log_dict[f"test/l1_movement_mean/time_{cfg.time_steps-1}"])
med.append(log_dict[f"test/l1_movement_median/time_{cfg.time_steps-1}"])
ninetieth.append(log_dict[f"test/l1_movement_90th_percentile/time_{cfg.time_steps-1}"])
mean.append(log_dict[f"test/l1_movement_mean/time_{cfg.time_steps-1}"].item())
med.append(log_dict[f"test/l1_movement_median/time_{cfg.time_steps-1}"].item())
ninetieth.append(log_dict[f"test/l1_movement_90th_percentile/time_{cfg.time_steps-1}"].item())

print(f"Mean: {sum(mean) / len(mean)}")
print(f"Median: {sum(med) / len(med)}")
Expand Down Expand Up @@ -121,9 +122,9 @@ def eval_metrics(
total_movement = torch.sum(torch.abs((target[:, t, :, :] - target[:, 0, :, :])))
log_dict[f"average_movement/time_{t}"] = total_movement / torch.sum(movement_mask)
l1 = torch.abs(target[:, t, :, :][movement_mask] - output[:, t, :, :][movement_mask])
log_dict[f"l1_movement_mean/time_{t}"] = torch.mean(l1)
log_dict[f"l1_movement_median/time_{t}"] = torch.median(l1)
log_dict[f"l1_movement_90th_percentile/time_{t}"] = torch.quantile(l1, 0.9)
log_dict[f"l1_movement_mean/time_{t}"] = torch.mean(l1) if l1.shape[0] != 0 else torch.zeros([1])
log_dict[f"l1_movement_median/time_{t}"] = torch.median(l1) if l1.shape[0] != 0 else torch.zeros([1])
log_dict[f"l1_movement_90th_percentile/time_{t}"] = torch.quantile(l1, 0.9) if l1.shape[0] != 0 else torch.zeros([1])

log_dict[f"error/time_{t}"] = diff[:, t, :, :].mean()

Expand Down

0 comments on commit a61cd1d

Please sign in to comment.