In [1]:
import torch
from torch import nn, Tensor
import torch.nn.functional as F
from torch.optim import Optimizer, Adam
from typing import Optional,Callable,Any
from torchvision.models.resnet import ResNet,_resnet,conv1x1,conv3x3
from torchvision.models import WeightsEnum
import torch
from torch import nn
import torch.nn.functional as F
from torchmetrics.functional.classification.accuracy import accuracy
import pytorch_lightning as pl
import numpy as np
import logging
import sys
sys.path.append("C:\Files\Github Repo\TrainLPS\learn_poly_sampling")
from layers import get_pool_method
from layers.polydown import set_pool

  np.object,


In [2]:
class AbstractBaseClassifierModel(pl.LightningModule):
    """Abstract class classifiers.
    Reusable code for clasifier models.
    To make new classifiers, just implement the initializer and the forward method
    """

    def __init__(self,
                 optimizer=None, optimizer_kwargs={},
                 scheduler=None, scheduler_kwargs={},
                 param_scheduler=None, param_scheduler_kwargs={},
                 warmup_epochs=0, eval_mode='class_accuracy',
                 shift_seed=7,shift_max=None,
                 shift_samples=None,shift_patch_size=None
    ):
        super().__init__()
        self.optimizer_fn = optimizer if optimizer is not None else torch.optim.Adam
        self.optimizer_kwargs = optimizer_kwargs
        self.scheduler_fn = scheduler
        self.scheduler_kwargs = scheduler_kwargs
        self.param_scheduler_fn = param_scheduler
        self.param_scheduler_kwargs = param_scheduler_kwargs
        self.warmup_epochs = warmup_epochs

        # Evaluation settings
        self.eval_mode = eval_mode
        self.shift_seed = shift_seed
        self.shift_max = shift_max
        self.shift_samples = shift_samples
        self.shift_patch_size = shift_patch_size
        logging.info(f'Evaluation mode: {self.eval_mode}')

    # logic for a single training step
    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.nll_loss(logits, y)

        # training metrics
        preds = torch.argmax(logits, dim=1)
        acc = accuracy(preds, y)
        self.log('train_loss', loss, on_step=True, on_epoch=True, logger=True)
        self.log('train_acc', acc, on_step=True, on_epoch=True, logger=True)
        return loss

    # logic for a single validation step
    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.nll_loss(logits, y)

        # validation metrics
        preds = torch.argmax(logits, dim=1)
        acc = accuracy(preds, y)
        self.log('val_loss',loss,prog_bar=True,sync_dist=True)
        self.log('val_acc',acc,prog_bar=True,sync_dist=True)
        return loss

    def on_test_start(self) -> None:
        """Called when the test begins."""
        np.random.seed(self.shift_seed)

    # logic for a single testing step
    def test_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.nll_loss(logits, y)

        # test metrics
        if self.eval_mode=='class_accuracy':
            preds = torch.argmax(logits,dim=1)
            acc = accuracy(preds,y)
            self.log('test_loss',loss,prog_bar=True,sync_dist=True)
            self.log('test_acc',acc,prog_bar=True,sync_dist=True)
        elif self.eval_mode=='shift_consistency':
            self.eval_consistency_step(batch,batch_idx,'test','shift')
            self.eval_consistency_step(batch,batch_idx,'test','circular')
        return loss

    def eval_consistency_step(self, batch, batch_idx,
                              val_mode='test', mode='shift'):
        assert mode in ['shift', 'circular']
        x,_ = batch

        # Read dataset from 'inc_conv1_support'
        if self.inc_conv1_support: dataset = 'imagenet'
        else: dataset = 'cifar10'

        if dataset=='imagenet':
            outputs = []
            if mode == 'shift':
                # Pass shifted inputs
                offsets = [np.random.randint(self.shift_max,size=2) for j in range(0,self.shift_samples)]
                for j in range(0,self.shift_samples):
                    outputs.append(self(x[:,:,offsets[j][0]:offsets[j][0]+self.shift_patch_size,offsets[j][1]:offsets[j][1]+self.shift_patch_size]))
                # Compute consistency
                cur_agree = self.agreement(outputs,self.shift_samples).type(torch.FloatTensor).to(outputs[0].device)
            elif mode == 'circular':
                # Pass rolled inputs
                # -max to max for comparison purposes
                offsets = [np.random.randint(-self.shift_max,self.shift_max,size=2) for j in range(0,self.shift_samples)]
                for j in range(0,self.shift_samples):
                    outputs.append(self(torch.roll(x,shifts=(offsets[j][0],offsets[j][1]),dims=(2,3))))
                # Compute consistency
                cur_agree = self.agreement(outputs,self.shift_samples).type(torch.FloatTensor).to(outputs[0].device)
        elif dataset=='cifar10':
            max_shift = 3
            random_shift1 = torch.randint(-max_shift, max_shift, (2,))
            random_shift2 = torch.randint(-max_shift, max_shift, (2,))
            if mode == 'shift':
                pad_lengths = (max_shift, max_shift, max_shift, max_shift)
                i1_l, i1_r = random_shift1[0] + max_shift, random_shift1[0]-max_shift
                j1_l, j1_r = random_shift1[1] + max_shift, random_shift1[1]-max_shift
                i2_l, i2_r = random_shift2[0] + max_shift, random_shift2[0]-max_shift
                j2_l, j2_r = random_shift2[1] + max_shift, random_shift2[1]-max_shift
                shifted_x1 = F.pad(x, pad_lengths)[:, :, i1_l:i1_r, j1_l:j1_r ]
                shifted_x2 = F.pad(x, pad_lengths)[:, :, i2_l:i2_r, j2_l:j2_r ]
            elif mode == 'circular':
                shifted_x1 = torch.roll(x, shifts = (random_shift1[0], random_shift1[1]), dims = (2, 3))
                shifted_x2 = torch.roll(x, shifts = (random_shift2[0], random_shift2[1]), dims = (2, 3))

            # Compute consistency
            shifted_preds1 = torch.argmax(self(shifted_x1),1)
            shifted_preds2 = torch.argmax(self(shifted_x2),1)
            cur_agree = accuracy(shifted_preds1,shifted_preds2)
        self.log('%s_%s_consistency' % (val_mode, mode), cur_agree, prog_bar=True)
        return

    # (Core) Compute consistency
    def agreement(self,outputs,robust_num):
        preds = torch.stack([output.argmax(dim=1,keepdim=False) for output in outputs], dim=0)
        similarity = torch.sum((preds == preds[0:1,:]).int(), dim=0)
        agree = 100*torch.mean((similarity == robust_num).float())
        return agree

    def optimizer_step( self,epoch,batch_idx,
                        optimizer,optimizer_idx,optimizer_closure,
                        on_tpu,using_native_amp,using_lbfgs):
        """Called when the train epoch begins."""
        # Lr warmup
        if self.current_epoch < self.warmup_epochs:
            it_curr = self.trainer.num_training_batches*self.current_epoch+1+batch_idx
            it_max = self.trainer.num_training_batches*self.warmup_epochs
            lr_scale = float(it_curr) / it_max
            for pg in self.trainer.optimizers[0].param_groups:
                pg['lr'] = lr_scale * self.learning_rate

        # Update params
        optimizer.step(closure=optimizer_closure)
        optimizer.zero_grad()


    def configure_optimizers(self):
        logging.info(f'Configuring optimizer: {self.optimizer_fn} with {self.optimizer_kwargs}')

        # Filter net/pool parameters
        net_pars, pool_pars = [], []
        for n,p in self.named_parameters():
            if "component_selection" in n and p.requires_grad_:
                pool_pars.append(p)
            elif p.requires_grad_:
                net_pars.append(p)

        # Set opt
        # Start lr at 0 if warmup #moved to WarmUpScheduler
        #_lr = self.learning_rate
        #print('*************************************************************_lr',_lr)
        _lr= 0 if self.warmup_epochs!=0 else self.learning_rate
        #optimizer = self.optimizer_fn([{'params': net_pars,
        #                                'lr': _lr},
        #                               {'params': pool_pars,
        #                                'weight_decay': 0,
        #                                'lr': _lr}],
        #                              lr=_lr,
        #                              **self.optimizer_kwargs)
        optimizer = self.optimizer_fn([{'params': net_pars},
                                       {'params': pool_pars,'weight_decay': 0}],
                                      lr=_lr,
                                      **self.optimizer_kwargs)

        if self.scheduler_fn is None:
            return optimizer

        logging.info(f'Configuring lr scheduler: {self.scheduler_fn} with {self.scheduler_kwargs}')
        if isinstance(self.scheduler_fn, list):
            schedulers = [sch_fn(optimizer, **sch_kwargs) for sch_fn, sch_kwargs in zip(self.scheduler_fn, self.scheduler_kwargs)]
        else:
            schedulers = [
                {'scheduler': self.scheduler_fn(optimizer, **self.scheduler_kwargs),
                 'frequency': 1,
                 'name': 'main_lr_scheduler'
                 }]

        #if self.warmup_epochs > 0:
        #    schedulers.append({'scheduler': WarmupScheduler(optimizer,
        #                                                    warmup_steps=self.warmup_epochs),
        #                       'frequency': 1,
        #                       'name': 'warmup_scaled_lr'})

        return [optimizer], schedulers

    def configure_callbacks(self):
        if self.param_scheduler_fn is not None:
            logging.info(f'Configuring tau scheduler: {self.param_scheduler_fn} with {self.param_scheduler_kwargs}')
            return [self.param_scheduler_fn('gumbel_tau', **self.param_scheduler_kwargs)]
        return super().configure_callbacks()

