diff --git a/deeprank/learn/NeuralNet.py b/deeprank/learn/NeuralNet.py index 0eeedb76..cd394ec5 100644 --- a/deeprank/learn/NeuralNet.py +++ b/deeprank/learn/NeuralNet.py @@ -8,6 +8,8 @@ import matplotlib.pyplot as plt import matplotlib.ticker as mtick import numpy as np +from torchsummary import summary +import pdb # xue import torch @@ -196,6 +198,15 @@ def __init__(self,data_set,model, # load the model self.net = model(self.data_set.input_shape) + # model summary + sys.stdout.flush() + if cuda is True: + device = torch.device("cuda") # PyTorch v0.4.0 + else: + device = torch.device("cpu") + summary(self.net.to(device), self.data_set.input_shape) + sys.stdout.flush() + # load parameters of pretrained model if provided if self.pretrained_model: ## a prefix 'module.' is added to parameter names if torch.nn.DataParallel was used