Skip to content
This repository has been archived by the owner on Nov 28, 2023. It is now read-only.

Commit

Permalink
update class weight
Browse files Browse the repository at this point in the history
  • Loading branch information
CunliangGeng committed Feb 12, 2020
1 parent 36a7417 commit 7633882
Showing 1 changed file with 7 additions and 3 deletions.
10 changes: 7 additions & 3 deletions deeprank/learn/NeuralNet.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,9 @@

class NeuralNet():

def __init__(self, data_set, model, class_weights=torch.FloatTensor([1,1]),
def __init__(self, data_set, model,
model_type='3d', proj2d=0, task='reg',
class_weights = None,
pretrained_model=None,
cuda=False, ngpu=0,
plot=True,
Expand Down Expand Up @@ -59,6 +60,10 @@ def __init__(self, data_set, model, class_weights=torch.FloatTensor([1,1]),
The loss function, the target datatype and plot functions
will be autmatically adjusted depending on the task.
class_weights (Tensor): a manual rescaling weight given to
each class. If given, has to be a Tensor of size #classes.
Only applicable on 'class' task.
pretrained_model (str): Saved model to be used for further
training or testing.
Expand Down Expand Up @@ -107,7 +112,7 @@ def __init__(self, data_set, model, class_weights=torch.FloatTensor([1,1]),

# pretrained model
self.pretrained_model = pretrained_model

self.class_weights = class_weights

if isinstance(self.data_set, (str, list)) and pretrained_model is None:
Expand Down Expand Up @@ -171,7 +176,6 @@ def __init__(self, data_set, model, class_weights=torch.FloatTensor([1,1]),
self._plot_scatter = self._plot_scatter_reg

elif self.task == 'class':
#self.criterion = nn.CrossEntropyLoss(reduction='sum')
self.criterion = nn.CrossEntropyLoss(weight = self.class_weights, reduction='mean')
self._plot_scatter = self._plot_boxplot_class
self.data_set.normalize_targets = False
Expand Down

0 comments on commit 7633882

Please sign in to comment.