In [3]:
# def get_pool_method(name, FLAGS):
#     #different pool methods uses different flags, this needs cleaning up
#     _available_pool_methods = ('max_2_norm', 'LPS', 'avgpool', 'Decimation', 'skip')
#     assert name in _available_pool_methods
#     antialias_layer = get_antialias(antialias_mode=FLAGS.antialias_mode,
#                                     antialias_size=FLAGS.antialias_size,
#                                     antialias_padding=FLAGS.antialias_padding,
#                                     antialias_padding_mode=FLAGS.antialias_padding_mode,
#                                     antialias_group=FLAGS.antialias_group)
#     pool_method = {
#         'max_2_norm': partial(
#             PolyphaseInvariantDown2D,
#             component_selection=max_p_norm,
#             antialias_layer=antialias_layer,
#             selection_noantialias=FLAGS.selection_noantialias,
#         ),
#         'LPS': partial(
#             PolyphaseInvariantDown2D,
#             component_selection= LPS,
#             get_logits=get_logits_model(FLAGS.logits_model),
#             logits_pad=FLAGS.LPS_pad,
#             comp_fix_train=FLAGS.LPS_gumbel,
#             comp_train_convex=FLAGS.LPS_train_convex,
#             comp_convex=FLAGS.LPS_convex,
#             antialias_layer=antialias_layer,
#             selection_noantialias=FLAGS.selection_noantialias,
#         ),
#         'Decimation': partial(
#             Decimation,
#             stride=FLAGS.pool_k,
#             antialias_layer=antialias_layer,
#         ),
#         'avgpool': partial(
#             nn.AvgPool2d,
#             kernel_size=FLAGS.pool_k
#         ),
#         'skip': None
#     }
#     return pool_method[name]

