In [1]:
import sys
sys.path.insert(1, '../')

import torch
import warnings 
from modules.models import *

warnings.filterwarnings('ignore') 

# Convert DispNetB

In [2]:
model = DispNetB()
model.load_state_dict(torch.load(f'../weights/disp.pth', map_location=torch.device('cpu')))

input_1, input_2 = torch.rand((1, 3, 137, 137)), torch.rand((1, 3, 137, 137))
with torch.no_grad():
    pred_1, pred_2 = model(input_1, input_2)
    print(pred_1.shape, pred_2.shape)

torch.onnx.export(model, (input_1, input_2), "../weights/disp.onnx", 
                  input_names=['left', 'right'], output_names=['disp_left', 'disp_right'],
                  dynamic_axes={'left' : {0 : 'batch'}, 'right' : {0 : 'batch'}, 'disp_left' : {0 : 'batch'}, 'disp_right' : {0 : 'batch'}}, opset_version=17)

torch.Size([1, 1, 137, 137]) torch.Size([1, 1, 137, 137])


# Convert Encoder

In [3]:
model = Encoder()
model.load_state_dict(torch.load(f'../weights/enc.pth', map_location=torch.device('cpu')))

input_1 = torch.rand((1, 4, 137, 137))
with torch.no_grad():
    pred_1, pred_2 = model(input_1)
    print(pred_1.shape, pred_2.shape)

torch.onnx.export(model, input_1, "../weights/enc.onnx", 
                  input_names=['rgbd'], output_names=['emb', 'cor'],
                  dynamic_axes={'rgbd' : {0 : 'batch'}, 'emb' : {0 : 'batch'}, 'cor' : {0 : 'batch'}}, opset_version=11)

torch.Size([1, 8192]) torch.Size([1, 256, 34, 34])


# Convert Corr

In [4]:
model = CorrNet(torch.device('cpu'))
model.load_state_dict(torch.load(f'../weights/cor.pth', map_location=torch.device('cpu')))

input_1, input_2 = torch.rand((1, 256, 34, 34)), torch.rand((1, 256, 34, 34))
with torch.no_grad():
    pred_1 = model(input_1, input_2)
    print(pred_1.shape)

torch.onnx.export(model, (input_1, input_2), "../weights/cor.onnx", 
                  input_names=['left', 'right'], output_names=['output'],
                  dynamic_axes={'left' : {0 : 'batch'}, 'right' : {0 : 'batch'}, 'output' : {0 : 'batch'}}, opset_version=17)

torch.Size([1, 4096])


# Convert Decoder

In [5]:
model = Decoder()
model.load_state_dict(torch.load(f'../weights/dec.pth', map_location=torch.device('cpu')))

input_1, input_2, input_3 = torch.rand((1, 128, 4, 4, 4)), torch.rand((1, 128, 4, 4, 4)), torch.rand((1, 64, 4, 4, 4))
with torch.no_grad():
    pred_1 = model(input_1, input_2, input_3)
    print(pred_1.shape)

torch.onnx.export(model, (input_1, input_2, input_3), "../weights/dec.onnx", 
                  input_names=['left', 'right', 'cor'], output_names=['output'],
                  dynamic_axes={'left' : {0 : 'batch'}, 'right' : {0 : 'batch'}, 'cor' : {0 : 'batch'}, 'output' : {0 : 'batch'}}, opset_version=17)

torch.Size([1, 32, 32, 32])
