Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions src/data_structures/floyds_algorithm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -91,5 +91,12 @@ mod tests {

assert!(has_cycle(&linked_list));
assert_eq!(detect_cycle(&linked_list), Some(3));

// Break the cycle before the list is dropped
unsafe {
if let Some(mut tail) = linked_list.tail {
tail.as_mut().next = None;
}
}
}
}
114 changes: 100 additions & 14 deletions src/general/huffman_encoding.rs
Original file line number Diff line number Diff line change
Expand Up @@ -77,10 +77,50 @@ pub struct HuffmanDictionary<T> {
}

impl<T: Clone + Copy + Ord> HuffmanDictionary<T> {
/// The list of alphabet symbols and their respective frequency should
/// be given as input
pub fn new(alphabet: &[(T, u64)]) -> Self {
/// Creates a new Huffman dictionary from alphabet symbols and their frequencies.
///
/// Returns `None` if the alphabet is empty.
///
/// # Arguments
/// * `alphabet` - A slice of tuples containing symbols and their frequencies
///
/// # Example
/// ```
/// # use the_algorithms_rust::general::HuffmanDictionary;
/// let freq = vec![('a', 5), ('b', 2), ('c', 1)];
/// let dict = HuffmanDictionary::new(&freq).unwrap();
///
pub fn new(alphabet: &[(T, u64)]) -> Option<Self> {
if alphabet.is_empty() {
return None;
}

let mut alph: BTreeMap<T, HuffmanValue> = BTreeMap::new();

// Special case: single symbol
if alphabet.len() == 1 {
let (symbol, _freq) = alphabet[0];
alph.insert(
symbol,
HuffmanValue {
value: 0,
bits: 1, // Must use at least 1 bit per symbol
},
);

let root = HuffmanNode {
left: None,
right: None,
symbol: Some(symbol),
frequency: alphabet[0].1,
};

return Some(HuffmanDictionary {
alphabet: alph,
root,
});
}

let mut queue: BinaryHeap<HuffmanNode<T>> = BinaryHeap::new();
for (symbol, freq) in alphabet.iter() {
queue.push(HuffmanNode {
Expand All @@ -101,11 +141,14 @@ impl<T: Clone + Copy + Ord> HuffmanDictionary<T> {
frequency: sm_freq,
});
}
let root = queue.pop().unwrap();
HuffmanNode::get_alphabet(0, 0, &root, &mut alph);
HuffmanDictionary {
alphabet: alph,
root,
if let Some(root) = queue.pop() {
HuffmanNode::get_alphabet(0, 0, &root, &mut alph);
Some(HuffmanDictionary {
alphabet: alph,
root,
})
} else {
None
}
}
pub fn encode(&self, data: &[T]) -> HuffmanEncoding {
Expand Down Expand Up @@ -143,27 +186,48 @@ impl HuffmanEncoding {
}
self.num_bits += data.bits as u64;
}

#[inline]
fn get_bit(&self, pos: u64) -> bool {
(self.data[(pos >> 6) as usize] & (1 << (pos & 63))) != 0
}

/// In case the encoding is invalid, `None` is returned
pub fn decode<T: Clone + Copy + Ord>(&self, dict: &HuffmanDictionary<T>) -> Option<Vec<T>> {
// Handle empty encoding
if self.num_bits == 0 {
return Some(vec![]);
}

// Special case: single symbol in dictionary
if dict.alphabet.len() == 1 {
//all bits represent the same symbol
let symbol = dict.alphabet.keys().next()?;
let result = vec![*symbol; self.num_bits as usize];
return Some(result);
}

// Normal case: multiple symbols
let mut state = &dict.root;
let mut result: Vec<T> = vec![];

for i in 0..self.num_bits {
if state.symbol.is_some() {
result.push(state.symbol.unwrap());
if let Some(symbol) = state.symbol {
result.push(symbol);
state = &dict.root;
}
state = if self.get_bit(i) {
state.right.as_ref().unwrap()
state.right.as_ref()?
} else {
state.left.as_ref().unwrap()
state.left.as_ref()?
}
}

// Check if we ended on a symbol
if self.num_bits > 0 {
result.push(state.symbol?);
}

Some(result)
}
}
Expand All @@ -181,12 +245,34 @@ mod tests {
.for_each(|(b, &cnt)| result.push((b as u8, cnt)));
result
}

#[test]
fn empty_text() {
let text = "";
let bytes = text.as_bytes();
let freq = get_frequency(bytes);
let dict = HuffmanDictionary::new(&freq);
assert!(dict.is_none());
}

#[test]
fn one_symbol_text() {
let text = "aaaa";
let bytes = text.as_bytes();
let freq = get_frequency(bytes);
let dict = HuffmanDictionary::new(&freq).unwrap();
let encoded = dict.encode(bytes);
assert_eq!(encoded.num_bits, 4);
let decoded = encoded.decode(&dict).unwrap();
assert_eq!(decoded, bytes);
}

#[test]
fn small_text() {
let text = "Hello world";
let bytes = text.as_bytes();
let freq = get_frequency(bytes);
let dict = HuffmanDictionary::new(&freq);
let dict = HuffmanDictionary::new(&freq).unwrap();
let encoded = dict.encode(bytes);
assert_eq!(encoded.num_bits, 32);
let decoded = encoded.decode(&dict).unwrap();
Expand All @@ -208,7 +294,7 @@ mod tests {
);
let bytes = text.as_bytes();
let freq = get_frequency(bytes);
let dict = HuffmanDictionary::new(&freq);
let dict = HuffmanDictionary::new(&freq).unwrap();
let encoded = dict.encode(bytes);
assert_eq!(encoded.num_bits, 2372);
let decoded = encoded.decode(&dict).unwrap();
Expand Down
2 changes: 1 addition & 1 deletion src/machine_learning/loss_function/kl_divergence_loss.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ pub fn kld_loss(actual: &[f64], predicted: &[f64]) -> f64 {
let loss: f64 = actual
.iter()
.zip(predicted.iter())
.map(|(&a, &p)| ((a + eps) * ((a + eps) / (p + eps)).ln()))
.map(|(&a, &p)| (a + eps) * ((a + eps) / (p + eps)).ln())
.sum();
loss
}
Expand Down