Lecture 13-16: Extracting the main components of a 2D matrix using SVD factorization and reconstructing it using its principal components (PCA).

In [None]:
%pylab

First Example:

In [None]:
from mpl_toolkits.mplot3d import axes3d

x = linspace(0,1,25)
t = linspace(0,2,50)

T,X = meshgrid(t,x)     # if change the order of X & T, interpreration of u and vh 
                            # would be different
    
f = exp(-abs((X-0.5)*(T-1))) + sin(X*T)   # original matrix (surface)

fig = figure()
ax = fig.add_subplot(2,2,1,projection = '3d')
ax.plot_surface(X,T,f,cmap='gist_earth')
ax.set_title('main matrix')

u,s,vh = linalg.svd(f,full_matrices=False)

print("u.shape: {} \ns.shape: {} \nvh.shape: {}".format(u.shape,s.shape,vh.shape))

figure()
plot(s,'ro')

# reconstructing using 3 main components
for i in range(3):
    ff = dot(u[:,:i+1]*s[:i+1],vh[:i+1,:])      
    ax = fig.add_subplot(2,2,i+2,projection='3d')
    ax.plot_surface(X,T,ff,cmap='gist_earth')
    ax.set_title('using {} components'.format(i+1))
    
# energy portion of modes used to reconstruct f     
energy_ratio = sqrt(sum(s[:2]**2))/sqrt(sum(s**2))  

# first three spatial modes used above to reconstruct f
plt.figure()
plt.plot(x,u[:,0],'b-',x,u[:,1],'r-',x,u[:,2],'-g') ;
plt.legend(['1st','2nd','3rd'])

Second Example:

In [None]:
from mpl_toolkits.mplot3d import axes3d

x = linspace(-10,10,100)
t = linspace(0,10,30)

X,T = meshgrid(x,t)

f = (1-0.5*cos(2*T))/cosh(X) + (1-0.5*sin(2*T))*tanh(X)/cosh(X)

fig1 = figure()
ax1 = fig1.add_subplot(2,2,1,projection='3d')
ax1.plot_wireframe(X,T,f,rstride=2,cstride=0)

u,s,vh = linalg.svd(f,full_matrices=False)

figure()
plot(s,'ro')

# reconstructing using 3 main components
for j in range(3):
    ff = dot(u[:,:j+1]*s[:j+1],vh[:j+1,:])
    ax = fig1.add_subplot(2,2,j+2,projection='3d')
    ax.plot_wireframe(X,T,ff,rstride=2,cstride=0)

# main spatial modes
figure()
plot(x,vh[0,:],'b-',x,vh[1,:],'r-')     
plt.title('main spatial modes')

# main temporal modes
figure()
plot(t,u[:,0],t,u[:,1],t,u[:,2])        
plt.title('main temporal modes')
    
energy_ratio = sqrt(sum(s[:2]**2))/sqrt(sum(s**2))
print("energy ratio : {}".format(energy_ratio))