## Exploring ONNX Export Optimization and Model Preprocessing for RawNet3 Integration with ezkl

In this notebook, we explore techniques to optimize the ONNX export configuration and preprocess the exported ONNX model to address the compatibility issues encountered when integrating the RawNet3 model with ezkl for speaker verification using zero-knowledge proofs.

### Importing Dependencies

In [38]:
import torch
import torch.nn as nn
import os
import json
import ezkl
import librosa
import numpy as np
import math
import torch
import torch.nn as nn
import torch.nn.functional as F

from asteroid_filterbanks import Encoder, ParamSincFB

### Defining RawNet3 Model Components

In [39]:
class PreEmphasis(torch.nn.Module):
    def __init__(self, coef: float = 0.97) -> None:
        super().__init__()
        self.coef = coef
        self.register_buffer(
            "flipped_filter",
            torch.FloatTensor([-self.coef, 1.0]).unsqueeze(0).unsqueeze(0),
        )

    def forward(self, input: torch.tensor) -> torch.tensor:
        assert (
            len(input.size()) == 2
        ), "The number of dimensions of input tensor must be 2!"
        input = input.unsqueeze(1)
        input = F.pad(input, (1, 0), "reflect")
        return F.conv1d(input, self.flipped_filter)

class AFMS(nn.Module):
    def __init__(self, nb_dim: int) -> None:
        super().__init__()
        self.alpha = nn.Parameter(torch.ones((nb_dim, 1)))
        self.fc = nn.Linear(nb_dim, nb_dim)
        self.sig = nn.Sigmoid()

    def forward(self, x):
        y = F.adaptive_avg_pool1d(x, 1).view(x.size(0), -1)
        y = self.sig(self.fc(y)).view(x.size(0), x.size(1), -1)

        x = x + self.alpha
        x = x * y
        return x

class Bottle2neck(nn.Module):
    def __init__(
        self,
        inplanes,
        planes,
        kernel_size=None,
        dilation=None,
        scale=4,
        pool=False,
    ):
        super().__init__()
        width = int(math.floor(planes / scale))
        self.conv1 = nn.Conv1d(inplanes, width * scale, kernel_size=1)
        self.bn1 = nn.BatchNorm1d(width * scale)
        self.nums = scale - 1
        convs = []
        bns = []
        num_pad = math.floor(kernel_size / 2) * dilation
        for i in range(self.nums):
            convs.append(
                nn.Conv1d(
                    width,
                    width,
                    kernel_size=kernel_size,
                    dilation=dilation,
                    padding=num_pad,
                )
            )
            bns.append(nn.BatchNorm1d(width))
        self.convs = nn.ModuleList(convs)
        self.bns = nn.ModuleList(bns)
        self.conv3 = nn.Conv1d(width * scale, planes, kernel_size=1)
        self.bn3 = nn.BatchNorm1d(planes)
        self.relu = nn.ReLU()
        self.width = width
        self.mp = nn.MaxPool1d(pool) if pool else False
        self.afms = AFMS(planes)
        if inplanes != planes:
            self.residual = nn.Sequential(
                nn.Conv1d(inplanes, planes, kernel_size=1, stride=1, bias=False)
            )
        else:
            self.residual = nn.Identity()

    def forward(self, x):
        residual = self.residual(x)
        out = self.conv1(x)
        out = self.relu(out)
        out = self.bn1(out)
        spx = torch.split(out, self.width, 1)
        for i in range(self.nums):
            if i == 0:
                sp = spx[i]
            else:
                sp = sp + spx[i]
            sp = self.convs[i](sp)
            sp = self.relu(sp)
            sp = self.bns[i](sp)
            if i == 0:
                out = sp
            else:
                out = torch.cat((out, sp), 1)
        out = torch.cat((out, spx[self.nums]), 1)
        out = self.conv3(out)
        out = self.relu(out)
        out = self.bn3(out)
        out += residual
        if self.mp:
            out = self.mp(out)
        out = self.afms(out)
        return out

