Skip to content

Commit

Permalink
PureData reload mode
Browse files Browse the repository at this point in the history
  • Loading branch information
caillonantoine committed May 24, 2023
1 parent 4553058 commit 3517fb5
Showing 1 changed file with 15 additions and 11 deletions.
26 changes: 15 additions & 11 deletions src/frontend/puredata/nn_tilde/nn_tilde.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#include "../../maxmsp/shared/circular_buffer.h"
#include "m_pd.h"
#include <memory>
#include <mutex>
#include <string>
#include <vector>

Expand All @@ -28,7 +29,7 @@ typedef struct _nn_tilde {

int m_enabled;
// BACKEND RELATED MEMBERS
Backend m_model;
std::unique_ptr<Backend> m_model;
std::vector<std::string> settable_attributes;
t_symbol *m_method, *m_path;
std::unique_ptr<std::thread> m_compute_thread;
Expand Down Expand Up @@ -57,15 +58,15 @@ void model_perform(t_nn_tilde *nn_instance) {
for (int c(0); c < nn_instance->m_out_dim; c++)
out_model.push_back(nn_instance->m_out_model[c].get());

nn_instance->m_model.perform(in_model, out_model, nn_instance->m_buffer_size,
nn_instance->m_method->s_name, 1);
nn_instance->m_model->perform(in_model, out_model, nn_instance->m_buffer_size,
nn_instance->m_method->s_name, 1);
}

// DSP CALL
t_int *nn_tilde_perform(t_int *w) {
t_nn_tilde *x = (t_nn_tilde *)(w[1]);

if (!x->m_model.is_loaded() || !x->m_enabled) {
if (!x->m_model->is_loaded() || !x->m_enabled) {
for (int c(0); c < x->m_out_dim; c++) {
for (int i(0); i < x->m_dsp_vec_size; i++) {
x->m_dsp_out_vec[c][i] = 0;
Expand Down Expand Up @@ -131,7 +132,7 @@ void nn_tilde_free(t_nn_tilde *x) {
void *nn_tilde_new(t_symbol *s, int argc, t_atom *argv) {
t_nn_tilde *x = (t_nn_tilde *)pd_new(nn_tilde_class);

x->m_model = Backend();
x->m_model = std::make_unique<Backend>();
x->m_head = 0;
x->m_compute_thread = nullptr;
x->m_in_dim = 1;
Expand Down Expand Up @@ -170,29 +171,29 @@ void *nn_tilde_new(t_symbol *s, int argc, t_atom *argv) {
}

// TRY TO LOAD MODEL
if (x->m_model.load(x->m_path->s_name)) {
if (x->m_model->load(x->m_path->s_name)) {
post("error during loading");
return (void *)x;
} else {
// cout << "successfully loaded model" << endl;
}

// GET MODEL'S METHOD PARAMETERS
auto params = x->m_model.get_method_params(x->m_method->s_name);
x->settable_attributes = x->m_model.get_settable_attributes();
auto params = x->m_model->get_method_params(x->m_method->s_name);
x->settable_attributes = x->m_model->get_settable_attributes();

if (!params.size()) {
post("method not found, using forward instead");
x->m_method = gensym("forward");
params = x->m_model.get_method_params(x->m_method->s_name);
params = x->m_model->get_method_params(x->m_method->s_name);
}

x->m_in_dim = params[0];
x->m_in_ratio = params[1];
x->m_out_dim = params[2];
x->m_out_ratio = params[3];

auto higher_ratio = x->m_model.get_higher_ratio();
auto higher_ratio = x->m_model->get_higher_ratio();

if (!x->m_buffer_size) {
// NO THREAD MODE
Expand Down Expand Up @@ -229,6 +230,7 @@ void *nn_tilde_new(t_symbol *s, int argc, t_atom *argv) {
}

void nn_tilde_enable(t_nn_tilde *x, t_floatarg arg) { x->m_enabled = int(arg); }
void nn_tilde_reload(t_nn_tilde *x) { x->m_model->reload(); }

void nn_tilde_set(t_nn_tilde *x, t_symbol *s, int argc, t_atom *argv) {
if (argc < 2) {
Expand All @@ -254,7 +256,7 @@ void nn_tilde_set(t_nn_tilde *x, t_symbol *s, int argc, t_atom *argv) {
}
}
try {
x->m_model.set_attribute(argname, attribute_args);
x->m_model->set_attribute(argname, attribute_args);
} catch (const std::exception &e) {
post(e.what());
}
Expand All @@ -280,6 +282,8 @@ void nn_tilde_setup(void) {
0);
class_addmethod(nn_tilde_class, (t_method)nn_tilde_enable, gensym("enable"),
A_DEFFLOAT, A_NULL);
class_addmethod(nn_tilde_class, (t_method)nn_tilde_reload, gensym("reload"),
A_NULL);
class_addmethod(nn_tilde_class, (t_method)nn_tilde_set, gensym("set"),
A_GIMME, A_NULL);
CLASS_MAINSIGNALIN(nn_tilde_class, t_nn_tilde, f);
Expand Down

0 comments on commit 3517fb5

Please sign in to comment.