In [None]:
from scipy.io import loadmat
import os
import numpy as np
import matplotlib.pyplot as plt

import plotly.graph_objects as go

In [None]:
os.chdir('/Users/kurtsmith/dev/AMATH_482/HW_01/')
testdata = loadmat('./Testdata.mat')
Undata = testdata['Undata']

In [None]:
L = 15; # spatial domain
n = 64; # Fourier modes, i.e. # of points in spatial range (2^6)
x2 = np.linspace(-L,L,n+1); 

x = x2[:-1]; # x-coord of each point in the discretized spatial range [-15,15)
y = x;  # y-coord of each point in the discretized spatial range [-15,15)
z = x;  # z-coord of each point in the discretized spatial range [-15,15)

k = (2*np.pi/(2*L))*np.concatenate((np.arange(n/2) , np.arange(-n/2, 0) ))
ks=np.fft.fftshift(k);

# grid with spatial coordinates
X,Y,Z = np.meshgrid(x,y,z)
# grid with frequency values
Kx,Ky,Kz = np.meshgrid(ks,ks,ks)

In [None]:
mean_fft = np.zeros([n,n,n]); 
for j in range(20):  
    Un = Undata[j,:].reshape((n,n,n))
    mean_fft = mean_fft + np.fft.fftn(Un);

mean_fft = np.abs(mean_fft)/20;  

display(np.abs(mean_fft).mean())

# vectorizes the matrix and get max value
max_signal = np.max(mean_fft)  
max_index = np.argmax(mean_fft)
max_index = np.unravel_index(max_index, (n,n,n))

display(max_signal, max_index)

In [None]:
# Alternate method, do FFT last.
# Result is same as above
mean_fft = np.zeros([n,n,n]); 
for j in range(20):  
    Un = Undata[j,:].reshape((n,n,n))
    mean_fft = mean_fft + Un;

mean_fft = mean_fft/20;  
mean_fft = np.fft.fftn(mean_fft)

mean_fft = np.abs(mean_fft)

display(np.abs(mean_fft).mean())

# vectorizes the matrix and get max value
max_signal = np.max(mean_fft)  
max_index = np.argmax(mean_fft)
max_index = np.unravel_index(max_index, (n,n,n))

display(max_signal, max_index)

In [None]:
kx_center = Kx[max_index]
ky_center = Ky[max_index]
kz_center = Kz[max_index]

Kx_shifted = Kx - kx_center
Ky_shifted = Ky - ky_center
Kz_shifted = Kz - kz_center

k_shift_sq = Kx_shifted*Kx_shifted + Ky_shifted*Ky_shifted + Kz_shifted*Kz_shifted
filter_weight = np.exp(-0.5*k_shift_sq)

In [None]:
display(kx_center, ky_center, kz_center)

In [None]:
def plot_iso(Un_filtered, X, Y, Z):
    fig= go.Figure(data=go.Isosurface(
        x=X.flatten(),
        y=Y.flatten(),
        z=Z.flatten(),
        value=np.abs(Un_filtered).flatten(),
#          isomin=0.04, isomax=10.041
        isomin=0.5*np.abs(Un_filtered).max(),
        isomax=1*np.abs(Un_filtered).max(),
    ))
    fig.show()

In [None]:
def plot_slice(j, Undata, filter_weight, X, Y, Z):
    Un = Undata[j,:].reshape((n,n,n))
    Un_fft = np.fft.fftn(Un)  
    Un_fft_filtered = Un_fft*filter_weight
    Un_filtered = np.fft.ifftn(Un_fft_filtered)
    plot_iso(Un_filtered, X, Y, Z)
    return Un_filtered, Un_fft_filtered

In [None]:
Un_filtered, Un_fft_filtered = plot_slice(0, Undata, filter_weight, X, Y, Z)

In [None]:
fig= go.Figure(data=go.Isosurface(
    x=Kx.flatten(),
    y=Ky.flatten(),
    z=Kz.flatten(),
    value=np.abs(Un_fft_filtered).flatten(),
    isomin=100.,
    isomax=200.,
))

fig.show()

In [None]:
x = np.fft.fftshift(Un_fft_filtered)


fig= go.Figure(data=go.Isosurface(
    x=Kx.flatten(),
    y=Ky.flatten(),
    z=Kz.flatten(),
    value=np.abs(x).flatten(),
    isomin=100.,
    isomax=200.,
))

fig.show()

In [None]:
fig= go.Figure(data=go.Isosurface(
    x=Kx.flatten(),
    y=Ky.flatten(),
    z=Kz.flatten(),
    value=filter_weight.flatten(),
    isomin=0.2,
    isomax=1.0,
))

fig.show()

In [None]:
z_index=10
plt.contour(X[:,:,z_index], Y[:,:,z_index], np.abs(Un_filtered[:,:,z_index]))