In [4]:
# Circular padding (class)
class cpad(nn.Module):
  def __init__(self, pad):
    super(cpad,self).__init__()
    self.pad = pad
  def forward(self,x):
    return F.pad(x, pad = self.pad, mode = 'circular')
  def extra_repr(self):
    return ("pad={pad}".format(pad = self.pad))

In [5]:
# Replace and initialize conv
def replace_conv(in_ch,out_ch,kernel_size,
                 padding,padding_mode,init,
                 bias=False,stride=1):
  c=nn.Conv2d(in_channels=in_ch,
              out_channels=out_ch,
              kernel_size=kernel_size,
              padding=padding,
              padding_mode=padding_mode,
              bias=bias,
              stride=stride)
  if init:
    nn.init.kaiming_normal_(c.weight,
                            mode='fan_out',
                            nonlinearity='relu')
  return c

In [6]:
# Replace and initialize pool
def replace_pool(p,in_ch,out_ch,
                 kernel_size,padding,padding_mode,
                 init,bn,swap_conv_pool=False):
  # Conv
  c=nn.Conv2d(in_ch,
              out_ch,
              kernel_size=kernel_size,
              padding=padding,
              padding_mode=padding_mode,
              bias=False)
  if init:
    # Kaiming init.
    nn.init.kaiming_normal_(c.weight,
                            mode='fan_out',
                            nonlinearity='relu')

  if bn:
    # Include BN
    b=nn.BatchNorm2d(out_ch)
    if init:
      # Constant init.
      nn.init.constant_(b.weight,1)
      nn.init.constant_(b.bias,0)
    if swap_conv_pool: s=nn.Sequential(c,b,p) # Pool applied last
    else: s=nn.Sequential(p,c,b)
  else:
    if swap_conv_pool: s=nn.Sequential(c,p)
    else: s=nn.Sequential(p,c)
  return s

