Skip to content

Commit

Permalink
Changed ComplexVector and ComplexMatrix implementation to use con…
Browse files Browse the repository at this point in the history
…st generics.

By using const generics, we avoid having to manually assert that instances of vectors and matrices have the correct dimensions to operate on them. The compiler enforces that for us.
  • Loading branch information
jpyamamoto committed Jul 11, 2022
1 parent cd5e64d commit c2c58a4
Show file tree
Hide file tree
Showing 3 changed files with 100 additions and 160 deletions.
16 changes: 8 additions & 8 deletions src/exercises/chapter2.rs
Expand Up @@ -5,33 +5,33 @@ use crate::utils::complex_matrix::ComplexMatrix;
pub fn programming_drill_2_1_1() {
println!("Solution to the programming drill 2.1.1.");

let v1 = ComplexVector(vec![Complex::new(6.0, -4.0), Complex::new(7.0, 3.0), Complex::new(4.2, -8.1), Complex::new(0.0, -3.0)]);
let v1 = ComplexVector([Complex::new(6.0, -4.0), Complex::new(7.0, 3.0), Complex::new(4.2, -8.1), Complex::new(0.0, -3.0)]);
print!("-{} = ", v1);
println!("{}", -v1);

let v2 = ComplexVector(vec![Complex::new(6.0, 3.0), Complex::new(0.0, 0.0), Complex::new(5.0, 1.0), Complex::new(4.0, 0.0)]);
let v2 = ComplexVector([Complex::new(6.0, 3.0), Complex::new(0.0, 0.0), Complex::new(5.0, 1.0), Complex::new(4.0, 0.0)]);
print!("{} * {} = ", Complex::new(3.0, 2.0), v2);
println!("{}", v2 * Complex::new(3.0, 2.0));

let v3 = ComplexVector(vec![Complex::new(6.0, -4.0), Complex::new(7.0, 3.0), Complex::new(4.2, -8.1), Complex::new(0.0, -3.0)]);
let v4 = ComplexVector(vec![Complex::new(16.0, 2.5), Complex::new(0.0, -7.0), Complex::new(6.0, 0.0), Complex::new(0.0, -4.0)]);
let v3 = ComplexVector([Complex::new(6.0, -4.0), Complex::new(7.0, 3.0), Complex::new(4.2, -8.1), Complex::new(0.0, -3.0)]);
let v4 = ComplexVector([Complex::new(16.0, 2.5), Complex::new(0.0, -7.0), Complex::new(6.0, 0.0), Complex::new(0.0, -4.0)]);
print!("{} + {} = ", v3, v4);
println!("{}", v3+v4);
}

pub fn programming_drill_2_2_1() {
println!("Solution to the programming drill 2.2.1.");

let v1 = ComplexMatrix::new(vec![Complex::new(6.0, -4.0), Complex::new(7.0, 3.0), Complex::new(4.2, -8.1), Complex::new(0.0, -3.0)], 2, 2);
let v1 = ComplexMatrix::new([[Complex::new(6.0, -4.0), Complex::new(7.0, 3.0)], [Complex::new(4.2, -8.1), Complex::new(0.0, -3.0)]]);
print!("-{} = ", v1);
println!("{}", -v1);

let v2 = ComplexMatrix::new(vec![Complex::new(6.0, 3.0), Complex::new(0.0, 0.0), Complex::new(5.0, 1.0), Complex::new(4.0, 0.0)], 2, 2);
let v2 = ComplexMatrix::new([[Complex::new(6.0, 3.0), Complex::new(0.0, 0.0)], [Complex::new(5.0, 1.0), Complex::new(4.0, 0.0)]]);
print!("{} * {} = ", Complex::new(3.0, 2.0), v2);
println!("{}", v2 * Complex::new(3.0, 2.0));

let v3 = ComplexMatrix::new(vec![Complex::new(6.0, -4.0), Complex::new(7.0, 3.0), Complex::new(4.2, -8.1), Complex::new(0.0, -3.0)], 2, 2);
let v4 = ComplexMatrix::new(vec![Complex::new(16.0, 2.5), Complex::new(0.0, -7.0), Complex::new(6.0, 0.0), Complex::new(0.0, -4.0)], 2, 2);
let v3 = ComplexMatrix::new([[Complex::new(6.0, -4.0), Complex::new(7.0, 3.0)], [Complex::new(4.2, -8.1), Complex::new(0.0, -3.0)]]);
let v4 = ComplexMatrix::new([[Complex::new(16.0, 2.5), Complex::new(0.0, -7.0)], [Complex::new(6.0, 0.0), Complex::new(0.0, -4.0)]]);
print!("{} + {} = ", v3, v4);
println!("{}", v3+v4);
}
191 changes: 71 additions & 120 deletions src/utils/complex_matrix.rs
Expand Up @@ -5,62 +5,49 @@ use crate::utils::complex_number::Complex;
use crate::utils::complex_vector::ComplexVector;

