Skip to content

Commit

Permalink
fix: use Ref instead of raw pointer
Browse files Browse the repository at this point in the history
  • Loading branch information
Adriankhl committed May 29, 2024
1 parent 88fd65a commit 261fe25
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 37 deletions.
6 changes: 3 additions & 3 deletions src/gdllava.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -304,7 +304,7 @@ Error GDLlava::run_generate_text_base64(String prompt, String image_base64) {
}


String GDLlava::generate_text_image_internal(String prompt, Image* image) {
String GDLlava::generate_text_image_internal(String prompt, Ref<Image> image) {
glog_verbose("generate_text_image_internal");

String image_base64 = Marshalls::get_singleton()->raw_to_base64(image->save_jpg_to_buffer());
Expand All @@ -319,7 +319,7 @@ String GDLlava::generate_text_image_internal(String prompt, Image* image) {
return full_generated_text;
}

String GDLlava::generate_text_image(String prompt, Image* image) {
String GDLlava::generate_text_image(String prompt, Ref<Image> image) {
glog_verbose("generate_text_image");

func_mutex->lock();
Expand All @@ -341,7 +341,7 @@ String GDLlava::generate_text_image(String prompt, Image* image) {
return full_generated_text;
}

Error GDLlava::run_generate_text_image(String prompt, Image* image) {
Error GDLlava::run_generate_text_image(String prompt, Ref<Image> image) {
glog_verbose("run_generate_text_image");
func_mutex->lock();

Expand Down
6 changes: 3 additions & 3 deletions src/gdllava.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ class GDLlava : public Node {
Ref<Thread> generate_text_thread;
String generate_text_common(String prompt, String image_base64);
String generate_text_base64_internal(String prompt, String image_base64);
String generate_text_image_internal(String prompt, Image* image);
String generate_text_image_internal(String prompt, Ref<Image> image);
std::function<void(std::string)> glog;
std::function<void(std::string)> glog_verbose;
std::string generate_text_buffer;
Expand Down Expand Up @@ -60,8 +60,8 @@ class GDLlava : public Node {
bool is_running();
String generate_text_base64(String prompt, String image_base64);
Error run_generate_text_base64(String prompt, String image_base64);
String generate_text_image(String prompt, Image* image);
Error run_generate_text_image(String prompt, Image* image);
String generate_text_image(String prompt, Ref<Image> image);
Error run_generate_text_image(String prompt, Ref<Image> image);
void stop_generate_text();
};

Expand Down
58 changes: 32 additions & 26 deletions src/llm_db.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#include "gdembedding.hpp"
#include "sqlite3.h"
#include "sqlite-vec.h"
#include <algorithm>
#include <cstring>
#include <gdextension_interface.h>
#include <godot_cpp/classes/global_constants.hpp>
Expand All @@ -20,6 +21,7 @@
#include <godot_cpp/variant/utility_functions.hpp>
#include <godot_cpp/variant/variant.hpp>
#include <queue>
#include <vector>

namespace godot {

Expand Down Expand Up @@ -50,36 +52,36 @@ LlmDBMetaData::LlmDBMetaData() : data_name {"default_name"},

LlmDBMetaData::~LlmDBMetaData() {}

LlmDBMetaData* LlmDBMetaData::create(String data_name, int data_type) {
LlmDBMetaData* data = memnew(LlmDBMetaData());
Ref<LlmDBMetaData> LlmDBMetaData::create(String data_name, int data_type) {
Ref<LlmDBMetaData> data = memnew(LlmDBMetaData());
data->set_data_name(data_name);
data->set_data_type(data_type);
return data;
}

LlmDBMetaData* LlmDBMetaData::create_int(String data_name) {
LlmDBMetaData* data = memnew(LlmDBMetaData());
Ref<LlmDBMetaData> LlmDBMetaData::create_int(String data_name) {
Ref<LlmDBMetaData> data = memnew(LlmDBMetaData());
data->set_data_name(data_name);
data->set_data_type(0);
return data;
}

LlmDBMetaData* LlmDBMetaData::create_real(String data_name) {
LlmDBMetaData* data = memnew(LlmDBMetaData());
Ref<LlmDBMetaData> LlmDBMetaData::create_real(String data_name) {
Ref<LlmDBMetaData> data = memnew(LlmDBMetaData());
data->set_data_name(data_name);
data->set_data_type(1);
return data;
}

LlmDBMetaData* LlmDBMetaData::create_text(String data_name) {
LlmDBMetaData* data = memnew(LlmDBMetaData());
Ref<LlmDBMetaData> LlmDBMetaData::create_text(String data_name) {
Ref<LlmDBMetaData> data = memnew(LlmDBMetaData());
data->set_data_name(data_name);
data->set_data_type(2);
return data;
}

LlmDBMetaData* LlmDBMetaData::create_blob(String data_name) {
LlmDBMetaData* data = memnew(LlmDBMetaData());
Ref<LlmDBMetaData> LlmDBMetaData::create_blob(String data_name) {
Ref<LlmDBMetaData> data = memnew(LlmDBMetaData());
data->set_data_name(data_name);
data->set_data_type(2);
return data;
Expand Down Expand Up @@ -277,7 +279,7 @@ TypedArray<LlmDBMetaData> LlmDB::get_meta() const {

void LlmDB::set_meta(TypedArray<LlmDBMetaData> p_meta) {
bool is_id_valid = true;
int col_to_remove = -1;
std::vector<int> cols_to_remove {};


if (p_meta.size() != 0) {
Expand All @@ -286,27 +288,31 @@ void LlmDB::set_meta(TypedArray<LlmDBMetaData> p_meta) {
UtilityFunctions::print_verbose("Checking meta data " + String::num_int64(i));
if (p_meta[i].get_type() != Variant::NIL) {
UtilityFunctions::print_verbose("Correct resource type");
LlmDBMetaData* sd = Object::cast_to<LlmDBMetaData>(p_meta[i]);
Ref<LlmDBMetaData> sd = Object::cast_to<LlmDBMetaData>(p_meta[i]);
if (sd->get_data_name() == "id") {
UtilityFunctions::printerr("Column " + String::num_int64(i) + " error: Id column must be the first column (0)");
col_to_remove = i;
cols_to_remove.push_back(i);
}
}
}
if (col_to_remove != -1) {
UtilityFunctions::printerr("Removing column " + String::num(col_to_remove));
p_meta.remove_at(col_to_remove);

// Remove from the end to make sure the indexes are correct
std::reverse(cols_to_remove.begin(), cols_to_remove.end());

for (int i : cols_to_remove) {
UtilityFunctions::printerr("Removing column " + String::num(i));
p_meta.remove_at(i);
}

LlmDBMetaData* sd0 = Object::cast_to<LlmDBMetaData>(p_meta[0]);
Ref<LlmDBMetaData> sd0 = Object::cast_to<LlmDBMetaData>(p_meta[0]);
if (sd0->get_data_name() == "id" && sd0->get_data_type() != LlmDBMetaDataType::TEXT) {
UtilityFunctions::printerr("Id column should be TEXT type, removing");
p_meta.remove_at(0);
}

// Get again since it might get removed
sd0 = Object::cast_to<LlmDBMetaData>(p_meta[0]);
if (sd0->get_data_name() != "id") {
Ref<LlmDBMetaData> sd0_1 = Object::cast_to<LlmDBMetaData>(p_meta[0]);
if (sd0_1->get_data_name() != "id") {
UtilityFunctions::printerr("First column is not id");
is_id_valid = false;
}
Expand Down Expand Up @@ -485,7 +491,7 @@ void LlmDB::create_llm_tables() {
UtilityFunctions::print_verbose("create_llm_tables: " + table_name);
String statement = "CREATE TABLE IF NOT EXISTS " + table_name + " (";
for (int i = 0; i < meta.size(); i++) {
LlmDBMetaData* sd = Object::cast_to<LlmDBMetaData>(meta[i]);
Ref<LlmDBMetaData> sd = Object::cast_to<LlmDBMetaData>(meta[i]);
statement += "'" + sd->get_data_name() + "' ";
statement += type_int_to_string(sd->get_data_type());
statement += ", ";
Expand All @@ -507,7 +513,7 @@ void LlmDB::create_llm_tables() {

String statement_meta = "CREATE TABLE IF NOT EXISTS " + meta_table_name + " (";
for (int i = 0; i < meta.size(); i++) {
LlmDBMetaData* sd = Object::cast_to<LlmDBMetaData>(meta[i]);
Ref<LlmDBMetaData> sd = Object::cast_to<LlmDBMetaData>(meta[i]);
statement_meta += " '" + sd->get_data_name() + "' ";
statement_meta += type_int_to_string(sd->get_data_type());
if (i == 0) {
Expand Down Expand Up @@ -632,7 +638,7 @@ bool LlmDB::is_table_valid(String p_table_name) {
String name = String::utf8((char *) sqlite3_column_text(stmt, 1));
String type = String::utf8((char *) sqlite3_column_text(stmt, 2));

LlmDBMetaData* sd = Object::cast_to<LlmDBMetaData>(meta[i]);
Ref<LlmDBMetaData> sd = Object::cast_to<LlmDBMetaData>(meta[i]);

if (name != sd->get_data_name()) {
UtilityFunctions::printerr("Column name wrong, table : " + name + ", meta: " + sd->get_data_name());
Expand Down Expand Up @@ -668,7 +674,7 @@ void LlmDB::store_meta(Dictionary meta_dict) {
Dictionary p_meta_dict = meta_dict.duplicate(false);
PackedStringArray array_bind {PackedStringArray()};
for (int i = 0; i < meta.size(); i++) {
LlmDBMetaData* sd = Object::cast_to<LlmDBMetaData>(meta[i]);
Ref<LlmDBMetaData> sd = Object::cast_to<LlmDBMetaData>(meta[i]);
if(p_meta_dict.has(sd->get_data_name())) {
Variant v = p_meta_dict.get(sd->get_data_name(), nullptr);
if (v.get_type() != type_int_to_variant(sd->get_data_type())) {
Expand Down Expand Up @@ -875,14 +881,14 @@ void LlmDB::insert_text_by_id(String id, String text) {
String statement = "INSERT INTO " + table_name + " (";

for (int i = 0; i < meta.size(); i++) {
LlmDBMetaData* sd = Object::cast_to<LlmDBMetaData>(meta[i]);
Ref<LlmDBMetaData> sd = Object::cast_to<LlmDBMetaData>(meta[i]);
statement += sd->get_data_name() + ", ";
}

statement += "llm_text, embedding) VALUES (?, ";

for (int i = 1; i < meta.size(); i++) {
LlmDBMetaData* sd = Object::cast_to<LlmDBMetaData>(meta[i]);
Ref<LlmDBMetaData> sd = Object::cast_to<LlmDBMetaData>(meta[i]);
statement += "(SELECT " + sd->get_data_name() + " FROM " + table_name + "_meta" + " WHERE id=?), ";
}

Expand Down Expand Up @@ -985,7 +991,7 @@ void LlmDB::insert_text_by_meta(Dictionary meta_dict, String text) {
Dictionary p_meta_dict = meta_dict.duplicate(false);
PackedStringArray array_bind {PackedStringArray()};
for (int i = 0; i < meta.size(); i++) {
LlmDBMetaData* sd = Object::cast_to<LlmDBMetaData>(meta[i]);
Ref<LlmDBMetaData> sd = Object::cast_to<LlmDBMetaData>(meta[i]);
if(p_meta_dict.has(sd->get_data_name())) {
Variant v = p_meta_dict.get(sd->get_data_name(), nullptr);
if (v.get_type() != type_int_to_variant(sd->get_data_type())) {
Expand Down
10 changes: 5 additions & 5 deletions src/llm_db.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,11 @@ class LlmDBMetaData : public Resource {
public:
LlmDBMetaData();
~LlmDBMetaData();
static LlmDBMetaData* create(String data_name, int data_type);
static LlmDBMetaData* create_int(String data_name);
static LlmDBMetaData* create_real(String data_name);
static LlmDBMetaData* create_text(String data_name);
static LlmDBMetaData* create_blob(String data_name);
static Ref<LlmDBMetaData> create(String data_name, int data_type);
static Ref<LlmDBMetaData> create_int(String data_name);
static Ref<LlmDBMetaData> create_real(String data_name);
static Ref<LlmDBMetaData> create_text(String data_name);
static Ref<LlmDBMetaData> create_blob(String data_name);
String get_data_name() const;
void set_data_name(const String p_data_name);
int get_data_type() const;
Expand Down

0 comments on commit 261fe25

Please sign in to comment.