In [11]:
import sys

#install torch and torchvision if not already installed
!{sys.executable} -m pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118

import torch.nn as nn
import torch, torchvision
import torch.nn.functional as F
import regex as re

Looking in indexes: https://download.pytorch.org/whl/cu118


In [12]:
class DenseNet121(nn.Module):
    """Model modified.
    The architecture of our model is the same as standard DenseNet121
    except the classifier layer which has an additional sigmoid function.
    """
    def __init__(self, out_size):
        super(DenseNet121, self).__init__()
        self.densenet121 = torchvision.models.densenet121(pretrained=True)
        # embeddings
        num_ftrs = self.densenet121.classifier.in_features
        self.densenet121.classifier = nn.Sequential(
            nn.Linear(num_ftrs, out_size),
            nn.Sigmoid()
        )

    def forward(self, x):
        x = self.densenet121(x)
        return x


model = DenseNet121(14).cuda()

modelCheckpoint = torch.load("../models/chexnet.pth.tar")
pattern = re.compile(r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$')
state_dict = modelCheckpoint['state_dict']
for i, key in enumerate(list(state_dict.keys())):
    res = pattern.match(key)

    if res:
        new_key = (res.group(1) + res.group(2))[7:]
        state_dict[new_key] = state_dict[key]
        del state_dict[key]
    else:
        new_key = key[7:]
        state_dict[new_key] = state_dict[key]
        del state_dict[key]
model.load_state_dict(state_dict, strict=False)



<All keys matched successfully>

In [20]:
!{sys.executable} -m pip install onnx onnx2keras onnxruntime jax

input_data = torch.randn(1, 3, 224, 224).cuda()
torch.onnx.export(
    model,
    input_data,
    "densenet121.onnx",
    export_params=True,
    opset_version=11,
    do_constant_folding=True,
    input_names=['input'],
    output_names=['output']
)

Collecting jax
  Using cached jax-0.6.0-py3-none-any.whl.metadata (22 kB)
Collecting jaxlib<=0.6.0,>=0.6.0 (from jax)
  Using cached jaxlib-0.6.0-cp312-cp312-manylinux2014_x86_64.whl.metadata (1.2 kB)
Collecting scipy>=1.11.1 (from jax)
  Using cached scipy-1.15.3-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (61 kB)
Using cached jax-0.6.0-py3-none-any.whl (2.3 MB)
Using cached jaxlib-0.6.0-cp312-cp312-manylinux2014_x86_64.whl (87.8 MB)
Using cached scipy-1.15.3-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (37.3 MB)
Installing collected packages: scipy, jaxlib, jax
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3/3[0m [jax][32m2/3[0m [jax]ib]
[1A[2KSuccessfully installed jax-0.6.0 jaxlib-0.6.0 scipy-1.15.3


In [22]:
import os
os.environ["KERAS_BACKEND"] = "torch"

import onnx
onnx_model = onnx.load('densenet121.onnx')

from onnx2keras import onnx_to_keras
keras_model = onnx_to_keras(onnx_model, ['input'])

# Load weights into Keras model
weights = model.state_dict()
keras_model.set_weights([weights[key].numpy() for key in weights.keys()])

# Use Keras model for inference
output = keras_model.predict(input_data)


INFO:onnx2keras:Converter is called.
DEBUG:onnx2keras:List input shapes:
DEBUG:onnx2keras:None
DEBUG:onnx2keras:List inputs:
DEBUG:onnx2keras:Input 0 -> input.
DEBUG:onnx2keras:List outputs:
DEBUG:onnx2keras:Output 0 -> output.
DEBUG:onnx2keras:Gathering weights to dictionary.
DEBUG:onnx2keras:Found weight densenet121.features.denseblock1.denselayer1.norm1.weight with shape (64,).
DEBUG:onnx2keras:Found weight densenet121.features.denseblock1.denselayer1.norm1.bias with shape (64,).
DEBUG:onnx2keras:Found weight densenet121.features.denseblock1.denselayer1.norm1.running_mean with shape (64,).
DEBUG:onnx2keras:Found weight densenet121.features.denseblock1.denselayer1.norm1.running_var with shape (64,).
DEBUG:onnx2keras:Found weight densenet121.features.denseblock1.denselayer1.conv2.weight with shape (32, 128, 3, 3).
DEBUG:onnx2keras:Found weight densenet121.features.denseblock1.denselayer2.norm1.weight with shape (96,).
DEBUG:onnx2keras:Found weight densenet121.features.denseblock1.dens

ValueError: Argument `name` must be a string and cannot contain character `/`. Received: name=/densenet121/features/conv0/Conv_output_0_pad (of type <class 'str'>)