In [3]:
import torch 
from torch import nn
from torch.functional import F
from torch import Tensor

from typing import Optional, List, Tuple

In [4]:
# this code defines a custom 1d conv layer class and help to normalize the weights of the convolutional layers


class WSConv1d(nn.Conv1d):
    r"""applies a 1D convolutional over an input signal composed of several input
    planes.
    This module supports
    * :attr:`stride` controls the stride for the cross-correlation, a single
      number or a one-element tuple.
    * :attr:`padding` controls the amount of implicit zero-paddings on both sides
      for :attr:`padding` number of points.
    * :attr:`dilation` controls the spacing between the kernel points; also
      known as the à trous algorithm. It is harder to describe, but this `link`_
      has a nice visualization of what :attr:`dilation` does.
    * :attr:`groups` controls the connections between inputs and outputs.
      :attr:`in_channels` and :attr:`out_channels` must both be divisible by
      :attr:`groups`. For example,
        * At groups=1, all inputs are convolved to all outputs.
        * At groups=2, the operation becomes equivalent to having two conv
          layers side by side, each seeing half the input channels,
          and producing half the output channels, and both subsequently
          concatenated.
        * At groups= :attr:`in_channels`, each input channel is convolved with
          its own set of filters,
    """
    def __init__(self,in_channels, out_channels, kernal_size, stride=1, padding=0, dilation=1, groups=1, bais=True, padding_mode='zeros'):
        super().__init__(in_channels,out_channels, kernal_size,stride=stride, padding=padding,
                         dilation=dilation, groups=groups, bais=bais, padding_mode=padding_mode)
        nn.init.kaiming_normal_(self.weight)
        self.gain = nn.Parameter(torch.ones(
            self.weight.size()[0], requires_grad=True
        ))

    # the function help to calculate the mean and variance of the weight 
    def standardize_weights(self,eps):
        mean=torch.mean(self.weight, dim=(1,2), keepdim=True)
        var= torch.std(self.weight, dim=(1,2), keepdim=True, unbiased=False)**2
        fan_in= torch.prod(torch.tensor(self.weight.shape))
 
        #fan in : cal the product of the weight dimensions 
        scale= torch.rsqrt(torch.max(
            var * fan_in,torch.tensor(eps).to(var.device))) * self.gain.view_as(var).to(var.device)
        #scale and shift : compute the scaling factor and shift for weights standardization
        shift = mean * scale 
        return self.weight * scale - shift 
    
    def forward(self, input, eps=1e-4):
        weight=self.standardize_weights(eps)
        return F.conv1d(input,weight,self.bias,self.stride,self.padding,self.dilation,self.groups)

In [5]:
class WSCon2d(nn.Conv2d):
    """Applies a 2D convo over an input signal composed 
    of several input planes after wieght normalization"""
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, padding_mode='zeros'):
        super().__init__(in_channels, out_channels, kernel_size, stride=stride, padding=padding,
                         dilation=dilation, groups=groups, bias=bias, padding_mode=padding_mode)

        nn.init.kaiming_normal_(self.weight)
        self.gain = nn.Parameter(torch.ones(
            self.weight.size(0), requires_grad=True))
        
    def standardize_weights(self, eps):
        mean= torch.mean(self.weight, dim=(1, 2, 3), keepdim=True)
        var= torch.std(self.weight, dim=(1,2), keepdim=True, unbiased=False) ** 2
        fan_in=torch.prod(torch.tensor(self.weight.shape))

        scale=torch.rsqrt(torch.max(
            var * fan_in, torch.tensor(eps).to(var.device))) * self.gain.view_as(var).to(var.device)
        shift=mean * scale 
        return self.weight * scale - shift
    
    def forward(self, input, eps=1e-4):
        weight= self.standardize_weights(eps)
        return F.conv1d(input, weight, self.bias, self.stride, self.padding, self.dilation, self.groups)

