Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pyo3 python bindings, support for parquets output and s3 input/output #70

Open
wants to merge 8 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
890 changes: 653 additions & 237 deletions Cargo.lock

Large diffs are not rendered by default.

9 changes: 9 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,10 @@ description = """
Cleora is a general-purpose model for efficient, scalable learning of stable and inductive entity embeddings for heterogeneous relational data.
"""

[lib]
name = "cleora"
crate-type = ["cdylib"]

[build]
rustflags = ["-C", "target-cpu=native"]

Expand All @@ -28,6 +32,11 @@ ndarray = "0.15.4"
ndarray-npy = "0.8.1"
serde_json = "1.0.81"
uuid = { version = "1.1.2", features = ["v4"] }
pyo3 = { version = "0.16.5", features = ["extension-module"] }
arrow2 = { version="0.12.0", default-features = false, features = ["io_parquet", "io_parquet_compression"] }
rusoto_s3 = "0.42.0"
rusoto_core = "0.42.0"
chrono = "0.4.22"

[dev-dependencies]
criterion = "0.3.3"
Expand Down
19 changes: 19 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
[build-system]
requires = ["maturin>=0.13,<0.14"]
build-backend = "maturin"

[project]
name = "cleora"
version = "1.2.3"
requires-python = ">=3.7"
description = "Cleora is a general-purpose model for efficient, scalable learning of stable and inductive entity embeddings for heterogeneous relational data."
readme = "README.md"
license = {file = "LICENSE"}
classifiers = [
"Programming Language :: Rust",
"Programming Language :: Python :: Implementation :: CPython",
"Programming Language :: Python :: Implementation :: PyPy",
]

[project.urls]
repository = "https://github.com/Synerise/cleora"
5 changes: 5 additions & 0 deletions src/configuration.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ pub enum FileType {
#[derive(Debug)]
pub enum OutputFormat {
TextFile,
Parquet,
Numpy,
}

Expand Down Expand Up @@ -53,6 +54,9 @@ pub struct Configuration {

/// Columns configuration
pub columns: Vec<Column>,

/// Chunk size used in write
pub chunk_size: usize,
}

/// Column configuration
Expand Down Expand Up @@ -91,6 +95,7 @@ impl Configuration {
output_format: OutputFormat::TextFile,
relation_name: String::from("emb"),
columns,
chunk_size: 1000,
}
}

