In [1]:
from IPython.display import display, HTML

display(HTML(data="""
<style>
    div#notebook-container    { width: 99%; }
    div#menubar-container     { width: 99%; }
    div#maintoolbar-container { width: 99%; }
</style>
"""))

In [2]:
import os
import torch 
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt

%matplotlib inline

In [3]:
from democlassi_helpers import initialize_model, PretrainedMT

In [4]:
# create model objects
resnet_fer, _ = initialize_model("resnet", False, 7, "fer2013", False)
resnet_rag = PretrainedMT("resnet", False, False)

In [5]:
# load models weights
fer_weight = torch.load("./ml_outputs/democlassi/FER_resnet_model_109_val_accuracy_0.6227361.pth", map_location=torch.device("cpu"))
rag_weight = torch.load("./ml_outputs/democlassi/RAG_resnet_model_21_val_loss_4.275671.pth", map_location=torch.device("cpu"))

In [6]:
resnet_fer.load_state_dict(fer_weight)

<All keys matched successfully>

In [7]:
resnet_rag.load_state_dict(rag_weight)

<All keys matched successfully>

In [8]:
print(f"Number of parameters for FER model : {sum(p.numel() for p in resnet_fer.parameters()):_}")
print(f"Number of parameters for RAG model : {sum(p.numel() for p in resnet_rag.parameters()):_}")

Number of parameters for FER model : 23_522_375
Number of parameters for RAG model : 23_771_336


In [9]:
class FERModel(nn.Module):
    def __init__(self, model):
        super(FERModel, self).__init__()
        self.model = model
        self.eval()
    
    def forward(self, x):
        with torch.no_grad():
            x = self.model(x)
        return x


class RAGModel(nn.Module):
    def __init__(self, model):
        super(RAGModel, self).__init__()
        self.model = model
        self.eval()
    
    def forward(self, x):
        with torch.no_grad():
            age, gender, race = self.model(x)
        return age, gender, race

In [10]:
fer_model = FERModel(resnet_fer)
# im = torch.tensor(np.array(tmp_im))
# print(fer_model(im))

In [11]:
rag_model = RAGModel(resnet_rag)
# print(rag_model(im))

In [12]:
torch.onnx.export(fer_model,               # model being run
                  torch.randn((1, 3, 128, 128)),                         # model input (or a tuple for multiple inputs)
                  "../../public/static/ml_models/fer.onnx",   # where to save the model (can be a file or file-like object)
                  export_params=True,        # store the trained parameter weights inside the model file
                  opset_version=9,          # the ONNX version to export the model to
                  do_constant_folding=True,  # whether to execute constant folding for optimization
                  input_names = ['input'],   # the model's input names
                  output_names = ['output'], # the model's output names
                  dynamic_axes={'input' : {0 : 'sH', 1: "W"},    # variable lenght axes
                                })

In [13]:
torch.onnx.export(rag_model,               # model being run
                  torch.randn((1, 3, 128, 128)),                         # model input (or a tuple for multiple inputs)
                  "../../public/static/ml_models/rag.onnx",   # where to save the model (can be a file or file-like object)
                  export_params=True,        # store the trained parameter weights inside the model file
                  opset_version=9,          # the ONNX version to export the model to
                  do_constant_folding=True,  # whether to execute constant folding for optimization
                  input_names = ['input'],   # the model's input names
                  output_names = ['age', "gender", "race"], # the model's output names
                  dynamic_axes={'input' : {0 : 'H', 1: "W"},    # variable lenght axes
                                })