-
Notifications
You must be signed in to change notification settings - Fork 50
/
persistence.rs
257 lines (221 loc) · 8.75 KB
/
persistence.rs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
pub mod entity {
use rustc_hash::FxHashMap;
use std::sync::RwLock;
pub trait EntityMappingPersistor {
fn get_entity(&self, hash: u64) -> Option<String>;
fn put_data(&self, hash: u64, entity: String);
fn contains(&self, hash: u64) -> bool;
}
#[derive(Debug, Default)]
pub struct InMemoryEntityMappingPersistor {
entity_mappings: RwLock<FxHashMap<u64, String>>,
}
impl EntityMappingPersistor for InMemoryEntityMappingPersistor {
fn get_entity(&self, hash: u64) -> Option<String> {
let entity_mappings_read = self.entity_mappings.read().unwrap();
entity_mappings_read.get(&hash).map(|s| s.to_string())
}
fn put_data(&self, hash: u64, entity: String) {
let mut entity_mappings_write = self.entity_mappings.write().unwrap();
entity_mappings_write.insert(hash, entity);
}
fn contains(&self, hash: u64) -> bool {
let entity_mappings_read = self.entity_mappings.read().unwrap();
entity_mappings_read.contains_key(&hash)
}
}
}
pub mod embedding {
use crate::persistence::embedding::memmap::OwnedMmapArrayViewMut;
use ndarray::{s, Array};
use ndarray_npy::write_zeroed_npy;
use std::fs::File;
use std::io;
use std::io::{BufWriter, Error, ErrorKind, Write};
pub trait EmbeddingPersistor {
fn put_metadata(&mut self, entity_count: u32, dimension: u16) -> Result<(), io::Error>;
fn put_data(
&mut self,
entity: &str,
occur_count: u32,
vector: Vec<f32>,
) -> Result<(), io::Error>;
fn finish(&mut self) -> Result<(), io::Error>;
}
pub struct TextFileVectorPersistor {
buf_writer: BufWriter<File>,
produce_entity_occurrence_count: bool,
}
impl TextFileVectorPersistor {
pub fn new(filename: String, produce_entity_occurrence_count: bool) -> Self {
let msg = format!("Unable to create file: {}", filename);
let file = File::create(filename).expect(&msg);
TextFileVectorPersistor {
buf_writer: BufWriter::new(file),
produce_entity_occurrence_count,
}
}
}
impl EmbeddingPersistor for TextFileVectorPersistor {
fn put_metadata(&mut self, entity_count: u32, dimension: u16) -> Result<(), io::Error> {
write!(&mut self.buf_writer, "{} {}", entity_count, dimension)?;
Ok(())
}
fn put_data(
&mut self,
entity: &str,
occur_count: u32,
vector: Vec<f32>,
) -> Result<(), io::Error> {
self.buf_writer.write_all(b"\n")?;
self.buf_writer.write_all(entity.as_bytes())?;
if self.produce_entity_occurrence_count {
write!(&mut self.buf_writer, " {}", occur_count)?;
}
for &v in &vector {
self.buf_writer.write_all(b" ")?;
let mut buf = ryu::Buffer::new(); // cheap op
self.buf_writer.write_all(buf.format_finite(v).as_bytes())?;
}
Ok(())
}
fn finish(&mut self) -> Result<(), io::Error> {
self.buf_writer.write_all(b"\n")?;
Ok(())
}
}
mod memmap {
use memmap::MmapMut;
use ndarray::ArrayViewMut2;
use std::fs::OpenOptions;
use std::io;
use std::io::{Error, ErrorKind};
use std::ptr::drop_in_place;
pub struct OwnedMmapArrayViewMut {
mmap_ptr: *mut MmapMut,
mmap_data: Option<ndarray::ArrayViewMut2<'static, f32>>,
}
impl OwnedMmapArrayViewMut {
pub fn new(filename: &str) -> Result<Self, io::Error> {
use ndarray_npy::ViewMutNpyExt;
let file = OpenOptions::new().read(true).write(true).open(filename)?;
let mmap = unsafe { MmapMut::map_mut(&file)? };
let mmap = Box::new(mmap);
let mmap = Box::leak(mmap);
let mmap_ptr: *mut MmapMut = mmap as *mut _;
let mmap_data = ArrayViewMut2::<'static, f32>::view_mut_npy(mmap)
.map_err(|_| Error::new(ErrorKind::Other, "Mmap view error"))?;
Ok(Self {
mmap_ptr,
mmap_data: Some(mmap_data),
})
}
pub fn data_view<'a>(&'a mut self) -> &'a mut ArrayViewMut2<'a, f32> {
let view = self
.mmap_data
.as_mut()
.expect("Should be always defined. None only used in Drop");
// SAFETY: shortening lifetime from 'static to 'a is safe because underlying buffer won't be dropped until view is borrowed
unsafe {
core::mem::transmute::<
&mut ArrayViewMut2<'static, f32>,
&mut ArrayViewMut2<'a, f32>,
>(view)
}
}
}
impl Drop for OwnedMmapArrayViewMut {
fn drop(&mut self) {
// Unwind references with reverse order.
// First remove view that points to mmap_ptr
self.mmap_data = None;
// And now drop mmap_ptr
// SAFETY: safe because pointer leaked in constructor.
unsafe { drop_in_place(self.mmap_ptr) }
}
}
}
pub struct NpyPersistor {
entities: Vec<String>,
occurences: Vec<u32>,
array_file_name: String,
array_file: File,
array_write_context: Option<OwnedMmapArrayViewMut>,
occurences_buf: Option<BufWriter<File>>,
entities_buf: BufWriter<File>,
}
impl NpyPersistor {
pub fn new(filename: String, produce_entity_occurrence_count: bool) -> Self {
let entities_filename = format!("{}.entities", &filename);
let entities_buf = BufWriter::new(
File::create(&entities_filename)
.unwrap_or_else(|_| panic!("Unable to create file: {}", &entities_filename)),
);
let occurences_filename = format!("{}.occurences", &filename);
let occurences_buf = if produce_entity_occurrence_count {
Some(BufWriter::new(
File::create(&occurences_filename).unwrap_or_else(|_| {
panic!("Unable to create file: {}", &occurences_filename)
}),
))
} else {
None
};
let array_file_name = format!("{}.npy", &filename);
let array_file = File::create(&array_file_name)
.unwrap_or_else(|_| panic!("Unable to create file: {}", &array_file_name));
Self {
entities: vec![],
occurences: vec![],
array_file_name,
array_file,
array_write_context: None,
occurences_buf,
entities_buf,
}
}
}
impl EmbeddingPersistor for NpyPersistor {
fn put_metadata(&mut self, entity_count: u32, dimension: u16) -> Result<(), io::Error> {
write_zeroed_npy::<f32, _>(
&self.array_file,
[entity_count as usize, dimension as usize],
)
.map_err(|_| Error::new(ErrorKind::Other, "Write zeroed npy error"))?;
self.array_write_context = Some(OwnedMmapArrayViewMut::new(&self.array_file_name)?);
Ok(())
}
fn put_data(
&mut self,
entity: &str,
occur_count: u32,
vector: Vec<f32>,
) -> Result<(), io::Error> {
let array = &mut self
.array_write_context
.as_mut()
.expect("Should be defined. Was put_metadata not called?")
.data_view();
array
.slice_mut(s![self.entities.len(), ..])
.assign(&Array::from(vector));
self.entities.push(entity.to_owned());
self.occurences.push(occur_count);
Ok(())
}
fn finish(&mut self) -> Result<(), io::Error> {
use ndarray_npy::WriteNpyExt;
serde_json::to_writer_pretty(&mut self.entities_buf, &self.entities)?;
if let Some(occurences_buf) = self.occurences_buf.as_mut() {
let occur = ndarray::ArrayView1::from(&self.occurences);
occur.write_npy(occurences_buf).map_err(|e| {
Error::new(
ErrorKind::Other,
format!("Could not save occurences: {}", e),
)
})?;
}
Ok(())
}
}
}