In [36]:
class BottleneckLiteCustom(nn.Module):
    # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2)
    # while original implementation places the stride at the first 1x1 convolution(self.conv1)
    # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385.
    # This variant is also known as ResNet V1.5 and improves accuracy according to
    # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch.

    expansion: int = 2

    def __init__(
        self,
        inplanes: int,
        planes: int,
        stride: int = 1,
        downsample: Optional[nn.Module] = None,
        groups: int = 1,
        base_width: int = 64,
        dilation: int = 1,
        norm_layer: Optional[Callable[..., nn.Module]] = None,
    ) -> None:
        super(BottleneckLiteCustom, self).__init__()
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        width = int(planes * (base_width / 64.0 / 2.0)) * groups
        # Both self.conv2 and self.downsample layers downsample the input when stride != 1
        self.conv1 = conv1x1(inplanes, width)
        self.bn1 = norm_layer(width)
        self.conv2 = conv3x3(width, width, stride, groups, dilation)
        self.bn2 = norm_layer(width)
        self.conv3 = conv1x1(width, planes * self.expansion)
        self.bn3 = norm_layer(planes * self.expansion)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x: Tensor, global_ret_prob=False) -> Tensor:
        # TODO: Generalize to any sequential length
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        if self.stride>1 and self.ret_prob:
          # Return feats and prob
          if self.swap_conv_pool:
            # conv->pool
            out=self.conv3[0](out)
            out,_p=self.conv3[1](x=out,ret_prob=True)
          else:
            # pool->conv
            out,_p=self.conv3[0](x=out,ret_prob=True)
            out=self.conv3[1](out)
          out=self.bn3(out)
          if self.downsample is not None:
            # Pass original input and prob to downsample
            if self.forward_pool_method=="LPS" and self.training:
              # Train: If LPS, prob is first element of tuple
              p = _p[0] 
            else:
              p = _p
            identity=self.downsample[0](x=x,prob=p)
            identity=self.downsample[1](identity)
            identity=self.downsample[2](identity)
        else:
          # Original pipeline
          out = self.conv3(out)
          out = self.bn3(out)
          if self.downsample is not None:
            identity = self.downsample(x)

        out += identity
        out = self.relu(out)
        if global_ret_prob:
          # Train: Return feats and probability-logits tuple
          # Test: return feats and logits
          return out,_p
        return out

In [37]:
# Core ResNet50, fixed shortcut via BottleneckCustom
def resnet50_fs(pretrained: WeightsEnum = None, progress: bool = True, **kwargs: Any) -> ResNet:
    return _resnet(BottleneckLiteCustom, [3, 4, 6, 3], pretrained, progress,
                   **kwargs)

