In [None]:
from deltamic import Gaussian_psf, compute_spatial_frequency_grid,normalize_tensor,render_image_from_ftmesh,Fourier3dMesh,compute_box_size,generate_gaussian_psf
import trimesh
import numpy as np
import torch
import skimage.io as io 
from skimage import img_as_ubyte
from time import time

"""
Creation of
initial 
configurations: 
"""

assert torch.cuda.is_available()

device = 'cuda:0'
box_shape = np.array([200]*3)

filename = "../data/spot.obj"

Mesh_gt = trimesh.load(filename)
faces = np.array(Mesh_gt.faces)
verts = np.array(Mesh_gt.vertices)

box_size = compute_box_size(verts, offset = 0.2)
Verts = torch.tensor(verts, dtype = torch.float, device = device,requires_grad = True)
Faces = torch.tensor(faces, dtype = torch.long, device = device)
Faces_coeff = torch.ones(len(Faces),dtype = torch.float, device = device)
box_size = torch.tensor(box_size) 

print("Vertices shape: ",verts.shape,"Faces shape: ", faces.shape)
print("Min/Max x/y/z position of vertices: ",verts.min(axis=0),verts.max(axis=0))
print("box_size: ",box_size, "box_shape",box_shape)

narrowband_thresh = torch.tensor(0,dtype = torch.float, device = device)
meshFT = Fourier3dMesh(box_size,box_shape,device=device, dtype = torch.float32)

sigma_matrix = (1e-3*torch.eye(3,device=device))
OTF = generate_gaussian_psf(sigma_matrix,meshFT.xi0,meshFT.xi1,meshFT.xi2).to(device)

"""
Image creation:
"""

t1 = time()
ftmesh = meshFT(Verts,Faces, Faces_coeff)

image_ft=normalize_tensor(render_image_from_ftmesh(ftmesh, OTF, box_shape))
t2 = time()

io.imsave("Spot.tif",img_as_ubyte(image_ft.detach().cpu().numpy()))
print("forward pass computed successfully in", t2-t1,"seconds", "Image saved as Spot.tif")


In [None]:
cp ../../../Projects/Differentiable_rendering_3D/Meshes_benchmark/spot.obj data

In [None]:
from deltamic import normalize_tensor,render_image_from_ftmesh,Fourier3dMesh,generate_gaussian_psf, center_verts_in_box
from tqdm import tqdm
import trimesh
import numpy as np
import os
import torch
import skimage.io as io 
from largesteps.optimize import AdamUniform
from largesteps.geometry import compute_matrix
from largesteps.solvers import CholeskySolver


folder_result = "Results/"
def create_dir(folder_name):
    try: os.mkdir(folder_name)
    except: pass
  
create_dir(folder_result)
create_dir(folder_result+"Images")
create_dir(folder_result+"Meshes")

"""
Parameters:
"""

device = 'cuda:0'
name_micim = "Spot.tif"

"""
Creation of
initial 
configurations: 
"""

micim = normalize_tensor(torch.tensor(io.imread(name_micim).astype(np.float32),dtype=torch.float)).to(device)
box_size = np.array([[-0.9058,  1.2208],
                     [-0.9058,  1.2208],
                     [-0.9058,  1.2208]])
mesh = trimesh.primitives.Sphere(subdivisions = 4)
verts, faces = np.array(mesh.vertices)*2,np.array(mesh.faces)
verts = center_verts_in_box(verts, box_size, offset=.4)
Verts = torch.tensor(verts, dtype = torch.float, device = device)
Faces = torch.tensor(faces, dtype = torch.long, device = device)
Faces_coeff = torch.ones(len(Faces),dtype = torch.float, device = device)
box_size = torch.tensor(box_size) 
box_shape= np.array([200]*3)

meshFT = Fourier3dMesh(box_size,box_shape,device=device, dtype = torch.float32,narrowband_thresh=1e-2)


#Gaussian PSF
model = Gaussian_psf(box_shape,box_size[:,1],sigma = 1e3,device = device)


print("box_size: ",box_size, "Box_ft_shape",box_shape)

"""
Large 
Steps 
Optimization
"""

lr_base = 0.01
lambda_=50.0
alpha = 1.0
Verts.requires_grad = True 
optimizer_geometry = AdamUniform([{'params': Verts}],lr=lr_base)
optimizer_psf = torch.optim.Adam(model.parameters(), lr = lr_base)


losses = []

M = compute_matrix(Verts, Faces, lambda_)
solver = CholeskySolver(M@M)
    
for k in (pbar:=tqdm(range(10000))):
    
    
    #Loop of the optimization of the position of the vertices once we got a good approximation of the PSF
    ###############################
    optimizer_geometry.zero_grad()
    ###############################
    
    PSF = model.forward()
    OTF = torch.abs(torch.fft.fftshift(torch.fft.fftn(PSF)))
    meshFT.OTF = OTF
    
    ftmesh = meshFT(Verts, Faces,Faces_coeff)
    image_ft=normalize_tensor(render_image_from_ftmesh(ftmesh, OTF, box_shape))
    
    loss_mse = torch.mean(((image_ft-micim)**2)*(micim**alpha))
    loss = loss_mse
    loss.backward()
    
    pbar.set_description("Current_Loss: "+str([loss.item()]))
    
    
    with torch.no_grad():
        Verts.grad = solver.solve(Verts.grad)
    
    np.save(folder_result+'Meshes/'+str(k)+".npy",(Verts.detach().cpu().numpy(),Faces.detach().cpu().numpy()))
    ###############################
    optimizer_geometry.step()
    ###############################
    
    with torch.no_grad():
        loss_mse = torch.mean((image_ft-micim)**2)
    
    losses.append(loss_mse.item())
    np.save(folder_result+'Meshes/loss.npy',losses)
  