In [1]:
%load_ext autoreload
%autoreload 2

import os
import sys
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)
    
from IPython.display import clear_output

from src.torch.gridgen import gridgen
from src.torch.gridgen import mygriddata
# try adding this 
# from src.torch.gridgen import fast_sine_transform_y
import torch.nn.functional as nnf

import torch
# print(torch.__version__)
import matplotlib.pyplot as plt
import numpy as np 

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

def display_im_grid(xgrid, ygrid, im, ngrid):
    """Display both grids and image"""
    clear_output(wait=True)
    plt.imshow(im, cmap="gray")
    plt.plot(ygrid[::ngrid, ::ngrid], xgrid[::ngrid, ::ngrid], "b", lw=1.0)
    plt.plot(ygrid[::ngrid, ::ngrid].T, xgrid[::ngrid, ::ngrid].T, "b", lw=1.0)

    plt.axis("off")
    plt.axis("equal")
    plt.pause(0.1)
    
# The original fast_sine_transform_y function ### 
def fast_sine_transform_y1(v):
    """Perform fast sine transform in y direction"""
    n, m = v.shape[-3], v.shape[-2]
    v = nnf.pad(v, (0, 0, 1, m + 1, 0, 0, 0, 0, 0, 0))
    v = torch.fft(v, signal_ndim=1)
    print ('v1: ' + str(v.dtype))
    # return imaginary value only
    v[..., 0] = v[..., 1]
    v[..., 1] = 0
    print ('v2: ' + str(v.dtype))
    return v[..., 1 : m + 1, :]

In [2]:
"""Example code"""
im = torch.zeros((10, 2, 81, 121), device=device)

bsz, c, szx, szy = im.shape

im[:, :, ::5, :], im[:, :, :, ::5] = 1, 1
im[:, :, ::5, ::5] = 1

nframes = bsz
j_lb = 0.3
j_ub = 3.0

f11 = torch.linspace(1.0, 0.3, nframes, device=device)
f12 = torch.linspace(1.0, 3.0, nframes, device=device)
f21 = torch.linspace(0.0, 1.0, nframes, device=device)
f22 = torch.linspace(0.0, -1.0, nframes, device=device)

cz = int(0.1 * min(szx, szy))

f = torch.ones((bsz, 2, szx, szy), device=device, requires_grad=True)
f[:, 1, :, :] = 0

for i in range(nframes):
    f[
        i, 0, szx // 3 - 1 - cz : szx // 3 + cz, szy // 3 - 1 - cz : szy // 3 + cz
    ] = f11[i]
    f[
        i,
        0,
        2 * szx // 3 - 1 - cz : 2 * szx // 3 + cz,
        2 * szy // 3 - 1 - cz : 2 * szy // 3 + cz,
    ] = f12[i]
    f[
        i, 1, szx // 3 - 1 - cz : szx // 3 + cz, szy // 3 - 1 - cz : szy // 3 + cz
    ] = f21[i]
    f[
        i,
        1,
        2 * szx // 3 - 1 - cz : 2 * szx // 3 + cz,
        2 * szy // 3 - 1 - cz : 2 * szy // 3 + cz,
    ] = f22[i]

ff = torch.zeros((f.shape + (2,)), device=device)
ff[..., 0] = f
temp1 = fast_sine_transform_y1(ff)
print ('temp1: ' + str(temp1.shape))
print ('temp1: ' + str(temp1.dtype)) 

import torch.fft

def fast_sine_transform_y2(v):
    """Perform fast sine transform in y direction"""
    n, m = v.shape[-3], v.shape[-2]
    v = nnf.pad(v, (0, 0, 1, m + 1, 0, 0, 0, 0, 0, 0))
    # v = torch.fft(v, signal_ndim=1) # original yields torch.float32
    # v = torch.fft.fft(v, dim=1) # yields torch.complex64, is still the correct size 
    # v = torch.fft.fft(v, dim=0) # yields torch.complex64, is still the correct size 
    v = torch.fft.fft(v, dim=(-1))
    print ('v1: ' + str(v.dtype))
    # return imaginary value only
    # v[..., 0] = v[..., 1]
    # v[..., 1] = 0
    print ('v2: ' + str(v.dtype))
    return v[..., 1 : m + 1, :]

### Get torch.fft.fft to match the old version ### 
temp2 = fast_sine_transform_y2(ff)
print ('temp2: ' + str(temp2.shape))
print (temp2.dtype)

print ('temp1: ' + str(np.squeeze(temp1[0,0,:,:,0])))
print ('temp2: ' + str(np.squeeze(temp2[0,0,:,:,0])))

# diff_temp = np.sum(temp1-temp2)
# print ('diff_temp: ' + str(diff_temp))

