Skip to content

Commit

Permalink
Merge pull request #46 from PolarizedLightFieldMicroscopy/gpu_fix
Browse files Browse the repository at this point in the history
Fixed gpu issues with network
  • Loading branch information
pvjosue committed May 10, 2023
2 parents 4f4def7 + 720bd36 commit 8de1a38
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 9 deletions.
8 changes: 3 additions & 5 deletions src/napari_lf/_widgetLF.py
Original file line number Diff line number Diff line change
Expand Up @@ -687,11 +687,9 @@ def run_lf_proj_vol(self, args):
def run_lf_net(self, args):
try:
args += ['--solver','net']
# torch_device = torch.device("cuda:1")
# torch_device = torch.device("cpu:1")
gpu_id = 1 #It seems that even when no GPU is selected, devices should be >=1 (Josué Page Vizcaíno)
gpu_id = 0 #It seems that even when no GPU is selected, devices should be >=1 (Josué Page Vizcaíno)
try:
gpu_id = max(self.gui.gpu_choices.index(self.gui.gui_elms["hw"]["gpu_id"].value)+1,1)
gpu_id = self.gui.gpu_choices.index(self.gui.gui_elms["hw"]["gpu_id"].value)
except Exception as err:
pass

Expand Down Expand Up @@ -770,7 +768,7 @@ def run_lf_net(self, args):
## Process image:
with torch.no_grad():
# Move network to device (GPU/CPU)
torch_device = torch.device("cuda:"+str(gpu_id) if torch.cuda.is_available() and not '--disable-gpu' in args else "cpu")
torch_device = torch.device(gpu_id)
net = net.to(torch_device)
# Prepare input to network
im_lenslet = lf.asimage(representation = lfdeconvolve.LightField.TILED_LENSLET)
Expand Down
9 changes: 5 additions & 4 deletions src/napari_lf/lfa/neural_nets/LFNeuralNetworkProto.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,10 +106,11 @@ def load_network_from_file(file_name, LFshape=None):

# Extract which network to use from checkpoint file
checkpoint_file = glob.glob(file_name)
device = 'cpu'
if torch.cuda.is_available():
network_params = torch.load(checkpoint_file[0])
else:
network_params = torch.load(checkpoint_file[0], map_location='cpu')
device = 'cuda'

network_params = torch.load(checkpoint_file[0], map_location=device)
# Extract name
try:
network_name = network_params['hyper_parameters']['name']
Expand Down Expand Up @@ -138,7 +139,7 @@ def load_network_from_file(file_name, LFshape=None):
# Create network with stored hyperparameters
net = network_class(**network_hp)
# net = net_lib.Net(im_lenslet.shape, (64,)+im_lenslet.shape, network_settings_dict={'LFshape' : LFshape})
net.load_from_checkpoint(checkpoint_file[0], strict=False)
net.load_from_checkpoint(checkpoint_file[0], strict=False, map_location=device)
net.load_state_dict(network_params['state_dict'], strict=False)

return net

0 comments on commit 8de1a38

Please sign in to comment.