In [None]:
from PIL import Image
import nibabel
import os, pathlib, numpy, json
import h5py
from fenics import *
from fenics_adjoint import *

In [None]:
datadir = pathlib.Path("/home/bastian/Oscar-Image-Registration-via-Transport-Equation/testdata_3d")

In [None]:
datadir = pathlib.Path("/home/bastian/Oscar-Image-Registration-via-Transport-Equation/testdata_2d")

In [None]:
hyperparameters = {}
hyperparameters["state_functiondegree"] = 1
hyperparameters["state_functionspace"] = "DG"

In [None]:
hyperparameters["normalize"] = True

In [None]:
hyperparameters["image"] = str(datadir / "input.mgz")

In [None]:
def read_image(hyperparameters, name, mesh=None, printout=True, normalize=True, degree=0):
    
    if hyperparameters[name].endswith(".mgz"):
        image2 = nibabel.load(hyperparameters[name])
        data = image2.get_fdata()
    elif hyperparameters[name].endswith(".png"):
        img = Image.open(hyperparameters[name])
        img = img.convert("L")
        data = np.array(img)
        data = np.expand_dims(data, -1)
    
    if normalize:
        if printout:
            print("Normalizing image")
            print("Img.vector()[:].max()", data.max())

        data *= 1 /data.max()

        if printout:
            print("Applying ReLU() to image")
        
        data = np.where(data < 0, 0, data)
        
    
    if printout:
        print("dimension of image:", data.shape, "(", data.size, "voxels)")


    # data = np.swapaxes(data, 1, 0)

    nx = data.shape[0] 
    ny = data.shape[1]
    nz = data.shape[2]
    hyperparameters[name + ".shape"] = list(data.shape)

    dn = 0
    
    if mesh is None:
        if nz == 1:
            mesh = UnitSquareMesh(MPI.comm_world, nx + dn, ny + dn)
        else:
            mesh = UnitCubeMesh(MPI.comm_world, nx + dn, ny + dn, nz + dn)
    
    dx = (nx-1)
    dy = (ny-1)
    dz = (nz-1)
    
    mesh.coordinates()[:, 0] *= nx
    mesh.coordinates()[:, 1] *= ny
    mesh.coordinates()[:, 2] *= nz
    dx = 1
    dy = 1
    dz = 1
    
    # xyz = space.tabulate_dof_coordinates().transpose()

    xyz[0, :] *= dx
    xyz[1, :] *= dy
    
    if nz > 1:
        xyz[2, :] *= dz

    else:

        xyz2 = np.zeros((3, xyz.shape[1]))

        xyz2[0, :] = xyz[0, :]
        xyz2[1, :] = xyz[1, :]

        xyz = xyz2
        del xyz2
        
    space = FunctionSpace(mesh, hyperparameters["state_functionspace"], 
                         0,
                          #hyperparameters["state_functiondegree"]
                         )
    u_data = Function(space)
    
    u_data.vector()[:] = np.nan
    
    
    i, j, k = np.rint(xyz).astype("int")
    u_data.vector()[:] = data[i, j, k]

    if degree == 1:
    
        space = FunctionSpace(mesh, hyperparameters["state_functionspace"], degree)

        u_data = project(u_data, space)
        
    assert np.sum(np.isnan(u_data.vector()[:])) == 0
    return mesh, u_data, 1