v1: torch.float32
v2: torch.float32
temp1: torch.Size([10, 2, 81, 121, 2])
temp1: torch.float32
v1: torch.complex64
v2: torch.complex64
temp2: torch.Size([10, 2, 81, 121, 2])
torch.complex64
temp1: tensor([[-7.7663e+01,  0.0000e+00, -2.5876e+01,  ..., -3.8646e-02,
          0.0000e+00, -1.2876e-02],
        [-7.7663e+01,  0.0000e+00, -2.5876e+01,  ..., -3.8646e-02,
          0.0000e+00, -1.2876e-02],
        [-7.7663e+01,  0.0000e+00, -2.5876e+01,  ..., -3.8646e-02,
          0.0000e+00, -1.2876e-02],
        ...,
        [-7.7663e+01,  0.0000e+00, -2.5876e+01,  ..., -3.8646e-02,
          0.0000e+00, -1.2876e-02],
        [-7.7663e+01,  0.0000e+00, -2.5876e+01,  ..., -3.8646e-02,
          0.0000e+00, -1.2876e-02],
        [-7.7663e+01,  0.0000e+00, -2.5876e+01,  ..., -3.8646e-02,
          0.0000e+00, -1.2876e-02]], grad_fn=<SqueezeBackward0>)
temp2: tensor([[1.+0.j, 1.+0.j, 1.+0.j,  ..., 1.+0.j, 1.+0.j, 1.+0.j],
        [1.+0.j, 1.+0.j, 1.+0.j,  ..., 1.+0.j, 1.+0.j, 1.+0.j],
       



In [78]:
  
    
# v = div_curl_solver_2d(f, inv=False)
# print ('v: ' + str(v.dtype))

############# """Perform div curl solver""" ######

weights = torch.tensor(
    [[[[-1, -4, -1], [0, 0, 0], [1, 4, 1]]]], dtype=torch.float, device=device
)

weights = weights.repeat(2, 1, 1, 1)

dfdx = nnf.conv2d(f, weights, padding=1, groups=2)
dfdy = nnf.conv2d(f, weights.permute(0, 1, 3, 2), padding=1, groups=2)

F = torch.empty_like(f)

F[:, 0, :, :] = (dfdx[:, 0, :, :] + dfdy[:, 1, :, :]) / 12
F[:, 1, :, :] = (-dfdx[:, 1, :, :] + dfdy[:, 0, :, :]) / 12
    
print ('F: ' + str(F.dtype)) # torch.float32


#############"""Poisson solver""" #########
n, m = F.shape[-2], F.shape[-1]
pi = 3.141592653589793
XI, YI = torch.meshgrid(
    torch.arange(1, n + 1, device=device), torch.arange(1, m + 1, device=device)
)
LL = torch.zeros((XI.shape + (2,)), device=device)
LL[:, :, 0] = 1.0 / (
    4.0 - 2.0 * torch.cos(XI * pi / (n + 1)) - 2.0 * torch.cos(YI * pi / (m + 1))
)

FF = torch.zeros((F.shape + (2,)), device=device)
FF[..., 0] = F
    
print ('FF: ' + str(FF.dtype)) # torch.float32

LL = LL.repeat(F.shape[0], F.shape[1], 1, 1, 1)

F: torch.float32
FF: torch.float32


In [80]:

temp = fast_sine_transform_y(FF)
print ('temp: ' + str(temp.shape))


# X = (
#     4.0
#     / ((n + 1.0) * (m + 1.0))
#     * LL
#     * fast_sine_transform_y(fast_sine_transform_x(FF))
# )

# v = -1.0 * fast_sine_transform_y(fast_sine_transform_x(X))
# print ('v: ' + str(v.dtype)) # pretty sure this should be float32. 




# F = poisson_solver_2d_fft(v)
# print ('F: ' + str(F.dtype))

# pos = gridgen(f, j_lb, j_ub, inv=False)
# pos_inv = gridgen(f, j_lb, j_ub, inv=True)

# im[:, 0:1, :, :] = mygriddata(pos_inv, im[:, 0:1, :, :])

# np_pos = pos.detach().cpu().numpy()
# np_imw = im.detach().cpu().numpy()

# for i in range(nframes):
#     display_im_grid(
#         (szx - 1) * (np_pos[i, 0] + 1) / 2,
#         (szy - 1) * (np_pos[i, 1] + 1) / 2,
#         np_imw[i, 0],
#         5,
#     )

TypeError: 'module' object is not callable

In [45]:
### Copy parts of the div_curl_solver 3d ###

f = torch.zeros((10,4,5,7,9))
f[:,0,:,:,:] = 1 # f1 =1, f2,f3,f4=0 

dx_kernel = torch.tensor([ [[1,6,1], [1,6,1], [1,6,1]], \
                           [[0,0,0], [0,0,0], [0,0,0]], \
                           [[-1,-6,-1], [-1,-6,-1], [-1,-6,-1]] ], dtype=torch.float, device=device)

dy_kernel = torch.tensor([ [[1,6,1], [0,0,0], [-1,-6,-1]], \
                           [[1,6,1], [0,0,0], [-1,-6,-1]], \
                           [[1,6,1], [0,0,0], [-1,-6,-1]] ], dtype=torch.float, device=device)

dz_kernel = torch.tensor([ [[1,0,-1], [6,0,-6], [1,0,-1]], \
                           [[1,0,-1], [6,0,-6], [1,0,-1]], \
                           [[1,0,-1], [6,0,-6], [1,0,-1]] ], dtype=torch.float, device=device)

dx_kernel = torch.unsqueeze(torch.unsqueeze(dx_kernel,0),0)
dy_kernel = torch.unsqueeze(torch.unsqueeze(dy_kernel,0),0)
dz_kernel = torch.unsqueeze(torch.unsqueeze(dz_kernel,0),0)

dx_kernel = dx_kernel.repeat(4, 1, 1, 1, 1)
dy_kernel = dy_kernel.repeat(4, 1, 1, 1, 1)
dz_kernel = dz_kernel.repeat(4, 1, 1, 1, 1)

dfdx = nnf.conv3d(f, dx_kernel, padding=1, groups=4) # check groups
dfdy = nnf.conv3d(f, dy_kernel, padding=1, groups=4) # check groups
dfdz = nnf.conv3d(f, dz_kernel, padding=1, groups=4) # check groups

# This worked before because 1 div, 1 curl input, and v1, v2 output
# But now, 1 div, 3 curl input, and v1, v2, v3 output
# F = torch.empty_like(f)
F = torch.zeros((f.shape[0], 3, f.shape[2], f.shape[3], f.shape[4]))

# From my pycardiac 3D functions
# F1 = poisson_solver_3d_fft((df1dx - df4dy + df3dz) / 12.0)
# F2 = poisson_solver_3d_fft((df4dx + df1dy - df2dz) / 12.0)
# F3 = poisson_solver_3d_fft((-df3dx + df2dy + df1dz) / 12.0)

F[:, 0, :, :, :] = ( dfdx[:, 0, :, :, :] - dfdy[:, 3, :, :, :] + dfdz[:, 2, :, :, :]) / 12.0
F[:, 1, :, :, :] = ( dfdx[:, 3, :, :, :] + dfdy[:, 0, :, :, :] - dfdz[:, 1, :, :, :]) / 12.0
F[:, 2, :, :, :] = (-dfdx[:, 2, :, :, :] + dfdy[:, 1, :, :, :] + dfdz[:, 0, :, :, :]) / 12.0

print ('F type: ' + str(F.dtype)) # torch.float32

### copy parts of poisson_solver_3d_fft ###

o, n, m = F.shape[-3], F.shape[-2], F.shape[-1]
pi = 3.141592653589793
XI, YI, ZI = torch.meshgrid(
                    torch.arange(1, o + 1, device=device), \
                    torch.arange(1, n + 1, device=device), \
                    torch.arange(1, m + 1, device=device)
)

LL = torch.zeros((XI.shape + (2,)), device=device)

LL[:, :, :, 0] = 1.0 / (
        6.0 - 2.0 * torch.cos(XI * pi / (o + 1)) - 2.0 * torch.cos(YI * pi / (n + 1)) - 2.0 * torch.cos(ZI * pi / (m + 1))
)

FF = torch.zeros((F.shape + (2,)), device=device)
FF[..., 0] = F

LL = LL.repeat(F.shape[0], F.shape[1], 1, 1, 1, 1)

print ('LL type: ' + str(LL.dtype)) # torch.float32
print ('FF type: ' + str(FF.dtype)) # torch.float32

X = (
    4.0
    / ((n + 1.0) * (m + 1.0) * (o + 1.0))
    * LL
    * fast_sine_transform_z_3d(fast_sine_transform_y_3d(fast_sine_transform_x_3d(FF)))
)
print ('X: ' + str(X.dtype)) # torch.complex64

F_x = fast_sine_transform_x_3d(FF)
print ('F_x: ' + str(F_x.dtype)) # torch.complex64

F_y = fast_sine_transform_y_3d(FF)
print ('F_y: ' + str(F_y.dtype)) # torch.complex64

F_z = fast_sine_transform_z_3d(FF)
print ('F_z: ' + str(F_z.dtype)) # torch.complex64
 
v = -1.0 * fast_sine_transform_z_3d(fast_sine_transform_y_3d(fast_sine_transform_x_3d(X)))
print ('v dtype: ' + str(v.dtype)) # torch.complex64

### Since this is complex and it shouldn't be -- copy parts of poisson_solver_3d_fft into here ####
# v = poisson_solver_3d_fft(F)
# print ('v: type: ' + str(v.dtype))

F type: torch.float32
LL type: torch.float32
FF type: torch.float32
X: torch.complex64
F_x: torch.complex64
F_y: torch.complex64
F_z: torch.complex64
v dtype: torch.complex64
