# Export des FNN Models
Da es zwei Modelle (FNN und CNN) gibt, werden beide einzeln nach onnx konvertiert und eingebaut

In [1]:
from vespag.utils import load_model, get_device
from vespag.utils.type_hinting import Architecture, EmbeddingType
import torch
from vespag.models import fnn, cnn

# copied from vespag.utils "DEFAULT_MODEL_PARAMETERS"
params = {
    "architecture": Architecture.fnn,
    "model_parameters": {"hidden_dims": [256], "dropout_rate": 0.2},
    "embedding_type": EmbeddingType.esm2,
    "onnx_model_path": ""
}

device = get_device()
fnn_model = load_model(**params).eval().to(device, dtype=torch.float)

In [20]:
from pathlib import Path
from vespag.utils import get_embedding_dim


def export_vespag_fnn_to_onnx(fnn_model:fnn, onnx_dir_path:str, params:dict) -> None:
    
    if not Path(onnx_dir_path).exists():
        Path.mkdir(onnx_dir_path)
    onnx_file_path = f'{onnx_dir_path}/fnn.onnx'
    batch_size = 2
    input_length = get_embedding_dim(params['embedding_type'])

    x = torch.randn(batch_size, 10, input_length)
    torch.onnx.export(
    fnn_model,                           # model being run
    x,                                   # model input (or a tuple for multiple inputs)
    onnx_file_path,                      # where to save the model
    export_params=True,                  # store the trained parameter weights inside the model file
    opset_version=12,                    # the ONNX version to export the model to
    do_constant_folding=True,            # whether to execute constant folding for optimization
    input_names=['input'],               # the model's input names
    output_names=['output'],             # the model's output names
    dynamic_axes={'input': {0: 'batch_size', 1: 'sequence_length', 2: 'embedding_size'},# variable length axes
                    'output': {0: 'batch_size'}}
    )
    print(f'Model has been successfully exported to {onnx_file_path}')

In [21]:
root_dir = Path.cwd()
onnx_dir_path = f'{root_dir}/onnx_models'
export_vespag_fnn_to_onnx(fnn_model=fnn_model, onnx_dir_path=onnx_dir_path, params=params)

Model has been successfully exported to /home/paula/projects/biocentral/vespag/onnx_models/fnn.onnx




# Testen ob die Ergebnisse identisch sind

In [6]:
import pandas as pd

original_results = pd.read_csv('output/vespag_scores_all_org.csv')
onnx_results = pd.read_csv('output/vespag_scores_all_onnx.csv')

In [16]:
pd.options.display.float_format = '{:.100f}'.format

comparison = original_results.compare(onnx_results)
comparison[comparison['VespaG']['self']!=comparison['VespaG']['other']]

Unnamed: 0_level_0,VespaG,VespaG
Unnamed: 0_level_1,self,other
0,0.117311270116262605922585748885467182844877243...,0.117311333330984493561466308619856135919690132...
1,0.082738412807202496579428441236814251169562339...,0.082738483280551194942731285664194729179143905...
2,-0.11772366176011249405686243107993504963815212...,-0.11772354919943039952556773641845211386680603...
3,-0.18746656643736309133529971404641401022672653...,-0.18746643923404660014853106986265629529953002...
4,-0.07681562571595799970847195936585194431245326...,-0.07681552174398340038230514892347855493426322...
...,...,...
5448,0.720043148284801759473339188843965530395507812...,0.720043084955001022606779770285356789827346801...
5449,0.729664666226049529740294019575230777263641357...,0.729664600876195557077608100371435284614562988...
5450,0.605876766541208144900565457646735012531280517...,0.605876727180821861296067254443187266588211059...
5451,0.415064978604241296977761521702632308006286621...,0.415064979305095393957714122734614647924900054...


In [19]:
comparison['diff'] = comparison['VespaG']['self'] - comparison['VespaG']['other']
print(comparison['diff'])
print(comparison['diff'].max())
print(comparison['diff'].min())

0      -0.00000006321472188763888055973438895307481288...
1      -0.00000007047334869836330284442738047800958156...
2      -0.00000011256068209453129469466148293577134609...
3      -0.00000012720331649118676864418375771492719650...
4      -0.00000010397197459932616681044237338937819004...
                              ...                        
5448   0.000000063329800736866559418558608740568161010...
5449   0.000000065349853972662685919203795492649078369...
5450   0.000000039360386283604498203203547745943069458...
5451   -0.00000000070085409697995260103198233991861343...
5452   0.000000004575012646501619428818230517208576202...
Name: diff, Length: 5453, dtype: float64
0.00023994055454412688
-0.0004840852139935681
