In [1]:
import utilities
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import open_clip

In [2]:
class CFG:
    model_name = 'ViT-L-14-336' 
    model_data = 'openai'
    emb_size = 512

In [3]:
vit_backbone, model_transforms, _ = open_clip.create_model_and_transforms(CFG.model_name)

In [None]:
class Head(nn.Module):
    def __init__(self, hidden_size):
        super(Head, self).__init__()

        self.emb = nn.Linear(hidden_size, CFG.emb_size, bias=False)
        self.arc = None
        self.dropout = utilities.Multisample_Dropout()

    def forward(self, x):
        embeddings = self.dropout(x, self.emb)
        
        output = self.arc(embeddings)

        return output, embeddings

In [None]:
class Model(nn.Module):
    def __init__(self, vit_backbone):
        super(Model, self).__init__()

        self.vit_backbone = vit_backbone

        self.head = Head(768)

In [None]:
path_list =  [
              '../models/soup-v1/ViT-L-14-336',
              '../models/soup-v2/ViT-L-14-336',
              '../models/soup-v3/ViT-L-14-336',
              '../models/soup-v4/ViT-L-14-336'
              ]

# Load models weights
weight_list = []

for path in path_list:
    model = Model(vit_backbone)
    model.load_state_dict(torch.load(path), strict=False)
    weight_list.append(model.state_dict())

# Average weights
state_dict = dict((k, torch.stack([v[k] for v in weight_list]).mean(0)) for k in weight_list[0])
model.load_state_dict(state_dict)


In [None]:
model_name = CFG.model_name.replace('/','-')
torch.save(model.state_dict(), f'../models/{model_name}-soup')

In [6]:
import torch
from transformers import AutoModelForSequenceClassification, AutoTokenizer

# load model and tokenizer
model_id = "distilbert-base-uncased-finetuned-sst-2-english"
model = AutoModelForSequenceClassification.from_pretrained(model_id)
tokenizer = AutoTokenizer.from_pretrained(model_id)
dummy_model_input = tokenizer("This is a sample", return_tensors="pt")

# export
torch.onnx.export(
    model, 
    tuple(dummy_model_input.values()),
    f="torch-model.onnx",  
    input_names=['input_ids', 'attention_mask'], 
    output_names=['logits'], 
    dynamic_axes={'input_ids': {0: 'batch_size', 1: 'sequence'}, 
                  'attention_mask': {0: 'batch_size', 1: 'sequence'}, 
                  'logits': {0: 'batch_size', 1: 'sequence'}}, 
    do_constant_folding=True, 
    opset_version=13, 
)

In [7]:
import onnx

# Load the ONNX model
model = onnx.load("torch-model.onnx")

# Check that the model is well formed
onnx.checker.check_model(model)

# Print a human readable representation of the graph
print(onnx.helper.printable_graph(model.graph))

graph torch_jit (
  %input_ids[INT64, batch_sizexsequence]
  %attention_mask[INT64, batch_sizexsequence]
) initializers (
  %distilbert.embeddings.word_embeddings.weight[FLOAT, 30522x768]
  %distilbert.embeddings.position_embeddings.weight[FLOAT, 512x768]
  %distilbert.embeddings.LayerNorm.weight[FLOAT, 768]
  %distilbert.embeddings.LayerNorm.bias[FLOAT, 768]
  %distilbert.transformer.layer.0.attention.q_lin.bias[FLOAT, 768]
  %distilbert.transformer.layer.0.attention.k_lin.bias[FLOAT, 768]
  %distilbert.transformer.layer.0.attention.v_lin.bias[FLOAT, 768]
  %distilbert.transformer.layer.0.attention.out_lin.bias[FLOAT, 768]
  %distilbert.transformer.layer.0.sa_layer_norm.weight[FLOAT, 768]
  %distilbert.transformer.layer.0.sa_layer_norm.bias[FLOAT, 768]
  %distilbert.transformer.layer.0.ffn.lin1.bias[FLOAT, 3072]
  %distilbert.transformer.layer.0.ffn.lin2.bias[FLOAT, 768]
  %distilbert.transformer.layer.0.output_layer_norm.weight[FLOAT, 768]
  %distilbert.transformer.layer.0.output_lay

In [10]:
import onnxruntime as ort

ort_session = ort.InferenceSession("torch-model.onnx")

outputs = ort_session.run(
    None,
    tuple(dummy_model_input.values()),
)
print(outputs[0])

TypeError: run(): incompatible function arguments. The following argument types are supported:
    1. (self: onnxruntime.capi.onnxruntime_pybind11_state.InferenceSession, arg0: List[str], arg1: Dict[str, object], arg2: onnxruntime.capi.onnxruntime_pybind11_state.RunOptions) -> List[object]

Invoked with: <onnxruntime.capi.onnxruntime_pybind11_state.InferenceSession object at 0x7fa60c5badf0>, ['logits'], (tensor([[ 101, 2023, 2003, 1037, 7099,  102]]), tensor([[1, 1, 1, 1, 1, 1]])), None

In [2]:
dummy_model_input

{'input_ids': tensor([[ 101, 2023, 2003, 1037, 7099,  102]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1]])}

In [3]:
tuple(dummy_model_input.values())

(tensor([[ 101, 2023, 2003, 1037, 7099,  102]]), tensor([[1, 1, 1, 1, 1, 1]]))