In [6]:
class WSConvTranspose2d(nn.ConvTranspose2d):
    """Applies a 2D transposed convolution operator over an input image
    composed of several input planes after weight normalization/standardization."""
    def __init__(self,
                 in_channels: int,
                 out_channels: int,
                 kernel_size,
                 stride=1,
                 padding=0,
                 output_padding=0,
                 groups: int = 1,
                 bias: bool = True,
                 dilation: int = 1,
                 padding_mode: str = 'zeros'):
        super().__init__(in_channels, out_channels, kernel_size, stride=stride, padding=padding,
                         output_padding=output_padding, groups=groups, bias=bias, dilation=dilation, padding_mode=padding_mode)

        nn.init.kaiming_normal_(self.weight)
        self.gain = nn.Parameter(torch.ones(
            self.weight.size(0), requires_grad=True))
        
    def standardize_weights(self, eps):
        mean = torch.mean(self.weight, dims=(1,2,3), keepdim=True)
        var= torch.std(self.weight, dim=(1,2,3), keepdim=True) ** 2
        fan_in= torch.prod(torch.tensor(self.weight.shape[1:]))

        scale = torch.rsqrt(torch.max(
            var * fan_in, torch.tensor(eps).to(var.device))) * self.gain.view_as(var).to(var.device)
        shift = mean * scale
        return self.weight * scale - shift

    def forward(self, input: Tensor, output_size: Optional[List[int]] = None, eps: float = 1e-4) -> Tensor:
        weight = self.standardize_weight(eps)
        return F.conv_transpose2d(input, weight, self.bias, self.stride, self.padding, self.output_padding, self.groups, self.dilation)
    
    


In [7]:
class ScaledStdConv2d(nn.Conv2d):
    """Conv2d layer with scaled weight standardization"""
    def __init__(self, in_channels, out_channels, kernel_size, stride= 1, padding = 0, dilation = 1, groups = 1, bias= True, gain=True,gamma=1.0,eps=1e-5,use_layernorm=False) -> None:
        super().__init__(in_channels, out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias)
        self.gain = nn.Parameter(torch.ones(
            self.out_channels,1,1,1
        )) if gain else None
        self.scale = gamma * self.weight[0].numel() ** -0.5
        self.eps= eps ** 2 if use_layernorm else eps
        self.use_layernorm = use_layernorm

    def get_weight(self):
        if self.use_layernorm:
            weight = self.scale * \
                F.layer_norm(self.weight, self.weight.shape[1:], eps=self.eps)
        else:
            mean = torch.mean(
                self.weight, dim=[1, 2, 3], keepdim=True)
            std = torch.std(
                self.weight, dim=[1, 2, 3], keepdim=True, unbiased=False)
            weight = self.scale * (self.weight - mean) / (std + self.eps)
        if self.gain is not None:
            weight = weight * self.gain
        return weight

    def forward(self, x):
        return F.conv2d(x, self.get_weight(), self.bias, self.stride, self.padding, self.dilation, self.groups)

In [8]:
import torch 
from torch import nn 
from function.base import WSCon2d
import warnings

In [9]:
# help to replace Conv2d layer with a custom layers
def replace_conv(module : nn.Module, conv_class:WSCon2d):
    """Recursively replaces every convolution with WSConv2d"""
    warnings.warn("Make sure to use it with non-residual model only")
    for name, mod in module.named_children():
        target_mod= getattr(module, name)
        if type(mod) == torch.nn.Conv2d:
            setattr(module, name, conv_class(target_mod.in_channels, target_mod.out_channels, target_mod.kernel_size,
                                           target_mod.stride, target_mod.padding, target_mod.dilation, target_mod.groups, target_mod.bias is not None))
            
        if type(mod) == torch.nn.BatchNorm2d:
            setattr(module, name, torch.nn.Identity())

    for name, mod in module.named_children():
        replace_conv(mod, conv_class)


