Skip to content
This repository has been archived by the owner on Nov 28, 2023. It is now read-only.

Added class weights to the CrossEntropy Loss #120

Merged
merged 5 commits into from
Feb 12, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .travis.yml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
sudo: false
distr : "trusty"
distr : "bionic"

language: generic

Expand Down
9 changes: 8 additions & 1 deletion deeprank/learn/NeuralNet.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ class NeuralNet():

def __init__(self, data_set, model,
model_type='3d', proj2d=0, task='reg',
class_weights = None,
pretrained_model=None,
cuda=False, ngpu=0,
plot=True,
Expand Down Expand Up @@ -59,6 +60,10 @@ def __init__(self, data_set, model,
The loss function, the target datatype and plot functions
will be autmatically adjusted depending on the task.

class_weights (Tensor): a manual rescaling weight given to
each class. If given, has to be a Tensor of size #classes.
Only applicable on 'class' task.

pretrained_model (str): Saved model to be used for further
training or testing.

Expand Down Expand Up @@ -108,6 +113,8 @@ def __init__(self, data_set, model,
# pretrained model
self.pretrained_model = pretrained_model

self.class_weights = class_weights

if isinstance(self.data_set, (str, list)) and pretrained_model is None:
raise ValueError(
'Argument data_set must be a DeepRankDataSet object\
Expand Down Expand Up @@ -169,7 +176,7 @@ def __init__(self, data_set, model,
self._plot_scatter = self._plot_scatter_reg

elif self.task == 'class':
self.criterion = nn.CrossEntropyLoss(reduction='sum')
self.criterion = nn.CrossEntropyLoss(weight = self.class_weights, reduction='mean')
self._plot_scatter = self._plot_boxplot_class
self.data_set.normalize_targets = False

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
'pandas',
'matplotlib',
'torchsummary',
'torch',
'torch < 1.4.0',
'pdb2sql >= 0.2.1',
'freesasa==2.0.3.post7;platform_system=="Linux"',
'freesasa==2.0.3.post6;platform_system=="Darwin"'
Expand Down
8 changes: 4 additions & 4 deletions test/test_learn.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def test_learn_3d_reg_mapfly():
model.train(
nepoch=5,
divide_trainset=0.8,
train_batch_size=5,
train_batch_size=2,
preshuffle_seed=2019,
num_workers=0)

Expand Down Expand Up @@ -116,7 +116,7 @@ def test_learn_3d_reg():
model.train(
nepoch=5,
divide_trainset=0.8,
train_batch_size=5,
train_batch_size=2,
num_workers=0,
preshuffle_seed=2019,
save_model='all')
Expand Down Expand Up @@ -161,7 +161,7 @@ def test_learn_3d_class():
model.train(
nepoch=5,
divide_trainset=0.8,
train_batch_size=5,
train_batch_size=2,
num_workers=0,
save_epoch='all')

Expand Down Expand Up @@ -210,7 +210,7 @@ def test_learn_2d_reg():
model.train(
nepoch=5,
divide_trainset=0.8,
train_batch_size=5,
train_batch_size=2,
num_workers=0)

@unittest.skipIf(skip, "torch fails on Travis")
Expand Down