## RawNet3 with ezkl

In this notebook, we demonstrate the integration of the RawNet3 model with ezkl, a toolkit for zero-knowledge proof systems. The goal is to run the RawNet3 model with ezkl and generate proofs for speaker verification.

#### Importing Dependencies

In [31]:
import torch
import torch.nn as nn
import os
import json
import ezkl
import librosa
import numpy as np
import math
import torch
import torch.nn as nn
import torch.nn.functional as F

from asteroid_filterbanks import Encoder, ParamSincFB

#### RawNetBasicBlock

In [32]:
class PreEmphasis(torch.nn.Module):
    def __init__(self, coef: float = 0.97) -> None:
        super().__init__()
        self.coef = coef
        self.register_buffer(
            "flipped_filter",
            torch.FloatTensor([-self.coef, 1.0]).unsqueeze(0).unsqueeze(0),
        )

    def forward(self, input: torch.tensor) -> torch.tensor:
        assert (
            len(input.size()) == 2
        ), "The number of dimensions of input tensor must be 2!"
        input = input.unsqueeze(1)
        input = F.pad(input, (1, 0), "reflect")
        return F.conv1d(input, self.flipped_filter)

class AFMS(nn.Module):
    def __init__(self, nb_dim: int) -> None:
        super().__init__()
        self.alpha = nn.Parameter(torch.ones((nb_dim, 1)))
        self.fc = nn.Linear(nb_dim, nb_dim)
        self.sig = nn.Sigmoid()

    def forward(self, x):
        y = F.adaptive_avg_pool1d(x, 1).view(x.size(0), -1)
        y = self.sig(self.fc(y)).view(x.size(0), x.size(1), -1)

        x = x + self.alpha
        x = x * y
        return x

class Bottle2neck(nn.Module):
    def __init__(
        self,
        inplanes,
        planes,
        kernel_size=None,
        dilation=None,
        scale=4,
        pool=False,
    ):
        super().__init__()
        width = int(math.floor(planes / scale))
        self.conv1 = nn.Conv1d(inplanes, width * scale, kernel_size=1)
        self.bn1 = nn.BatchNorm1d(width * scale)
        self.nums = scale - 1
        convs = []
        bns = []
        num_pad = math.floor(kernel_size / 2) * dilation
        for i in range(self.nums):
            convs.append(
                nn.Conv1d(
                    width,
                    width,
                    kernel_size=kernel_size,
                    dilation=dilation,
                    padding=num_pad,
                )
            )
            bns.append(nn.BatchNorm1d(width))
        self.convs = nn.ModuleList(convs)
        self.bns = nn.ModuleList(bns)
        self.conv3 = nn.Conv1d(width * scale, planes, kernel_size=1)
        self.bn3 = nn.BatchNorm1d(planes)
        self.relu = nn.ReLU()
        self.width = width
        self.mp = nn.MaxPool1d(pool) if pool else False
        self.afms = AFMS(planes)
        if inplanes != planes:
            self.residual = nn.Sequential(
                nn.Conv1d(inplanes, planes, kernel_size=1, stride=1, bias=False)
            )
        else:
            self.residual = nn.Identity()

    def forward(self, x):
        residual = self.residual(x)
        out = self.conv1(x)
        out = self.relu(out)
        out = self.bn1(out)
        spx = torch.split(out, self.width, 1)
        for i in range(self.nums):
            if i == 0:
                sp = spx[i]
            else:
                sp = sp + spx[i]
            sp = self.convs[i](sp)
            sp = self.relu(sp)
            sp = self.bns[i](sp)
            if i == 0:
                out = sp
            else:
                out = torch.cat((out, sp), 1)
        out = torch.cat((out, spx[self.nums]), 1)
        out = self.conv3(out)
        out = self.relu(out)
        out = self.bn3(out)
        out += residual
        if self.mp:
            out = self.mp(out)
        out = self.afms(out)
        return out

#### RawNet3

