In [1]:
import torch
from torch import nn
from torch.utils.checkpoint import checkpoint

__all__ = ['iresnet1','iresnet18', 'iresnet34', 'iresnet50', 'iresnet100', 'iresnet200']
using_ckpt = False

def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
    """3x3 convolution with padding"""
    return nn.Conv2d(in_planes,
                     out_planes,
                     kernel_size=3,
                     stride=stride,
                     padding=dilation,
                     groups=groups,
                     bias=False,
                     dilation=dilation)


def conv1x1(in_planes, out_planes, stride=1):
    """1x1 convolution"""
    return nn.Conv2d(in_planes,
                     out_planes,
                     kernel_size=1,
                     stride=stride,
                     bias=False)


class IBasicBlock(nn.Module):
    expansion = 1
    def __init__(self, inplanes, planes, stride=1, downsample=None,
                 groups=1, base_width=64, dilation=1):
        super(IBasicBlock, self).__init__()
        if groups != 1 or base_width != 64:
            raise ValueError('BasicBlock only supports groups=1 and base_width=64')
        if dilation > 1:
            raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
        self.bn1 = nn.BatchNorm2d(inplanes, eps=1e-05,)
        self.conv1 = conv3x3(inplanes, planes)
        self.bn2 = nn.BatchNorm2d(planes, eps=1e-05,)
        self.prelu = nn.PReLU(planes)
        self.conv2 = conv3x3(planes, planes, stride)
        self.bn3 = nn.BatchNorm2d(planes, eps=1e-05,)
        self.downsample = downsample
        self.stride = stride

    def forward_impl(self, x):
        identity = x
        out = self.bn1(x)
        out = self.conv1(out)
        out = self.bn2(out)
        out = self.prelu(out)
        out = self.conv2(out)
        out = self.bn3(out)
        if self.downsample is not None:
            identity = self.downsample(x)
        out += identity
        return out        

    def forward(self, x):
        if self.training and using_ckpt:
            return checkpoint(self.forward_impl, x)
        else:
            return self.forward_impl(x)


class IResNet(nn.Module):
    fc_scale = 7 * 7
    def __init__(self,
                 block, layers, dropout=0, num_features=512, zero_init_residual=False,
                 groups=1, width_per_group=64, replace_stride_with_dilation=None, fp16=False):
        super(IResNet, self).__init__()
        self.extra_gflops = 0.0
        self.fp16 = fp16
        self.inplanes = 64
        self.dilation = 1
        if replace_stride_with_dilation is None:
            replace_stride_with_dilation = [False, False, False]
        if len(replace_stride_with_dilation) != 3:
            raise ValueError("replace_stride_with_dilation should be None "
                             "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
        self.groups = groups
        self.base_width = width_per_group
        self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(self.inplanes, eps=1e-05)
        self.prelu = nn.PReLU(self.inplanes)
        self.layer1 = self._make_layer(block, 64, layers[0], stride=2)
        self.layer2 = self._make_layer(block,
                                       128,
                                       layers[1],
                                       stride=2,
                                       dilate=replace_stride_with_dilation[0])
        self.layer3 = self._make_layer(block,
                                       256,
                                       layers[2],
                                       stride=2,
                                       dilate=replace_stride_with_dilation[1])
        self.layer4 = self._make_layer(block,
                                       512,
                                       layers[3],
                                       stride=2,
                                       dilate=replace_stride_with_dilation[2])
        self.bn2 = nn.BatchNorm2d(512 * block.expansion, eps=1e-05,)
        self.dropout = nn.Dropout(p=dropout, inplace=True)
        self.fc = nn.Linear(512 * block.expansion * self.fc_scale, num_features)
        self.features = nn.BatchNorm1d(num_features, eps=1e-05)
        nn.init.constant_(self.features.weight, 1.0)
        self.features.weight.requires_grad = False

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.normal_(m.weight, 0, 0.1)
            elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

        if zero_init_residual:
            for m in self.modules():
                if isinstance(m, IBasicBlock):
                    nn.init.constant_(m.bn2.weight, 0)

    def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
        downsample = None
        previous_dilation = self.dilation
        if dilate:
            self.dilation *= stride
            stride = 1
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                conv1x1(self.inplanes, planes * block.expansion, stride),
                nn.BatchNorm2d(planes * block.expansion, eps=1e-05, ),
            )
        layers = []
        layers.append(
            block(self.inplanes, planes, stride, downsample, self.groups,
                  self.base_width, previous_dilation))
        self.inplanes = planes * block.expansion
        for _ in range(1, blocks):
            layers.append(
                block(self.inplanes,
                      planes,
                      groups=self.groups,
                      base_width=self.base_width,
                      dilation=self.dilation))

        return nn.Sequential(*layers)

    def forward(self, x):
        with torch.cuda.amp.autocast(self.fp16):
            x = self.conv1(x)
            x = self.bn1(x)
            x = self.prelu(x)
            x = self.layer1(x)
            x = self.layer2(x)
            x = self.layer3(x)
            x = self.layer4(x)
            x = self.bn2(x)
            x = torch.flatten(x, 1)
            x = self.dropout(x)
        x = self.fc(x.float() if self.fp16 else x)
        x = self.features(x)
        return x