In [10]:
def unitwise_norm(x:torch.Tensor):
    if x.dim <=1:
        dim=0
        keepdim=False
    elif x.ndim in [2, 3]:
        dim = 0
        keepdim = True
    elif x.ndim == 4:
        dim = [1, 2, 3]
        keepdim = True
    else:
        raise ValueError('Wrong input dimensions')

    return torch.sum(x**2, dim=dim, keepdim=keepdim) ** 0.5

In [11]:
import torch 
from torch.optim.optimizer import Optimizer, required
from torch import optim, nn
from function.utils import unitwise_norm


In [12]:

# creating a Stochasitic Gradient Descent with momentum, weight decay, Nesterov momentum and adaptive gradient clipping
class SGD_AGC(Optimizer):
    """Implements stochastic gradient  descent """
    def __init__(self, params, lr=required, momentum=0, dampening=0,
                 weight_decay=0, nesterov=False, clipping=1e-2, eps=1e-3):
        if lr is not required and lr < 0.0:
            raise ValueError("Invalid learning rate: {}".format(lr))
        if momentum < 0.0:
            raise ValueError("Invalid momentum value: {}".format(momentum))
        if weight_decay < 0.0:
            raise ValueError(
                "Invalid weight_decay value: {}".format(weight_decay))
        if clipping < 0.0:
            raise ValueError("Invalid clipping value: {}".format(clipping))
        if eps < 0.0:
            raise ValueError("Invalid eps value: {}".format(eps))

        defaults = dict(lr=lr, momentum=momentum, dampening=dampening,
                        weight_decay=weight_decay, nesterov=nesterov, clipping=clipping, eps=eps)
        if nesterov and (momentum <= 0 or dampening != 0):
            raise ValueError(
                "Nesterov momentum requires a momentum and zero dampening")
        super(SGD_AGC, self).__init__(params, defaults)

#the optimizer maintains compatibility with older saved states
    def __setstate__(self, state):
        super(SGD_AGC, self).__setstate__(state)
        for group in self.param_groups:
            group.setdefault('nesterov', False)

    @torch.no_grad()
    def step(self, closure=None):
        """Performs a single optimization step.

        Arguments:
            closure (callable, optional): A closure that reevaluates the model
                and returns the loss.
        """
        loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()
# code for adaptive gradient clipping
        for group in self.param_groups:
            for p in group['params']:
                if p.grad is None:
                    continue
                param_norm = torch.max(unitwise_norm(
                    p.detach()), torch.tensor(group['eps']).to(p.device))
                grad_norm = unitwise_norm(p.grad.detach())
                max_norm = param_norm * group['clipping']

                trigger = grad_norm > max_norm

                clipped_grad = p.grad * \
                    (max_norm / torch.max(grad_norm,
                                          torch.tensor(1e-6).to(grad_norm.device)))
                p.grad.detach().copy_(torch.where(trigger, clipped_grad, p.grad))

        for group in self.param_groups:
            weight_decay = group['weight_decay']
            momentum = group['momentum']
            dampening = group['dampening']
            nesterov = group['nesterov']

            for p in group['params']:
                if p.grad is None:
                    continue
                d_p = p.grad
                if weight_decay != 0:
                    d_p = d_p.add(p, alpha=weight_decay)
                if momentum != 0:
                    param_state = self.state[p]
                    if 'momentum_buffer' not in param_state:
                        buf = param_state['momentum_buffer'] = torch.clone(
                            d_p).detach()
                    else:
                        buf = param_state['momentum_buffer']
                        buf.mul_(momentum).add_(d_p, alpha=1 - dampening)
                    if nesterov:
                        d_p = d_p.add(buf, alpha=momentum)
                    else:
                        d_p = buf

                p.add_(d_p, alpha=-group['lr'])

        return loss

In [14]:
import torch 
from torch import nn, optim 

from function.utils import unitwise_norm
from collections.abc import Iterable

In [16]:
#creating Adaptive gradient clipping 