In [33]:
# RawNet3.py
class RawNet3(nn.Module):
    def __init__(self, block, model_scale, context, summed, C=1024, **kwargs):
        super().__init__()

        nOut = kwargs["nOut"]

        self.context = context
        self.encoder_type = kwargs["encoder_type"]
        self.log_sinc = kwargs["log_sinc"]
        self.norm_sinc = kwargs["norm_sinc"]
        self.out_bn = kwargs["out_bn"]
        self.summed = summed

        self.preprocess = nn.Sequential(
            PreEmphasis(), nn.InstanceNorm1d(1, eps=1e-4, affine=True)
        )
        self.conv1 = Encoder(
            ParamSincFB(
                C // 4,
                251,
                stride=kwargs["sinc_stride"],
            )
        )
        self.relu = nn.ReLU()
        self.bn1 = nn.BatchNorm1d(C // 4)

        self.layer1 = block(
            C // 4, C, kernel_size=3, dilation=2, scale=model_scale, pool=5
        )
        self.layer2 = block(
            C, C, kernel_size=3, dilation=3, scale=model_scale, pool=3
        )
        self.layer3 = block(C, C, kernel_size=3, dilation=4, scale=model_scale)
        self.layer4 = nn.Conv1d(3 * C, 1536, kernel_size=1)

        if self.context:
            attn_input = 1536 * 3
        else:
            attn_input = 1536
        print("self.encoder_type", self.encoder_type)
        if self.encoder_type == "ECA":
            attn_output = 1536
        elif self.encoder_type == "ASP":
            attn_output = 1
        else:
            raise ValueError("Undefined encoder")

        self.attention = nn.Sequential(
            nn.Conv1d(attn_input, 128, kernel_size=1),
            nn.ReLU(),
            nn.BatchNorm1d(128),
            nn.Conv1d(128, attn_output, kernel_size=1),
            nn.Softmax(dim=2),
        )

        self.bn5 = nn.BatchNorm1d(3072)

        self.fc6 = nn.Linear(3072, nOut)
        self.bn6 = nn.BatchNorm1d(nOut)

        self.mp3 = nn.MaxPool1d(3)

    def forward(self, x):
        """
        :param x: input mini-batch (bs, samp)
        """

        with torch.cuda.amp.autocast(enabled=False):
            x = self.preprocess(x)
            x = torch.abs(self.conv1(x))
            if self.log_sinc:
                x = torch.log(x + 1e-6)
            if self.norm_sinc == "mean":
                x = x - torch.mean(x, dim=-1, keepdim=True)
            elif self.norm_sinc == "mean_std":
                m = torch.mean(x, dim=-1, keepdim=True)
                s = torch.std(x, dim=-1, keepdim=True)
                s[s < 0.001] = 0.001
                x = (x - m) / s

        if self.summed:
            x1 = self.layer1(x)
            x2 = self.layer2(x1)
            x3 = self.layer3(self.mp3(x1) + x2)
        else:
            x1 = self.layer1(x)
            x2 = self.layer2(x1)
            x3 = self.layer3(x2)

        x = self.layer4(torch.cat((self.mp3(x1), x2, x3), dim=1))
        x = self.relu(x)

        t = x.size()[-1]

        if self.context:
            global_x = torch.cat(
                (
                    x,
                    torch.mean(x, dim=2, keepdim=True).repeat(1, 1, t),
                    torch.sqrt(
                        torch.var(x, dim=2, keepdim=True).clamp(
                            min=1e-4, max=1e4
                        )
                    ).repeat(1, 1, t),
                ),
                dim=1,
            )
        else:
            global_x = x

        w = self.attention(global_x)

        mu = torch.sum(x * w, dim=2)
        sg = torch.sqrt(
            (torch.sum((x**2) * w, dim=2) - mu**2).clamp(min=1e-4, max=1e4)
        )

        x = torch.cat((mu, sg), 1)

        x = self.bn5(x)

        x = self.fc6(x)

        if self.out_bn:
            x = self.bn6(x)

        return x

In [34]:
def MainModel(**kwargs):

    model = RawNet3(
        Bottle2neck, model_scale=8, context=True, summed=True, **kwargs
    )
    return model

In [35]:
# Load the pre-trained RawNet3 model
model = MainModel(nOut=256, encoder_type="ECA", log_sinc=True, norm_sinc="mean", out_bn=False, sinc_stride=10)
model.load_state_dict(torch.load("./models/model.pt", map_location=lambda storage, loc: storage)["model"])
model.eval()

self.encoder_type ECA


  model.load_state_dict(torch.load("./models/model.pt", map_location=lambda storage, loc: storage)["model"])


RawNet3(
  (preprocess): Sequential(
    (0): PreEmphasis()
    (1): InstanceNorm1d(1, eps=0.0001, momentum=0.1, affine=True, track_running_stats=False)
  )
  (conv1): Encoder(
    (filterbank): ParamSincFB()
  )
  (relu): ReLU()
  (bn1): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (layer1): Bottle2neck(
    (conv1): Conv1d(256, 1024, kernel_size=(1,), stride=(1,))
    (bn1): BatchNorm1d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (convs): ModuleList(
      (0-6): 7 x Conv1d(128, 128, kernel_size=(3,), stride=(1,), padding=(2,), dilation=(2,))
    )
    (bns): ModuleList(
      (0-6): 7 x BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (conv3): Conv1d(1024, 1024, kernel_size=(1,), stride=(1,))
    (bn3): BatchNorm1d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU()
    (mp): MaxPool1d(kernel_size=5, stride=5, padding=0, dilation=1, ceil_

#### Load the pre-trained RawNet3 model

#### Specifying Files and Paths

In [36]:
# Specify all the files we need
model_path = os.path.join('rawnet3.onnx')
data_path = os.path.join('input.json')
cal_data_path = os.path.join('calibration.json')

#### Preparing Input Data

In [37]:
# Load an audio file for testing
audio_file = "./sample1.wav"
audio, sample_rate = librosa.load(audio_file, sr=16000, mono=True)
audio = audio[:48000]  # Truncate to 3 seconds (48000 samples)
audio_tensor = torch.from_numpy(audio).float().unsqueeze(0)

## Integration with ezkl

#### Exporting the Model to ONNX

In [38]:
# Export the model to ONNX
torch.onnx.export(model,                     # model being run
                  audio_tensor,              # model input
                  'network.onnx',            # where to save the model
                  export_params=True,        # store the trained parameter weights inside the model file
                  opset_version=10,          # the ONNX version to export the model to
                  do_constant_folding=True,  # whether to execute constant folding for optimization
                  input_names = ['input'],   # the model's input names
                  output_names = ['output'], # the model's output names
                  dynamic_axes={'input' : {0 : 'batch_size'},    # variable length axes
                                'output' : {0 : 'batch_size'}})

  with torch.cuda.amp.autocast(enabled=False):


In [39]:
# Prepare input data
data_array = audio_tensor.flatten().tolist()
data = dict(input_data = [data_array])
json.dump(data, open(data_path, 'w'))

#### Generating Settings

In [40]:
res = ezkl.gen_settings()
assert res == True

RuntimeError: Failed to generate settings: [graph] [tract] Translating node #71 "If_109" If ToTypedTranslator

#### Error Explanation

This error indicates that there is an issue with generating the ezkl settings due to a specific node in the exported ONNX model. The error suggests that the exported ONNX model contains an "If" node (conditional node) that is not compatible with ezkl's type translator.

"If" nodes are control flow structures that allow conditional execution of subgraphs based on a boolean condition. ezkl has limitations in handling complex operations like conditional statements in the model.

## Limitations and Considerations

When attempting to run a complex model like RawNet3 with ezkl, it's important to consider the limitations of ezkl in handling certain operations. ezkl is designed for zero-knowledge proofs and may not support all the complex operations present in the RawNet3 model. 

Some possible limitations and considerations include:

1. **Conditional Statements**: ezkl may have difficulties in handling conditional statements like "If" nodes in the exported ONNX model. These nodes allow for conditional execution of subgraphs based on a boolean condition, which can be challenging to translate into a format compatible with ezkl's type translator.

2. **Complex Architectures**: RawNet3 has a complex architecture with various layers and operations. Some of these operations may not be directly supported by ezkl, leading to compatibility issues during the translation process.

3. **ONNX Export Configuration**: The configuration used while exporting the model to ONNX format can also impact compatibility with ezkl. Experimenting with different export configurations, such as changing the opset version or adjusting other export options, may

5. **ONNX Model Preprocessing**: If the exported ONNX model contains problematic nodes like "If" nodes, preprocessing the model using tools like ONNX Simplifier or ONNX Optimizer may help in simplifying or removing these nodes before passing the model to ezkl.

## Conclusion

In this notebook, we attempted to integrate the RawNet3 model with ezkl for speaker verification using zero-knowledge proofs. However, we encountered an error related to the compatibility of the exported ONNX model with ezkl's type translator.

The error suggests that the presence of an "If" node in the exported model is causing issues during the generation of ezkl settings. This highlights the limitations of ezkl in handling complex operations like conditional statements.

To overcome these limitations, we may need to simplify the RawNet3 model architecture, experiment with different ONNX export configurations, or preprocess the exported ONNX model to remove problematic nodes.

It's important to note that running complex models like RawNet3 with ezkl may require significant modifications to the model architecture and careful consideration of the supported operations to ensure compatibility.

Further investigation and experimentation would be necessary to successfully integrate RawNet3 with ezkl for speaker verification using zero-knowledge proofs. 

Despite the encountered challenges, the integration of RawNet3 with ezkl remains an interesting avenue for future research and development in the field of privacy-preserving speaker verification.

## Next Steps

To further explore the integration of RawNet3 with ezkl, the following steps can be considered:

1. **Model Simplification**: Analyze the RawNet3 model architecture and identify parts that use unsupported operations like conditional statements. Modify the model's source code to simplify or remove these parts while preserving the core functionality.

2. **ONNX Export Optimization**: Experiment with different ONNX export configurations, such as changing the opset version or adjusting other export options, to find a configuration that generates an ONNX model compatible with ezkl.

3. **ONNX Model Preprocessing**: Explore tools like ONNX Simplifier or ONNX Optimizer to preprocess the exported ONNX model and remove problematic nodes before passing it to ezkl.
