Skip to content

Commit

Permalink
Inference without Horovod (#33)
Browse files Browse the repository at this point in the history
  • Loading branch information
VitorGuizilini-TRI committed Jun 17, 2020
1 parent dfbdc27 commit f824ffc
Show file tree
Hide file tree
Showing 19 changed files with 290 additions and 157 deletions.
7 changes: 5 additions & 2 deletions configs/default_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,11 @@
########################################################################################################################
cfg.save = CN()
cfg.save.folder = '' # Folder where data will be saved
cfg.save.viz = True # Flag for saving inverse depth map visualization
cfg.save.npz = True # Flag for saving numpy depth maps
cfg.save.depth = CN()
cfg.save.depth.rgb = True # Flag for saving rgb images
cfg.save.depth.viz = True # Flag for saving inverse depth map visualization
cfg.save.depth.npz = True # Flag for saving numpy depth maps
cfg.save.depth.png = True # Flag for saving png depth maps
########################################################################################################################
### WANDB
########################################################################################################################
Expand Down
7 changes: 5 additions & 2 deletions configs/eval_ddad.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -21,5 +21,8 @@ datasets:
cameras: [['camera_01']]
save:
folder: '/data/save'
viz: True
npz: True
depth:
rgb: True
viz: True
npz: True
png: True
7 changes: 5 additions & 2 deletions configs/eval_image.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,5 +15,8 @@ datasets:
split: ['{:010d}']
save:
folder: '/data/save'
viz: True
npy: True
depth:
rgb: True
viz: True
npz: True
png: True
7 changes: 5 additions & 2 deletions configs/eval_kitti.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -20,5 +20,8 @@ datasets:
depth_type: ['velodyne']
save:
folder: '/data/save'
viz: True
npz: True
depth:
rgb: True
viz: True
npz: True
png: True
5 changes: 0 additions & 5 deletions packnet_sfm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,3 @@
Furthermore, it gets better with input resolution and number of parameters, generalizes better, and can run in real-time (with TensorRT). See [References](#references) for more info on our models.
"""

from packnet_sfm.models import ModelWrapper, ModelCheckpoint
from packnet_sfm.trainers import HorovodTrainer

__all__ = ["ModelWrapper", "HorovodTrainer", "ModelCheckpoint"]
10 changes: 0 additions & 10 deletions packnet_sfm/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,3 @@
- ImageDataset: reads from a folder containing image sequences (no support for depth maps)
"""

from packnet_sfm.datasets.kitti_dataset import KITTIDataset
from packnet_sfm.datasets.dgp_dataset import DGPDataset
from packnet_sfm.datasets.image_dataset import ImageDataset

__all__ = [
"KITTIDataset",
"DGPDataset",
"ImageDataset",
]
11 changes: 8 additions & 3 deletions packnet_sfm/datasets/kitti_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,12 @@

########################################################################################################################

# Cameras from the stero pair (left is the origin)
IMAGE_FOLDER = {
'left': 'image_02',
'right': 'image_03',
}
# Name of different calibration files
CALIB_FILE = {
'cam2cam': 'calib_cam_to_cam.txt',
'velo2cam': 'calib_velo_to_cam.txt',
Expand Down Expand Up @@ -144,9 +146,12 @@ def _get_parent_folder(image_file):
return os.path.abspath(os.path.join(image_file, "../../../.."))

@staticmethod
def _get_intrinsics(calib_data):
def _get_intrinsics(image_file, calib_data):
"""Get intrinsics from the calib_data dictionary."""
return np.reshape(calib_data['P_rect_02'], (3, 4))[:, :3]
for cam in ['left', 'right']:
# Check for both cameras, if found replace and return intrinsics
if IMAGE_FOLDER[cam] in image_file:
return np.reshape(calib_data[IMAGE_FOLDER[cam].replace('image', 'P_rect')], (3, 4))[:, :3]

@staticmethod
def _read_raw_calib_file(folder):
Expand Down Expand Up @@ -358,7 +363,7 @@ def __getitem__(self, idx):
c_data = self._read_raw_calib_file(parent_folder)
self.calibration_cache[parent_folder] = c_data
sample.update({
'intrinsics': self._get_intrinsics(c_data),
'intrinsics': self._get_intrinsics(self.paths[idx], c_data),
})

# Add pose information if requested
Expand Down
14 changes: 0 additions & 14 deletions packnet_sfm/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,17 +9,3 @@
- ModelCheckpoint enables saving/restoring state of torch.nn.Module objects
"""

from packnet_sfm.models.model_checkpoint import ModelCheckpoint
from packnet_sfm.models.model_wrapper import ModelWrapper
from packnet_sfm.models.SfmModel import SfmModel
from packnet_sfm.models.SelfSupModel import SelfSupModel
from packnet_sfm.models.SemiSupModel import SemiSupModel

__all__ = [
"ModelCheckpoint",
"ModelWrapper",
"SfmModel",
"SelfSupModel",
"SemiSupModel",
]
4 changes: 3 additions & 1 deletion packnet_sfm/models/model_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import torch
from torch.utils.data import ConcatDataset, DataLoader

from packnet_sfm.datasets import KITTIDataset, DGPDataset, ImageDataset
from packnet_sfm.datasets.transforms import get_transforms
from packnet_sfm.utils.depth import inv2depth, post_process_inv_depth, compute_depth_metrics
from packnet_sfm.utils.horovod import print0, world_size, rank, on_rank_0
Expand Down Expand Up @@ -516,19 +515,22 @@ def setup_dataset(config, mode, requirements, **kwargs):

# KITTI dataset
if config.dataset[i] == 'KITTI':
from packnet_sfm.datasets.kitti_dataset import KITTIDataset
dataset = KITTIDataset(
config.path[i], path_split,
**dataset_args, **dataset_args_i,
)
# DGP dataset
elif config.dataset[i] == 'DGP':
from packnet_sfm.datasets.dgp_dataset import DGPDataset
dataset = DGPDataset(
config.path[i], config.split[i],
**dataset_args, **dataset_args_i,
cameras=config.cameras[i],
)
# Image dataset
elif config.dataset[i] == 'Image':
from packnet_sfm.datasets.image_dataset import ImageDataset
dataset = ImageDataset(
config.path[i], config.split[i],
**dataset_args, **dataset_args_i,
Expand Down
63 changes: 59 additions & 4 deletions packnet_sfm/utils/depth.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,68 @@
# Copyright 2020 Toyota Research Institute. All rights reserved.

from matplotlib.cm import get_cmap
import torch
import numpy as np
from packnet_sfm.utils.image import \
gradient_x, gradient_y, flip_lr, interpolate_image
import torch
import torchvision.transforms as transforms
from matplotlib.cm import get_cmap

from packnet_sfm.utils.image import load_image, gradient_x, gradient_y, flip_lr, interpolate_image
from packnet_sfm.utils.types import is_seq, is_tensor


def load_depth(file):
"""
Load a depth map from file
Parameters
----------
file : str
Depth map filename (.npz or .png)
Returns
-------
depth : np.array [H,W]
Depth map (invalid pixels are 0)
"""
if file.endswith('npz'):
return np.load(file)['depth']
elif file.endswith('png'):
depth_png = np.array(load_image(file), dtype=int)
assert (np.max(depth_png) > 255), 'Wrong .png depth file'
return depth_png.astype(np.float) / 256.
else:
raise NotImplementedError('Depth extension not supported.')


def write_depth(filename, depth, intrinsics=None):
"""
Write a depth map to file, and optionally its corresponding intrinsics.
Parameters
----------
filename : str
File where depth map will be saved (.npz or .png)
depth : np.array [H,W]
Depth map
intrinsics : np.array [3,3]
Optional camera intrinsics matrix
"""
# If depth is a tensor
if is_tensor(depth):
depth = depth.detach().squeeze().cpu()
# If intrinsics is a tensor
if is_tensor(intrinsics):
intrinsics = intrinsics.detach().cpu()
# If we are saving as a .npz
if filename.endswith('.npz'):
np.savez_compressed(filename, depth=depth, intrinsics=intrinsics)
# If we are saving as a .png
elif filename.endswith('.png'):
depth = transforms.ToPILImage()((depth * 256).int())
depth.save(filename)
# Something is wrong
else:
raise NotImplementedError('Depth filename not valid.')


def viz_inv_depth(inv_depth, normalizer=None, percentile=95,
colormap='plasma', filter_zeros=False):
"""
Expand Down
35 changes: 29 additions & 6 deletions packnet_sfm/utils/horovod.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,15 @@

import horovod.torch as hvd
try:
import horovod.torch as hvd
HAS_HOROVOD = True
except ImportError:
HAS_HOROVOD = False

########################################################################################################################

def hvd_init():
hvd.init()
if HAS_HOROVOD:
hvd.init()
return HAS_HOROVOD

def on_rank_0(func):
def wrapper(*args, **kwargs):
Expand All @@ -13,13 +18,31 @@ def wrapper(*args, **kwargs):
return wrapper

def rank():
return hvd.rank()
return hvd.rank() if HAS_HOROVOD else 0

def world_size():
return hvd.size()
return hvd.size() if HAS_HOROVOD else 1

@on_rank_0
def print0(string='\n'):
print(string)

########################################################################################################################
def reduce_value(value, average, name):
"""
Reduce the mean value of a tensor from all GPUs
Parameters
----------
value : torch.Tensor
Value to be reduced
average : bool
Whether values will be averaged or not
name : str
Value name
Returns
-------
value : torch.Tensor
reduced value
"""
return hvd.allreduce(value, average=average, name=name)
17 changes: 15 additions & 2 deletions packnet_sfm/utils/image.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
# Copyright 2020 Toyota Research Institute. All rights reserved.

import cv2
import torch
import torch.nn.functional as funct
from functools import lru_cache
from PIL import Image

from packnet_sfm.utils.misc import same_shape

########################################################################################################################

def load_image(path):
"""
Expand All @@ -25,7 +25,20 @@ def load_image(path):
"""
return Image.open(path)

########################################################################################################################

def write_image(filename, image):
"""
Write an image to file.
Parameters
----------
filename : str
File where image will be saved
image : np.array [H,W,3]
RGB image
"""
cv2.imwrite(filename, image[:, :, ::-1])


def flip_lr(image):
"""
Expand Down
23 changes: 14 additions & 9 deletions packnet_sfm/utils/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,29 +30,34 @@ def pcolor(string, color, on_color=None, attrs=None):
return colored(string, color, on_color, attrs)


def prepare_dataset_prefix(config, n):
def prepare_dataset_prefix(config, dataset_idx):
"""
Concatenates dataset path and split for metrics logging
Parameters
----------
config : CfgNode
Dataset configuration
n : int
dataset_idx : int
Dataset index for multiple datasets
Returns
-------
prefix : str
Dataset prefix for metrics logging
"""
prefix = '{}-{}'.format(
os.path.splitext(config.path[n].split('/')[-1])[0],
os.path.splitext(os.path.basename(config.split[n]))[0])
if config.depth_type[n] is not '':
prefix += '-{}'.format(config.depth_type[n])
if len(config.cameras[n]) == 1: # only allows single cameras
prefix += '-{}'.format(config.cameras[n][0])
# Path is always available
prefix = '{}'.format(os.path.splitext(config.path[dataset_idx].split('/')[-1])[0])
# If split is available and does not contain { character
if config.split[dataset_idx] != '' and '{' not in config.split[dataset_idx]:
prefix += '-{}'.format(os.path.splitext(os.path.basename(config.split[dataset_idx]))[0])
# If depth type is available
if config.depth_type[dataset_idx] != '':
prefix += '-{}'.format(config.depth_type[dataset_idx])
# If we are using specific cameras
if len(config.cameras[dataset_idx]) == 1: # only allows single cameras
prefix += '-{}'.format(config.cameras[dataset_idx][0])
# Return full prefix
return prefix


Expand Down

0 comments on commit f824ffc

Please sign in to comment.