In [1]:
from penzai import pz

In [2]:
import json
import urllib.request

import torch
from jaxonmodels.vision.resnet import resnet18 as resnet18_jax
from PIL import Image
from torchvision import transforms
from torchvision.models import resnet18, resnet50


img_name = "doggo.jpeg"
resnet = resnet18(weights=None)

transform = transforms.Compose(
    [
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ]
)
img = Image.open(img_name)
img_t = transform(img)
batch_t = torch.unsqueeze(img_t, 0)  # pyright:ignore

# Predict
with torch.no_grad():
    output = resnet(batch_t)
    print(output.shape)
    _, predicted = torch.max(output, 1)

print(
    f"Predicted: {predicted.item()}"
)  # Outputs the ImageNet class index of the prediction

# Load ImageNet labels
url = "https://storage.googleapis.com/download.tensorflow.org/data/imagenet_class_index.json"
with urllib.request.urlopen(url) as url:
    imagenet_labels = json.loads(url.read().decode())

label = imagenet_labels[str(predicted.item())][1]
print(f"Label for index {predicted.item()}: {label}")


jax_resnet, state = resnet18_jax()

block=<class 'torchvision.models.resnet.BasicBlock'>
 layers=[2, 2, 2, 2]
 num_classes=1000
 zero_init_residual=False
 groups=1
 width_per_group=64
 replace_stride_with_dilation=None
 norm_layer=None

inplanes=64 planes=64 stride=1 downsample=None groups=1 base_width=64 dilation=1 norm_layer=<class 'torch.nn.modules.batchnorm.BatchNorm2d'>
inplanes=64 planes=64 stride=1 downsample=None groups=1 base_width=64 dilation=1 norm_layer=<class 'torch.nn.modules.batchnorm.BatchNorm2d'>
inplanes=64 planes=128 stride=2 downsample=Sequential(
  (0): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)
  (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
) groups=1 base_width=64 dilation=1 norm_layer=<class 'torch.nn.modules.batchnorm.BatchNorm2d'>
inplanes=128 planes=128 stride=1 downsample=None groups=1 base_width=64 dilation=1 norm_layer=<class 'torch.nn.modules.batchnorm.BatchNorm2d'>
inplanes=128 planes=256 stride=2 downsample=Sequential(
  (0): Co

In [3]:
with pz.ts.active_autovisualizer.set_scoped(pz.ts.ArrayAutovisualizer()):
  pz.ts.display(jax_resnet)

In [4]:
import jax.numpy as jnp

sd = resnet.state_dict()
resnet_state_dict = {
    k: jnp.array(sd[k].numpy()) for k in sd if "weight" in k
}

with pz.ts.active_autovisualizer.set_scoped(pz.ts.ArrayAutovisualizer()):
  pz.ts.display(resnet_state_dict)