Skip to content

Commit

Permalink
write eval combined images for diagnostics (nerfstudio-project#3070)
Browse files Browse the repository at this point in the history
  • Loading branch information
jb-ye authored Apr 11, 2024
1 parent eba72db commit 4714ae7
Showing 1 changed file with 9 additions and 2 deletions.
11 changes: 9 additions & 2 deletions nerfstudio/pipelines/base_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@

import torch
import torch.distributed as dist
import torchvision.utils as vutils
from rich.progress import BarColumn, MofNCompleteColumn, Progress, TextColumn, TimeElapsedColumn
from torch import nn
from torch.cuda.amp.grad_scaler import GradScaler
Expand Down Expand Up @@ -361,6 +362,8 @@ def get_average_eval_image_metrics(
metrics_dict_list = []
assert isinstance(self.datamanager, (VanillaDataManager, ParallelDataManager, FullImageDatamanager))
num_images = len(self.datamanager.fixed_indices_eval_dataloader)
if output_path is not None:
output_path.mkdir(exist_ok=True, parents=True)
with Progress(
TextColumn("[progress.description]{task.description}"),
BarColumn(),
Expand All @@ -369,15 +372,18 @@ def get_average_eval_image_metrics(
transient=True,
) as progress:
task = progress.add_task("[green]Evaluating all eval images...", total=num_images)
idx = 0
for camera, batch in self.datamanager.fixed_indices_eval_dataloader:
# time this the following line
inner_start = time()
outputs = self.model.get_outputs_for_camera(camera=camera)
height, width = camera.height, camera.width
num_rays = height * width
metrics_dict, _ = self.model.get_image_metrics_and_images(outputs, batch)
metrics_dict, image_dict = self.model.get_image_metrics_and_images(outputs, batch)
if output_path is not None:
raise NotImplementedError("Saving images is not implemented yet")
for key in image_dict.keys():
image = image_dict[key] # [H, W, C] order
vutils.save_image(image.permute(2, 0, 1).cpu(), output_path / f"eval_{key}_{idx:04d}.png")

assert "num_rays_per_sec" not in metrics_dict
metrics_dict["num_rays_per_sec"] = (num_rays / (time() - inner_start)).item()
Expand All @@ -386,6 +392,7 @@ def get_average_eval_image_metrics(
metrics_dict[fps_str] = (metrics_dict["num_rays_per_sec"] / (height * width)).item()
metrics_dict_list.append(metrics_dict)
progress.advance(task)
idx = idx + 1
# average the metrics list
metrics_dict = {}
for key in metrics_dict_list[0].keys():
Expand Down

0 comments on commit 4714ae7

Please sign in to comment.