# Convert PyTorch model to TorchScript for Android

In [2]:
import torch
from helper_functions_pt import MNISTClassifier
from torch.utils.mobile_optimizer import optimize_for_mobile

device = torch.device("cpu")
model = MNISTClassifier().to(device)
state_dict = torch.load("models/pt_cnn/ft_model_epoch12.pth", map_location=device)
model.load_state_dict(state_dict)
model.eval()

# Trace the model and save as TorchScript
traced_script_module = torch.jit.script(model)

# This will result in minor differences in the model compared to the original PyTorch model
traced_script_module = optimize_for_mobile(traced_script_module)

output_path = "models/pt_cnn/sudoku_digit_classifier_android.ptl"
traced_script_module._save_for_lite_interpreter(output_path)
print(f"Saved TorchScript model to {output_path}")

Saved TorchScript model to models/pt_cnn/sudoku_digit_classifier_android.ptl


# Compare predictions from the original and TorchScript models, should be equal

In [3]:
from digits_classifier import sudoku_cells_reduce_noise
import cv2
from helper_functions_pt import MNISTClassifier, get_mnist_transform
import torch
from PIL import Image
import os
os.environ["XNNPACK_DISABLE"] = "1"

device = "cpu"

# Path to your mobile‐optimized TorchScript
ptl_path = "models/pt_cnn/sudoku_digit_classifier_android.ptl"
jitted_model = torch.jit.load(ptl_path, map_location=device)
jitted_model.eval()

# Define testing image filename
test_img = "test/1/1.jpg.png"

# Load testing image
digit = cv2.imread(test_img, cv2.IMREAD_GRAYSCALE)
digit_inv = cv2.adaptiveThreshold(digit, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C, cv2.THRESH_BINARY_INV, 27, 11)
denoised_digit = sudoku_cells_reduce_noise(digit_inv)
if denoised_digit is not None:
    digit = Image.fromarray(denoised_digit)
    # Reshape to fit model input, [1,28,28]
    digit_tensor = get_mnist_transform()(digit)
    # Add batch dim, send to device
    digit_tensor = digit_tensor.unsqueeze(0).to(device)

    jitted_logits = jitted_model(digit_tensor)

    arg_max = torch.argmax(jitted_logits, dim=1)+1
    print(f"Predicted digit: {arg_max.item()}")


Predicted digit: 1


In [4]:
# Path
path = "models/pt_cnn/ft_model_epoch10.pth"
device = "cpu"
model = MNISTClassifier().to(device)
state_dict = torch.load(path, map_location=device)
model.load_state_dict(state_dict)
model.eval()

# Define testing image filename
test_img = "test/1/1.jpg.png"

# Load testing image
digit = cv2.imread(test_img, cv2.IMREAD_GRAYSCALE)
digit_inv = cv2.adaptiveThreshold(digit, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C, cv2.THRESH_BINARY_INV, 27, 11)
denoised_digit = sudoku_cells_reduce_noise(digit_inv)
if denoised_digit is not None:
    digit = Image.fromarray(denoised_digit)
    # Reshape to fit model input, [1,28,28]
    digit_tensor = get_mnist_transform()(digit)
    # Add batch dim, send to device
    digit_tensor = digit_tensor.unsqueeze(0).to(device)

    # Make prediction
    with torch.no_grad():
        logits = model(digit_tensor)
        prediction = torch.argmax(logits, dim=1).item()+1
        print(f"Predicted digit: {prediction}")


Predicted digit: 1


In [5]:
# Absolute match
if not torch.equal(logits, jitted_logits):
    diff = (logits - jitted_logits).abs().max().item()
    raise AssertionError(f"logits differ! max absolute difference = {diff}")
print("Exact match!")

# optimize_for_mobile will fail this check

AssertionError: logits differ! max absolute difference = 7.62939453125e-06

In [6]:
# tiny floating‐point discrepancies
if not torch.allclose(logits, jitted_logits, rtol=1e-5, atol=1e-6):
    diff = (logits - jitted_logits).abs().max().item()
    raise AssertionError(f"logits not allclose! max absolute difference = {diff}")
print("Allclose match!")

Allclose match!
