Skip to content

Commit

Permalink
Merge pull request #18 from Cubevoid/feat_extractor_baseline
Browse files Browse the repository at this point in the history
Implement feature extractor baseline
  • Loading branch information
quajak committed Mar 30, 2024
2 parents 6187a67 + 4277c06 commit 14874e6
Show file tree
Hide file tree
Showing 10 changed files with 93 additions and 19 deletions.
5 changes: 4 additions & 1 deletion configs/training/config.yaml
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
defaults:
- feature_extractor: feature_extractor

debug: True
batch_size: 32
time_steps: 5
num_objects: 32
name: debug-training
game: Pong
predictor: transformer
num_iterations: 1000
num_iterations: 1000
1 change: 1 addition & 0 deletions configs/training/feature_extractor/baseline.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
_target_: src.model.feat_extractor_baseline.FeatureExtractorBaseline
1 change: 1 addition & 0 deletions configs/training/feature_extractor/feature_extractor.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
_target_: src.model.feat_extractor.FeatureExtractor
16 changes: 9 additions & 7 deletions src/data_collection/data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from src.data_collection.common import get_data_directory, get_id_from_episode_name, get_length_from_episode_name


class DataLoader:
def __init__(self, game: str, num_obj: int):
self.dataset_path = get_data_directory(game)
Expand Down Expand Up @@ -55,7 +56,7 @@ def sample(self, batch_size: int, time_steps: int) -> Tuple[torch.Tensor, torch.
start = np.random.randint(0, len(frames) - self.history_len - time_steps)
base = start + self.history_len
states.append(frames[start:base])
obj_bbxs = object_bounding_boxes[base:base+time_steps] # [T, O, 4]
obj_bbxs = object_bounding_boxes[base:base + time_steps] # [T, O, 4]
objs = obj_bbxs[0].sum(-1) != 0 # [O]
orderd_bbxs = np.zeros_like(obj_bbxs) # [T, O, 4] ordered by the initial object they are tracking
order = np.arange(objs.sum()) # [o]
Expand All @@ -64,7 +65,7 @@ def sample(self, batch_size: int, time_steps: int) -> Tuple[torch.Tensor, torch.
order = last_idxs[base + t, order]
object_bounding_boxes_list.append(orderd_bbxs)
masks.append(detected_masks[base])
actions.append(episode_actions[base:base+time_steps])
actions.append(episode_actions[base:base + time_steps])

states_tensor = torch.from_numpy(np.array(states))
states_tensor = states_tensor / 255
Expand All @@ -77,11 +78,12 @@ def sample(self, batch_size: int, time_steps: int) -> Tuple[torch.Tensor, torch.
object_bounding_boxes_tensor = object_bounding_boxes_tensor[:, :, :self.num_obj]

states_tensor = states_tensor.reshape(*states_tensor.shape[:1], -1, *states_tensor.shape[3:])
states_tensor = F.interpolate(states_tensor , (128, 128))
states_tensor = states_tensor.reshape((-1, 12, 128, 128))
states_tensor = F.interpolate(states_tensor, (128, 128))
states_tensor = states_tensor.reshape((-1, 12, 128, 128))

masks_tensor = torch.from_numpy(np.array(masks))[:, :self.num_obj]
masks_tensor = F.one_hot(masks_tensor.long(), num_classes=self.num_obj + 1).float()[:, :, :, 1:] # get rid of background [B, H, W, O]
masks_tensor = masks_tensor.permute(0, 3, 1, 2) # [B, O, H, W]
masks_tensor = torch.from_numpy(np.array(masks))
masks_tensor = F.one_hot(masks_tensor.long(), num_classes=self.num_obj + 1).float()[:, :, :, 1:]
masks_tensor = masks_tensor.permute(0, 3, 1, 2)
masks_tensor = F.interpolate(masks_tensor, (128, 128))

return states_tensor, object_bounding_boxes_tensor, masks_tensor, torch.from_numpy(np.array(actions))
2 changes: 0 additions & 2 deletions src/data_visualization/visualizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@
import cv2 # type: ignore
import numpy as np
import matplotlib.pyplot as plt
import torch
from torch.nn import functional as F
from matplotlib.backends.backend_tkagg import FigureCanvasTkAgg

from src.data_collection.data_loader import DataLoader
Expand Down
8 changes: 3 additions & 5 deletions src/model/feat_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,8 @@ class FeatureExtractor(torch.nn.Module):
Performs CNN-based feature extraction and ROI pooling.
"""

def __init__(self, input_size: int = 128, num_frames: int = 4, num_objects: int = 32, debug: bool = False):
def __init__(self, input_size: int = 128, num_frames: int = 4, num_objects: int = 32):
super().__init__()
self.debug = debug
self.num_frames = num_frames
self.num_objects = num_objects
self.input_size = input_size
Expand Down Expand Up @@ -70,9 +69,8 @@ def forward(self, images: torch.Tensor, rois: torch.Tensor) -> torch.Tensor:
Returns:
(B, num_objects, 128) feature vector
"""
if self.debug:
assert len(images.shape) == 4, f"Expected 4D tensor, got {images.shape}"
assert images.shape[-1] == images.shape[-2] == self.input_size, f"Expected input size {self.input_size}, got {images.shape}"
assert len(images.shape) == 4, f"Expected 4D tensor, got {images.shape}"
assert images.shape[-1] == images.shape[-2] == self.input_size, f"Expected input size {self.input_size}, got {images.shape}"
images = self.conv(images) # [input_size/2, input_size/2]
images = self.position_embed(images)
objects = self.roi_pool(images, rois)
Expand Down
35 changes: 35 additions & 0 deletions src/model/feat_extractor_baseline.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import torch
from torch import nn

class FeatureExtractorBaseline(nn.Module):
def __init__(self, input_size: int = 128, num_objects: int = 32, num_features: int = 128, device: str = 'cpu'):
super().__init__()
self.input_size = input_size
self.num_objects = num_objects
self.fc1 = nn.Linear(2, num_features)
self.device = device

def forward(self, rois: torch.Tensor) -> torch.Tensor:
"""
Args:
rois: (B, num_objects, input_size, input_size) input image tensor
Returns:
(B, num_objects, num_features) feature vector
"""
x_indices = torch.arange(self.input_size, dtype=torch.float32, device=self.device).view(1, 1, -1, 1)
y_indices = torch.arange(self.input_size, dtype=torch.float32, device=self.device).view(1, 1, 1, -1)

sum_x = torch.sum(rois * x_indices, dim=(2, 3))
sum_y = torch.sum(rois * y_indices, dim=(2, 3))

mask_areas = torch.sum(rois, dim=(2, 3))
mask_areas[mask_areas == 0] = 1

average_x = sum_x / mask_areas
average_y = sum_y / mask_areas

# (B, num_objects, 2)
average_xy = torch.stack((average_x, average_y), dim=-1)

#output = F.relu(self.fc1(average_xy))
return average_xy/128
18 changes: 18 additions & 0 deletions src/model/predictor_baseline.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import torch
from torch import nn
import torch.nn.functional as F

class PredictorBaseline(nn.Module):
def __init__(self, input_size: int = 128, time_steps: int = 5):
super().__init__()
self.time_steps = time_steps
self.fc1 = nn.Linear(input_size, input_size)
self.fc2 = nn.Linear(input_size, 2)

def forward(self, x: torch.Tensor) -> torch.Tensor:
predictions = []
for _ in range(self.time_steps):
predictions.append(F.relu(self.fc1(x)))
x = torch.stack(predictions, 1)
x = F.sigmoid(self.fc2(x))
return x
17 changes: 17 additions & 0 deletions src/model/small_mlp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import torch
from torch import nn
import torch.nn.functional as F

class SmallMLP(nn.Module):
def __init__(self, input_size: int = 2, hidden_size: int = 32, output_size: int = 2):
super().__init__()
self.input_size = input_size
self.output_size = output_size
self.fc1 = nn.Linear(input_size, hidden_size)
self.fc2 = nn.Linear(hidden_size, output_size)

def forward(self, x: torch.Tensor) -> torch.Tensor:

output = F.relu(self.fc1(x))
output = F.relu(self.fc2(output))
return output
9 changes: 5 additions & 4 deletions src/scripts/train_prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,16 @@
import time
import typing
from typing import Any, Dict

from omegaconf import DictConfig, OmegaConf
from tqdm import tqdm
import torch
from torch import nn
import hydra
from hydra.utils import to_absolute_path
from hydra.utils import to_absolute_path, instantiate
import wandb

from src.data_collection.data_loader import DataLoader
from src.model.feat_extractor import FeatureExtractor
from src.model.predictor import Predictor
from src.model.mlp_predictor import MLPPredictor

Expand All @@ -22,7 +22,7 @@ def train(cfg: DictConfig) -> None:
use_mlp = cfg.predictor == "mlp"

data_loader = DataLoader(cfg.game, cfg.num_objects)
feature_extract = FeatureExtractor(num_objects=cfg.num_objects, debug=cfg.debug).to(device)
feature_extract = instantiate(cfg.feature_extractor, num_objects=cfg.num_objects).to(device)
predictor = (MLPPredictor() if use_mlp else Predictor(num_layers=1, time_steps=cfg.time_steps)).to(device)

wandb.init(project="oc-data-training", entity="atari-obj-pred", name=cfg.name + cfg.game, config=typing.cast(Dict[Any, Any], OmegaConf.to_container(cfg)))
Expand All @@ -37,6 +37,7 @@ def train(cfg: DictConfig) -> None:
images, bboxes, masks, _ = data_loader.sample(cfg.batch_size, cfg.time_steps)
images, bboxes, masks = images.to(device), bboxes.to(device), masks.to(device)
target = bboxes[:, :, :, :2] # [B, T, O, 2]

# Run models
features: torch.Tensor = feature_extract(images, masks)
output: torch.Tensor = predictor(features)
Expand All @@ -53,7 +54,7 @@ def train(cfg: DictConfig) -> None:
tqdm.write(f"loss={loss.item()}, output_mean={output.mean().item()}, std={output.std().item()}")
tqdm.write(f"target_mean={target.mean().item()} std={target.std().item()}")
tqdm.write(f"l1 average loss = {l1sum/total}")
tqdm.write(f"Predicted: {output[:,:,0]}, Target: {target[:,:,0]}")
# tqdm.write(f"Predicted: {output[:,:,0]}, Target: {target[:,:,0]}")
# tqdm.write(f"Std: {std} {std.shape}")
# tqdm.write(f"Corr: {corr} {corr.shape}")
error_dict = {"loss": loss, "error/x": diff[:, :, :, 0].mean(), "error/y": diff[:, :, :, 1].mean()}
Expand Down

0 comments on commit 14874e6

Please sign in to comment.