def _iresnet(arch, block, layers, pretrained, progress, **kwargs):
    model = IResNet(block, layers, **kwargs)
    if pretrained:
        raise ValueError()
    return model

def iresnet1(pretrained=False, progress=True, **kwargs):
    return _iresnet('iresnet1', IBasicBlock, [1, 1, 1, 1], pretrained,
                    progress, **kwargs)


def iresnet18(pretrained=False, progress=True, **kwargs):
    return _iresnet('iresnet18', IBasicBlock, [2, 2, 2, 2], pretrained,
                    progress, **kwargs)


def iresnet34(pretrained=False, progress=True, **kwargs):
    return _iresnet('iresnet34', IBasicBlock, [3, 4, 6, 3], pretrained,
                    progress, **kwargs)


def iresnet50(pretrained=False, progress=True, **kwargs):
    return _iresnet('iresnet50', IBasicBlock, [3, 4, 14, 3], pretrained,
                    progress, **kwargs)


def iresnet100(pretrained=False, progress=True, **kwargs):
    return _iresnet('iresnet100', IBasicBlock, [3, 13, 30, 3], pretrained,
                    progress, **kwargs)


def iresnet200(pretrained=False, progress=True, **kwargs):
    return _iresnet('iresnet200', IBasicBlock, [6, 26, 60, 6], pretrained,
                    progress, **kwargs)


In [2]:
import argparse

import cv2
import numpy as np
import torch



model = iresnet1(fp16=True)
#model.load_state_dict(torch.load("./backbone_16.pth",map_location=torch.device('mps')))

model.eval()


 # Export the model

def get_img(path):
    img = cv2.imread(path)
    img = cv2.resize(img, (112, 112))

    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    img = np.transpose(img, (2, 0, 1))
    img = torch.from_numpy(img).unsqueeze(0).float()
    img.div_(255).sub_(0.5).div_(0.5)
    return img

@torch.no_grad()
def inference(model,img):
    img = get_img(img)
    feat = model(img).numpy()
    print(feat.shape)
    return feat.flatten()




In [3]:
#get model params and size
params = sum([np.prod(p.size()) for p in model.parameters()])
print("Number of Parameters: %.1fM"%(params/1e6))
print("Model Size: %.1fM"%(params*4/1e6))



Number of Parameters: 17.8M
Model Size: 71.0M


In [4]:

    f1 = inference(model, "./ema.png")
    f2 = inference(model, "./ema2.png")
    f3 = inference(model, "./john.png")
   

    #compoute cosine similarity
    from scipy.spatial.distance import cosine
    print("cosine similarity between ema and ema2: %.4f"%(1-cosine(f1,f2)))
    print("cosine similarity between ema and john: %.4f"%(1-cosine(f1,f3)))
    print("cosine similarity between ema2 and john: %.4f"%(1-cosine(f2,f3)))
    print("cosine similarity between ema and ema: %.4f"%(1-cosine(f1,f1)))




(1, 512)
(1, 512)
(1, 512)
cosine similarity between ema and ema2: 0.9500
cosine similarity between ema and john: 0.6493
cosine similarity between ema2 and john: 0.6450
cosine similarity between ema and ema: 1.0000




In [5]:
from torch import nn
import ezkl
import os
import json
import logging

# uncomment for more descriptive logging 
FORMAT = '%(levelname)s %(name)s %(asctime)-15s %(filename)s:%(lineno)d %(message)s'
logging.basicConfig(format=FORMAT)
logging.getLogger().setLevel(logging.INFO)

In [6]:

x = get_img("./ema.png")

torch.onnx.export(model,               # model being run
                      x,                   # model input (or a tuple for multiple inputs)
                      "network.onnx",            # where to save the model (can be a file or file-like object)
                      export_params=True,        # store the trained parameter weights inside the model file
                      opset_version=10,          # the ONNX version to export the model to
                      do_constant_folding=True,  # whether to execute constant folding for optimization
                      input_names = ['input'],   # the model's input names
                      output_names = ['output'], # the model's output names
                      dynamic_axes={'input' : {0 : 'batch_size'},    # variable length axes
                                    'output' : {0 : 'batch_size'}})

data_array = ((x).to("cpu").detach().numpy()).reshape([-1]).tolist()

data = dict(input_data = [data_array])

    # Serialize data into file:
json.dump( data, open("input.json", 'w' ))



verbose: False, log level: Level.ERROR



In [7]:
import ezkl

