diff --git a/CHANGES.md b/CHANGES.md index 21f31d6..045f9b3 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -10,6 +10,10 @@ ## Fixes +- Potential bug in convolutional models where checking the out size of the head would affect the batchnorm averaging +- Potential bug in `plot_sample_pred` to do with bin ranges + + ## Changes ## Depreciations diff --git a/examples/RNNs_CNNs_and_GNNs_for_matrix_data.ipynb b/examples/RNNs_CNNs_and_GNNs_for_matrix_data.ipynb index 6ee1744..390cd14 100644 --- a/examples/RNNs_CNNs_and_GNNs_for_matrix_data.ipynb +++ b/examples/RNNs_CNNs_and_GNNs_for_matrix_data.ipynb @@ -4290,7 +4290,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.3" + "version": "3.6.5" }, "toc": { "base_numbering": 1, diff --git a/lumin/nn/callbacks/abs_callback.py b/lumin/nn/callbacks/abs_callback.py index 2e350c8..87dbf9a 100644 --- a/lumin/nn/callbacks/abs_callback.py +++ b/lumin/nn/callbacks/abs_callback.py @@ -5,19 +5,19 @@ class AbsCallback(ABC): '''Abstract callback class for typing''' - def __init__(self): pass - def set_model(self, **kargs): pass - def set_plot_settings(self, **kargs): pass - def on_train_begin(self, **kargs): pass - def on_train_end(self, **kargs): pass - def on_epoch_begin(self, **kargs): pass - def on_epoch_end(self, **kargs): pass - def on_batch_begin(self, **kargs): pass - def on_batch_end(self, **kargs): pass - def on_eval_begin(self, **kargs): pass - def on_eval_end(self, **kargs): pass - def on_backwards_begin(self, **kargs): pass - def on_backwards_end(self, **kargs): pass - def on_pred_begin(self, **kargs): pass - def on_pred_end(self, **kargs): pass + def __init__(self): pass + def set_model(self, **kargs): pass + def set_plot_settings(self, **kargs): pass + def on_train_begin(self, **kargs): pass + def on_train_end(self, **kargs): pass + def on_epoch_begin(self, **kargs): pass + def on_epoch_end(self, **kargs): pass + def on_batch_begin(self, **kargs): pass + def on_batch_end(self, **kargs): pass + def on_eval_begin(self, **kargs): pass + def on_eval_end(self, **kargs): pass + def on_backwards_begin(self, **kargs): pass + def on_backwards_end(self, **kargs): pass + def on_pred_begin(self, **kargs): pass + def on_pred_end(self, **kargs): pass \ No newline at end of file diff --git a/lumin/nn/models/blocks/head.py b/lumin/nn/models/blocks/head.py index 6d2d279..dca0d68 100644 --- a/lumin/nn/models/blocks/head.py +++ b/lumin/nn/models/blocks/head.py @@ -685,7 +685,10 @@ def check_out_sz(self) -> int: ''' x = torch.rand((1, self.n_fpv,self.n_v)) + training = self.training + self.eval() x = self.forward(x) + if training: self.train() return x.size(-1) def get_conv1d_block(self, in_c:int, out_c:int, kernel_sz:int, padding:Union[int,str]='auto', stride:int=1,act:str='relu', bn:bool=False) -> Conv1DBlock: @@ -941,7 +944,10 @@ def check_out_sz(self) -> int: ''' x = torch.rand((1,self.n_v, self.n_fpv)) + training = self.training + self.eval() x = self.forward(x) + if training: self.train() return x.size(-1) diff --git a/lumin/nn/models/initialisations.py b/lumin/nn/models/initialisations.py index b3bb981..fd1dc29 100644 --- a/lumin/nn/models/initialisations.py +++ b/lumin/nn/models/initialisations.py @@ -21,14 +21,14 @@ def lookup_normal_init(act:str, fan_in:Optional[int]=None, fan_out:Optional[int] Callable to initialise weight tensor ''' - if act == 'relu': return partial(nn.init.kaiming_normal_, nonlinearity='relu') - if act == 'prelu': return partial(nn.init.kaiming_normal_, nonlinearity='relu') + if act == 'relu': return partial(nn.init.kaiming_normal_, nonlinearity='relu', a=0) + if act == 'prelu': return partial(nn.init.kaiming_normal_, nonlinearity='relu', a=0) if act == 'selu': return partial(nn.init.normal_, std=1/np.sqrt(fan_in)) if act == 'sigmoid': return nn.init.xavier_normal_ if act == 'logsoftmax': return nn.init.xavier_normal_ if act == 'softmax': return nn.init.xavier_normal_ if act == 'linear': return nn.init.xavier_normal_ - if 'swish' in act: return partial(nn.init.kaiming_normal_, nonlinearity='relu') + if 'swish' in act: return partial(nn.init.kaiming_normal_, nonlinearity='relu', a=0) raise ValueError("Activation not implemented") @@ -45,13 +45,13 @@ def lookup_uniform_init(act:str, fan_in:Optional[int]=None, fan_out:Optional[int Callable to initialise weight tensor ''' - if act == 'relu': return partial(nn.init.kaiming_uniform_, nonlinearity='relu') - if act == 'prelu': return partial(nn.init.kaiming_uniform_, nonlinearity='relu') + if act == 'relu': return partial(nn.init.kaiming_uniform_, nonlinearity='relu', a=0) + if act == 'prelu': return partial(nn.init.kaiming_uniform_, nonlinearity='relu', a=0) if act == 'selu': return partial(nn.init.uniform_, a=-1/np.sqrt(fan_in), b=1/np.sqrt(fan_in)) if act == 'sigmoid': return nn.init.xavier_uniform_ if act == 'logsoftmax': return nn.init.xavier_uniform_ if act == 'softmax': return nn.init.xavier_uniform_ if act == 'linear': return nn.init.xavier_uniform_ - if 'swish' in act: return partial(nn.init.kaiming_uniform_, nonlinearity='relu') + if 'swish' in act: return partial(nn.init.kaiming_uniform_, nonlinearity='relu', a=0) raise ValueError("Activation not implemented") \ No newline at end of file diff --git a/lumin/plotting/results.py b/lumin/plotting/results.py index 947b1a8..48fba86 100644 --- a/lumin/plotting/results.py +++ b/lumin/plotting/results.py @@ -183,9 +183,9 @@ def plot_sample_pred(df:pd.DataFrame, pred_name:str='pred', targ_name:str='gen_t settings: :class:`~lumin.plotting.plot_settings.PlotSettings` class to control figure appearance ''' - hist_params = {'range': lim_x, 'bins': bins, 'density': density, 'alpha': 0.8, 'stacked':True, 'rwidth':1.0} sig,bkg = (df[targ_name] == 1),(df[targ_name] == 0) if not isinstance(bins,list): bins = np.linspace(df[pred_name].min(),df[pred_name].max(), bins if isinstance(bins, int) else 10) + hist_params = {'range': lim_x, 'bins': bins, 'density': density, 'alpha': 0.8, 'stacked':True, 'rwidth':1.0} sig_samples = _get_samples(df[sig], sample_name, wgt_name) bkg_samples = _get_samples(df[bkg], sample_name, wgt_name) sample2col = {k: v for v, k in enumerate(bkg_samples)} if settings.sample2col is None else settings.sample2col