In [1]:
import torch
from src.net.model import ViTSTRTransducer

In [2]:
DEVICE = 'cuda:0' if torch.cuda.is_available() else 'cpu'
checkpoint_path = "ViTSTR-FP32-base.ckpt"
model = ViTSTRTransducer.load_from_checkpoint(checkpoint_path, training=False, map_location=DEVICE).eval()
model.freeze()
input_size = (model.input_channels, *model.input_size)
vocab_size = model.vocab_size

# FP16 torchscript

In [3]:
fp16_weights_path = "ViTSTR-FP16-base.torchscript"
model.half().to_torchscript(fp16_weights_path)
model_ts_fp16 = torch.jit.load(fp16_weights_path, map_location=DEVICE)
params = {i: param.dtype == torch.half for i, param in enumerate(model_ts_fp16.parameters())}
print(f'Total params: {len(params.keys())}\nTotal FP16 params: {sum(params.values())}')

Total params: 175
Total FP16 params: 175


In [4]:
# Image input shape: [1, C, H, W], dont forget about normalizing
img_input_fp16 = torch.rand(1, *input_size, device=DEVICE).half()
seq_len = 5
vocab_size = vocab_size
target_input = torch.randint(0, vocab_size, (1, seq_len), dtype=torch.int32, device=DEVICE)
output = model_ts_fp16(img_input_fp16, target_input)
print(output.shape, output.dtype) # Output shape: [1, seq, vocab_size]

torch.Size([1, 5, 48]) torch.float16


# BF16 torchscript

In [5]:
bf16_weights_path = "ViTSTR-BF16-base.torchscript"
model.bfloat16().to_torchscript(bf16_weights_path)
model_ts_bf16 = torch.jit.load(bf16_weights_path, map_location=DEVICE)
params = {i: param.dtype == torch.bfloat16 for i, param in enumerate(model_ts_bf16.parameters())}
print(f'Total params: {len(params.keys())}\nTotal BF16 params: {sum(params.values())}')

Total params: 175
Total BF16 params: 175


In [6]:
# Image input shape: [1, C, H, W], dont forget about normalizing
img_input_bf16 = torch.rand(1, *input_size, device=DEVICE).bfloat16()
seq_len = 5
vocab_size = vocab_size
target_input = torch.randint(0, vocab_size, (1, seq_len), dtype=torch.int32, device=DEVICE)
output = model_ts_bf16(img_input_bf16, target_input)
print(output.shape, output.dtype) # Output shape: [1, seq, vocab_size]

torch.Size([1, 5, 48]) torch.bfloat16