class AGC(optim.Optimizer):
    """Generic implmentation of the Adaptive Gradient Clipping"""

    def __init__(self, params, optim: optim.Optimizer, clipping: float = 1e-2, eps: float = 1e-3, model=None, ignore_agc=["fc"]):
        if clipping < 0.0:
            raise ValueError("Invalid clipping value: {}".format(clipping))
        if eps < 0.0:
            raise ValueError("Invalid eps value: {}".format(eps))
        
        self.optim=optim

        defaults= dict(clipping=clipping, eps=eps)
        defaults ={**defaults, **optim.defaults}

        if not isinstance(ignore_agc, Iterable):
            ignore_agc = [ignore_agc]

        if model is not None:
            assert ignore_agc not in [
                None, []], "You must specify ignore_agc for AGC to ignore fc-like(or other) layers"
            names = [name for name, module in model.named_modules()]

            for module_name in ignore_agc:
                if module_name not in names:
                    raise ModuleNotFoundError(
                        "Module name {} not found in the model".format(module_name))
            params = [{"params": list(module.parameters())} for name,
                          module in model.named_modules() if name not in ignore_agc]
        
        else:
            params = [{"params": params}]

        self.agc_params = params
        self.eps = eps
        self.clipping = clipping
        
        self.param_groups = optim.param_groups
        self.state = optim.state

        #super(AGC, self).__init__([], defaults)

    @torch.no_grad()
    def step(self, closure=None):
        """Performs a single optimization step.

        Arguments:
            closure (callable, optional): A closure that reevaluates the model
                and returns the loss.
        """
        loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.agc_params:
            for p in group['params']:
                if p.grad is None:
                    continue

                param_norm = torch.max(unitwise_norm(
                    p.detach()), torch.tensor(self.eps).to(p.device))
                grad_norm = unitwise_norm(p.grad.detach())
                max_norm = param_norm * self.clipping

                trigger = grad_norm > max_norm

                clipped_grad = p.grad * \
                    (max_norm / torch.max(grad_norm,
                                          torch.tensor(1e-6).to(grad_norm.device)))
                p.grad.detach().data.copy_(torch.where(trigger, clipped_grad, p.grad))

        return self.optim.step(closure)

    def zero_grad(self, set_to_none: bool = False):
        r"""Sets the gradients of all optimized :class:`torch.Tensor` s to zero..
        """
        for group in self.agc_params:
            for p in group['params']:
                if p.grad is not None:
                    if set_to_none:
                        p.grad = None
                    else:
                        if p.grad.grad_fn is not None:
                            p.grad.detach_()
                        else:
                            p.grad.requires_grad_(False)
                        p.grad.zero_()

In [18]:
import torch
from torch import nn

from function.base import WSCon2d, ScaledStdConv2d

In [19]:
class SqueezeExcite(nn.Module):
  
  def __init__(self, in_channels, out_channels, se_ratio=0.5, hidden_channels=None, activation='relu'):
    assert (se_ratio != None) or ((se_ratio is None) and (hidden_channels is not None))
    
    if se_ratio is None:
      hidden_channels = hidden_channels
    else:
      hidden_channels = max(1, se_ratio * in_channels)
      
    self.fc0 = nn.Linear(in_channels, hidden_channels)
    self.fc1 = nn.Linear(hidden_channels, out_channels)
    
    self.activation  = activation_fn[activation]
    super(SqueezeExcite, self).__init__()
    
  def forward(self, x):
    h = torch.mean(x, [2,3])
    h = self.fc0(h)
    h = self.fc1(self.activation(h))
    
    return h.expand_as(x)

In [20]:
class NFBlock(nn.Module):
  
  def __init__(self, in_channels, out_channels, expansion=0.5, se_ratio=0.5, kernel_shape=3, group_size=128, stride=1, beta=1.0, alpha=0.2, conv=ScaledStdConv2d, activation='gelu'):
    
    width = int(self.out_channels * expansion)
    self.groups = width // group_size
    self.width = group_size * self.groups
    
    self.conv0 = conv(in_channels, self.width, 1)
    
    self.conv1 = conv(self.width, self.width, 3, groups=self.groups)
    
    self.alpha = alpha
    self.beta = beta