model_path = os.path.join('network.onnx')
compiled_model_path = os.path.join('network.compiled')
pk_path = os.path.join('test.pk')
vk_path = os.path.join('test.vk')
settings_path = os.path.join('settings.json')
srs_path = os.path.join('kzg.srs')
data_path = os.path.join('input.json')

run_args = ezkl.PyRunArgs()
run_args.input_visibility = "private"
run_args.param_visibility = "public"
run_args.output_visibility = "public"
run_args.variables = [("batch_size", 1)]

In [8]:
!RUST_LOG=trace
# TODO: Dictionary outputs
res = ezkl.gen_settings(model_path, settings_path, py_run_args=run_args)
assert res == True

INFO ezkl.graph.model 2023-09-16 23:23:35,621 model.rs:708 set batch_size to 1
INFO ezkl.graph.model 2023-09-16 23:23:51,744 model.rs:416 model has 1 instances
INFO ezkl.graph.model 2023-09-16 23:23:51,745 model.rs:1265 calculating num of constraints using dummy model layout...
INFO ezkl.graph.model 2023-09-16 23:37:49,697 model.rs:430 model generates 1597662973 constraints (excluding modules)


In [9]:

cal_data = {
    "input_data": [data_array],
}

print(cal_data)



# cal_path = os.path.join('val_data.json')
# # save as json file
# with open(cal_path, "w") as f:
#     json.dump(cal_data, f)

# res = await ezkl.calibrate_settings(cal_path, model_path, settings_path, "resources")

{'input_data': [[0.5372549295425415, 0.5372549295425415, 0.5372549295425415, 0.5372549295425415, 0.5372549295425415, 0.5372549295425415, 0.5372549295425415, 0.5529412031173706, 0.5529412031173706, 0.5529412031173706, 0.5529412031173706, 0.5529412031173706, 0.5529412031173706, 0.5529412031173706, 0.5529412031173706, 0.5529412031173706, 0.5529412031173706, 0.5529412031173706, 0.5529412031173706, 0.5529412031173706, 0.5529412031173706, 0.5607843399047852, 0.5607843399047852, 0.5607843399047852, 0.5529412031173706, 0.5529412031173706, 0.5529412031173706, 0.5607843399047852, 0.5686274766921997, 0.5686274766921997, 0.5686274766921997, 0.5764706134796143, 0.5764706134796143, 0.5764706134796143, 0.5843137502670288, 0.5921568870544434, 0.6078431606292725, 0.615686297416687, 0.615686297416687, 0.615686297416687, 0.6313725709915161, 0.6392157077789307, 0.6392157077789307, 0.6313725709915161, 0.5529412031173706, 0.458823561668396, 0.5843137502670288, 0.3803921937942505, 0.4901961088180542, 0.12941

In [9]:
res = ezkl.compile_model(model_path, compiled_model_path, settings_path)
assert res == True

INFO ezkl.graph.model 2023-09-16 23:42:52,991 model.rs:708 set batch_size to 1


In [10]:
res = ezkl.get_srs(srs_path, settings_path)


INFO ezkl.execute 2023-09-16 23:43:13,885 execute.rs:418 SRS downloaded


In [11]:
!export RUST_BACKTRACE=1

witness_path = "witness.json"

res = ezkl.gen_witness(data_path, compiled_model_path, witness_path, settings_path = settings_path)

INFO ezkl.graph 2023-09-16 23:43:17,492 mod.rs:598 input scales: [7]


In [13]:
res = ezkl.mock(witness_path, compiled_model_path, settings_path)

INFO ezkl.graph 2023-09-16 23:13:41,552 mod.rs:548 public inputs lengths: [512]
INFO ezkl.execute 2023-09-16 23:13:41,553 execute.rs:809 Mock proof
INFO ezkl.graph.model 2023-09-16 23:13:42,160 model.rs:981 configuring model
INFO ezkl.graph 2023-09-16 23:13:45,914 mod.rs:1037 circuit size: 
 {
  "num_advice_columns": 39471,
  "num_challenges": 0,
  "num_fixed": 4,
  "num_instances": 1,
  "num_selectors": 171041
}


: 

In [7]:
# HERE WE SETUP THE CIRCUIT PARAMS
# WE GOT KEYS
# WE GOT CIRCUIT PARAMETERS
# EVERYTHING ANYONE HAS EVER NEEDED FOR ZK
res = ezkl.setup(
        compiled_model_path,
        vk_path,
        pk_path,
        srs_path,
        settings_path,
    )

assert res == True
assert os.path.isfile(vk_path)
assert os.path.isfile(pk_path)
assert os.path.isfile(settings_path)

INFO ezkl.pfsys.srs 2023-09-16 23:16:44,869 srs.rs:23 loading srs from "kzg.srs"
INFO ezkl.execute 2023-09-16 23:16:44,879 execute.rs:1698 downsizing params to 17 logrows
INFO ezkl.graph.model 2023-09-16 23:16:45,578 model.rs:981 configuring model


: 