Skip to content
This repository has been archived by the owner on Nov 28, 2023. It is now read-only.

Commit

Permalink
Merge pull request #78 from DeepRank/issue77_modelSummary
Browse files Browse the repository at this point in the history
add model summary
  • Loading branch information
LilySnow committed May 13, 2019
2 parents dc61c0b + 837febd commit 37bac16
Showing 1 changed file with 11 additions and 0 deletions.
11 changes: 11 additions & 0 deletions deeprank/learn/NeuralNet.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
import matplotlib.pyplot as plt
import matplotlib.ticker as mtick
import numpy as np
from torchsummary import summary
import pdb # xue


import torch
Expand Down Expand Up @@ -196,6 +198,15 @@ def __init__(self,data_set,model,
# load the model
self.net = model(self.data_set.input_shape)

# model summary
sys.stdout.flush()
if cuda is True:
device = torch.device("cuda") # PyTorch v0.4.0
else:
device = torch.device("cpu")
summary(self.net.to(device), self.data_set.input_shape)
sys.stdout.flush()

# load parameters of pretrained model if provided
if self.pretrained_model:
## a prefix 'module.' is added to parameter names if torch.nn.DataParallel was used
Expand Down

0 comments on commit 37bac16

Please sign in to comment.