In [21]:
import torch
from torch import Tensor
import torch.nn as nn
from typing import Type, Any, Callable, Union, List, Optional

from function.base import WSCon2d, ScaledStdConv2d

from functools import partial

In [22]:
__all__=["nf_resnet18"]

In [23]:
_nonlin_gamma = dict(
    identity=1.0,
    celu=1.270926833152771,
    elu=1.2716004848480225,
    gelu=1.7015043497085571,
    leaky_relu=1.70590341091156,
    log_sigmoid=1.9193484783172607,
    log_softmax=1.0002083778381348,
    relu=1.7139588594436646,
    relu6=1.7131484746932983,
    selu=1.0008515119552612,
    sigmoid=4.803835391998291,
    silu=1.7881293296813965,
    softsign=2.338853120803833,
    softplus=1.9203323125839233,
    tanh=1.5939117670059204,
)

ignore_inplace = ['gelu', 'silu', 'softplus', ]

In [24]:
activation_fn = {
    'identity': lambda x, *args, **kwargs: nn.Identity(*args, **kwargs)(x) * _nonlin_gamma['identity'],
    'celu': lambda x, *args, **kwargs: nn.CELU(*args, **kwargs)(x) * _nonlin_gamma['celu'],
    'elu': lambda x, *args, **kwargs: nn.ELU(*args, **kwargs)(x) * _nonlin_gamma['elu'],
    'gelu': lambda x, *args, **kwargs: nn.GELU(*args, **kwargs)(x) * _nonlin_gamma['gelu'],
    'leaky_relu': lambda x, *args, **kwargs: nn.LeakyReLU(*args, **kwargs)(x) * _nonlin_gamma['leaky_relu'],
    'log_sigmoid': lambda x, *args, **kwargs: nn.LogSigmoid(*args, **kwargs)(x) * _nonlin_gamma['log_sigmoid'],
    'log_softmax': lambda x, *args, **kwargs: nn.LogSoftmax(*args, **kwargs)(x) * _nonlin_gamma['log_softmax'],
    'relu': lambda x, *args, **kwargs: nn.ReLU(*args, **kwargs)(x) * _nonlin_gamma['relu'],
    'relu6': lambda x, *args, **kwargs: nn.ReLU6(*args, **kwargs)(x) * _nonlin_gamma['relu6'],
    'selu': lambda x, *args, **kwargs: nn.SELU(*args, **kwargs)(x) * _nonlin_gamma['selu'],
    'sigmoid': lambda x, *args, **kwargs: nn.Sigmoid(*args, **kwargs)(x) * _nonlin_gamma['sigmoid'],
    'silu': lambda x, *args, **kwargs: nn.SiLU(*args, **kwargs)(x) * _nonlin_gamma['silu'],
    'softplus': lambda x, *args, **kwargs: nn.Softplus(*args, **kwargs)(x) * _nonlin_gamma['softplus'],
    'tanh': lambda x, *args, **kwargs: nn.Tanh(*args, **kwargs)(x) * _nonlin_gamma['tanh'],
}

In [25]:
def conv3x3(in_planes: int, out_planes: int, stride: int = 1, groups: int = 1, dilation: int = 1, base_conv: nn.Conv2d = ScaledStdConv2d) -> nn.Conv2d:
    """3x3 convolution with padding"""
    return base_conv(in_planes, out_planes, kernel_size=3, stride=stride,
                     padding=dilation, groups=groups, bias=False, dilation=dilation)


def conv1x1(in_planes: int, out_planes: int, stride: int = 1, base_conv: nn.Conv2d = ScaledStdConv2d) -> nn.Conv2d:
    """1x1 convolution"""
    return base_conv(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)

