diff --git a/libs/math/Matrix3.h b/libs/math/Matrix3.h index d01bb42055..4d4721e18b 100644 --- a/libs/math/Matrix3.h +++ b/libs/math/Matrix3.h @@ -1,5 +1,6 @@ #pragma once +#include "Vector2.h" #include "Vector3.h" #include "eigen.h" @@ -126,6 +127,59 @@ class alignas(16) Matrix3 static Matrix3 byRows(double xx, double yx, double zx, double xy, double yy, double zy, double xz, double yz, double zz); + + /// Return the result of this matrix post-multiplied by another matrix. + Matrix3 getMultipliedBy(const Matrix3& other) const + { + return Matrix3(_transform * other.eigen()); + } + + /// Post-multiply this matrix by another matrix, in-place. + void multiplyBy(const Matrix3& other) + { + *this = getMultipliedBy(other); + } + + /// Returns this matrix pre-multiplied by the other + Matrix3 getPremultipliedBy(const Matrix3& other) const + { + return other.getMultipliedBy(*this); + } + + /// Pre-multiplies this matrix by other in-place. + void premultiplyBy(const Matrix3& other) + { + *this = getPremultipliedBy(other); + } + + /// Return the full inverse of this matrix. + Matrix3 getFullInverse() const + { + return Matrix3(_transform.inverse(Eigen::Projective)); + } + + /// Invert this matrix in-place. + void invertFull() + { + *this = getFullInverse(); + } + + /** + * \brief Returns the given 2-component point transformed by this matrix. + * + * The point is assumed to have a W component of 1, and no division by W is + * performed before returning the 3-component vector. + */ + template + BasicVector2 transformPoint(const BasicVector2& point) const + { + auto transformed = transform(BasicVector3(point.x(), point.y(), 1)); + return BasicVector2(transformed.x(), transformed.y()); + } + + /// Return the given 3-component vector transformed by this matrix. + template + BasicVector3 transform(const BasicVector3& vector3) const; }; // Private constructor @@ -164,6 +218,14 @@ inline Matrix3 Matrix3::byRows(double xx, double yx, double zx, zx, zy, zz); } +template +BasicVector3 Matrix3::transform(const BasicVector3& vector3) const +{ + Eigen::Matrix eVec(static_cast(vector3)); + auto result = _transform * eVec; + return BasicVector3(result[0], result[1], result[2]); +} + /// Compare two matrices elementwise for equality inline bool operator==(const Matrix3& l, const Matrix3& r) { @@ -175,3 +237,32 @@ inline bool operator!=(const Matrix3& l, const Matrix3& r) { return !(l == r); } + +/// Multiply two matrices together +inline Matrix3 operator*(const Matrix3& m1, const Matrix3& m2) +{ + return m1.getMultipliedBy(m2); +} + +/** + * \brief Multiply a 3-component vector by this matrix. + * + * Equivalent to m.transform(v). + */ +template +BasicVector3 operator*(const Matrix3& m, const BasicVector3& v) +{ + return m.transform(v); +} + +/** + * \brief Multiply a 2-component vector by this matrix. + * + * The vector is upgraded to a 3-component vector with a Z (or W) component of 1, i.e. + * equivalent to m.transformPoint(v). + */ +template +BasicVector2 operator*(const Matrix3& m, const BasicVector2& v) +{ + return m.transformPoint(v); +} diff --git a/test/math/Matrix3.cpp b/test/math/Matrix3.cpp index f7a5b0853c..bd18db511d 100644 --- a/test/math/Matrix3.cpp +++ b/test/math/Matrix3.cpp @@ -87,4 +87,71 @@ TEST(Matrix3Test, MatrixEquality) EXPECT_TRUE(m2 != Matrix3::getIdentity()); } +TEST(Matrix3Test, MatrixMultiplication) +{ + auto a = Matrix3::byColumns(3, 5, 7, 11, 13, 17, 19, 23, 29); + auto b = Matrix3::byColumns(61, 67, 71, 73, 79, 83, 89, 97, 101); + + // Check multiplied result + auto c = a.getMultipliedBy(b); + EXPECT_EQ(c, Matrix3::byColumns(2269, 2809, 3625, + 2665, 3301, 4261, + 3253, 4029, 5201)); + + // Multiplication has not changed original + EXPECT_NE(a, c); + + // Check operator multiplication as well + EXPECT_EQ(a * b, c); + + // Test Pre-Multiplication + EXPECT_EQ(b.getMultipliedBy(a), a.getPremultipliedBy(b)) << "Matrix pre-multiplication mismatch"; +} + +TEST(Matrix3Test, MatrixFullInverse) +{ + auto a = Matrix3::byColumns(3, 5, 7, 11, 13, 17, 19, 23, 29); + + auto inv = a.getFullInverse(); + + EXPECT_DOUBLE_EQ(inv.xx(), -7.0 / 10) << "Matrix inversion failed on xx"; + EXPECT_DOUBLE_EQ(inv.xy(), 4.0 / 5) << "Matrix inversion failed on xy"; + EXPECT_DOUBLE_EQ(inv.xz(), -3.0 / 10) << "Matrix inversion failed on xz"; + + EXPECT_DOUBLE_EQ(inv.yx(), 1.0 / 5) << "Matrix inversion failed on yx"; + EXPECT_DOUBLE_EQ(inv.yy(), -23.0 / 10) << "Matrix inversion failed on yy"; + EXPECT_DOUBLE_EQ(inv.yz(), 13.0 / 10) << "Matrix inversion failed on yz"; + + EXPECT_DOUBLE_EQ(inv.zx(), 3.0 / 10) << "Matrix inversion failed on zx"; + EXPECT_DOUBLE_EQ(inv.zy(), 13.0 / 10) << "Matrix inversion failed on zy"; + EXPECT_DOUBLE_EQ(inv.zz(), -4.0 / 5) << "Matrix inversion failed on zz"; +} + +TEST(Matrix3Test, MatrixTransformation) +{ + auto a = Matrix3::byColumns(3, 5, 7, 11, 13, 17, 19, 23, 29); + + { + Vector2 v(61, 67); + + Vector2 transformed = a.transformPoint(v); + + EXPECT_EQ(transformed.x(), 939) << "Vector2 transformation failed"; + EXPECT_EQ(transformed.y(), 1199) << "Vector2 transformation failed"; + + EXPECT_EQ(a * v, a.transformPoint(v)); + } + + { + Vector3 vector(83, 89, 97); + Vector3 transformed = a.transform(vector); + + EXPECT_EQ(transformed.x(), 3071) << "Vector3 transformation failed"; + EXPECT_EQ(transformed.y(), 3803) << "Vector3 transformation failed"; + EXPECT_EQ(transformed.z(), 4907) << "Vector3 transformation failed"; + + EXPECT_EQ(a * vector, a.transform(vector)); + } +} + }