In [38]:
# Custom ResNet50 (ImageNet)
class ResNet50LiteCustom(AbstractBaseClassifierModel):
  def __init__(self,input_shape,num_classes,
               padding_mode='zeros',learning_rate=0.1,pooling_layer=nn.AvgPool2d,
               extras_model=None,**kwargs):
    super().__init__(**kwargs)
 
    # log hyperparameters
    self.save_hyperparameters()
    self.learning_rate=learning_rate

    # Model-specific extras
    self.logits_channels=extras_model['logits_channels']
    self.conv1_stride=extras_model['conv1_stride']
    self.maxpool_zpad=extras_model['maxpool_zpad']
    self.swap_conv_pool=extras_model['swap_conv_pool']
    self.inc_conv1_support=extras_model['inc_conv1_support']
    self.apply_maxpool=extras_model['apply_maxpool']
    self.ret_prob=extras_model['ret_prob']
    self.forward_pool_method = extras_model['forward_pool_method'] if 'forward_pool_method' in extras_model.keys()\
      else 'LPS'

    # ResNet50 model with fixed shortcut
    self.core=resnet50_fs()

    # Modify Conv2d padding attribute
    for layer in self.core.modules():
      if isinstance(layer,nn.Conv2d):
        layer.padding_mode=padding_mode

    # Pass extras to BasicBlock
    for i in range(len(self.core.layer1)):
      self.core.layer1[i].ret_prob= self.ret_prob
      self.core.layer1[i].swap_conv_pool= self.swap_conv_pool
      self.core.layer1[i].forward_pool_method = self.forward_pool_method
    for i in range(len(self.core.layer2)):
      self.core.layer2[i].ret_prob= self.ret_prob
      self.core.layer2[i].swap_conv_pool= self.swap_conv_pool
      self.core.layer2[i].forward_pool_method = self.forward_pool_method
    for i in range(len(self.core.layer3)):
      self.core.layer3[i].ret_prob= self.ret_prob
      self.core.layer3[i].swap_conv_pool= self.swap_conv_pool
      self.core.layer3[i].forward_pool_method = self.forward_pool_method
    for i in range(len(self.core.layer4)):
      self.core.layer4[i].ret_prob= self.ret_prob
      self.core.layer4[i].swap_conv_pool= self.swap_conv_pool
      self.core.layer4[i].forward_pool_method = self.forward_pool_method

    if pooling_layer is None:
      # Keep original pool
      pass
    else:
      # Replace pool
      # Logits model channels
      if self.logits_channels:
        maxpool_h_ch = self.logits_channels["maxpool"]
        layer2_h_ch = self.logits_channels["layer2"]
        layer3_h_ch = self.logits_channels["layer3"]
        layer4_h_ch = self.logits_channels["layer4"]
      else:
        maxpool_h_ch = 64
        layer2_h_ch = 128
        layer3_h_ch = 256
        layer4_h_ch = 512

      if self.inc_conv1_support:
        # ImageNet/Imagenette: Update conv1 stride
        conv1_stride=2 if self.conv1_stride else 1
        self.core.conv1=replace_conv(in_ch=3,
                                     out_ch=64,
                                     kernel_size=7,
                                     padding=3,
                                     padding_mode=padding_mode,
                                     stride=conv1_stride,
                                     init=True)

      if self.apply_maxpool:
        # ImageNet/Imagenette: Replace maxpool stride by custom pool
        _maxpool = []
        if self.conv1_stride:
          # Conv1 stride applied already
          pass
        else:
          # Replace conv1 stride by custom pool
          _maxpool.append(set_pool(pooling_layer=pooling_layer,
                                   p_ch=64,
                                   h_ch=maxpool_h_ch,
                                   no_antialias=True))
        if self.maxpool_zpad:
          _maxpool.append(nn.ZeroPad2d((0,1,0,1)))
        else:
          _maxpool.append(cpad(pad=[0,1,0,1]))
        _maxpool.append(nn.MaxPool2d(kernel_size=2,
                        stride=1))
        _maxpool.append(set_pool(pooling_layer=pooling_layer,
                        p_ch=64,
                        h_ch=maxpool_h_ch))
        self.core.maxpool=nn.Sequential(*_maxpool)

      # Replace stride [layer2, layer3, layer4]
      # Set main branch pool
      p2_1=set_pool(pooling_layer=pooling_layer,
                    p_ch=128,
                    h_ch=layer2_h_ch)
      p3_1=set_pool(pooling_layer=pooling_layer,
                    p_ch=256,
                    h_ch=layer3_h_ch)
      p4_1=set_pool(pooling_layer=pooling_layer,
                    p_ch=512,
                    h_ch=layer4_h_ch)

      # Replace and init. layers
      self.core.layer2[0].conv2=replace_conv(in_ch=64,
                                             out_ch=64,
                                             kernel_size=3,
                                             padding=1,
                                             padding_mode=padding_mode,
                                             init=True)
      self.core.layer2[0].conv3=replace_pool(p=p2_1,
                                             in_ch=128,
                                             out_ch=256,
                                             kernel_size=1,
                                             padding=0,
                                             padding_mode=padding_mode,
                                             swap_conv_pool= self.swap_conv_pool,
                                             init=True,
                                             bn=False)
      self.core.layer3[0].conv2=replace_conv(in_ch=128,
                                             out_ch=128,
                                             kernel_size=3,
                                             padding=1,
                                             padding_mode=padding_mode,
                                             init=True)
      self.core.layer3[0].conv3=replace_pool(p=p3_1,
                                             in_ch=256,
                                             out_ch=512,
                                             kernel_size=1,
                                             padding=0,
                                             padding_mode=padding_mode,
                                             swap_conv_pool= self.swap_conv_pool,
                                             init=True,
                                             bn=False)
      self.core.layer4[0].conv2=replace_conv(in_ch=256,
                                             out_ch=256,
                                             kernel_size=3,
                                             padding=1,
                                             padding_mode=padding_mode,
                                             init=True)
      self.core.layer4[0].conv3=replace_pool(p=p4_1,
                                             in_ch=512,
                                             out_ch=1024,
                                             kernel_size=1,
                                             padding=0,
                                             padding_mode=padding_mode,
                                             swap_conv_pool= self.swap_conv_pool,
                                             init=True,
                                             bn=False)

      # Set shortcut branch pool
      # No component selection, indices precomputed
      # https://github.com/pytorch/vision/blob/863e904e4165fe42950c355325a93198d56e4271/torchvision/models/resnet.py#L78
      p2_2=set_pool(pooling_layer=pooling_layer,
                    p_ch=256,
                    use_get_logits=False)
      p3_2=set_pool(pooling_layer=pooling_layer,
                    p_ch=512,
                    use_get_logits=False)
      p4_2=set_pool(pooling_layer=pooling_layer,
                    p_ch=1024,
                    use_get_logits=False)

      # Replace and init. layers
      # Ksize=1, no padding required
      self.core.layer2[0].downsample=replace_pool(p=p2_2,
                                                  in_ch=256,
                                                  out_ch=256,
                                                  kernel_size=1,
                                                  padding=0,
                                                  padding_mode=padding_mode,
                                                  init=True,
                                                  bn=True)
      self.core.layer3[0].downsample=replace_pool(p=p3_2,
                                                  in_ch=512,
                                                  out_ch=512,
                                                  kernel_size=1,
                                                  padding=0,
                                                  padding_mode=padding_mode,
                                                  init=True,
                                                  bn=True)
      self.core.layer4[0].downsample=replace_pool(p=p4_2,
                                                  in_ch=1024,
                                                  out_ch=1024,
                                                  kernel_size=1,
                                                  padding=0,
                                                  padding_mode=padding_mode,
                                                  init=True,
                                                  bn=True)

    # Replace head
    self.core.fc= nn.Linear(1024,num_classes)

  def forward(self,x):
    out=self.core(x)
    out=F.log_softmax(out,dim=1)
    return out

