Skip to content

Commit

Permalink
Fix SE blocks
Browse files Browse the repository at this point in the history
  • Loading branch information
GilesStrong committed Oct 13, 2021
1 parent db49b7d commit 0db61cb
Showing 1 changed file with 5 additions and 10 deletions.
15 changes: 5 additions & 10 deletions lumin/nn/models/blocks/conv_blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,9 @@ def __init__(self, n_in:int, r:int, act:str='relu', lookup_init:Callable[[str,Op
super().__init__()
self.n_in,self.r,self.act,self.lookup_init,self.lookup_act = n_in,r,act,lookup_init,lookup_act
self.layers = self._get_layers()
self._set_pooling()

def _set_pooling(self) -> None:
self.sz = [1]
self.pool = nn.AdaptiveAvgPool1d(self.sz)

Expand Down Expand Up @@ -370,11 +373,7 @@ class SEBlock2d(SEBlock1d):
lookup_act: function taking choice of activation function and returning an activation function layer
'''

def __init__(self, n_in:int, r:int, act:str='relu', lookup_init:Callable[[str,Optional[int],Optional[int]],Callable[[Tensor],None]]=lookup_normal_init,
lookup_act:Callable[[str],Any]=lookup_act):
super().__init__()
self.n_in,self.r,self.act,self.lookup_init,self.lookup_act = n_in,r,act,lookup_init,lookup_act
self.layers = self._get_layers()
def _set_pooling(self) -> None:
self.sz = [1,1]
self.pool = nn.AdaptiveAvgPool2d(self.sz)

Expand All @@ -393,10 +392,6 @@ class SEBlock3d(SEBlock1d):
lookup_act: function taking choice of activation function and returning an activation function layer
'''

def __init__(self, n_in:int, r:int, act:str='relu', lookup_init:Callable[[str,Optional[int],Optional[int]],Callable[[Tensor],None]]=lookup_normal_init,
lookup_act:Callable[[str],Any]=lookup_act):
super().__init__()
self.n_in,self.r,self.act,self.lookup_init,self.lookup_act = n_in,r,act,lookup_init,lookup_act
self.layers = self._get_layers()
def _set_pooling(self) -> None:
self.sz = [1,1,1]
self.pool = nn.AdaptiveAvgPool3d(self.sz)

0 comments on commit 0db61cb

Please sign in to comment.