Skip to content

Commit

Permalink
Merge branch 'final_data_collection' of https://github.com/Cubevoid/a…
Browse files Browse the repository at this point in the history
…tari-obj-pred into final_data_collection
  • Loading branch information
quajak committed Apr 19, 2024
2 parents 9aa2264 + 56892c3 commit dfaa7f3
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 16 deletions.
16 changes: 9 additions & 7 deletions src/model/predictor_baseline.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,34 @@
import torch
from torch import nn


class PredictorBaseline(nn.Module):
def __init__(self, input_size: int = 128, time_steps: int = 5, embed_dim: int = 8, num_actions: int = 18, log: bool = False):
super().__init__()
self.time_steps = time_steps
self.log = log
self.encoder = nn.Sequential(nn.Linear(input_size, input_size), nn.ReLU(), nn.Linear(input_size, input_size))
self.next_state = nn.Sequential(nn.Linear(input_size, input_size), nn.ReLU(), nn.Linear(input_size, input_size))
self.output = nn.Sequential(nn.Linear(input_size, input_size), nn.ReLU(), nn.Linear(input_size, 2))
self.action_embedding = nn.Embedding(num_actions, embed_dim)
self.embedding = nn.Sequential(nn.Linear(input_size+embed_dim, input_size), nn.ReLU())
self.embedding = nn.Sequential(nn.Linear(input_size + embed_dim, input_size), nn.ReLU())

def forward(self, x: torch.Tensor, curr_pos: torch.Tensor, actions: torch.Tensor) -> torch.Tensor:
z = self.encoder(x)
act_embed = self.action_embedding(actions) # [B, T, embed_dim]
act_embed = self.action_embedding(actions) # [B, T, embed_dim]
act_embed = act_embed.unsqueeze(-2)
zeros = torch.zeros((x.size()[0], act_embed.size()[1], x.size()[1], act_embed.size()[2]), device=x.device) #[B, T, num_objects, embed_dim]
act_embed = zeros + act_embed #[B, T, num_objects, embed_dim]
zeros = torch.zeros((x.size()[0], act_embed.size()[1], x.size()[1], act_embed.size()[2]), device=x.device) # [B, T, num_objects, embed_dim]
act_embed = zeros + act_embed # [B, T, num_objects, embed_dim]
predictions = []
for i in range(self.time_steps):
z = torch.cat((z, act_embed[:, i, :, :]), dim = 2)
z = torch.cat((z, act_embed[:, i, :, :]), dim=2)
z = self.embedding(z)
z = self.next_state(z)
predictions.append(z)
stacked_predictions = torch.stack(predictions, 1)
movements = self.output(stacked_predictions)
outputs = torch.zeros((curr_pos.shape[0], self.time_steps, curr_pos.shape[1], 2), device=x.device)
outputs[:, 0, :, :] = curr_pos
for j in range(self.time_steps-1):
outputs[:, j+1, :, :] = outputs[:, j, :, :] + movements[:, j, :, :]
for j in range(self.time_steps - 1):
outputs[:, j + 1, :, :] = outputs[:, j, :, :] + movements[:, j, :, :]
return outputs
16 changes: 7 additions & 9 deletions src/scripts/eval_model.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,18 @@
import os
import time
import typing
from typing import Any, Dict

from omegaconf import DictConfig, OmegaConf
from omegaconf import DictConfig
from tqdm import tqdm
import torch
from torch import nn
import hydra
from hydra.utils import to_absolute_path, instantiate
import wandb
from hydra.utils import instantiate

from src.scripts.train_prediction import get_ground_truth_masks
from src.data_collection.data_loader import DataLoader


@hydra.main(version_base=None, config_path="../../configs/training", config_name="config")
def eval(cfg: DictConfig) -> None:
def evaluate(cfg: DictConfig) -> None:
device = "cuda" if torch.cuda.is_available() else "cpu"

data_loader = instantiate(cfg.data_loader, model=cfg.model, game=cfg.game, num_obj=cfg.num_objects, val_pct=0, test_pct=0.3)
Expand Down Expand Up @@ -61,6 +57,7 @@ def eval(cfg: DictConfig) -> None:
print(f"Median: {sum(med) / len(med)}")
print(f"Ninetieth: {sum(ninetieth) / len(ninetieth)}")


def test_metrics(cfg: DictConfig, data_loader: DataLoader, feature_extractor: nn.Module, predictor: nn.Module, criterion: Any) -> Dict[str, Any]:
"""
Test the model on the test set and return the evaluation metrics.
Expand Down Expand Up @@ -96,7 +93,7 @@ def eval_metrics(
Returns:
A dictionary containing the evaluation metrics
"""
mask = target != 0
# mask = target != 0
diff = torch.pow(output - target, 2)
max_loss = torch.max(torch.abs((output - target))).item()
total_movement = torch.sum(torch.abs((target[:, cfg.time_steps - 1, :, :] - target[:, 0, :, :])))
Expand Down Expand Up @@ -133,5 +130,6 @@ def eval_metrics(
log_dict = {f"{prefix}/{key}": value for key, value in log_dict.items()}
return log_dict


if __name__ == "__main__":
eval() # pylint: disable=no-value-for-parameter
evaluate() # pylint: disable=no-value-for-parameter

0 comments on commit dfaa7f3

Please sign in to comment.