diff --git a/deeprank/learn/NeuralNet.py b/deeprank/learn/NeuralNet.py index cd394ec5..9a15aadf 100644 --- a/deeprank/learn/NeuralNet.py +++ b/deeprank/learn/NeuralNet.py @@ -9,7 +9,6 @@ import matplotlib.ticker as mtick import numpy as np from torchsummary import summary -import pdb # xue import torch @@ -204,7 +203,7 @@ def __init__(self,data_set,model, device = torch.device("cuda") # PyTorch v0.4.0 else: device = torch.device("cpu") - summary(self.net.to(device), self.data_set.input_shape) + summary(self.net.to(device), self.data_set.input_shape, device = device.type) sys.stdout.flush() # load parameters of pretrained model if provided