Skip to content

Commit

Permalink
Merge pull request #43 from PolarizedLightFieldMicroscopy/NN
Browse files Browse the repository at this point in the history
NN
  • Loading branch information
gschlafly committed May 7, 2023
2 parents 76f7d2a + 0aac2f5 commit d0acbf8
Show file tree
Hide file tree
Showing 10 changed files with 361 additions and 218 deletions.
140 changes: 90 additions & 50 deletions src/napari_lf/_widgetLF.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,28 +220,26 @@ def input_model_change_call():
from napari_lf.lfa.neural_nets.LFNeuralNetworkProto import LFNeuralNetworkProto
except:
from lfa.neural_nets.LFNeuralNetworkProto import LFNeuralNetworkProto


# Define input shape, and extract it either from a calib file or the stored checkpoint
LFshape = None
# Load calib file
if self.gui.gui_elms["lfmnet"]["calibration_file"].value == None:
return
calibFile_path = str(os.path.join(str(self.gui.gui_elms["main"]["img_folder"].value), self.gui.gui_elms["lfmnet"]["calibration_file"].value))
path = Path(calibFile_path)
if path.is_file():
import h5py
with h5py.File(calibFile_path, "r") as f:
lf = f['geometry']
LFshape = [lf.attrs['nu'], lf.attrs['nv'], lf.attrs['ns'], lf.attrs['nt']]

# LFshape = [lf.nu, lf.nv, lf.ns, lf.nt]
# VCDNet
pass
else:
calibFile_path = str(os.path.join(str(self.gui.gui_elms["main"]["img_folder"].value), self.gui.gui_elms["lfmnet"]["calibration_file"].value))
path = Path(calibFile_path)
if path.is_file():
import h5py
with h5py.File(calibFile_path, "r") as f:
lf = f['geometry']
LFshape = [lf.attrs['nu'], lf.attrs['nv'], lf.attrs['ns'], lf.attrs['nt']]

if self.gui.gui_elms["lfmnet"]["input_model"].value == None:
return
checkpoint_path = str(os.path.join(str(self.gui.gui_elms["main"]["img_folder"].value), self.gui.gui_elms["lfmnet"]["input_model"].value))
# LFMNet
# checkpoint_path = '../checkpoints/*.ckpt'
# Load network based on checkpoint
#print(checkpoint_path)
#print(LFshape)

# Load Network
net = LFNeuralNetworkProto.load_network_from_file(checkpoint_path, LFshape)

# Set network into evaluation mode (faster ode)
Expand All @@ -266,13 +264,25 @@ def input_model_change_call():
self.gui.populate_cal_img_list()
self.gui.load_plugin_prefs()

if "mode_choice" in self.gui.settings["main"] and self.gui.settings["main"]["mode_choice"] == 'NeuralNet':
self.gui.LFAnalyze_btn_cont.hide()
self.gui.NeuralNet_btn_cont.show()
self.gui.widget_main_bottom_comps0.hide()
self.gui.widget_main_bottom_comps1.hide()
self.gui.widget_main_bottom_comps2.show()
self.gui._cont_btn_processing.hide()
self.gui._cont_btn_processing2.show()
self.gui.NeuralNet_btn.toggle()
else:
self.gui.LFAnalyze_btn.toggle()

#Layout
layout = QVBoxLayout()
self.setLayout(layout)

self.setMinimumWidth(480)
self.layout().addWidget(self.gui.widget_main_top_comps.native)
self.layout().addWidget(self.gui.scroll_bottom)
self.layout().addWidget(self.gui.qtab_widget_top)
self.layout().addWidget(self.gui.widget_main_proc_btn_comps.native)
self.layout().setAlignment(Qt.AlignTop)
self.layout().setContentsMargins(0,0,0,0)
Expand Down Expand Up @@ -320,12 +330,23 @@ def set_lfa_libs(self):
print('LFA could not be loaded from:', self.gui.gui_elms["misc"]["lib_folder"].value)
self.gui.gui_elms["misc"]["lib_ver_label"].value = 'Error!'
print(traceback.format_exc())

def closeEvent(self, event):
self.gui.save_plugin_prefs()
if self.gui.timer is not None:
self.gui.timer.stop()
# print('closeEvent')

def hideEvent(self, event):
self.gui.save_plugin_prefs()
if self.gui.timer is not None:
self.gui.timer.stop()
# print('hideEvent')

def showEvent(self, event):
if self.gui.timer is not None:
self.gui.timer.start(500)
# print('showEvent')

#Event Filter
def eventFilter(self, source, event):
Expand Down Expand Up @@ -666,71 +687,90 @@ def run_lf_proj_vol(self, args):
def run_lf_net(self, args):
try:
args += ['--solver','net']
gpu_id = 0
# torch_device = torch.device("cuda:1")
# torch_device = torch.device("cpu:1")
gpu_id = 1 #It seems that even when no GPU is selected, devices should be >=1 (Josué Page Vizcaíno)
try:
gpu_id = self.gui.gpu_choices.index(self.gui.gui_elms["hw"]["gpu_id"].value)
gpu_id = max(self.gui.gpu_choices.index(self.gui.gui_elms["hw"]["gpu_id"].value)+1,1)
except Exception as err:
pass
if self.gui.gui_elms["hw"]["disable_gpu"].value == True:
args += ['--disable_gpu']
args += ['--gpu_id', gpu_id]
else:
args += ['--gpu_id', gpu_id]