In [39]:
extras_model = {
    'maxpool_zpad': True,
    'swap_conv_pool': False,
    'inc_conv1_support': True,
    'apply_maxpool': True,
    'ret_prob': True,
    'logits_channels': None,
    'conv1_stride': False,
}

class flag_struct():
  def __init__(self):
    self.dryrun = False
    self.pool_method = 'LPS'
    self.logits_model = 'LPSLogitLayers'
    self.swap_conv_pool = False
    self.maxpool_zpad = False
    self.LPS_pad = 'circular'
    self.LPS_debug = False
    self.LPS_convex = False
    self.pool_k = 2
    self.antialias_mode = 'LowPassFilter'
    self.antialias_size = 3
    self.antialias_padding = 'same'
    self.antialias_padding_mode = "circular"
    self.antialias_group = 8
    self.selection_noantialias = False
    self.LPS_gumbel = False
    self.LPS_train_convex = False


FLAGS = flag_struct()
pool_layer = get_pool_method(FLAGS.pool_method, FLAGS)
model = ResNet50LiteCustom(224, 1000, extras_model = extras_model, pooling_layer=pool_layer)

In [40]:
model(torch.rand(2,3,224,224))

tensor([[-7.2343, -7.2820, -6.7135,  ..., -7.2695, -7.5290, -7.1022],
        [-7.0196, -7.3989, -7.0052,  ..., -7.6473, -6.9428, -6.9855]],
       grad_fn=<LogSoftmaxBackward0>)

In [41]:
total_params = sum(p.numel() for p in model.parameters())
print(total_params)

14302376


In [42]:
model

ResNet50LiteCustom(
  (core): ResNet(
    (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (maxpool): Sequential(
      (0): PolyphaseInvariantDown2D(
        (component_selection): LPS(
          (get_logits): LPSLogitLayersV2(
            (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=same, padding_mode=circular)
            (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=same, padding_mode=circular)
            (relu): ReLU()
          )
        )
      )
      (1): ZeroPad2d((0, 1, 0, 1))
      (2): MaxPool2d(kernel_size=2, stride=1, padding=0, dilation=1, ceil_mode=False)
      (3): PolyphaseInvariantDown2D(
        (component_selection): LPS(
          (get_logits): LPSLogitLayersV2(
            (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=same, padding_mode=c