In [40]:
# RawNet3.py
class RawNet3(nn.Module):
    def __init__(self, block, model_scale, context, summed, C=1024, **kwargs):
        super().__init__()

        nOut = kwargs["nOut"]

        self.context = context
        self.encoder_type = kwargs["encoder_type"]
        self.log_sinc = kwargs["log_sinc"]
        self.norm_sinc = kwargs["norm_sinc"]
        self.out_bn = kwargs["out_bn"]
        self.summed = summed

        self.preprocess = nn.Sequential(
            PreEmphasis(), nn.InstanceNorm1d(1, eps=1e-4, affine=True)
        )
        self.conv1 = Encoder(
            ParamSincFB(
                C // 4,
                251,
                stride=kwargs["sinc_stride"],
            )
        )
        self.relu = nn.ReLU()
        self.bn1 = nn.BatchNorm1d(C // 4)

        self.layer1 = block(
            C // 4, C, kernel_size=3, dilation=2, scale=model_scale, pool=5
        )
        self.layer2 = block(
            C, C, kernel_size=3, dilation=3, scale=model_scale, pool=3
        )
        self.layer3 = block(C, C, kernel_size=3, dilation=4, scale=model_scale)
        self.layer4 = nn.Conv1d(3 * C, 1536, kernel_size=1)

        if self.context:
            attn_input = 1536 * 3
        else:
            attn_input = 1536
        print("self.encoder_type", self.encoder_type)
        if self.encoder_type == "ECA":
            attn_output = 1536
        elif self.encoder_type == "ASP":
            attn_output = 1
        else:
            raise ValueError("Undefined encoder")

        self.attention = nn.Sequential(
            nn.Conv1d(attn_input, 128, kernel_size=1),
            nn.ReLU(),
            nn.BatchNorm1d(128),
            nn.Conv1d(128, attn_output, kernel_size=1),
            nn.Softmax(dim=2),
        )

        self.bn5 = nn.BatchNorm1d(3072)

        self.fc6 = nn.Linear(3072, nOut)
        self.bn6 = nn.BatchNorm1d(nOut)

        self.mp3 = nn.MaxPool1d(3)

    def forward(self, x):
        """
        :param x: input mini-batch (bs, samp)
        """

        with torch.cuda.amp.autocast(enabled=False):
            x = self.preprocess(x)
            x = torch.abs(self.conv1(x))
            if self.log_sinc:
                x = torch.log(x + 1e-6)
            if self.norm_sinc == "mean":
                x = x - torch.mean(x, dim=-1, keepdim=True)
            elif self.norm_sinc == "mean_std":
                m = torch.mean(x, dim=-1, keepdim=True)
                s = torch.std(x, dim=-1, keepdim=True)
                s[s < 0.001] = 0.001
                x = (x - m) / s

        if self.summed:
            x1 = self.layer1(x)
            x2 = self.layer2(x1)
            x3 = self.layer3(self.mp3(x1) + x2)
        else:
            x1 = self.layer1(x)
            x2 = self.layer2(x1)
            x3 = self.layer3(x2)

        x = self.layer4(torch.cat((self.mp3(x1), x2, x3), dim=1))
        x = self.relu(x)

        t = x.size()[-1]

        if self.context:
            global_x = torch.cat(
                (
                    x,
                    torch.mean(x, dim=2, keepdim=True).repeat(1, 1, t),
                    torch.sqrt(
                        torch.var(x, dim=2, keepdim=True).clamp(
                            min=1e-4, max=1e4
                        )
                    ).repeat(1, 1, t),
                ),
                dim=1,
            )
        else:
            global_x = x

        w = self.attention(global_x)

        mu = torch.sum(x * w, dim=2)
        sg = torch.sqrt(
            (torch.sum((x**2) * w, dim=2) - mu**2).clamp(min=1e-4, max=1e4)
        )

        x = torch.cat((mu, sg), 1)

        x = self.bn5(x)

        x = self.fc6(x)

        if self.out_bn:
            x = self.bn6(x)

        return x

In [41]:
def MainModel(**kwargs):

    model = RawNet3(
        Bottle2neck, model_scale=8, context=True, summed=True, **kwargs
    )
    return model

### Loading the Pre-trained RawNet3 Model

In [42]:
model = MainModel(nOut=256, encoder_type="ECA", log_sinc=True, norm_sinc="mean", out_bn=False, sinc_stride=10)
model.load_state_dict(torch.load("./models/model.pt", map_location=lambda storage, loc: storage)["model"])
model.eval()

self.encoder_type ECA


  model.load_state_dict(torch.load("./models/model.pt", map_location=lambda storage, loc: storage)["model"])


RawNet3(
  (preprocess): Sequential(
    (0): PreEmphasis()
    (1): InstanceNorm1d(1, eps=0.0001, momentum=0.1, affine=True, track_running_stats=False)
  )
  (conv1): Encoder(
    (filterbank): ParamSincFB()
  )
  (relu): ReLU()
  (bn1): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (layer1): Bottle2neck(
    (conv1): Conv1d(256, 1024, kernel_size=(1,), stride=(1,))
    (bn1): BatchNorm1d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (convs): ModuleList(
      (0-6): 7 x Conv1d(128, 128, kernel_size=(3,), stride=(1,), padding=(2,), dilation=(2,))
    )
    (bns): ModuleList(
      (0-6): 7 x BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (conv3): Conv1d(1024, 1024, kernel_size=(1,), stride=(1,))
    (bn3): BatchNorm1d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU()
    (mp): MaxPool1d(kernel_size=5, stride=5, padding=0, dilation=1, ceil_

## Integration with ezkl

### Exporting the Model to ONNX 

Here, we export the RawNet3 model to the ONNX format using torch.onnx.export(). 

We specify the model, input tensor, output file path, and other export configurations.

In [43]:
# Load an audio file for testing
audio_file = "./sample1.wav"
audio, sample_rate = librosa.load(audio_file, sr=16000, mono=True)
audio = audio[:48000]  # Truncate to 3 seconds (48000 samples)
audio_tensor = torch.from_numpy(audio).float().unsqueeze(0)


In [44]:
# Exporta el modelo a ONNX
torch.onnx.export(model,                     # modelo a ejecutar
                  audio_tensor,              # entrada del modelo
                  'network.onnx',            # donde guardar el modelo
                  export_params=True,        # almacenar los pesos entrenados dentro del archivo del modelo
                  opset_version=10,          # la versión de ONNX a la que exportar el modelo
                  do_constant_folding=True,  # si ejecutar constant folding para optimización
                  input_names = ['input'],   # nombres de entrada del modelo
                  output_names = ['output'], # nombres de salida del modelo
                  dynamic_axes={'input' : {0 : 'batch_size'},    # ejes de longitud variable
                                'output' : {0 : 'batch_size'}})


  with torch.cuda.amp.autocast(enabled=False):


### Optimizing the ONNX Model using ONNX Optimizer

We use the ONNX Optimizer to optimize the exported ONNX model. 

We load the ONNX model, apply a series of optimization passes using onnxoptimizer.optimize(), and save the optimized model.

In [45]:
# Optimize the ONNX model using ONNX Optimizer
import onnx
from onnxoptimizer import optimize

# Load the ONNX model
model = onnx.load('network.onnx')

# Apply optimizations to the ONNX model
optimized_model = optimize(model)

# Save the optimized ONNX model
onnx.save(optimized_model, 'network_optimized.onnx')


### Specifying Files and Paths

We specify the paths for the optimized ONNX model, input data, and calibration data files.

In [46]:
model_path = os.path.join('network_optimized.onnx')
data_path = os.path.join('input.json')
cal_data_path = os.path.join('calibration.json')


### Generating ezkl Settings

In this code block, we attempt to generate the ezkl settings using the optimized ONNX model by calling ezkl.gen_settings().

In [47]:
# Generate ezkl settings using the optimized ONNX model
res = ezkl.gen_settings()
assert res == True

RuntimeError: Failed to generate settings: [graph] [tract] Translating node #71 "If_109" If ToTypedTranslator

### Error Explanation: 

This error suggests that ezkl has limitations in handling conditional statements and control flow operations present in the RawNet3 model.

#### Implications and Limitations: 

The encountered error highlights the limitations of ezkl in supporting complex operations and control flow structures commonly found in neural network models like RawNet3. Despite the attempts to optimize the ONNX export configuration and preprocess the model using ONNX Optimizer, the compatibility issues persist.

ezkl is primarily designed for zero-knowledge proofs and may not have full support for the diverse range of operations and architectures used in deep learning models. Conditional statements, such as the "If" node, pose challenges for ezkl's type translator, making it difficult to convert the model into a format suitable for zero-knowledge proofs.

The complexity of the RawNet3 model, with its various layers and operations, further compounds the compatibility issues. The presence of unsupported operations and control flow structures in the model architecture makes it challenging to integrate with ezkl seamlessly.

#### Conclusion: 

Despite the efforts to optimize the ONNX export configuration and preprocess the RawNet3 model using ONNX Optimizer, the integration with ezkl remains unsuccessful due to the limitations of ezkl in handling complex operations and control flow structures present in the model.

The encountered error underscores the challenges in integrating deep learning models like RawNet3 with zero-knowledge proof frameworks like ezkl. The incompatibility arises from the fundamental differences in the supported operations and the complexity of the model architecture.

Integrating RawNet3 and similar neural network models with ezkl would require significant modifications to the model architecture, potentially simplifying or removing unsupported operations. However, such modifications may impact the model's performance and functionality.

Further research and development efforts are needed to bridge the gap between complex deep learning models and zero-knowledge proof frameworks. This may involve enhancing ezkl's capabilities to support a wider range of operations or exploring alternative approaches that can accommodate the intricacies of models like RawNet3.

In conclusion, the integration of RawNet3 with ezkl using ONNX export optimization and model preprocessing techniques remains a complex challenge due to the limitations of ezkl in handling the model's conditional statements and control flow operations. Addressing this compatibility issue requires further investigation, collaboration with the ezkl development team, and potential advancements in zero-knowledge proof frameworks to support complex neural network models effectively.