Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
196 changes: 122 additions & 74 deletions core/connectors/sources/postgres_source/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,23 @@ pub struct DatabaseRecord {
pub old_data: Option<serde_json::Value>,
}

#[derive(Clone, Copy)]
struct RowProcessingConfig<'a> {
table: &'a str,
tracking_column: &'a str,
pk_column: &'a str,
payload_format: PayloadFormat,
payload_col: &'a str,
snake_case_columns: bool,
include_metadata: bool,
}

struct ProcessedRow {
message: ProducedMessage,
max_offset: Option<String>,
row_pk: Option<String>,
}

const CONNECTOR_NAME: &str = "PostgreSQL source";

impl PostgresSource {
Expand Down Expand Up @@ -445,14 +462,27 @@ impl PostgresSource {
.primary_key_column
.as_deref()
.unwrap_or(tracking_column);
let payload_format = self.payload_format();
let payload_col = self.config.payload_column.as_deref().unwrap_or("");

let row_config = RowProcessingConfig {
table: "",
tracking_column,
pk_column,
payload_format: self.payload_format(),
payload_col: self.config.payload_column.as_deref().unwrap_or(""),
snake_case_columns: self.config.snake_case_columns.unwrap_or(false),
include_metadata: self.config.include_metadata.unwrap_or(true),
};

// Collect state updates to apply after processing
let mut state_updates: Vec<(String, String)> = Vec::new();
let mut total_processed: u64 = 0;

for table in &self.config.tables {
let table_config = RowProcessingConfig {
table,
..row_config
};

// Get last offset with minimal lock time
let last_offset = {
let state = self.state.lock().await;
Expand All @@ -478,82 +508,16 @@ impl PostgresSource {
let mut processed_ids: Vec<String> = Vec::new();

for row in rows {
let mut row_pk: Option<String> = None;
let mut extracted_payload: Option<Vec<u8>> = None;
let mut data = serde_json::Map::new();

for (i, column) in row.columns().iter().enumerate() {
let column_name = if self.config.snake_case_columns.unwrap_or(false) {
to_snake_case(column.name())
} else {
column.name().to_string()
};

if !payload_col.is_empty() && column.name() == payload_col {
extracted_payload =
Some(self.extract_payload_column(&row, i, payload_format)?);
continue;
}

let value = self.extract_column_value(&row, i)?;
data.insert(column_name.clone(), value.clone());

if column.name() == tracking_column {
if let serde_json::Value::String(ref s) = value {
max_offset = Some(s.clone());
} else if let serde_json::Value::Number(ref n) = value {
max_offset = Some(n.to_string());
}
}

if column.name() == pk_column {
if let serde_json::Value::String(ref s) = value {
row_pk = Some(s.clone());
} else if let serde_json::Value::Number(ref n) = value {
row_pk = Some(n.to_string());
}
}
}
let processed = self.process_row(&row, &table_config)?;

if let Some(pk) = row_pk {
if let Some(pk) = processed.row_pk {
processed_ids.push(pk);
}
if let Some(offset) = processed.max_offset {
max_offset = Some(offset);
}

let payload = if let Some(bytes) = extracted_payload {
bytes
} else {
let record = if self.config.include_metadata.unwrap_or(true) {
DatabaseRecord {
table_name: table.clone(),
operation_type: "SELECT".to_string(),
timestamp: Utc::now(),
data: serde_json::Value::Object(data),
old_data: None,
}
} else {
let mut simple_record = serde_json::Map::new();
simple_record.insert("data".to_string(), serde_json::Value::Object(data));
DatabaseRecord {
table_name: table.clone(),
operation_type: "SELECT".to_string(),
timestamp: Utc::now(),
data: serde_json::Value::Object(simple_record),
old_data: None,
}
};
simd_json::to_vec(&record).map_err(|_| Error::InvalidRecord)?
};

let message = ProducedMessage {
id: Some(Uuid::new_v4().as_u128()),
headers: None,
checksum: None,
timestamp: Some(Utc::now().timestamp_millis() as u64),
origin_timestamp: Some(Utc::now().timestamp_millis() as u64),
payload,
};

messages.push(message);
messages.push(processed.message);
total_processed += 1;
}

Expand Down Expand Up @@ -851,6 +815,90 @@ impl PostgresSource {
None
}

fn process_row(
&self,
row: &sqlx::postgres::PgRow,
config: &RowProcessingConfig,
) -> Result<ProcessedRow, Error> {
let mut row_pk: Option<String> = None;
let mut max_offset: Option<String> = None;
let mut extracted_payload: Option<Vec<u8>> = None;
let mut data = serde_json::Map::new();

for (i, column) in row.columns().iter().enumerate() {
let column_name = if config.snake_case_columns {
to_snake_case(column.name())
} else {
column.name().to_string()
};

if !config.payload_col.is_empty() && column.name() == config.payload_col {
extracted_payload =
Some(self.extract_payload_column(row, i, config.payload_format)?);
continue;
}

let value = self.extract_column_value(row, i)?;
data.insert(column_name.clone(), value.clone());

if column.name() == config.tracking_column {
if let serde_json::Value::String(ref s) = value {
max_offset = Some(s.clone());
} else if let serde_json::Value::Number(ref n) = value {
max_offset = Some(n.to_string());
}
}

if column.name() == config.pk_column {
if let serde_json::Value::String(ref s) = value {
row_pk = Some(s.clone());
} else if let serde_json::Value::Number(ref n) = value {
row_pk = Some(n.to_string());
}
}
}

let payload = if let Some(bytes) = extracted_payload {
bytes
} else {
let record = if config.include_metadata {
DatabaseRecord {
table_name: config.table.to_string(),
operation_type: "SELECT".to_string(),
timestamp: Utc::now(),
data: serde_json::Value::Object(data),
old_data: None,
}
} else {
let mut simple_record = serde_json::Map::new();
simple_record.insert("data".to_string(), serde_json::Value::Object(data));
DatabaseRecord {
table_name: config.table.to_string(),
operation_type: "SELECT".to_string(),
timestamp: Utc::now(),
data: serde_json::Value::Object(simple_record),
old_data: None,
}
};
simd_json::to_vec(&record).map_err(|_| Error::InvalidRecord)?
};

let message = ProducedMessage {
id: Some(Uuid::new_v4().as_u128()),
headers: None,
checksum: None,
timestamp: Some(Utc::now().timestamp_millis() as u64),
origin_timestamp: Some(Utc::now().timestamp_millis() as u64),
payload,
};

Ok(ProcessedRow {
message,
max_offset,
row_pk,
})
}

fn extract_payload_column(
&self,
row: &sqlx::postgres::PgRow,
Expand Down
Loading