In [10]:
import torch
import torch.nn as nn
import torchvision.models as models
import torch.onnx
import onnx

In [2]:
class ResNet18(nn.Module):
    def __init__(self, embedding_size=512, dropout=0.3):
        super(ResNet18, self).__init__()
        base = models.resnet18(weights=None)
        base.fc = nn.Identity()
        self.backbone = base
        self.dropout = nn.Dropout(p=dropout)
        self.embedding = nn.Linear(512, embedding_size)

    def forward(self, x):
        features = self.backbone(x)
        embeddings = self.dropout(features)
        embeddings = self.embedding(embeddings)
        embeddings = nn.functional.normalize(embeddings, p=2, dim=1)
        return embeddings

In [3]:
model = ResNet18(embedding_size=512, dropout=0.3)

In [4]:
model.load_state_dict(torch.load('./best_student.pth', map_location=torch.device('cpu')))

<All keys matched successfully>

In [5]:
model.eval()

ResNet18(
  (backbone): ResNet(
    (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (1): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_r

In [6]:
dummy_input = torch.randn(1, 3, 112, 112)

In [11]:
torch.onnx.export(model, dummy_input, 'student_resnet18.onnx', export_params=True, opset_version=11,
                  do_constant_folding=True, input_names=['input'], output_names=['embedding'],
                  dynamic_axes={
                      'input': {0: 'batch_size'},
                      'embedding': {0: 'batch_size'}
                  }, verbose=False)

print("✅ Converting successed: student_resnet18.onnx")

  torch.onnx.export(model, dummy_input, 'student_resnet18.onnx', export_params=True, opset_version=11,


✅ Converting successed: student_resnet18.onnx