In [26]:
class BasicBlock(nn.Module):
    expansion: int = 1

    def __init__(
        self,
        inplanes: int,
        planes: int,
        stride: int = 1,
        downsample: Optional[nn.Module] = None,
        groups: int = 1,
        base_width: int = 64,
        dilation: int = 1,
        alpha: float = 0.2,
        beta: float = 1.0,
        activation: str = 'relu',
        base_conv: nn.Conv2d = ScaledStdConv2d
    ) -> None:
        super(BasicBlock, self).__init__()
        if groups != 1 or base_width != 64:
            raise ValueError(
                'BasicBlock only supports groups=1 and base_width=64')
        if dilation > 1:
            raise NotImplementedError(
                "Dilation > 1 not supported in BasicBlock")
        # Both self.conv1 and self.downsample layers downsample the input when stride != 1
        self.conv1 = conv3x3(inplanes, planes, stride, base_conv=base_conv)
        self.activation = activation
        
        if activation not in ignore_inplace:
            self.act = partial(activation_fn[activation], inplace=True)
        else:
            self.act = partial(activation_fn[activation])
        self.conv2 = conv3x3(planes, planes, base_conv=base_conv)
        self.downsample = downsample
        self.stride = stride
        self.alpha = alpha
        self.beta = beta

    def forward(self, x: Tensor) -> Tensor:
        identity = x
        
        out = activation_fn[self.activation](x=x) * self.beta

        out = self.conv1(out)
        out = self.act(x=out)

        out = self.conv2(out)

        if self.downsample is not None:
            identity = self.downsample(x)

        out *= self.alpha
        out += identity

        return out

In [27]:
class Bottleneck(nn.Module):
    # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2)
    # while original implementation places the stride at the first 1x1 convolution(self.conv1)
    # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385.
    # This variant is also known as ResNet V1.5 and improves accuracy according to
    # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch.

    expansion: int = 4

    def __init__(
        self,
        inplanes: int,
        planes: int,
        stride: int = 1,
        downsample: Optional[nn.Module] = None,
        groups: int = 1,
        base_width: int = 64,
        dilation: int = 1,
        alpha: float = 0.2,
        beta: float = 1.0,
        activation: str = 'relu',
        base_conv: nn.Conv2d = ScaledStdConv2d,
    ) -> None:
        super(Bottleneck, self).__init__()
        width = int(planes * (base_width / 64.)) * groups
        # Both self.conv2 and self.downsample layers downsample the input when stride != 1
        self.conv1 = conv1x1(inplanes, width, base_conv=base_conv)
        self.conv2 = conv3x3(width, width, stride, groups,
                             dilation, base_conv=base_conv)
        self.conv3 = conv1x1(
            width, planes * self.expansion, base_conv=base_conv)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.stride = stride
        
        self.alpha = alpha
        self.beta = beta
        self.activation = activation
        if activation not in ignore_inplace:
            self.act = partial(activation_fn[activation], inplace=True)
        else:
            self.act = partial(activation_fn[activation])
        

    def forward(self, x: Tensor) -> Tensor:
        identity = x
        
        out = activation_fn[self.activation](x) * self.beta

        out = self.conv1(out)
        out = self.act(x=out)

        out = self.conv2(out)
        out = self.act(x=out)

        out = self.conv3(out)

        if self.downsample is not None:
            identity = self.downsample(x)

        out *= self.alpha
        out += identity

        return out

