Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 12 additions & 7 deletions Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,28 +1,32 @@
[package]
name = "pg_summarize"
version = "0.0.0"
version = "0.0.1"
edition = "2021"

[lib]
crate-type = ["cdylib"]
crate-type = ["cdylib", "rlib"]

[[bin]]
name = "pgrx_embed_pg_summarize"
path = "src/bin/pgrx_embed.rs"

[features]
default = ["pg13"]
pg11 = ["pgrx/pg11", "pgrx-tests/pg11" ]
pg12 = ["pgrx/pg12", "pgrx-tests/pg12" ]
default = ["pg18"]
pg13 = ["pgrx/pg13", "pgrx-tests/pg13" ]
pg14 = ["pgrx/pg14", "pgrx-tests/pg14" ]
pg15 = ["pgrx/pg15", "pgrx-tests/pg15" ]
pg16 = ["pgrx/pg16", "pgrx-tests/pg16" ]
pg17 = ["pgrx/pg17", "pgrx-tests/pg17" ]
pg18 = ["pgrx/pg18", "pgrx-tests/pg18" ]
pg_test = []

[dependencies]
pgrx = "=0.11.4"
pgrx = "=0.16.1"
reqwest = { version = "0.12.4", features = ["json", "blocking"] }
serde_json = "1.0.117"

[dev-dependencies]
pgrx-tests = "=0.11.4"
pgrx-tests = "=0.16.1"

[profile.dev]
panic = "unwind"
Expand All @@ -32,3 +36,4 @@ panic = "unwind"
opt-level = 3
lto = "fat"
codegen-units = 1

1 change: 1 addition & 0 deletions src/bin/pgrx_embed.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
pgrx::pgrx_embed!();
69 changes: 46 additions & 23 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,10 @@ use reqwest::blocking::Client;
use reqwest::header::{HeaderMap, HeaderValue, AUTHORIZATION, CONTENT_TYPE};
use serde_json::json;

pgrx::pg_module_magic!();
// Re-export pgrx for the binary
pub use pgrx;

pg_module_magic!();

#[pg_extern]
fn hello_pg_summarize() -> &'static str {
Expand All @@ -12,29 +15,48 @@ fn hello_pg_summarize() -> &'static str {

#[pg_extern]
fn summarize(input: &str) -> String {
let api_key = Spi::get_one::<&str>("SELECT current_setting('pg_summarizer.api_key', true)")
.expect("failed to get 'pg_summarizer.api_key' setting")
.expect("got null for 'pg_summarizer.api_key' setting");

let model = match Spi::get_one::<&str>("SELECT current_setting('pg_summarizer.model', true)") {
Ok(Some(model_name)) => model_name,
_ => "gpt-3.5-turbo",
};

let prompt = match Spi::get_one::<&str>("SELECT current_setting('pg_summarizer.prompt', true)")
{
Ok(Some(prompt_str)) => prompt_str,
_ => {
"You are an AI summarizing tool. \
Your purpose is to summarize the <text> tag, \
not to engage in conversation or discussion. \
Please read the <text> carefully. \
Then, summarize the key points. \
Focus on capturing the most important information as concisely as possible."
}
};
let api_key = Spi::connect(|client| {
client
.select("SELECT current_setting('pg_summarizer.api_key', true)", None, &[])?
.first()
.get::<String>(1)
.ok()
.flatten()
.ok_or(pgrx::spi::Error::InvalidPosition)
})
.expect("failed to get 'pg_summarizer.api_key' setting");

let model = Spi::connect(|client| -> Result<String, pgrx::spi::Error> {
Ok(client
.select("SELECT current_setting('pg_summarizer.model', true)", None, &[])?
.first()
.get::<String>(1)
.ok()
.flatten()
.unwrap_or_else(|| "gpt-3.5-turbo".to_string()))
})
.expect("failed to get 'pg_summarizer.model' setting");

match make_api_call(input, &api_key, model, prompt) {
let prompt = Spi::connect(|client| -> Result<String, pgrx::spi::Error> {
Ok(client
.select("SELECT current_setting('pg_summarizer.prompt', true)", None, &[])?
.first()
.get::<String>(1)
.ok()
.flatten()
.unwrap_or_else(|| {
"You are an AI summarizing tool. \
Your purpose is to summarize the <text> tag, \
not to engage in conversation or discussion. \
Please read the <text> carefully. \
Then, summarize the key points. \
Focus on capturing the most important information as concisely as possible."
.to_string()
}))
})
.expect("failed to get 'pg_summarizer.prompt' setting");

match make_api_call(input, &api_key, &model, &prompt) {
Ok(summary) => summary,
Err(e) => panic!("Error: {}", e),
}
Expand Down Expand Up @@ -110,3 +132,4 @@ pub mod pg_test {
vec![]
}
}