print(args)
print("LFMNet process")
print("Neural Net process")

print('\t--> hostname:{host}'.format(host=lfdeconvolve.socket.gethostname()))
print('\t--> specified gpu-id:{gpuid}'.format(gpuid=gpu_id))

if '--disable-gpu' not in args:
print('\t--> specified gpu-id:{gpuid}'.format(gpuid=gpu_id))

import torch
if '--disable-gpu' not in args:
print('\t--> cuda available:{gpu}'.format(gpu=torch.cuda.is_available()))
print("\t--> device: ","cuda:"+str(gpu_id) if torch.cuda.is_available() and not '--disable-gpu' in args else "cpu")

try:
from napari_lf.lfa.neural_nets.LFNeuralNetworkProto import LFNeuralNetworkProto
from napari_lf.lfa.lflib.lightfield import LightField
except:
from lfa.neural_nets.LFNeuralNetworkProto import LFNeuralNetworkProto
# Load calib file
if self.gui.gui_elms["lfmnet"]["calibration_file"].value == None:
return

calibFile_path = str(os.path.join(str(self.gui.gui_elms["main"]["img_folder"].value), self.gui.gui_elms["lfmnet"]["calibration_file"].value))
path = Path(calibFile_path)
if path.is_file() == False:
return
# Loadim the calibration data
calibration_file = lfdeconvolve.retrieve_calibration_file(calibFile_path, id=str(gpu_id))
lfcal = lfdeconvolve.LightFieldCalibration.load(calibration_file)
print('\t--> loaded calibration file: %s' % (calibFile_path))
from lfa.lflib.lightfield import LightField


cal_present = False
# # Load calib file
# if self.gui.gui_elms["lfmnet"]["calibration_file"].value != None:
# calibFile_path = str(os.path.join(str(self.gui.gui_elms["main"]["img_folder"].value), self.gui.gui_elms["lfmnet"]["calibration_file"].value))
# path = Path(calibFile_path)
# if path.is_file() != False:
# # Loadim the calibration data
# calibration_file = lfdeconvolve.retrieve_calibration_file(calibFile_path, id=str(gpu_id))
# lfcal = lfdeconvolve.LightFieldCalibration.load(calibration_file)
# print('\t--> loaded calibration file: %s' % (calibFile_path))

# cal_present = True

# Check if input file selected
if self.gui.gui_elms["lfmnet"]["input_file"].value == None:
return
LF_File_path = str(os.path.join(str(self.gui.gui_elms["main"]["img_folder"].value), self.gui.gui_elms["lfmnet"]["input_file"].value))

#Import LF image
im = lfdeconvolve.load_image(LF_File_path, dtype=lfdeconvolve.np.float32, normalize = False)
print('\t--> loaded LF file: %s. Pixel values range: [%d, %d]' % (LF_File_path, int(im.min()), int(im.max())))

# Rectify the image
# skip-alignment parameter is set by calib file
print('\t--> skip_alignment: %s' % (lfcal.skip_alignment))

lf = lfcal.rectify_lf(im)
LFshape = [lf.nu, lf.nv, lf.ns, lf.nt]

# VCDNet
LFshape = None
# if cal_present:
# # Rectify the image
# # skip-alignment parameter is set by calib file
# print('\t--> skip_alignment: %s' % (lfcal.skip_alignment))
# lf = lfcal.rectify_lf(im)
# LFshape = [lf.nu, lf.nv, lf.ns, lf.nt]

# Network path present?
if self.gui.gui_elms["lfmnet"]["input_model"].value == None:
return
checkpoint_path = str(os.path.join(str(self.gui.gui_elms["main"]["img_folder"].value), self.gui.gui_elms["lfmnet"]["input_model"].value))
# LFMNet
# Load network based on checkpoint


# Load network based on checkpoint
net = LFNeuralNetworkProto.load_network_from_file(checkpoint_path, LFshape)
print('\t--> loaded model-checkpoint file: %s' % (checkpoint_path))

# Set network into evaluation mode (faster ode)
net.eval()

# If there was no calibration file, extract info about lightfield from network

if not cal_present: # Load LFshape from model
LFshape = net.LF_in_shape
print(LFshape)
lf = LightField(im, LFshape[0], LFshape[1], LFshape[2], LFshape[3],
representation = LightField.TILED_LENSLET)
## Process image:
with torch.no_grad():
# Move network to device (GPU/CPU)
torch_device = torch.device("cuda:"+str(gpu_id) if torch.cuda.is_available() and not '--disable_gpu' in args else "cpu")
torch_device = torch.device("cuda:"+str(gpu_id) if torch.cuda.is_available() and not '--disable-gpu' in args else "cpu")
net = net.to(torch_device)
# Prepare input to network
im_lenslet = lf.asimage(representation = lfdeconvolve.LightField.TILED_LENSLET)
Expand Down

0 comments on commit d0acbf8

Please sign in to comment.