In [None]:
def Pic2FEM(FName, mesh=None, color=False, normalize=True):
    if FName.endswith("mgz"):
        import nibabel
        myarray=nibabel.load(FName).get_fdata()[:,:, 0]
        myarray = myarray / np.max(myarray)
        img = Image.fromarray(np.uint8(myarray*255))
        img = img.convert("RGB")
    else:
        img = Image.open(FName)
    xPixel = np.shape(img)[0]
    yPixel = np.shape(img)[1]
    
    #which of the color channels to process
    if color:
        Channels = (0,1,2)
        img.convert("RGB")
    else:
        Channels = (0,)
        img.convert("L")        
    
    if mesh == None:
        mesh = RectangleMesh(MPI.comm_world, Point(0.0, 0.0), Point(img.size[0], img.size[1]), int(img.size[0]), int(img.size[1]), "right")
    
    #Key mapping between global vertex index (input) and (i,j) pixel coordinate (output)
    #needs to be changed if the diagonal is not "right"
    PixID = np.zeros([2*img.size[0]*img.size[1], 2], dtype="uint")
    for i in range(img.size[0]):
        for j in range(img.size[1]):
            #print "init:", i, j
            PixID[2*(img.size[0]*j + i)+0, 0] = img.size[1] - 1 - j
            PixID[2*(img.size[0]*j + i)+1, 0] = img.size[1] - 1 - j
            PixID[2*(img.size[0]*j + i)+0, 1] = i
            PixID[2*(img.size[0]*j + i)+1, 1] = i
    
    if color:
        ImgSpace = VectorFunctionSpace(mesh, "DG", 0, len(Channels))
    else:
        ImgSpace = FunctionSpace(mesh, "DG", 0)
        
    ImgFunction = Function(ImgSpace)
    ImgFunction.rename("image", "")
    Fvalues = np.zeros(ImgFunction.vector().local_size())
        
    for chan in Channels:
        if color:
            ImgDofs = ImgSpace.sub(chan).dofmap()
            cData = np.array(img.getchannel(chan))
        else:
            ImgDofs = ImgSpace.dofmap() #tabulate_dof_coordinates()
            cData = np.array(img.getchannel(chan))
        
        for c in cells(mesh):
            #local and global index of cell
            LID = c.index()
            GID = c.global_index()
            #local dof in DG0 function
            FID = ImgDofs.cell_dofs(c.index())[0]
            #get grey value from image
            MyGrey = cData[PixID[GID, 0], PixID[GID, 1]]
            #map 0..255 grey steps to [0,1]
            normval = 1
            if normalize:
                normval = 255.0
            fValue = MyGrey/(normval)
            Fvalues[FID] = fValue

    #Set function values
    ImgFunction.vector().set_local(Fvalues)
    ImgFunction.vector().apply("")


    Space = FunctionSpace(mesh, hyperparameters["state_functionspace"], hyperparameters["state_functiondegree"])
    ImgFunction = project(sqrt(inner(ImgFunction, ImgFunction)), Space)
        
    return mesh, ImgFunction, len(Channels)

In [None]:
FName2 = str(hyperparameters["image"])

rawF = nibabel.load(FName2).get_fdata()

if hyperparameters["normalize"]:
    rawF /= np.max(rawF)

    rawF = np.where(rawF < 0, 0, rawF)

In [None]:
mesh2.coordinates().shape

In [None]:
mymesh, my_img, channels = read_image(hyperparameters, name="image", mesh=None, printout=True, 
                                      normalize=hyperparameters["normalize"])
mymesh1, my_img1, channels1 = read_image(hyperparameters, name="image", mesh=None, degree=hyperparameters["state_functiondegree"],
                                         printout=True, normalize=hyperparameters["normalize"])

In [None]:
fenmesh, fen_img1, fenchannels = Pic2FEM(FName=hyperparameters["image"], 
                                         color=False, normalize=hyperparameters["normalize"],
                                         mesh=None)#hyperparameters["state_functiondegree"])

In [None]:
Space = FunctionSpace(fenmesh, hyperparameters["state_functionspace"], 
                      0
                      #hyperparameters["state_functiondegree"]
                     )
fen_img = project(fen_img1, Space)

In [None]:
my_img1.vector()[:].size / my_img.vector()[:].size

In [None]:
print(np.mean(rawF))

print(assemble(fen_img1*dx(domain=fenmesh)) / assemble(1*dx(domain=fenmesh)))
print(assemble(fen_img*dx(domain=fenmesh)) / assemble(1*dx(domain=fenmesh)))
print(assemble(my_img*dx(domain=mymesh)) / assemble(1*dx(domain=mymesh)))
print(assemble(my_img1*dx(domain=mymesh1)) / assemble(1*dx(domain=mymesh1)))

In [None]:
rawF.size

In [None]:
rawF.shape

In [None]:
nx, ny = rawF.shape[0], rawF.shape[1]
V=my_img1.function_space()

dn = 1

mesh2 = UnitSquareMesh(MPI.comm_world, nx - dn, ny - dn)

V = FunctionSpace(mesh2, "DG", 1)

print(np.unique(V.tabulate_dof_coordinates(), axis=0).shape)
print(np.unique(V.tabulate_dof_coordinates(), axis=1).shape)

