In [1]:
from src import *

module = MNISTModule.load_from_checkpoint( 'checkpoints/006-val_loss=0.15781-epoch=8.ckpt' )
module.mlp

Sequential(
  (0): Linear(in_features=784, out_features=100, bias=True)
  (1): ReLU()
  (2): Linear(in_features=100, out_features=1, bias=True)
)

In [2]:
import torch

dm = MNISTDataModule( **module.hparams[ 'datamodule' ] )
dm.setup()

def torch_eval():
    module.to('cpu') # se movio a cpu, aunque se puede mover el resto a gpu con .to('cuda')
    module.eval()
    with torch.no_grad():
        preds, labels = torch.tensor([]), torch.tensor([])
        for imgs, _labels in dm.val_dataloader():
            outputs = module.predict(imgs) > 0.5
            preds = torch.cat( [ preds, outputs.cpu().long() ] )
            labels = torch.cat( [ labels, _labels ] )
    
    acc = (preds == labels).float().mean()
    return acc.item()

torch_eval()

0.9375

In [4]:
input_sample = torch.randint( 0, 255, ( 1, 28, 28 ), dtype = torch.uint8 )
module.to_onnx(
    'models/binary_classifier_3.onnx', # file path to save the model
    input_sample, # model's input sample
    export_params = True, # export parameters (weights) of the model
    opset_version = 11, # en función de los OPS en el modelo, se puede cambiar el opset
    input_names = ['input'], # nombre de la entrada para usar en producción
    output_names = ['output'], # nombre de la salida para usar en producción
    dynamic_axes = {
        'input' : { 0 : 'batch_size' }, 
        'output' : { 0 : 'batch_size' },
    },
)

verbose: False, log level: Level.ERROR



In [5]:
import onnxruntime as ort
import numpy as np

ort_session = ort.InferenceSession('models/binary_classifier_3.onnx')

ort_inputs = {
    "input": np.random.randint( 0, 255, ( 10, 28, 28 ), dtype = np.uint8 )
}

ort_output = ort_session.run( ['output'], ort_inputs )
ort_output[0].shape

(10,)

In [8]:
def sigmoid(x):
    return 1 / ( 1 + np.exp( -x ) )

def onnx_eval():
    with torch.no_grad():
        preds, labels = [], torch.tensor( [] )
        for imgs, _labels in dm.val_dataloader():
            ort_inputs = {
                "input": imgs.numpy(),
            }
            ort_output = ort_session.run( ["output"], ort_inputs)[0]
            outputs = sigmoid( ort_output ) > 0.5
            preds += outputs.astype( int ).tolist()
            labels = torch.cat( [ labels, _labels ] )
    acc = (np.array( preds ) == labels.numpy() ).astype( float ).mean()
    return acc

onnx_eval()

0.9375