In [54]:
import jax
import jax.numpy as jnp
import numpy as np
import flax
from flax import linen as nn
from functools import partial

In [121]:
from typing import Sequence

In [12]:
rng = jax.random.PRNGKey(0)
rng, init_rng = jax.random.split(rng)



In [78]:
def conv1x1(out_planes, strides=1):
    return nn.Conv(out_planes, kernel_size=(1, 1), strides=(strides, strides), use_bias=False)

In [79]:
def conv3x3(out_planes, strides=1, groups=1, dilation=1):
    return nn.Conv(out_planes, kernel_size=(3, 3), strides=(strides, strides),
                     padding=[(dilation, dilation), (dilation, dilation)], use_bias=False)

In [100]:
class BasicBlock(nn.Module):
    planes: int
    stride: int = 1
    expansion: int = 1
    downsample: bool = False
        
    def setup(self):
        self.conv1 = conv3x3(self.planes, strides=self.stride)
        self.bn1 = nn.BatchNorm(use_running_average=False)
        self.conv2 = conv3x3(self.planes)
        self.bn2 = nn.BatchNorm(use_running_average=False)
        if self.downsample:
            self.dw = conv1x1(self.planes, strides=2)
            self.dw_bn = nn.BatchNorm(use_running_average=False)
    
    def __call__(self, x):
        identity = x
        x = self.conv1(x)
        x = self.bn1(x)
        x = nn.relu(x)
        x = self.conv2(x)
        x = self.bn2(x)
        if self.downsample:
            identity = self.dw(identity)
            identity = self.dw_bn(identity)
        x = x + identity
        x = nn.relu(x)
        return x

In [123]:
class ResNet(nn.Module):
    block = BasicBlock
    n_classes = 1000
    depths: Sequence[int]
    inplanes= 64
    def setup(self):
        super(ResNet, self).__init__()
        self.conv1 = nn.Conv(self.inplanes, (7, 7), strides=(2, 2), padding=[(3, 3), (3, 3)], use_bias=False)
        self.bn1 = nn.BatchNorm(use_running_average=False)
        
        self.layers = [
            *self._make_layer(self.block, self.depths[0], 64),
            *self._make_layer(self.block, self.depths[1], 128, stride=2),
            *self._make_layer(self.block, self.depths[2], 256, stride=2),
            *self._make_layer(self.block, self.depths[3], 512, stride=2),   
        ]
        self.head = nn.Dense(self.n_classes)
    
    def __call__(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = nn.relu(x)
        x = nn.pooling.max_pool(x, (3, 3), (2, 2))
        for layer in self.layers:
            x = layer(x)
        x = nn.pooling.avg_pool(x, (1, 1))
        x = self.head(x)
        return x
    
    def _make_layer(self, block, depth, planes, stride=1):
        dw = False
        inplanes = self.inplanes
        if stride !=1 or inplanes != planes * block.expansion:
            dw = True
        layers = [
            block(planes, stride=stride, downsample=dw)
        ]
        inplanes = planes * block.expansion
        for _ in range(1, depth):
            layers.append(
                block(planes)
            )
        self.inplanes = inplanes
        return layers
        

In [124]:
x = jnp.ones([1, 224, 224, 3])

In [125]:
resnet18 = ResNet(depths=[2, 2, 2, 2])
resnet34 = ResNet(depths=[3, 4, 6, 3])

In [132]:
params18 = resnet18.init(rng, x)['params']
params34 = resnet34.init(rng, x)['params']