In [None]:
ax = plot(fen_img1)
plt.colorbar(ax)
plt.show()
ax = plot(fen_img)
plt.colorbar(ax)
plt.show()
ax = plot(my_img)
plt.colorbar(ax)
plt.show()
fig, ax1 = plt.subplots(1)

"""
for x in range(nx):
    for y in range(ny):
        ax1.set_title("image[x,y]")
        ax1.scatter(y,x, c=my_img[x,y], 
                    cmap="viridis", vmin=0, vmax=1,
                    s=240, marker="s")
for ax in [ax1]:
    ax.invert_yaxis()
    ax.set_aspect(1)
plt.colorbar()
plt.show()
"""

ax = plot(my_img1)
plt.colorbar(ax)
plt.show()

In [None]:
print(assemble((my_img1-my_img)*dx(domain=mymesh1)))
print(assemble((fen_img-fen_img1)*dx(domain=fenmesh)))

In [None]:
print(fenmesh.num_cells(), fenmesh.num_vertices())
print(mymesh.num_cells(), mymesh.num_vertices())

In [None]:
print(fen_img1.function_space())
print(fen_img.function_space())
print(my_img.function_space())

In [None]:
print(fen_img1.function_space().tabulate_dof_coordinates().shape)
print(fen_img.function_space().tabulate_dof_coordinates().shape)
print(my_img.function_space().tabulate_dof_coordinates().shape)

In [None]:
print(fen_img.vector()[:].shape)
print(my_img.vector()[:].shape)

In [None]:
image = nibabel.load(FName2).get_fdata()

In [None]:
nx, ny = image.shape[0], image.shape[1]

mesh2 = UnitSquareMesh(MPI.comm_world, nx, ny)

VDG0=FunctionSpace(mesh2, "DG", 0)
VDG1=FunctionSpace(mesh2, "DG", 1)
u0=Function(VDG0)
u1=Function(VDG1)

In [None]:
image.size

In [None]:
u0.vector()[:].shape

In [None]:
u1.vector()[:].shape

In [None]:
(rawF.shape[0]+1)*(rawF.shape[1]+1)

In [None]:
VDG0.tabulate_dof_coordinates().shape

In [None]:
VDG1.tabulate_dof_coordinates().shape

In [None]:
np.unique(VDG0.tabulate_dof_coordinates(), axis=0).shape

In [None]:
np.unique(VDG0.tabulate_dof_coordinates(), axis=1).shape

In [None]:
np.unique(VDG1.tabulate_dof_coordinates(), axis=0).shape

In [None]:
np.unique(VDG1.tabulate_dof_coordinates(), axis=1).shape

In [None]:
VDG1.tabulate_dof_coordinates().shape

In [None]:
xy = my_img.function_space().tabulate_dof_coordinates()

In [None]:
xy.shape

In [None]:
image.shape[0]*image.shape[1]

In [None]:
newimg = np.zeros_like(image) + np.nan
for x in range(nx):
    for y in range(ny):
        newimg[x,y] = u0(x/(nx-1),y/(ny-1))
        # newimg[x,y] = x/(nx-1)*y/(ny-1)
        
assert np.sum(np.isnan(newimg)) == 0

In [None]:
print(np.max(newimg))

In [None]:
plt.imshow(newimg)
plt.colorbar()

In [None]:
nx, ny = image.shape[0], image.shape[1]

dn = 1

mesh2 = UnitSquareMesh(MPI.comm_world, nx - dn, ny - dn)

VDG0=FunctionSpace(mesh2, "DG", 0)
VDG1=FunctionSpace(mesh2, "DG", 1)
u0=Function(VDG0)

u0 = interpolate(Expression("x[0]*x[1]", degree=3), VDG0)
print(assemble(u0*dx))
u0 = project(u0, VDG1)
print(assemble(u0*dx))

u1 = Function(VDG1)

In [None]:
plt.scatter(mesh2.coordinates()[:,0], mesh2.coordinates()[:,1])

xy = u0.function_space().tabulate_dof_coordinates()
plt.scatter(xy[:,0], xy[:,1], c="r", s=1)
print(np.sum(np.where(np.unique(xy, axis=0)[:, 0]==0, 1, 0)))
print(np.sum(np.where(np.unique(xy, axis=0)[:, 1]==0, 1, 0)))
                       
