Skip to content
This repository has been archived by the owner on Nov 28, 2023. It is now read-only.

Commit

Permalink
Merge pull request #120 from DeepRank/NeuralNet_class_weights
Browse files Browse the repository at this point in the history
Added class weights to the CrossEntropy Loss
  • Loading branch information
CunliangGeng committed Feb 12, 2020
2 parents 9f86546 + d021ad0 commit 3ab9097
Show file tree
Hide file tree
Showing 4 changed files with 14 additions and 7 deletions.
2 changes: 1 addition & 1 deletion .travis.yml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
sudo: false
distr : "trusty"
distr : "bionic"

language: generic

Expand Down
9 changes: 8 additions & 1 deletion deeprank/learn/NeuralNet.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ class NeuralNet():

def __init__(self, data_set, model,
model_type='3d', proj2d=0, task='reg',
class_weights = None,
pretrained_model=None,
cuda=False, ngpu=0,
plot=True,
Expand Down Expand Up @@ -59,6 +60,10 @@ def __init__(self, data_set, model,
The loss function, the target datatype and plot functions
will be autmatically adjusted depending on the task.
class_weights (Tensor): a manual rescaling weight given to
each class. If given, has to be a Tensor of size #classes.
Only applicable on 'class' task.
pretrained_model (str): Saved model to be used for further
training or testing.
Expand Down Expand Up @@ -108,6 +113,8 @@ def __init__(self, data_set, model,
# pretrained model
self.pretrained_model = pretrained_model

self.class_weights = class_weights

if isinstance(self.data_set, (str, list)) and pretrained_model is None:
raise ValueError(
'Argument data_set must be a DeepRankDataSet object\
Expand Down Expand Up @@ -169,7 +176,7 @@ def __init__(self, data_set, model,
self._plot_scatter = self._plot_scatter_reg

elif self.task == 'class':
self.criterion = nn.CrossEntropyLoss(reduction='sum')
self.criterion = nn.CrossEntropyLoss(weight = self.class_weights, reduction='mean')
self._plot_scatter = self._plot_boxplot_class
self.data_set.normalize_targets = False

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
'pandas',
'matplotlib',
'torchsummary',
'torch',
'torch < 1.4.0',
'pdb2sql >= 0.2.1',
'freesasa==2.0.3.post7;platform_system=="Linux"',
'freesasa==2.0.3.post6;platform_system=="Darwin"'
Expand Down
8 changes: 4 additions & 4 deletions test/test_learn.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def test_learn_3d_reg_mapfly():
model.train(
nepoch=5,
divide_trainset=0.8,
train_batch_size=5,
train_batch_size=2,
preshuffle_seed=2019,
num_workers=0)

Expand Down Expand Up @@ -116,7 +116,7 @@ def test_learn_3d_reg():
model.train(
nepoch=5,
divide_trainset=0.8,
train_batch_size=5,
train_batch_size=2,
num_workers=0,
preshuffle_seed=2019,
save_model='all')
Expand Down Expand Up @@ -161,7 +161,7 @@ def test_learn_3d_class():
model.train(
nepoch=5,
divide_trainset=0.8,
train_batch_size=5,
train_batch_size=2,
num_workers=0,
save_epoch='all')

Expand Down Expand Up @@ -210,7 +210,7 @@ def test_learn_2d_reg():
model.train(
nepoch=5,
divide_trainset=0.8,
train_batch_size=5,
train_batch_size=2,
num_workers=0)

@unittest.skipIf(skip, "torch fails on Travis")
Expand Down

0 comments on commit 3ab9097

Please sign in to comment.