Skip to content

Commit

Permalink
Optimizations
Browse files Browse the repository at this point in the history
Improves performance by ~400%
  • Loading branch information
Unrud committed Aug 14, 2019
1 parent ffa3127 commit 9b32e1e
Show file tree
Hide file tree
Showing 2 changed files with 91 additions and 65 deletions.
106 changes: 58 additions & 48 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,18 +19,23 @@ pub const FRAME_LEN: usize = START_SYMBOLS_LEN + PAYLOAD_LEN + ECC_LEN;
const MEASUREMENTS_PER_SYMBOL: usize = 10;
pub const SYMBOL_MNEMONICS: &str = "0123456789abcdefghijklmnopqrstuv";

#[derive(Clone, Copy, Default)]
struct Symbol {
value: u8,
magnitude: f32,
noise: f32,
macro_rules! mod_short {
($i:expr, $len:expr) => ({
debug_assert!($i < 2 * $len);
if $i < $len {
$i
} else {
$i - $len
}
})
}

#[derive(Clone, Copy, Default)]
struct Frame {
active: bool,
data: [Symbol; FRAME_LEN],
data: [u8; FRAME_LEN],
data_pos: usize,
signal_quality: f32,
}

pub struct Transceiver {
Expand All @@ -53,6 +58,7 @@ pub struct Transceiver {
}

impl Transceiver {
#[inline]
fn calc_freq(symbol: u8) -> f32 {
BASE_FREQ * SEMITONE.powi(symbol as i32)
}
Expand Down Expand Up @@ -96,100 +102,104 @@ impl Transceiver {
data[START_SYMBOLS_LEN..].copy_from_slice(payload);
let encoded_data = self.rs_encoder.encode(&data);
let mut frequencies: [f32; FRAME_LEN] = Default::default();
debug_assert_eq!(encoded_data.len(), frequencies.len());
for (i, symbol) in encoded_data.iter().enumerate() {
frequencies[i] = Self::calc_freq(*symbol);
}
(self.on_transmit)(frequencies);
}

pub fn push_sample(&mut self, sample: f32) {
// Push new sample to ring buffer
self.sample_buffer[self.sample_buffer_pos] = sample;
self.sample_buffer_pos = (self.sample_buffer_pos + 1) % self.sample_buffer.len();
self.sample_buffer_pos = mod_short!(self.sample_buffer_pos + 1, self.sample_buffer.len());
self.remaining_samples -= 1.;
if self.remaining_samples > 0. {
return;
}
self.remaining_samples += (self.sample_rate as f32) * BEEP_LEN / (MEASUREMENTS_PER_SYMBOL as f32);
let mut sample_window = vec![0f32; self.sample_buffer.len()];
for i in self.sample_buffer_pos..self.sample_buffer.len() {
let j = i - self.sample_buffer_pos;
sample_window[j] = self.sample_buffer[i] * self.window_weights[j];

// Decode symbol
let mut goertzel_partials: [goertzel::Partial; SYMBOL_COUNT] = unsafe {std::mem::uninitialized()};
for (i, goertzel_filter) in self.goertzel_filters.iter().enumerate() {
goertzel_partials[i] = goertzel_filter.start();
}
for i in 0..self.sample_buffer_pos {
let j = i + (self.sample_buffer.len() - self.sample_buffer_pos);
sample_window[j] = self.sample_buffer[i] * self.window_weights[j];
for i in 0..self.sample_buffer.len() {
let j = mod_short!(self.sample_buffer_pos + i, self.sample_buffer.len());
let window_sample = self.sample_buffer[j] * self.window_weights[i];
for goertzel_partial in goertzel_partials.iter_mut() {
goertzel_partial.push(window_sample);
}
}
let mut best_symbol: Symbol = Default::default();
for (i, goertzel_filter) in self.goertzel_filters.iter().enumerate() {
let magnitude = goertzel_filter.mag(&sample_window);
if magnitude > best_symbol.magnitude {
best_symbol = Symbol {
value: i as u8,
magnitude: magnitude,
noise: best_symbol.magnitude
};
} else if magnitude > best_symbol.noise {
best_symbol.noise = magnitude;
// Find symbol with highest magnitude
let mut next_symbol = 0u8;
let mut next_symbol_magnitude = 0.;
let mut next_symbol_noise = 0.;
for (i, &goertzel_partial) in goertzel_partials.into_iter().enumerate() {
let magnitude = goertzel_partial.finish_fast();
if magnitude > next_symbol_magnitude {
next_symbol = i as u8;
next_symbol_noise = next_symbol_magnitude;
next_symbol_magnitude = magnitude;
} else if magnitude > next_symbol_noise {
next_symbol_noise = magnitude;
}
}

// Add new symbol to partial frames
let completed_frame = self.frames[self.frames_pos];
self.frames[self.frames_pos] = Default::default();
self.frames[self.frames_pos].active = true;
let next_symbol_snr = next_symbol_magnitude / next_symbol_noise;
for i in 0..FRAME_LEN {
let frame = &mut self.frames[(self.frames_pos + i * MEASUREMENTS_PER_SYMBOL) % self.frames.len()];
let frame_pos = mod_short!(self.frames_pos + i * MEASUREMENTS_PER_SYMBOL, self.frames.len());
let frame = &mut self.frames[frame_pos];
if frame.active {
frame.data[frame.data_pos] = best_symbol;
frame.data[frame.data_pos] = next_symbol;
frame.data_pos += 1;
frame.signal_quality += next_symbol_snr;
}
}
self.frames_pos = mod_short!(self.frames_pos + 1, self.frames.len());

// Replace active frame, if completed frame is higher quality
if completed_frame.active {
let mut raw_data: [u8; FRAME_LEN] = Default::default();
for (i, symbol) in completed_frame.data.iter().enumerate() {
raw_data[i] = symbol.value;
}
if let Ok(corrected_data) = self.rs_decoder.correct(&mut raw_data, None) {
let corrected_data = *corrected_data;
debug_assert_eq!(completed_frame.data_pos, completed_frame.data.len());
if let Ok(corrected_data) = self.rs_decoder.correct(&completed_frame.data, None) {
let corrected_data = corrected_data.data();
let start_symbols_ok = corrected_data[..START_SYMBOLS_LEN] == START_SYMBOLS;
if start_symbols_ok {
let mut payload: [u8; PAYLOAD_LEN] = Default::default();
for (i, &c) in corrected_data.iter().skip(START_SYMBOLS_LEN).take(PAYLOAD_LEN).enumerate() {
payload[i] = c;
}
let mut correct_symbols = 0;
for (i, &c) in corrected_data.iter().enumerate() {
if raw_data[i] == c {
if completed_frame.data[i] == c {
correct_symbols += 1;
}
}
let mut signal_quality = 0.;
for symbol in completed_frame.data.iter() {
signal_quality += symbol.magnitude / symbol.noise;
}
let frame_quality = (correct_symbols, signal_quality);
let frame_quality = (correct_symbols, completed_frame.signal_quality);
if !self.active_frame || self.active_frame_quality < frame_quality {
self.active_frame_payload = payload;
self.active_frame_payload.copy_from_slice(&corrected_data[START_SYMBOLS_LEN..][..PAYLOAD_LEN]);
self.active_frame_quality = frame_quality;
if !self.active_frame {
self.active_frame = true;
self.active_frame_age = 0;
// Skip following frames
// Skip all frames after complete symbol
for i in MEASUREMENTS_PER_SYMBOL..self.frames.len() {
let frame_pos = (self.frames_pos + i) % self.frames.len();
self.frames[frame_pos].active = false;
self.frames[mod_short!(self.frames_pos + i - 1, self.frames.len())].active = false;
}
}
}
}
}
}

// Delay receiving of active frame until symbol is complete
if self.active_frame {
self.active_frame_age += 1;
if self.active_frame_age == MEASUREMENTS_PER_SYMBOL {
self.active_frame = false;
(self.on_received)(self.active_frame_payload);
}
}
self.frames_pos = (self.frames_pos + 1) % self.frames.len();
}
}

Expand Down
50 changes: 33 additions & 17 deletions thirdparty/goertzel/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ pub struct Parameters {
term_coefficient: f32,
}

#[derive(Clone, Copy)]
pub struct Partial {
params: Parameters,
count: usize,
Expand All @@ -34,43 +35,50 @@ impl Parameters {
}
}

#[inline]
pub fn start(self) -> Partial {
Partial{ params: self, count: 0, prev: 0., prevprev: 0. }
}

pub fn mag(self, samples: &[f32]) -> f32 {
self.start().add(samples).finish_mag()
}
}

impl Partial {
pub fn add(mut self, samples: &[f32]) -> Self {
for &sample in samples {
let this = self.params.term_coefficient * self.prev - self.prevprev + sample;
self.prevprev = self.prev;
self.prev = this;
}
self.count += samples.len();
self
#[inline]
pub fn push(&mut self, sample: f32) {
let this = self.params.term_coefficient * self.prev - self.prevprev + sample;
self.prevprev = self.prev;
self.prev = this;
debug_assert_eq!(self.count += 1, ());
}

#[inline]
pub fn finish(self) -> (f32, f32) {
assert_eq!(self.count, self.params.window_size);
debug_assert_eq!(self.count, self.params.window_size);
let real = self.prev - self.prevprev * self.params.cosine;
let imag = self.prevprev * self.params.sine;
(real, imag)
}

#[inline]
pub fn finish_mag(self) -> f32 {
let (real, imag) = self.finish();
(real*real + imag*imag).sqrt()
}

#[inline]
pub fn finish_fast(self) -> f32 {
debug_assert_eq!(self.count, self.params.window_size);
self.prev*self.prev + self.prevprev*self.prevprev - self.prev*self.prevprev*self.params.term_coefficient
}
}

#[test]
fn zero_data() {
let p = Parameters::new(1800., 8000, 256);
assert!(p.start().add(&[0.; 256]).finish_mag() == 0.);
assert!(p.start().add(&[0.; 128]).add(&[0.; 128]).finish_mag() == 0.);
let mut pa = p.start();
for &sample in [0.; 256].iter() {
pa.push(sample);
}
assert_eq!(pa.finish_mag(), 0.);
}

#[test]
Expand All @@ -85,10 +93,18 @@ fn sine() {
}

let p = Parameters::new(freq, 8000, 8000);
let mag = p.start().add(&buf[..]).finish_mag();
let mut pa = p.start();
for &sample in buf.iter() {
pa.push(sample);
}
let mag = pa.finish_mag();
for testfreq in (0 .. 30).map(|x| (x * 100) as f32) {
let p = Parameters::new(testfreq, 8000, 8000);
let testmag = p.mag(&buf[..]);
let mut pa = p.start();
for &sample in buf.iter() {
pa.push(sample);
}
let testmag = pa.finish_mag();
println!("{:4}: {:12.3}", testfreq, testmag);
if (freq-testfreq).abs() > 100. {
println!("{} > 10*{}", mag, testmag);
Expand Down

0 comments on commit 9b32e1e

Please sign in to comment.