#[derive(Debug, PartialEq)]
pub struct ComplexMatrix{
elements: Vec<Complex>,
rows: usize,
columns: usize,
}

impl ComplexMatrix {
pub fn new(vector: Vec<Complex>, rows: usize, columns: usize) -> Self {
if vector.len() != (rows * columns) {
panic!("Vector should be of size rows * columns.")
}

ComplexMatrix { elements: vector, rows, columns }
}
pub struct ComplexMatrix<const R: usize, const C: usize>([[Complex; C]; R]);

pub fn dimensions(&self) -> (usize, usize) {
(self.rows, self.columns)
impl<const R: usize, const C: usize> ComplexMatrix<R, C> {
pub fn new(values: [[Complex; C]; R]) -> Self {
ComplexMatrix(values)
}
}

impl From<ComplexVector> for ComplexMatrix {
fn from(ComplexVector(rhs): ComplexVector) -> Self {
let rows = rhs.len();
ComplexMatrix { elements: rhs, rows, columns: 1 }
impl<const N: usize> From<ComplexVector<N>> for ComplexMatrix<N, 1> {
fn from(ComplexVector(rhs): ComplexVector<N>) -> Self {
ComplexMatrix(rhs.map(|c| [c]))
}
}

impl Index<[usize; 2]> for ComplexMatrix {
impl<const R: usize, const C: usize> Index<[usize; 2]> for ComplexMatrix<R, C> {
type Output = Complex;

fn index(&self, index: [usize; 2]) -> &Self::Output {
let [row, column] = &index;
let [row, column] = index;

if row >= &self.rows || column >= &self.columns {
if row >= R || column >= C {
panic!("Index out of range.")
}

&self.elements[(row * &self.columns) + column]
&self.0[row][column]
}
}

impl IndexMut<[usize; 2]> for ComplexMatrix {
impl<const R: usize, const C: usize> IndexMut<[usize; 2]> for ComplexMatrix<R, C> {
fn index_mut(&mut self, index: [usize; 2]) -> &mut Self::Output {
let [row, column] = &index;
let [row, column] = index;

if row >= &self.rows || column >= &self.columns {
if row >= R || column >= C {
panic!("Index out of range.")
}

&mut self.elements[(row * &self.columns) + column]
&mut self.0[row][column]
}


}

impl Add for ComplexMatrix {
impl<const R: usize, const C: usize> Add for ComplexMatrix<R, C> {
type Output = Self;

fn add(self, rhs: Self) -> Self::Output {
Expand All @@ -69,7 +56,7 @@ impl Add for ComplexMatrix {
}

/// Support for scalar product on complex matrices.
impl Mul<Complex> for ComplexMatrix {
impl<const R: usize, const C: usize> Mul<Complex> for ComplexMatrix<R, C> {
type Output = Self;

fn mul(self, rhs: Complex) -> Self::Output {
Expand All @@ -78,41 +65,41 @@ impl Mul<Complex> for ComplexMatrix {
}

/// Support for vector-matrix product.
impl Mul<ComplexVector> for ComplexMatrix {
type Output = Self;
impl<const R: usize, const C: usize> Mul<ComplexVector<C>> for ComplexMatrix<R, C> {
type Output = ComplexMatrix<R, 1>;

fn mul(self, rhs: ComplexVector) -> Self::Output {
fn mul(self, rhs: ComplexVector<C>) -> Self::Output {
product_matrices(self, ComplexMatrix::from(rhs))
}
}

/// Support for product on complex matrices.
impl Mul<ComplexMatrix> for ComplexMatrix {
type Output = Self;
impl<const R: usize, const C: usize, const P: usize> Mul<ComplexMatrix<C, P>> for ComplexMatrix<R, C> {
type Output = ComplexMatrix<R, P>;

fn mul(self, rhs: ComplexMatrix) -> Self::Output {
fn mul(self, rhs: ComplexMatrix<C, P>) -> Self::Output {
product_matrices(self, rhs)
}
}

/// Support for negating complex matrices.
impl Neg for ComplexMatrix {
impl<const R: usize, const C: usize> Neg for ComplexMatrix<R, C> {
type Output = Self;

fn neg(self) -> Self::Output {
inverse_matrix(self)
negated_matrix(self)
}
}

/// Support for displaying complex matrices.
impl Display for ComplexMatrix {
impl<const R: usize, const C: usize> Display for ComplexMatrix<R, C> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let mut result_string = String::new();

for r in 0..self.rows {
for r in 0..R {
let mut row_display = String::new();

for c in 0..self.columns {
for c in 0..C {
row_display.push_str(format!("{}, ", self[[r, c]]).as_str());
}

Expand All @@ -127,61 +114,42 @@ impl Display for ComplexMatrix {
}

/// Coordinate-wise matrix addition.
fn add_matrices(matrix1: ComplexMatrix, matrix2: ComplexMatrix) -> ComplexMatrix {
if matrix1.dimensions() != matrix2.dimensions() {
panic!("Cannot add matrices of different dimensions.");
}

let result_vector: Vec<Complex> = matrix1.elements
.iter()
.zip(matrix2.elements.iter())
.map(|(&x,&y)| x + y)
.collect();
fn add_matrices<const R: usize, const C: usize>(matrix1: ComplexMatrix<R, C>, matrix2: ComplexMatrix<R, C>) -> ComplexMatrix<R, C> {
let mut result_array: [[Complex; C]; R] = [[Complex::new(0.0, 0.0); C]; R];

ComplexMatrix {
elements: result_vector,
rows: matrix1.rows,
columns: matrix1.columns
for y in 0..C {
for x in 0..R {
result_array[x][y] = matrix1[[x, y]] + matrix2[[x, y]];
}
}

ComplexMatrix(result_array)
}

/// Coordinate-wise complex scalar by complex matrix product.
fn product_matrix_scalar(matrix: ComplexMatrix, scalar: Complex) -> ComplexMatrix {
let new_elements = matrix.elements.iter().map(|&x| x * scalar).collect();
fn product_matrix_scalar<const R: usize, const C: usize>(matrix: ComplexMatrix<R, C>, scalar: Complex) -> ComplexMatrix<R, C> {
let new_elements = matrix.0.map(|arr| arr.map(|x| scalar * x));

ComplexMatrix {
elements: new_elements,
rows: matrix.rows,
columns: matrix.columns
}
ComplexMatrix(new_elements)
}

/// Matrix-Vector product.
pub fn product_matrix_vector(matrix: ComplexMatrix, vector: ComplexVector) -> ComplexVector {
//let vec_to_mat = ComplexMatrix::from(vector);
//let ComplexMatrix { elements, .. } = matrix * vec_to_mat;
//ComplexVector(elements)
//
pub fn product_matrix_vector<const R: usize, const C: usize>(matrix: ComplexMatrix<R, C>, vector: ComplexVector<C>) -> ComplexVector<R> {
let vec_to_mat = ComplexMatrix::from(vector);
let ComplexMatrix { elements, .. } = matrix * vec_to_mat;
ComplexVector(elements)
let result_matrix = matrix * vec_to_mat;
let result_vector = result_matrix.0.map(|row| row[0]);
ComplexVector(result_vector)
}

/// Standard complex matrices product.
fn product_matrices(m1: ComplexMatrix, m2: ComplexMatrix) -> ComplexMatrix {
if m1.columns != m2.rows {
panic!("Number of columns in the left-hand side matrix should be the \
same as number of rows in the right-hand side matrix.");
}

let mut m3 = ComplexMatrix::new(
vec![Complex::new(0.0, 0.0); m1.rows * m2.columns], m1.rows, m2.columns);
fn product_matrices<const R: usize, const C: usize, const P: usize>(m1: ComplexMatrix<R, C>, m2: ComplexMatrix<C, P>) -> ComplexMatrix<R, P> {
let mut m3 = ComplexMatrix::new([[Complex::new(0.0, 0.0); P]; R]);

for j in 0..m3.rows {
for k in 0..m3.columns {
for j in 0..R {
for k in 0..P {
let mut sum = Complex::new(0.0, 0.0);

for h in 0..m1.columns {
for h in 0..C {
sum += m1[[j,h]] * m2[[h,k]]
}

Expand All @@ -193,12 +161,8 @@ fn product_matrices(m1: ComplexMatrix, m2: ComplexMatrix) -> ComplexMatrix {
}

/// Inverse over addition matrix, by negating each coordinate.
fn inverse_matrix(matrix: ComplexMatrix) -> ComplexMatrix {
ComplexMatrix {
elements: matrix.elements.iter().map(|&x| -x).collect(),
rows: matrix.rows,
columns: matrix.columns
}
fn negated_matrix<const R: usize, const C: usize>(matrix: ComplexMatrix<R, C>) -> ComplexMatrix<R, C> {
ComplexMatrix(matrix.0.map(|row| row.map(|x| -x)))
}

#[cfg(test)]
Expand All @@ -207,68 +171,55 @@ mod tests {

#[test]
fn test_vector_matrix() {
let v = ComplexVector(vec![Complex::new(1.0, 0.0), Complex::new(0.0, 0.0), Complex::new(0.0, 0.0), Complex::new(1.0, 0.0)]);
let m = ComplexMatrix::new(vec![Complex::new(1.0, 0.0), Complex::new(0.0, 0.0), Complex::new(0.0, 0.0), Complex::new(1.0, 0.0)], 4, 1);
let v = ComplexVector([Complex::new(1.0, 0.0), Complex::new(0.0, 0.0), Complex::new(0.0, 0.0), Complex::new(1.0, 0.0)]);
let m = ComplexMatrix::new([[Complex::new(1.0, 0.0)], [Complex::new(0.0, 0.0)], [Complex::new(0.0, 0.0)], [Complex::new(1.0, 0.0)]]);
assert_eq!(ComplexMatrix::from(v), m);
}

#[test]
fn test_matrix_product_vector() {
let m = ComplexMatrix::new(vec![Complex::new(1.0, 0.0), Complex::new(2.0, 0.0), Complex::new(3.0, 0.0), Complex::new(4.0, 0.0)], 2, 2);
let v1 = ComplexVector(vec![Complex::new(1.0, 0.0), Complex::new(2.0, 0.0)]);
let v2 = ComplexVector(vec![Complex::new(5.0, 0.0), Complex::new(11.0, 0.0)]);
let m = ComplexMatrix::new([[Complex::new(1.0, 0.0), Complex::new(2.0, 0.0)], [Complex::new(3.0, 0.0), Complex::new(4.0, 0.0)]]);
let v1 = ComplexVector([Complex::new(1.0, 0.0), Complex::new(2.0, 0.0)]);
let v2 = ComplexVector([Complex::new(5.0, 0.0), Complex::new(11.0, 0.0)]);
assert_eq!(product_matrix_vector(m, v1), v2);
}

#[test]
fn test_matrix_add() {
let m1 = ComplexMatrix::new(vec![Complex::new(1.0, 0.0), Complex::new(0.0, 0.0), Complex::new(0.0, 0.0), Complex::new(1.0, 0.0)], 2, 2);
let m2 = ComplexMatrix::new(vec![Complex::new(1.0, 0.0), Complex::new(0.0, 0.0), Complex::new(0.0, 0.0), Complex::new(1.0, 0.0)], 2, 2);
let m3 = ComplexMatrix::new(vec![Complex::new(2.0, 0.0), Complex::new(0.0, 0.0), Complex::new(0.0, 0.0), Complex::new(2.0, 0.0)], 2, 2);
let m1 = ComplexMatrix::new([[Complex::new(1.0, 0.0), Complex::new(0.0, 0.0)], [Complex::new(0.0, 0.0), Complex::new(1.0, 0.0)]]);
let m2 = ComplexMatrix::new([[Complex::new(1.0, 0.0), Complex::new(0.0, 0.0)], [Complex::new(0.0, 0.0), Complex::new(1.0, 0.0)]]);
let m3 = ComplexMatrix::new([[Complex::new(2.0, 0.0), Complex::new(0.0, 0.0)], [Complex::new(0.0, 0.0), Complex::new(2.0, 0.0)]]);
assert_eq!(m1 + m2, m3);
}

#[test]
fn test_matrix_product_scalar() {
let m1 = ComplexMatrix::new(vec![Complex::new(0.0, 1.0), Complex::new(0.0, 0.0), Complex::new(0.0, 0.0), Complex::new(0.0, 1.0)], 2, 2);
let m2 = ComplexMatrix::new(vec![Complex::new(-1.0, 0.0), Complex::new(0.0, 0.0), Complex::new(0.0, 0.0), Complex::new(-1.0, 0.0)], 2, 2);
let m1 = ComplexMatrix::new([[Complex::new(0.0, 1.0), Complex::new(0.0, 0.0)], [Complex::new(0.0, 0.0), Complex::new(0.0, 1.0)]]);
let m2 = ComplexMatrix::new([[Complex::new(-1.0, 0.0), Complex::new(0.0, 0.0)], [Complex::new(0.0, 0.0), Complex::new(-1.0, 0.0)]]);

assert_eq!(m1 * Complex::new(0.0, 1.0), m2);
}

#[test]
fn test_matrix_inverse() {
let m1 = ComplexMatrix::new(vec![Complex::new(6.0, -4.0), Complex::new(7.0, 3.0), Complex::new(4.2, -8.1), Complex::new(0.0, -3.0)], 2, 2);
let m2 = ComplexMatrix::new(vec![Complex::new(-6.0, 4.0), Complex::new(-7.0, -3.0), Complex::new(-4.2, 8.1), Complex::new(0.0, 3.0)], 2, 2);
let m1 = ComplexMatrix::new([[Complex::new(6.0, -4.0), Complex::new(7.0, 3.0)], [Complex::new(4.2, -8.1), Complex::new(0.0, -3.0)]]);
let m2 = ComplexMatrix::new([[Complex::new(-6.0, 4.0), Complex::new(-7.0, -3.0)], [Complex::new(-4.2, 8.1), Complex::new(0.0, 3.0)]]);

assert_eq!(-m1, m2);
}

#[test]
fn test_matrix_product() {
let m1 = ComplexMatrix::new(vec![Complex::new(3.0, 2.0), Complex::new(0.0, 0.0), Complex::new(5.0, -6.0),
Complex::new(1.0, 0.0), Complex::new(4.0, 2.0), Complex::new(0.0, 1.0),
Complex::new(4.0, -1.0), Complex::new(0.0, 0.0), Complex::new(4.0, 0.0)], 3, 3);
let m2 = ComplexMatrix::new(vec![Complex::new(5.0, 0.0), Complex::new(2.0, -1.0), Complex::new(6.0, -4.0),
Complex::new(0.0, 0.0), Complex::new(4.0, 5.0), Complex::new(2.0, 0.0),
Complex::new(7.0, -4.0), Complex::new(2.0, 7.0), Complex::new(0.0, 0.0)], 3, 3);
let m3 = ComplexMatrix::new(vec![Complex::new(26.0, -52.0), Complex::new(60.0, 24.0), Complex::new(26.0, 0.0),
Complex::new(9.0, 7.0), Complex::new(1.0, 29.0), Complex::new(14.0, 0.0),
Complex::new(48.0, -21.0), Complex::new(15.0, 22.0), Complex::new(20.0, -22.0)], 3, 3);
let m1 = ComplexMatrix::new([[Complex::new(3.0, 2.0), Complex::new(0.0, 0.0), Complex::new(5.0, -6.0)],
[Complex::new(1.0, 0.0), Complex::new(4.0, 2.0), Complex::new(0.0, 1.0)],
[Complex::new(4.0, -1.0), Complex::new(0.0, 0.0), Complex::new(4.0, 0.0)]]);
let m2 = ComplexMatrix::new([[Complex::new(5.0, 0.0), Complex::new(2.0, -1.0), Complex::new(6.0, -4.0)],
[Complex::new(0.0, 0.0), Complex::new(4.0, 5.0), Complex::new(2.0, 0.0)],
[Complex::new(7.0, -4.0), Complex::new(2.0, 7.0), Complex::new(0.0, 0.0)]]);
let m3 = ComplexMatrix::new([[Complex::new(26.0, -52.0), Complex::new(60.0, 24.0), Complex::new(26.0, 0.0)],
[Complex::new(9.0, 7.0), Complex::new(1.0, 29.0), Complex::new(14.0, 0.0)],
[Complex::new(48.0, -21.0), Complex::new(15.0, 22.0), Complex::new(20.0, -22.0)]]);

assert_eq!(m1 * m2, m3);
}

#[test]
#[should_panic]
#[allow(unused_must_use)]
fn test_matrix_product_error() {
let m1 = ComplexMatrix::new(vec![Complex::new(3.0, 2.0), Complex::new(0.0, 0.0), Complex::new(5.0, -6.0),
Complex::new(1.0, 0.0), Complex::new(4.0, 2.0), Complex::new(0.0, 1.0),
Complex::new(4.0, -1.0), Complex::new(0.0, 0.0), Complex::new(4.0, 0.0)], 3, 3);
let m2 = ComplexMatrix::new(vec![Complex::new(5.0, 0.0), Complex::new(2.0, -1.0),
Complex::new(0.0, 0.0), Complex::new(4.0, 5.0)], 2, 2);

m1 * m2;
}
}

0 comments on commit c2c58a4

Please sign in to comment.