In [None]:
import torchvision
import pax
import numpy as np

In [None]:
resnet18 = torchvision.models.resnet18(pretrained=True).eval()

In [None]:
import torch
def convert_conv(conv, name=None):
    weight = conv.weight.data.contiguous().permute(2, 3, 1, 0).contiguous().numpy()[:]
    # print(conv.in_channels, conv.out_channels, conv.stride, conv.kernel_size)
    pax_conv = pax.nn.Conv2D(
        in_features=conv.in_channels, 
        out_features=conv.out_channels, 
        kernel_shape=conv.kernel_size, 
        stride=conv.stride, 
        with_bias=False,
        padding=conv.padding,
        data_format="NCHW",
        name=name
    )
    assert pax_conv.w.shape == weight.shape
    pax_conv.w = weight
    return pax_conv


def convert_bn(bn, name=None):

    weight = bn.weight.data.numpy()[None, :, None, None]
    bias = bn.bias.data.numpy()[None, :, None, None]
    running_mean = bn.running_mean.data.numpy()[None, :, None, None]
    running_var = bn.running_var.data.numpy()[None, :, None, None]
    pax_bn = pax.nn.BatchNorm2D(
        num_channels=bias.shape[1],
        create_offset=True,
        create_scale=True,
        decay_rate=0.9,
        eps=1e-5,
        data_format='NC...',
        name=name
    )
    assert pax_bn.params['batch_norm']['scale'].shape == weight.shape
    assert pax_bn.params['batch_norm']['offset'].shape == bias.shape
    assert pax_bn.state['batch_norm/~/mean_ema']['hidden'].shape == running_mean.shape
    assert pax_bn.state['batch_norm/~/mean_ema']['average'].shape == running_mean.shape
    assert pax_bn.state['batch_norm/~/var_ema']['hidden'].shape == running_var.shape
    assert pax_bn.state['batch_norm/~/var_ema']['average'].shape == running_var.shape

    pax_bn.params['batch_norm']['scale'] =  weight
    pax_bn.params['batch_norm']['offset'] = bias
    

    pax_bn.state['batch_norm/~/mean_ema']['counter'] = np.array(0, dtype=np.int32)
    pax_bn.state['batch_norm/~/mean_ema']['hidden'] = None
    pax_bn.state['batch_norm/~/mean_ema']['average'] = running_mean

    pax_bn.state['batch_norm/~/var_ema']['counter'] = np.array(0, dtype=np.int32)
    pax_bn.state['batch_norm/~/var_ema']['hidden'] = None
    pax_bn.state['batch_norm/~/var_ema']['average'] = running_var

    return pax_bn

def convert_basic_block(block):
    conv1 = convert_conv(block.conv1, name="conv1")
    bn1 = convert_bn(block.bn1, name="bn1")
    conv2 = convert_conv(block.conv2, name="conv2")
    bn2 = convert_bn(block.bn2, name="bn2")

    if  block.downsample is not None:
        conv0 = convert_conv(block.downsample[0], name="proj_conv")
        bn0 = convert_bn(block.downsample[1], name="proj_bn")
        return ( (conv1, bn1), (conv2, bn2) ), (conv0, bn0)
    else:
        return ( (conv1, bn1), (conv2, bn2) ), 

def convert_block_group(group):
    out = []
    for i in range(len(group)):
        out.append( convert_basic_block(group[i]))
    return out

def convert_linear(linear):
    weight = linear.weight.data.numpy()[:]
    bias = linear.bias.data.numpy()[:]
    # print('linear', weight.shape)
    pax_linear = pax.nn.Linear(in_dim = weight.shape[1], out_dim=weight.shape[0], with_bias=True)
    weight = np.ascontiguousarray(np.transpose(weight))

    assert pax_linear.b.shape == bias.shape
    assert pax_linear.W.shape == weight.shape
    pax_linear.W = weight
    pax_linear.b = bias
    return pax_linear

In [None]:
pax_resnet = [
    convert_conv(resnet18.conv1),
    convert_bn(resnet18.bn1),
    convert_block_group(resnet18.layer1),
    convert_block_group(resnet18.layer2),
    convert_block_group(resnet18.layer3),
    convert_block_group(resnet18.layer4),
    convert_linear(resnet18.fc)
]

In [None]:
rnet = pax.nets.ResNet18(3, 1000)
rnet.initial_conv = pax_resnet[0]
rnet.initial_batchnorm = pax_resnet[1]
for i in range(len(rnet.block_groups)):
    bg = rnet.block_groups[i]
    for j in range(len(bg.blocks)):
        b = bg.blocks[j]
        mods = pax_resnet[2 + i][j]
        b.layers = mods[0]
        if b.use_projection:
            b.proj_conv = mods[1][0]
            b.proj_batchnorm = mods[1][1]

rnet.logits = pax_resnet[-1]

In [None]:
import jax
img = jax.random.normal(jax.random.PRNGKey(11), (1, 3, 224, 224))
import einops
pax_img = img[:]  # np.ascontiguousarray(einops.rearrange(img, "N C H W -> N H W C"))

In [None]:
import torch

In [None]:
resnet18(torch.from_numpy(np.copy(jax.device_get(img))))

In [None]:
rnet = rnet.eval()

In [None]:
rnet(pax_img)

In [None]:
# resnet18(torch.from_numpy(jax.device_get(img)))

In [None]:
# resnet18(torch.from_numpy(np.copy(jax.device_get(img))))

In [None]:
# print(rnet.summary())

In [None]:
rnet.eval()(pax_img)

In [None]:
cbn = convert_bn(resnet18.bn1).eval()
cbn(pax_img)[0, 0]

In [None]:
cbn.params['batch_norm'].keys()

In [None]:
cbn.state['batch_norm/~/var_ema'].keys()