Skip to content
This repository has been archived by the owner on Nov 28, 2023. It is now read-only.

update paper data example #245

Merged
merged 9 commits into from
Dec 22, 2021
Merged

update paper data example #245

merged 9 commits into from
Dec 22, 2021

Conversation

manonreau
Copy link
Contributor

Add examples of code to run the CNN models from the paper on new data:

  • Example with 3DeepFace

  • Example with the docking scoring model

  • Changed NeuralNet and DataSet to handle test set with no target value

@coveralls
Copy link

coveralls commented Dec 20, 2021

Pull Request Test Coverage Report for Build 1611058010

  • 0 of 0 changed or added relevant lines in 0 files are covered.
  • No unchanged relevant lines lost coverage.
  • Overall coverage remained the same at 77.12%

Totals Coverage Status
Change from base Build 1602467931: 0.0%
Covered Lines: 1628
Relevant Lines: 2111

💛 - Coveralls

try:
self.hit_cutoff = self.state['hit_cutoff']
except Exception:
print('No hit_cutoff provided')
Copy link
Member

@CunliangGeng CunliangGeng Dec 21, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please use logger to output messages.
I guess this Exception only applies to our published models. If so, we have to tell users what to do in the message, like No "hit_cutoff" found in {model}. Please set it in function "test()" when doing prediction".

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The hit cutoff is used only if target values are provided. In prediction mode you do not need it, you need it for benchmark only.

@@ -766,7 +770,8 @@ def _train(self, index_train, index_valid, index_test,
return torch.cat([param.data.view(-1)
for param in self.net.parameters()], 0)

def _epoch(self, data_loader, train_model):

def _epoch(self, data_loader, train_model, test_model=False):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The name test_model is misleading. It's acutally a flag to tell whether there is target values or not in the data. So I suggest using e.g. has_target=True?
Also, please add the docstring for new parameters.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very good suggestion, thanks !

Copy link
Member

@CunliangGeng CunliangGeng left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @manonreau I left a few comments for you to check

Copy link
Contributor Author

@manonreau manonreau left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks Liang for the review, I modified the code accordingly.

deeprank/learn/DataSet.py Outdated Show resolved Hide resolved
deeprank/learn/DataSet.py Outdated Show resolved Hide resolved
try:
self.hit_cutoff = self.state['hit_cutoff']
except Exception:
print('No hit_cutoff provided')
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The hit cutoff is used only if target values are provided. In prediction mode you do not need it, you need it for benchmark only.

@@ -766,7 +770,8 @@ def _train(self, index_train, index_valid, index_test,
return torch.cat([param.data.view(-1)
for param in self.net.parameters()], 0)

def _epoch(self, data_loader, train_model):

def _epoch(self, data_loader, train_model, test_model=False):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very good suggestion, thanks !

Copy link
Member

@CunliangGeng CunliangGeng left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @manonreau , the changes look good!
One reminder for you is that we don't use print but logger for outputing message in DeepRank :-)

@manonreau manonreau merged commit 6783df2 into master Dec 22, 2021
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants