Skip to content

Commit

Permalink
A few fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
GilesStrong committed Oct 6, 2020
1 parent fb54f6f commit 5f35ba8
Show file tree
Hide file tree
Showing 6 changed files with 33 additions and 23 deletions.
4 changes: 4 additions & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@

## Fixes

- Potential bug in convolutional models where checking the out size of the head would affect the batchnorm averaging
- Potential bug in `plot_sample_pred` to do with bin ranges


## Changes

## Depreciations
Expand Down
2 changes: 1 addition & 1 deletion examples/RNNs_CNNs_and_GNNs_for_matrix_data.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -4290,7 +4290,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.3"
"version": "3.6.5"
},
"toc": {
"base_numbering": 1,
Expand Down
30 changes: 15 additions & 15 deletions lumin/nn/callbacks/abs_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,19 @@

class AbsCallback(ABC):
'''Abstract callback class for typing'''
def __init__(self): pass
def set_model(self, **kargs): pass
def set_plot_settings(self, **kargs): pass
def on_train_begin(self, **kargs): pass
def on_train_end(self, **kargs): pass
def on_epoch_begin(self, **kargs): pass
def on_epoch_end(self, **kargs): pass
def on_batch_begin(self, **kargs): pass
def on_batch_end(self, **kargs): pass
def on_eval_begin(self, **kargs): pass
def on_eval_end(self, **kargs): pass
def on_backwards_begin(self, **kargs): pass
def on_backwards_end(self, **kargs): pass
def on_pred_begin(self, **kargs): pass
def on_pred_end(self, **kargs): pass
def __init__(self): pass
def set_model(self, **kargs): pass
def set_plot_settings(self, **kargs): pass
def on_train_begin(self, **kargs): pass
def on_train_end(self, **kargs): pass
def on_epoch_begin(self, **kargs): pass
def on_epoch_end(self, **kargs): pass
def on_batch_begin(self, **kargs): pass
def on_batch_end(self, **kargs): pass
def on_eval_begin(self, **kargs): pass
def on_eval_end(self, **kargs): pass
def on_backwards_begin(self, **kargs): pass
def on_backwards_end(self, **kargs): pass
def on_pred_begin(self, **kargs): pass
def on_pred_end(self, **kargs): pass

6 changes: 6 additions & 0 deletions lumin/nn/models/blocks/head.py
Original file line number Diff line number Diff line change
Expand Up @@ -685,7 +685,10 @@ def check_out_sz(self) -> int:
'''

x = torch.rand((1, self.n_fpv,self.n_v))
training = self.training
self.eval()
x = self.forward(x)
if training: self.train()
return x.size(-1)

def get_conv1d_block(self, in_c:int, out_c:int, kernel_sz:int, padding:Union[int,str]='auto', stride:int=1,act:str='relu', bn:bool=False) -> Conv1DBlock:
Expand Down Expand Up @@ -941,7 +944,10 @@ def check_out_sz(self) -> int:
'''

x = torch.rand((1,self.n_v, self.n_fpv))
training = self.training
self.eval()
x = self.forward(x)
if training: self.train()
return x.size(-1)


Expand Down
12 changes: 6 additions & 6 deletions lumin/nn/models/initialisations.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,14 @@ def lookup_normal_init(act:str, fan_in:Optional[int]=None, fan_out:Optional[int]
Callable to initialise weight tensor
'''

if act == 'relu': return partial(nn.init.kaiming_normal_, nonlinearity='relu')
if act == 'prelu': return partial(nn.init.kaiming_normal_, nonlinearity='relu')
if act == 'relu': return partial(nn.init.kaiming_normal_, nonlinearity='relu', a=0)
if act == 'prelu': return partial(nn.init.kaiming_normal_, nonlinearity='relu', a=0)
if act == 'selu': return partial(nn.init.normal_, std=1/np.sqrt(fan_in))
if act == 'sigmoid': return nn.init.xavier_normal_
if act == 'logsoftmax': return nn.init.xavier_normal_
if act == 'softmax': return nn.init.xavier_normal_
if act == 'linear': return nn.init.xavier_normal_
if 'swish' in act: return partial(nn.init.kaiming_normal_, nonlinearity='relu')
if 'swish' in act: return partial(nn.init.kaiming_normal_, nonlinearity='relu', a=0)
raise ValueError("Activation not implemented")


Expand All @@ -45,13 +45,13 @@ def lookup_uniform_init(act:str, fan_in:Optional[int]=None, fan_out:Optional[int
Callable to initialise weight tensor
'''

if act == 'relu': return partial(nn.init.kaiming_uniform_, nonlinearity='relu')
if act == 'prelu': return partial(nn.init.kaiming_uniform_, nonlinearity='relu')
if act == 'relu': return partial(nn.init.kaiming_uniform_, nonlinearity='relu', a=0)
if act == 'prelu': return partial(nn.init.kaiming_uniform_, nonlinearity='relu', a=0)
if act == 'selu': return partial(nn.init.uniform_, a=-1/np.sqrt(fan_in), b=1/np.sqrt(fan_in))
if act == 'sigmoid': return nn.init.xavier_uniform_
if act == 'logsoftmax': return nn.init.xavier_uniform_
if act == 'softmax': return nn.init.xavier_uniform_
if act == 'linear': return nn.init.xavier_uniform_
if 'swish' in act: return partial(nn.init.kaiming_uniform_, nonlinearity='relu')
if 'swish' in act: return partial(nn.init.kaiming_uniform_, nonlinearity='relu', a=0)
raise ValueError("Activation not implemented")

2 changes: 1 addition & 1 deletion lumin/plotting/results.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,9 +183,9 @@ def plot_sample_pred(df:pd.DataFrame, pred_name:str='pred', targ_name:str='gen_t
settings: :class:`~lumin.plotting.plot_settings.PlotSettings` class to control figure appearance
'''

hist_params = {'range': lim_x, 'bins': bins, 'density': density, 'alpha': 0.8, 'stacked':True, 'rwidth':1.0}
sig,bkg = (df[targ_name] == 1),(df[targ_name] == 0)
if not isinstance(bins,list): bins = np.linspace(df[pred_name].min(),df[pred_name].max(), bins if isinstance(bins, int) else 10)
hist_params = {'range': lim_x, 'bins': bins, 'density': density, 'alpha': 0.8, 'stacked':True, 'rwidth':1.0}
sig_samples = _get_samples(df[sig], sample_name, wgt_name)
bkg_samples = _get_samples(df[bkg], sample_name, wgt_name)
sample2col = {k: v for v, k in enumerate(bkg_samples)} if settings.sample2col is None else settings.sample2col
Expand Down

0 comments on commit 5f35ba8

Please sign in to comment.