In [1]:
from src import *

In [8]:
module = MNISTModule.load_from_checkpoint('/content/drive/MyDrive/portafolio/dlops/checkpoints/final.ckpt')
module

MNISTModule(
  (conv1): Sequential(
    (0): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU()
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (mlp): Sequential(
    (0): Linear(in_features=6272, out_features=100, bias=True)
    (1): ReLU()
    (2): Linear(in_features=100, out_features=1, bias=True)
  )
)

In [9]:
import torch 

dm = MNISTDataModule(**module.hparams['datamodule'])
dm.setup()

def torch_eval():
    module.eval()
    with torch.no_grad():
        preds, labels = torch.tensor([]), torch.tensor([])
        for imgs, _labels in dm.val_dataloader():
            outputs = module.predict(imgs) > 0.5
            preds = torch.cat([preds, outputs.cpu().long()])
            labels = torch.cat([labels, _labels])

    acc = (preds == labels).float().mean()
    return acc.item()

torch_eval()

0.9649999737739563

In [12]:
input_sample = torch.randint(0, 255, (1, 1, 28, 28), dtype=torch.uint8)
module.to_onnx(
    'models/binary_classifier_3.onnx', # nombre del modelo a guardar
    input_sample, # ejemplo de entrada
    export_params=True, # exportar los parametros del modelo
    opset_version=11, # en función de las ops en el modelo, se puede cambiar el opset
    input_names = ['input'], # nombre de la entrada	para usar en producción
    output_names = ['output'],  # nombre de la salida para usar en producción
    dynamic_axes={  # para poder tener diferentes batch sizes
        'input' : {0 : 'batch_size'},
        'output' : {0 : 'batch_size'},
    },
)

In [16]:
import onnxruntime as ort 
import numpy as np

ort_session = ort.InferenceSession('/content/drive/MyDrive/portafolio/dlops/models/binary_classifier_3.onnx')

ort_inputs = {
    "input": np.random.randint(0, 255, (10, 1, 28, 28), dtype=np.uint8),
}

ort_output = ort_session.run(['output'], ort_inputs)
ort_output[0].shape

(10,)

In [17]:
def onnx_eval():
    with torch.no_grad():
        preds, labels = [], torch.tensor([])
        for imgs, _labels in dm.val_dataloader():
            ort_inputs = {
                "input": imgs.numpy(),
            }
            ort_output = ort_session.run(['output'], ort_inputs)[0]
            outputs = ort_output > 0.5
            preds += outputs.astype(int).tolist()
            labels = torch.cat([labels, _labels])
    acc = (np.array(preds) == labels.numpy()).astype(float).mean()
    return acc 

onnx_eval()

0.9625