Skip to content

Commit

Permalink
Implement loading snapshot from stream
Browse files Browse the repository at this point in the history
  • Loading branch information
Tom94 committed Oct 6, 2023
1 parent db56177 commit 8ee14d5
Show file tree
Hide file tree
Showing 4 changed files with 42 additions and 14 deletions.
3 changes: 3 additions & 0 deletions include/neural-graphics-primitives/testbed.h
Original file line number Diff line number Diff line change
Expand Up @@ -359,6 +359,7 @@ class Testbed {
);
void visualize_nerf_cameras(ImDrawList* list, const mat4& world2proj);
fs::path find_network_config(const fs::path& network_config_path);
nlohmann::json load_network_config(std::istream& stream, bool is_compressed);
nlohmann::json load_network_config(const fs::path& network_config_path);
void reload_network_from_file(const fs::path& path = "");
void reload_network_from_json(const nlohmann::json& json, const std::string& config_base_path=""); // config_base_path is needed so that if the passed in json uses the 'parent' feature, we know where to look... be sure to use a filename, or if a directory, end with a trailing slash
Expand Down Expand Up @@ -484,7 +485,9 @@ class Testbed {
vec2 fov_xy() const ;
void set_fov_xy(const vec2& val);
void save_snapshot(const fs::path& path, bool include_optimizer_state, bool compress);
void load_snapshot(nlohmann::json config);
void load_snapshot(const fs::path& path);
void load_snapshot(std::istream& stream, bool is_compressed = true);
CameraKeyframe copy_camera_to_keyframe() const;
void set_camera_from_keyframe(const CameraKeyframe& k);
void set_camera_from_time(float t);
Expand Down
2 changes: 1 addition & 1 deletion src/main.cu
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ int main_func(const std::vector<std::string>& arguments) {
}

if (snapshot_flag) {
testbed.load_snapshot(get(snapshot_flag));
testbed.load_snapshot(static_cast<fs::path>(get(snapshot_flag)));
} else if (network_config_flag) {
testbed.reload_network_from_file(get(network_config_flag));
}
Expand Down
2 changes: 1 addition & 1 deletion src/python_api.cu
Original file line number Diff line number Diff line change
Expand Up @@ -443,7 +443,7 @@ PYBIND11_MODULE(pyngp, m) {
.def("n_params", &Testbed::n_params, "Number of trainable parameters")
.def("n_encoding_params", &Testbed::n_encoding_params, "Number of trainable parameters in the encoding")
.def("save_snapshot", &Testbed::save_snapshot, py::arg("path"), py::arg("include_optimizer_state")=false, py::arg("compress")=true, "Save a snapshot of the currently trained model. Optionally compressed (only when saving '.ingp' files).")
.def("load_snapshot", &Testbed::load_snapshot, py::arg("path"), "Load a previously saved snapshot")
.def("load_snapshot", py::overload_cast<const fs::path&>(&Testbed::load_snapshot), py::arg("path"), "Load a previously saved snapshot")
.def("load_camera_path", &Testbed::load_camera_path, py::arg("path"), "Load a camera path")
.def("load_file", &Testbed::load_file, py::arg("path"), "Load a file and automatically determine how to handle it. Can be a snapshot, dataset, network config, or camera path.")
.def_property("loop_animation", &Testbed::loop_animation, &Testbed::set_loop_animation)
Expand Down
49 changes: 37 additions & 12 deletions src/testbed.cu
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,14 @@ fs::path Testbed::find_network_config(const fs::path& network_config_path) {
return network_config_path;
}

json Testbed::load_network_config(std::istream& stream, bool is_compressed) {
if (is_compressed) {
zstr::istream zstream{stream};
return json::from_msgpack(zstream);
}
return json::from_msgpack(stream);
}

json Testbed::load_network_config(const fs::path& network_config_path) {
bool is_snapshot = equals_case_insensitive(network_config_path.extension(), "msgpack") || equals_case_insensitive(network_config_path.extension(), "ingp");
if (network_config_path.empty() || !network_config_path.exists()) {
Expand Down Expand Up @@ -1543,7 +1551,7 @@ void Testbed::imgui() {
ImGui::SameLine();
if (ImGui::Button("Load")) {
try {
load_snapshot(m_imgui.snapshot_path);
load_snapshot(static_cast<fs::path>(m_imgui.snapshot_path));
} catch (const std::exception& e) {
imgui_error_string = fmt::format("Failed to load snapshot: {}", e.what());
ImGui::OpenPopup("Error");
Expand Down Expand Up @@ -2339,14 +2347,14 @@ void Testbed::SecondWindow::draw(GLuint texture) {
}

void Testbed::init_opengl_shaders() {
static const char* shader_vert = R"(#version 140
static const char* shader_vert = R"glsl(#version 140
out vec2 UVs;
void main() {
UVs = vec2((gl_VertexID << 1) & 2, gl_VertexID & 2);
gl_Position = vec4(UVs * 2.0 - 1.0, 0.0, 1.0);
})";
})glsl";

static const char* shader_frag = R"(#version 140
static const char* shader_frag = R"glsl(#version 140
in vec2 UVs;
out vec4 frag_color;
uniform sampler2D rgba_texture;
Expand Down Expand Up @@ -2386,7 +2394,7 @@ void Testbed::init_opengl_shaders() {
//Uncomment the following line of code to visualize debug the depth buffer for debugging.
// frag_color = vec4(vec3(texture(depth_texture, tex_coords.xy).r), 1.0);
gl_FragDepth = texture(depth_texture, tex_coords.xy).r;
})";
})glsl";

GLuint vert = glCreateShader(GL_VERTEX_SHADER);
glShaderSource(vert, 1, &shader_vert, NULL);
Expand Down Expand Up @@ -4746,12 +4754,7 @@ void Testbed::save_snapshot(const fs::path& path, bool include_optimizer_state,
tlog::success() << "Saved snapshot '" << path.str() << "'";
}

void Testbed::load_snapshot(const fs::path& path) {
auto config = load_network_config(path);
if (!config.contains("snapshot")) {
throw std::runtime_error{fmt::format("File '{}' does not contain a snapshot.", path.str())};
}

void Testbed::load_snapshot(nlohmann::json config) {
const auto& snapshot = config["snapshot"];
if (snapshot.value("version", 0) < SNAPSHOT_FORMAT_VERSION) {
throw std::runtime_error{"Snapshot uses an old format and can not be loaded."};
Expand Down Expand Up @@ -4841,7 +4844,6 @@ void Testbed::load_snapshot(const fs::path& path) {
m_render_aabb = snapshot.value("render_aabb", m_render_aabb);
if (snapshot.contains("up_dir")) from_json(snapshot.at("up_dir"), m_up_dir);

m_network_config_path = path;
m_network_config = std::move(config);

reset_network(false);
Expand All @@ -4868,6 +4870,29 @@ void Testbed::load_snapshot(const fs::path& path) {
set_all_devices_dirty();
}

void Testbed::load_snapshot(const fs::path& path) {
auto config = load_network_config(path);
if (!config.contains("snapshot")) {
throw std::runtime_error{fmt::format("File '{}' does not contain a snapshot.", path.str())};
}

load_snapshot(std::move(config));

m_network_config_path = path;
}

void Testbed::load_snapshot(std::istream& stream, bool is_compressed) {
auto config = load_network_config(stream, is_compressed);
if (!config.contains("snapshot")) {
throw std::runtime_error{"Given stream does not contain a snapshot."};
}

load_snapshot(std::move(config));

// Network config path is unknown.
m_network_config_path = "";
}

Testbed::CudaDevice::CudaDevice(int id, bool is_primary) : m_id{id}, m_is_primary{is_primary} {
auto guard = device_guard();
m_stream = std::make_unique<StreamAndEvent>();
Expand Down

0 comments on commit 8ee14d5

Please sign in to comment.