Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 65 additions & 20 deletions math/matrix/strassenmatrixmultiply.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,30 @@
// This program takes two matrices as input and performs matrix multiplication
// using the Strassen algorithm, which is an optimized divide-and-conquer
// approach. It allows for efficient multiplication of large matrices.
// time complexity: O(n^2.81)
// space complexity: O(n^2)
//
// Example:
// | 1 2 | * | 5 6 | = | (1*5 + 2*7) (1*6 + 2*8) | = | 19 22 |
// | 3 4 | | 7 8 | | (3*5 + 4*7) (3*6 + 4*8) | | 43 50 |
//
// Strassen's algorithm achieves a better time complexity by using
// 7 recursive multiplications instead of the 8 used in a standard
// divide-and-conquer approach.
//
// time complexity: O(n^log2(7)) ≈ O(n^2.81)
// space complexity: O(n^2) for storing submatrices
// author(s): Mohit Raghav(https://github.com/mohit07raghav19)
// See strassenmatrixmultiply_test.go for test cases
package matrix

// Perform matrix multiplication using Strassen's algorithm
// StrassenMatrixMultiply performs matrix multiplication using Strassen's
// divide-and-conquer algorithm.
//
// NOTE: This implementation expects the matrices to be square (n x n).
// The base case for the recursion is a 1x1 matrix.
func (A Matrix[T]) StrassenMatrixMultiply(B Matrix[T]) (Matrix[T], error) {
n := A.rows
// Check if matrices are 2x2 or smaller
// Base case for the recursion:
// If the matrix is 1x1, perform a single multiplication.
if n == 1 {
a1, err := A.Get(0, 0)
if err != nil {
Expand All @@ -26,10 +40,15 @@ func (A Matrix[T]) StrassenMatrixMultiply(B Matrix[T]) (Matrix[T], error) {
result := New(1, 1, a1*b1)
return result, nil
} else {
// Calculate the size of submatrices
// --- 1. DIVIDE ---
// Calculate the size of submatrices (quadrants)
mid := n / 2

// Create submatrices
// Create submatrices (A11, A12, A21, A22 and B11, B12, B21, B22)
// A = | A11 A12 |
// | A21 A22 |
// B = | B11 B12 |
// | B21 B22 |
A11, err := A.SubMatrix(0, 0, mid, mid)
if err != nil {
return Matrix[T]{}, err
Expand Down Expand Up @@ -64,83 +83,95 @@ func (A Matrix[T]) StrassenMatrixMultiply(B Matrix[T]) (Matrix[T], error) {
return Matrix[T]{}, err
}

// Calculate result submatrices
A1, err := A11.Add(A22)
// --- 2. CALCULATE INTERMEDIATE TERMS ---
// These are the 10 additions/subtractions required to
// set up the 7 recursive multiplication steps.
A1, err := A11.Add(A22) // A1 = A11 + A22
if err != nil {
return Matrix[T]{}, err
}

A2, err := B11.Add(B22)
A2, err := B11.Add(B22) // A2 = B11 + B22
if err != nil {
return Matrix[T]{}, err
}

A3, err := A21.Add(A22)
A3, err := A21.Add(A22) // A3 = A21 + A22
if err != nil {
return Matrix[T]{}, err
}

A4, err := A11.Add(A12)
A4, err := A11.Add(A12) // A4 = A11 + A12
if err != nil {
return Matrix[T]{}, err
}

A5, err := B11.Add(B12)
A5, err := B11.Add(B12) // A5 = B11 + B12
if err != nil {
return Matrix[T]{}, err
}

A6, err := B21.Add(B22)
A6, err := B21.Add(B22) // A6 = B21 + B22
if err != nil {
return Matrix[T]{}, err
}
//
S1, err := B12.Subtract(B22)
S1, err := B12.Subtract(B22) // S1 = B12 - B22
if err != nil {
return Matrix[T]{}, err
}
S2, err := B21.Subtract(B11)
S2, err := B21.Subtract(B11) // S2 = B21 - B11
if err != nil {
return Matrix[T]{}, err
}
S3, err := A21.Subtract(A11)
S3, err := A21.Subtract(A11) // S3 = A21 - A11
if err != nil {
return Matrix[T]{}, err
}
S4, err := A12.Subtract(A22)
S4, err := A12.Subtract(A22) // S4 = A12 - A22
if err != nil {
return Matrix[T]{}, err
}
// Recursive steps

// --- 3. CONQUER ---
// Recursive steps: Calculate the 7 Strassen products (M1 to M7).
// M1 = (A11 + A22) * (B11 + B22)
M1, err := A1.StrassenMatrixMultiply(A2)
if err != nil {
return Matrix[T]{}, err
}
// M2 = (A21 + A22) * B11
M2, err := A3.StrassenMatrixMultiply(B11)
if err != nil {
return Matrix[T]{}, err
}
// M3 = A11 * (B12 - B22)
M3, err := A11.StrassenMatrixMultiply(S1)
if err != nil {
return Matrix[T]{}, err
}
// M4 = A22 * (B21 - B11)
M4, err := A22.StrassenMatrixMultiply(S2)
if err != nil {
return Matrix[T]{}, err
}
// M5 = (A11 + A12) * B22
M5, err := A4.StrassenMatrixMultiply(B22)
if err != nil {
return Matrix[T]{}, err
}
// M6 = (A21 - A11) * (B11 + B12)
M6, err := S3.StrassenMatrixMultiply(A5)
if err != nil {
return Matrix[T]{}, err
}
// M7 = (A12 - A22) * (B21 + B22)
M7, err := S4.StrassenMatrixMultiply(A6)

if err != nil {
return Matrix[T]{}, err
} //
// (Temporary combinations for calculating C submatrices)
A7, err := M1.Add(M4)

if err != nil {
Expand All @@ -161,30 +192,39 @@ func (A Matrix[T]) StrassenMatrixMultiply(B Matrix[T]) (Matrix[T], error) {
if err != nil {
return Matrix[T]{}, err
}
// Calculate result submatrices

// --- 4. COMBINE ---
// Calculate result submatrices (C11, C12, C21, C22)
// C11 = M1 + M4 - M5 + M7
C11, err := A8.Subtract(M5)
if err != nil {
return Matrix[T]{}, err
}
// C12 = M3 + M5
C12, err := M3.Add(M5)
if err != nil {
return Matrix[T]{}, err
}
// C21 = M2 + M4
C21, err := M2.Add(M4)
if err != nil {
return Matrix[T]{}, err
}
// C22 = M1 - M2 + M3 + M6
C22, err := A10.Subtract(M2)
if err != nil {
return Matrix[T]{}, err
}

// Combine subMatrices into the result matrix
// Combine subMatrices into the final result matrix C
// C = | C11 C12 |
// | C21 C22 |
var zeroVal T
C := New(n, n, zeroVal)

for i := 0; i < mid; i++ {
for j := 0; j < mid; j++ {
// Set C11 (top-left quadrant)
val, err := C11.Get(i, j)
if err != nil {
return Matrix[T]{}, err
Expand All @@ -195,6 +235,7 @@ func (A Matrix[T]) StrassenMatrixMultiply(B Matrix[T]) (Matrix[T], error) {
return Matrix[T]{}, err
}

// Set C12 (top-right quadrant)
val, err = C12.Get(i, j)
if err != nil {
return Matrix[T]{}, err
Expand All @@ -204,6 +245,8 @@ func (A Matrix[T]) StrassenMatrixMultiply(B Matrix[T]) (Matrix[T], error) {
if err1 != nil {
return Matrix[T]{}, err1
}

// Set C21 (bottom-left quadrant)
val, err = C21.Get(i, j)
if err != nil {
return Matrix[T]{}, err
Expand All @@ -213,6 +256,8 @@ func (A Matrix[T]) StrassenMatrixMultiply(B Matrix[T]) (Matrix[T], error) {
if err2 != nil {
return Matrix[T]{}, err2
}

// Set C22 (bottom-right quadrant)
val, err = C22.Get(i, j)
if err != nil {
return Matrix[T]{}, err
Expand Down