In [1]:
import torch
import torch.onnx
from torchvision.models.feature_extraction import create_feature_extractor
import torch.nn as nn

import onnx
from torchvision.models import mobilenet_v3_small
from SZModel import SZModel, SZModel_SIGMOID, SZModel_Softmax, SZModel_Combined

from torchsummary import summary

  from .autonotebook import tqdm as notebook_tqdm


# Set Model


In [2]:
seizure_type = "combined"
model_type = "default"
device = torch.device("cpu")

Load Model


In [3]:
if seizure_type == "combined":
    model_absence = SZModel()
    model_absence.load_state_dict(torch.load("model/pytorch_models/absence_model.pt"))

    model_absence = model_absence.to(device)
    return_nodes = {
        "block2.4": "rr",
    }
    model_absence = create_feature_extractor(model_absence, return_nodes=return_nodes)

    model_tonic_clonic = SZModel()
    model_tonic_clonic.load_state_dict(
        torch.load("model/pytorch_models/tonic-clonic_model.pt")
    )
    model_tonic_clonic = model_tonic_clonic.to(device)
    return_nodes = {
        "block2.4": "rr",
    }
    model_tonic_clonic = create_feature_extractor(
        model_tonic_clonic, return_nodes=return_nodes
    )

    model_general = SZModel()
    model_general.load_state_dict(torch.load("model/pytorch_models/general_model.pt"))
    model_general = model_general.to(device)
    return_nodes = {
        "block2.4": "rr",
    }
    model_general = create_feature_extractor(model_general, return_nodes=return_nodes)

In [4]:
if seizure_type == "combined":
    model = SZModel_Combined(model_absence, model_tonic_clonic, model_general)

    # calculate parameter count

    n_params = 0

    param_list_format = [["Layer", "Number of Parameters"]]

    for name, param in model.named_parameters():

        n_params += param.numel()
        a = name.split(".")

        param_list_format.append([a, param.numel()])

    print(f"Model has {n_params} parameters")

else:
    model = SZModel()

Model has 18676 parameters


In [5]:
model.load_state_dict(torch.load(f"model/pytorch_models/{seizure_type}_2_model.pt"))
model.to(device)

pass

# Convert Pytorch to ONNX


In [6]:
input_shape = (1, 1, 40, 26)
torch.onnx.export(
    model,
    torch.randn(input_shape),
    f"model/onnx_models/{seizure_type}_2.onnx",
    opset_version=11,
)