In [None]:
# Strassen's Multiplication vs normal

import numpy as np
import timeit
import matplotlib.pyplot as plt

#normal multiplication
def mat_mul(A,B):
    return np.dot(A,B)

#strassen multiplication
def strsn_mul(A,B):
    size=len(A)
    if size==1:
        return A*B
    
    size//=2

    a=A[:size,:size]
    b=A[:size,size:]
    c=A[size:,:size]
    d=A[size:,size:]

    e=B[:size,:size]
    f=B[:size,size:]
    g=B[size:,:size]
    h=B[size:,size:]

    p1=strsn_mul(a,f-h)
    p2=strsn_mul(a+b,h)
    p3=strsn_mul(c+d,e)
    p4=strsn_mul(d,g-e)
    p5=strsn_mul(a+d,e+h)
    p6=strsn_mul(b-d,g+h)
    p7=strsn_mul(a-c,e+f)

    C00=p5+p4-p2+p6
    C01=p1+p2
    C10=p3+p4
    C11=p1+p5-p3-p7

    return np.vstack((np.hstack((C00,C01)),np.hstack((C10,C11))))

def measure(function,A,B):
    start_time = timeit.default_timer()
    result = function(A,B)
    end_time = timeit.default_timer()
    execution_time = end_time - start_time
    return execution_time

def compare(matrix_sizes):
    normal_time=[]
    strassen_time=[]

    for size in matrix_sizes:
        A=np.random.random((size,size))
        B=np.random.random((size,size))

        n_time=measure(mat_mul,A,B)
        s_time=measure(strsn_mul,A,B)

        normal_time.append(n_time)
        strassen_time.append(s_time)
    
    return normal_time,strassen_time

matrix_sizes = [2**i for i in range(1,9)]
normal_time,strassen_time=compare(matrix_sizes)

plt.figure(figsize=(8,6))
plt.plot(matrix_sizes,normal_time,label='Normal Matrix Multiplication')
plt.plot(matrix_sizes,strassen_time,label='Strassen Matrix Multiplication')

plt.xlabel('Matrix Sizes')
plt.ylabel('Execution Time')
plt.title('Comparison Graph')

plt.legend()
plt.show()