print(np.unique(xy, axis=0).shape)

xy = u1.function_space().tabulate_dof_coordinates()
#plt.scatter(xy[:,0], xy[:,1], c="yellow", s=1)
print(np.sum(np.where(np.unique(xy, axis=0)[:, 0]==0, 1, 0)))
print(np.sum(np.where(np.unique(xy, axis=0)[:, 1]==0, 1, 0)))
                       
print(np.unique(xy, axis=0).shape)

for x in range(nx):
    for y in range(ny):
        plt.scatter(x/nx, y/ny, s=1, c="k")
        
#for cell in cells(mesh2):
#    plt.scatter(x=cell.midpoint()[:][0],y=cell.midpoint()[:][1], c="navy")

In [None]:
do_norm = True

In [None]:
if do_norm:
    image2 = np.copy(image)
    image2 /= np.max(image2)

    image2 = np.where(image2 < 0, 0, image2)
    
else:
    image2 = np.copy(image)

In [None]:
_, my_img, _ = read_image(hyperparameters, name="image", 
                                      mesh=None, printout=True, normalize=do_norm)
_, my_img1, _ = read_image(hyperparameters, name="image", degree=1,
                                      mesh=None, printout=True, normalize=do_norm)

_, fen_img0, _ = Pic2FEM(FName=hyperparameters["image"], 
                                         color=True, normalize=do_norm,
                                         mesh=None)#hyperparameters["state_functiondegree"])

_, fen_img1, _ = Pic2FEM(FName=hyperparameters["image"], 
                                         color=False, normalize=do_norm,
                                         mesh=None)#hyperparameters["state_functiondegree"])

In [None]:
print(np.mean(image2))

print(assemble(fen_img1*dx(domain=fenmesh)) / assemble(1*dx(domain=fenmesh)))
print(assemble(fen_img0*dx(domain=fenmesh)) / assemble(1*dx(domain=fenmesh)))
print(assemble(my_img*dx(domain=mymesh)) / assemble(1*dx(domain=mymesh)))
print(assemble(my_img1*dx(domain=mymesh1)) / assemble(1*dx(domain=mymesh1)))

In [None]:
ax = plot(my_img)
plt.colorbar(ax)
plt.show()
ax = plot(my_img1)
plt.colorbar(ax)
plt.show()
ax = plot(fen_img0)
plt.colorbar(ax)
plt.show()
ax = plot(fen_img1)
plt.colorbar(ax)
plt.show()

In [None]:
fig, ax1 = plt.subplots(1)
fig2, ax2 = plt.subplots(1)
fig3, ax3 = plt.subplots(1)
fig4, ax4 = plt.subplots(1)
fig5, ax5 = plt.subplots(1)

for x in range(nx):
    for y in range(ny):
        ax1.set_title("image[x,y]")
        ax1.scatter(y,x, c=image[x,y], 
                    cmap="viridis", vmin=0, vmax=1,
                    s=240, marker="s")
        
        ax2.set_title("my_img(x/nx,y/ny)")
        ax2.scatter(y,x, c=my_img(x/nx,y/ny), 
                    cmap="viridis", vmin=0, vmax=1,
                    s=240, marker="s")
        
        ax3.set_title("my_img1(x/nx,y/ny)")
        ax3.scatter(y,x, c=my_img1(x/nx,y/ny), 
                    cmap="viridis", vmin=0, vmax=1,
                    s=240, marker="s")
        ax4.set_title("fen_img0(y, x)")
        ax4.scatter(y,x, c=fen_img0(y,x), 
                    cmap="viridis", vmin=0, vmax=1,
                    s=240, marker="s")
        ax5.set_title("fen_img1(y, x)")
        ax5.scatter(y,x, c=fen_img1(y,x), 
                    cmap="viridis", vmin=0, vmax=1,
                    s=240, marker="s")
        
for ax in [ax1, ax2, ax3, ax4, ax5]:
    # ax = plt.gca()
    ax.invert_yaxis()
    ax.set_aspect(1)
#plt.colorbar()
plt.show()

plt.title("imshow image")
plt.imshow(image2, vmin=0, vmax=1)
#plt.colorbar()
plt.show()