Expand Down
69 changes: 60 additions & 9 deletions src/embedding.rs
Original file line number Diff line number Diff line change
Expand Up @@ -331,7 +331,12 @@ pub fn calculate_embeddings<T1, T2>(
let mult = MatrixMultiplicator::new(config.clone(), sparse_matrix_reader);
let init: TwoDimVectorMatrix = mult.initialize();
let res = mult.propagate(config.max_number_of_iteration, init);
mult.persist(res, entity_mapping_persistor, embedding_persistor);
mult.persist(
res,
entity_mapping_persistor,
embedding_persistor,
config.chunk_size,
);

info!("Finalizing embeddings calculations!")
}
Expand Down Expand Up @@ -417,6 +422,7 @@ where
res: M,
entity_mapping_persistor: Arc<T1>,
embedding_persistor: &mut dyn EmbeddingPersistor,
chunk_size: usize,
) where
T1: EntityMappingPersistor,
{
Expand All @@ -434,22 +440,62 @@ where

// entities which can't be written to the file (error occurs)
let mut broken_entities = HashSet::new();
let mut chunk: (Vec<String>, Vec<u32>, Vec<Vec<f32>>) = (
Vec::new(),
Vec::new(),
(0..self.dimension)
.into_iter()
.map(|_x| Vec::new())
.collect(),
);

let mut entity_names: Vec<String> = Vec::new();
//let chunk_size: usize = 1000;

for (i, hash) in self.sparse_matrix_reader.iter_hashes().enumerate() {
let entity_name_opt = entity_mapping_persistor.get_entity(hash.value);
if let Some(entity_name) = entity_name_opt {
let mut embedding: Vec<f32> = Vec::with_capacity(self.dimension);
chunk.0.push(entity_name.clone());
chunk.1.push(hash.occurrence);
entity_names.push(entity_name);

//let mut embedding: Vec<f32> = Vec::with_capacity(self.dimension);
for j in 0..self.dimension {
let value = res.get_value(i, j);
embedding.push(value);
//embedding.push(value);
chunk.2[j].push(value);
}

if i % chunk_size == 0 {
embedding_persistor
.put_data_chunk(chunk)
.unwrap_or_else(|_| {
entity_names.into_iter().for_each(|e| {
broken_entities.insert(e);
});
});

entity_names = Vec::new();
chunk = (
Vec::new(),
Vec::new(),
(0..self.dimension)
.into_iter()
.map(|_x| Vec::new())
.collect(),
);
}
embedding_persistor
.put_data(&entity_name, hash.occurrence, embedding)
.unwrap_or_else(|_| {
broken_entities.insert(entity_name);
});
};
}

embedding_persistor
.put_data_chunk(chunk)
.unwrap_or_else(|_| {
entity_names.into_iter().for_each(|e| {
broken_entities.insert(e);
});
});

if !broken_entities.is_empty() {
log_broken_entities(broken_entities);
}
Expand Down Expand Up @@ -487,7 +533,12 @@ pub fn calculate_embeddings_mmap<T1, T2>(
let mult = MatrixMultiplicator::new(config.clone(), sparse_matrix_reader);
let init: MMapMatrix = mult.initialize();
let res = mult.propagate(config.max_number_of_iteration, init);
mult.persist(res, entity_mapping_persistor, embedding_persistor);
mult.persist(
res,
entity_mapping_persistor,
embedding_persistor,
config.chunk_size,
);

info!("Finalizing embeddings calculations!")
}
217 changes: 217 additions & 0 deletions src/io.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,217 @@
use rusoto_core::region::Region;
use rusoto_core::{ByteStream, RusotoError};
use rusoto_s3::{
AbortMultipartUploadRequest, CompleteMultipartUploadRequest, CompletedMultipartUpload,
CompletedPart, CreateMultipartUploadRequest, GetObjectError, GetObjectRequest,
UploadPartRequest,
};
use rusoto_s3::{S3Client, S3};
use std::env;
use std::io::{Error, Read, Write};
use std::time::Duration;

pub struct S3File {
bucket_name: String,
object_key: String,
s3_client: S3Client,
upload_id: String,
completed_parts: Vec<CompletedPart>,
part_number: i64,
buff: Vec<u8>,
completed: bool,
part_size: usize,
}

impl Drop for S3File {
fn drop(&mut self) {
self.complete();
}
}

impl S3File {
pub fn create(filename: String) -> S3File {
let (s3_client, bucket_name, object_key) = S3File::create_client(filename);

let part_size = 10 * 1024 * 1024;
let timeout = Duration::from_secs(10);

let completed_parts: Vec<CompletedPart> = Vec::new();
let upload_id = &s3_client
.create_multipart_upload(CreateMultipartUploadRequest {
bucket: bucket_name.clone(),
key: object_key.clone(),
//content_type: Some(meta.content_type),
//content_disposition: meta.content_disposition,
//content_language: meta.content_language,
..Default::default()
})
.with_timeout(timeout)
.sync()
.unwrap()
.upload_id
.expect("no upload ID");

let buff = Vec::new();

S3File {
bucket_name,
object_key,
s3_client,
upload_id: upload_id.to_string(),
completed_parts,
part_number: 0,
buff,
completed: false,
part_size,
}
}

pub fn open(
filename: String,
) -> Result<impl std::io::Read + Send, RusotoError<GetObjectError>> {
let (s3_client, bucket_name, object_key) = S3File::create_client(filename);

let data_timeout = Duration::from_secs(300);

s3_client
.get_object(GetObjectRequest {
bucket: bucket_name.clone(),
key: object_key.clone(),
..Default::default()
})
.with_timeout(data_timeout)
.sync()
.map(|output| output.body.unwrap().into_blocking_read())
}

fn create_client(filename: String) -> (S3Client, String, String) {
let region = match env::var("S3_ENDPOINT_URL") {
Ok(endpoint) => Region::Custom {
name: "custom".to_string(),
endpoint,
},
Err(_) => Region::default(),
};

let path: Vec<&str> = filename.strip_prefix("s3://").unwrap().split("/").collect();
let bucket_name: String = path[0].to_string();
let object_key: String = path[1..].join("/");

let s3_client = S3Client::new(region);


(s3_client, bucket_name, object_key)
}

fn write_buff(&mut self) {
if self.buff.len() == 0 {
return;
}

let buff = self.buff.to_owned();
let data_timeout = Duration::from_secs(300);

let result = self
.s3_client
.upload_part(UploadPartRequest {
body: Some(ByteStream::from(buff)),
bucket: self.bucket_name.clone(),
key: self.object_key.clone(),
part_number: self.part_number as i64,
upload_id: self.upload_id.clone(),
..Default::default()
})
.with_timeout(data_timeout)
.sync()
.unwrap();

self.completed_parts.push(CompletedPart {
e_tag: result.e_tag,
part_number: Some(self.part_number as i64),
});

self.part_number += 1;
self.buff = Vec::new();
}

pub fn complete(&mut self) {
if !self.completed {
self.write_buff();
let timeout = Duration::from_secs(10);
self.s3_client
.complete_multipart_upload(CompleteMultipartUploadRequest {
bucket: self.bucket_name.clone(),
key: self.object_key.clone(),
upload_id: self.upload_id.clone(),
multipart_upload: Some(CompletedMultipartUpload {
parts: Some(self.completed_parts.clone()),
}),
..Default::default()
})
.with_timeout(timeout)
.sync()
.unwrap();
self.completed = true;
}
}

pub fn abort_upload(&mut self) {
let timeout = Duration::from_secs(10);
self.s3_client
.abort_multipart_upload(AbortMultipartUploadRequest {
bucket: self.bucket_name.clone(),
key: self.object_key.clone(),
upload_id: self.upload_id.clone(),
..Default::default()
})
.with_timeout(timeout)
.sync()
.unwrap();
self.completed = true;
}
}

impl Write for S3File {
fn write(&mut self, buf: &[u8]) -> Result<usize, Error> {
self.buff.extend_from_slice(buf);

if self.buff.len() > self.part_size {
self.write_buff();
}

Ok(buf.len())
}

fn flush(&mut self) -> Result<(), Error> {
//self.write_buff();
Ok(())
}
}

#[test]
fn open_write_read_test() {
use std::io::{BufRead, BufReader, Read};

// the test requires local minio setup
env::set_var("S3_ENDPOINT_URL", "http://minio:9000");
env::set_var("AWS_ACCESS_KEY_ID", "minioadmin");
env::set_var("AWS_SECRET_ACCESS_KEY", "minioadmin");

let mut f = S3File::create("s3://input/hello.txt".to_string());

f.write(b"hello world\n");
f.write(b"hello world");
f.complete();

let mut file1 = S3File::open("s3://input/hello.txt".to_string()).unwrap();
let mut data: Vec<u8> = Vec::new();
file1.read_to_end(&mut data);
assert_eq!(data, b"hello world\nhello world");

let mut file2 = S3File::open("s3://input/hello.txt".to_string()).unwrap();
let mut buff = BufReader::new(file2);
let mut line = String::new();
buff.read_line(&mut line);

assert_eq!(line, "hello world\n");
}
Loading