In [None]:
#%matplotlib inline
import scipy.io as sio
import matplotlib.pyplot as plt
import numpy as np
from scipy import linalg as LA

N   = 20                                           #The number of pixels along a dimension of the image
A   = np.zeros((N, N))                             #The image
Adj = np.zeros((N*N,N*N))                          #Adjacency matrix

#Use 8 neighbors, and fill in the adjacency matrix
dx = [-1, 0, 1, -1, 1, -1, 0, 1];
dy = [-1, -1, -1, 0, 0, 1, 1, 1];
for x in range(1,N+1):
   for y in range(1 , N+1):
       index = (x-1)*N + y-1
       for ne in range(0, len(dx)):
           newx = x + dx[ne]
           newy = y + dy[ne]
           if newx > 0 and newx <= N and newy > 0 and newy <= N:
               index2 = (newx-1)*N + newy-1
               Adj[index, index2]= 1
           
#BELOW IS THE KEY CODE THAT COMPUTES THE SOLUTION TO THE DIFFERENTIAL EQUATION

Deg = np.diag(Adj.sum(axis=1))                 #compute the degree matrix
L = Deg - Adj;                                 #Compute the laplacian matrix in terms of the degree and adjacency matrices
D, V =LA.eigh(L);                              #Compute the eigenvalues/vectors of the laplacian matrix
D = np.reshape(D,(N*N,1), order='F')


#Initial condition (place a few large positive values around and make everything else zero)
C0 =np.zeros((N,N))
C0[1:5, 1:5]  = 5
C0[9:15,9:15] = 10
C0[1:5, 7:13] = 7
C0 = np.reshape(C0,(N*N,1), order='F');

C0V = np.dot((V.conj().transpose()),C0) #Transform the initial condition into the coordinate system of the eigenvectors

#plotting results
fig = plt.figure()                
for t in np.arange(0,5.05,0.05):                             #Loop through times and decay each initial component
    Phi = (C0V* (np.exp(-D*t)));                             #Exponential decay for each component
    Phi = np.dot(V,Phi);                                     #Transform from eigenvector coordinate system to original coordinate system
    Phi =  np.reshape(Phi, (N, N), order='F')                #shape phi to a lattice shape
    imgplot = plt.imshow(Phi, cmap="jet")                    # plot image from array
    plt.clim(0,10)                                           #set colour axis
    plt.savefig("img"+str(t*100), format='png') #save image at each time step