In [28]:
class NFResNet(nn.Module):

    def __init__(
        self,
        block: Type[Union[BasicBlock, Bottleneck]],
        layers: List[int],
        num_classes: int = 1000,
        zero_init_residual: bool = False,
        groups: int = 1,
        width_per_group: int = 64,
        replace_stride_with_dilation: Optional[List[bool]] = None,
        alpha: float = 0.2,
        beta: float = 1.0,
        activation: str = 'relu',
        base_conv: nn.Conv2d = ScaledStdConv2d
    ) -> None:
        super(NFResNet, self).__init__()
        
        assert activation in activation_fn.keys()

        self.inplanes = 64
        self.dilation = 1
        if replace_stride_with_dilation is None:
            # each element in the tuple indicates if we should replace
            # the 2x2 stride with a dilated convolution instead
            replace_stride_with_dilation = [False, False, False]
        if len(replace_stride_with_dilation) != 3:
            raise ValueError("replace_stride_with_dilation should be None "
                             "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
        self.groups = groups
        self.base_width = width_per_group
        self.conv1 = base_conv(3, self.inplanes, kernel_size=7, stride=2, padding=3,
                               bias=False)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.layer1 = self._make_layer(
            block, 64, layers[0], alpha=alpha, beta=beta, activation=activation, base_conv=base_conv)
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2,
                                       dilate=replace_stride_with_dilation[0], alpha=alpha, beta=beta, activation=activation, base_conv=base_conv)
        self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
                                       dilate=replace_stride_with_dilation[1], alpha=alpha, beta=beta, activation=activation, base_conv=base_conv)
        self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
                                       dilate=replace_stride_with_dilation[2], alpha=alpha, beta=beta, activation=activation, base_conv=base_conv)
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(512 * block.expansion, num_classes)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(
                    m.weight, mode='fan_in', nonlinearity='linear')

        # Zero-initialize the last BN in each residual branch,
        # so that the residual branch starts with zeros, and each residual block behaves like an identity.
        # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
        if zero_init_residual:
            for m in self.modules():
                if isinstance(m, Bottleneck):
                    # type: ignore[arg-type]
                    nn.init.constant_(m.bn3.weight, 0)
                elif isinstance(m, BasicBlock):
                    # type: ignore[arg-type]
                    nn.init.constant_(m.bn2.weight, 0)

    def _make_layer(self, block: Type[Union[BasicBlock, Bottleneck]], planes: int, blocks: int,
                    stride: int = 1, dilate: bool = False, alpha: float = 0.2, beta: float = 1.0, activation: str = 'relu', base_conv: nn.Conv2d = ScaledStdConv2d) -> nn.Sequential:
        downsample = None
        previous_dilation = self.dilation
        if dilate:
            self.dilation *= stride
            stride = 1
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                conv1x1(self.inplanes, planes * block.expansion,
                        stride, base_conv=base_conv),
            )

        layers = []
        layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
                            self.base_width, previous_dilation, alpha=alpha, beta=beta, activation=activation, base_conv=base_conv))
        self.inplanes = planes * block.expansion
        for _ in range(1, blocks):
            layers.append(block(self.inplanes, planes, groups=self.groups,
                                base_width=self.base_width, dilation=self.dilation,
                                alpha=alpha, beta=beta, activation=activation,
                                base_conv=base_conv))

        return nn.Sequential(*layers)

    def _forward_impl(self, x: Tensor) -> Tensor:
        # See note [TorchScript super()]
        x = self.conv1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)

        return x

    def forward(self, x: Tensor) -> Tensor:
        return self._forward_impl(x)

In [29]:
def _nf_resnet(
    arch: str,
    block: Type[Union[BasicBlock, Bottleneck]],
    layers: List[int],
    alpha: float,
    beta: float,
    activation: str,
    base_conv: nn.Conv2d,
    **kwargs: Any
) -> NFResNet:
    model = NFResNet(block, layers, alpha=alpha, beta=beta, activation=activation, base_conv=base_conv, **kwargs)
    return model


In [30]:
def nf_resnet18(alpha: float = 0.2, beta: float = 1.0, activation: str = 'relu', base_conv: nn.Conv2d = ScaledStdConv2d, **kwargs: Any) -> NFResNet:
    r"""ResNet-18 model from
    `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
    and `"High-Performance Large-Scale Image Recognition Without Normalization" <https://arxiv.org/pdf/2102.06171v1>`.
    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
    """
    return _nf_resnet('resnet18', BasicBlock, [2, 2, 2, 2], alpha=alpha, beta=beta, activation=activation, base_conv=base_conv,
                      **kwargs)
