diff --git a/lumin/nn/models/blocks/conv_blocks.py b/lumin/nn/models/blocks/conv_blocks.py index 40c86e0..71f1068 100644 --- a/lumin/nn/models/blocks/conv_blocks.py +++ b/lumin/nn/models/blocks/conv_blocks.py @@ -327,6 +327,9 @@ def __init__(self, n_in:int, r:int, act:str='relu', lookup_init:Callable[[str,Op super().__init__() self.n_in,self.r,self.act,self.lookup_init,self.lookup_act = n_in,r,act,lookup_init,lookup_act self.layers = self._get_layers() + self._set_pooling() + + def _set_pooling(self) -> None: self.sz = [1] self.pool = nn.AdaptiveAvgPool1d(self.sz) @@ -370,11 +373,7 @@ class SEBlock2d(SEBlock1d): lookup_act: function taking choice of activation function and returning an activation function layer ''' - def __init__(self, n_in:int, r:int, act:str='relu', lookup_init:Callable[[str,Optional[int],Optional[int]],Callable[[Tensor],None]]=lookup_normal_init, - lookup_act:Callable[[str],Any]=lookup_act): - super().__init__() - self.n_in,self.r,self.act,self.lookup_init,self.lookup_act = n_in,r,act,lookup_init,lookup_act - self.layers = self._get_layers() + def _set_pooling(self) -> None: self.sz = [1,1] self.pool = nn.AdaptiveAvgPool2d(self.sz) @@ -393,10 +392,6 @@ class SEBlock3d(SEBlock1d): lookup_act: function taking choice of activation function and returning an activation function layer ''' - def __init__(self, n_in:int, r:int, act:str='relu', lookup_init:Callable[[str,Optional[int],Optional[int]],Callable[[Tensor],None]]=lookup_normal_init, - lookup_act:Callable[[str],Any]=lookup_act): - super().__init__() - self.n_in,self.r,self.act,self.lookup_init,self.lookup_act = n_in,r,act,lookup_init,lookup_act - self.layers = self._get_layers() + def _set_pooling(self) -> None: self.sz = [1,1,1] self.pool = nn.AdaptiveAvgPool3d(self.sz)