# Convert PyTorch model to TorchScript for Android

In [2]:
import torch
from alphabet_classifier.helper_functions import MNISTClassifier
from torch.utils.mobile_optimizer import optimize_for_mobile

device = torch.device("cpu")
model = MNISTClassifier().to(device)
state_dict = torch.load("alphabet_classifier/models/finetune_model_100.pth", map_location=device)
model.load_state_dict(state_dict)
model.eval()

# Trace the model and save as TorchScript
traced_script_module = torch.jit.script(model)

# This will result in minor differences in the model compared to the original PyTorch model
traced_script_module = optimize_for_mobile(traced_script_module)

output_path = "alphabet_classifier/models/wordle_alphabet_classifier_android.ptl"
traced_script_module._save_for_lite_interpreter(output_path)
print(f"Saved TorchScript model to {output_path}")

Saved TorchScript model to models/wordle_alphabet_classifier_android.ptl


# Compare predictions from the original and TorchScript models, should be equal

In [11]:
from image_processing import get_wordle_grid_boxes, crop_cell_margin, detect_letter
from alphabet_classifier.helper_functions import MNISTClassifier, wordle_cell_preprocessing, transform_handwritten_alphabet_dataset
import cv2
import torch
from PIL import Image
import os
os.environ["XNNPACK_DISABLE"] = "1"

device = "cpu"

# Path to your mobile‐optimized TorchScript
ptl_path = "alphabet_classifier/models/wordle_alphabet_classifier_android.ptl"
jitted_model = torch.jit.load(ptl_path, map_location=device)
jitted_model.eval()

# Define testing image filename
image = cv2.imread('images/row 1.PNG')

rows = get_wordle_grid_boxes(image)
jitted_logits = []

if rows:
    print(f"Extracted {len(rows)} rows from the Wordle board")

    # Crop the cells and put them into a nested list
    wordle_board_imgs = []
    for row in rows:
        wordle_board_row = []
        for cell in row:
            x, y, w, h = cell[1:5]
            cropped_cell = image[y:y+h, x:x+w]
            cropped_cell = crop_cell_margin(cropped_cell)
            wordle_board_row.append(cropped_cell)
        wordle_board_imgs.append(wordle_board_row)


    # Run CNN model on each cell
    for i, row in enumerate(wordle_board_imgs):
        for j, cell in enumerate(row):
            if detect_letter(cell.copy()):
                # Pytorch model inference
                # Convert image to binary thresholded image
                thresh = wordle_cell_preprocessing(cell.copy())

                # Check if background is white or grey
                # CNN is trained on black background white letters
                # So need to invert the image if background is white or grey
                mean_pixel_value = cv2.mean(thresh)[0]
                if mean_pixel_value > 127:
                    thresh = cv2.bitwise_not(thresh)

                # Show
                cv2.imshow("Croped Image", thresh)
                cv2.waitKey(0)
                cv2.destroyAllWindows()

                alphabet = Image.fromarray(thresh)

                # Convert to tensor and apply transformations
                alphabet_tensor = transform_handwritten_alphabet_dataset()(alphabet)
                # Add batch dimension
                alphabet_tensor = alphabet_tensor.unsqueeze(0).to(device)

                # Inference
                with torch.no_grad():
                    jitted_logits.append(jitted_model(alphabet_tensor))
                    pred = jitted_logits[-1].argmax(dim=1).item()
                    letter = chr(ord('A') + pred)

                    # Get confidence
                    confidence = torch.max(torch.softmax(jitted_logits[-1], dim=1)).item()
                print(f"Row {i+1}, Column {j+1}: {letter}, Confidence: {confidence:.2f}")
            else:
                print(f"Row {i+1}, Column {j+1}: No letter detected")
else:
    print(f"Failed to extract wordle board, found {len(rows)} rows instead of 6")

Extracted 6 rows from the Wordle board
Row 1, Column 1: H, Confidence: 1.00
Row 1, Column 2: V, Confidence: 1.00
Row 1, Column 3: I, Confidence: 0.62
Row 1, Column 4: D, Confidence: 1.00
Row 1, Column 5: A, Confidence: 1.00
Row 2, Column 1: No letter detected
Row 2, Column 2: No letter detected
Row 2, Column 3: No letter detected
Row 2, Column 4: No letter detected
Row 2, Column 5: No letter detected
Row 3, Column 1: No letter detected
Row 3, Column 2: No letter detected
Row 3, Column 3: No letter detected
Row 3, Column 4: No letter detected
Row 3, Column 5: No letter detected
Row 4, Column 1: No letter detected
Row 4, Column 2: No letter detected
Row 4, Column 3: No letter detected
Row 4, Column 4: No letter detected
Row 4, Column 5: No letter detected
Row 5, Column 1: No letter detected
Row 5, Column 2: No letter detected
Row 5, Column 3: No letter detected
Row 5, Column 4: No letter detected
Row 5, Column 5: No letter detected
Row 6, Column 1: No letter detected
Row 6, Column 2: No 

In [12]:
from image_processing import get_wordle_grid_boxes, crop_cell_margin, detect_letter
from alphabet_classifier.helper_functions import MNISTClassifier, wordle_cell_preprocessing, transform_handwritten_alphabet_dataset
import cv2
import torch
from PIL import Image

# load pytorch model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = MNISTClassifier().to(device)
state_dict = torch.load("alphabet_classifier/models/finetune_model_100.pth", map_location=device)
model.load_state_dict(state_dict)
model.eval()

image = cv2.imread('images/row 1.PNG')

rows = get_wordle_grid_boxes(image)
logits = []

if rows:
    print(f"Extracted {len(rows)} rows from the Wordle board")

    # Crop the cells and put them into a nested list
    wordle_board_imgs = []
    for row in rows:
        wordle_board_row = []
        for cell in row:
            x, y, w, h = cell[1:5]
            cropped_cell = image[y:y+h, x:x+w]
            cropped_cell = crop_cell_margin(cropped_cell)
            wordle_board_row.append(cropped_cell)
        wordle_board_imgs.append(wordle_board_row)


    # Run CNN model on each cell
    for i, row in enumerate(wordle_board_imgs):
        for j, cell in enumerate(row):
            if detect_letter(cell.copy()):
                # Pytorch model inference
                # Convert image to binary thresholded image
                thresh = wordle_cell_preprocessing(cell.copy())

                # Check if background is white or grey
                # CNN is trained on black background white letters
                # So need to invert the image if background is white or grey
                mean_pixel_value = cv2.mean(thresh)[0]
                if mean_pixel_value > 127:
                    thresh = cv2.bitwise_not(thresh)

                # Show
                cv2.imshow("Croped Image", thresh)
                cv2.waitKey(0)
                cv2.destroyAllWindows()

                alphabet = Image.fromarray(thresh)

                # Convert to tensor and apply transformations
                alphabet_tensor = transform_handwritten_alphabet_dataset()(alphabet)
                # Add batch dimension
                alphabet_tensor = alphabet_tensor.unsqueeze(0).to(device)

                # Inference
                with torch.no_grad():
                    logits.append(model(alphabet_tensor))
                    pred = logits[-1].argmax(dim=1).item()
                    letter = chr(ord('A') + pred)

                    # Get confidence
                    confidence = torch.max(torch.softmax(logits[-1], dim=1)).item()
                print(f"Row {i+1}, Column {j+1}: {letter}, Confidence: {confidence:.2f}")
            else:
                print(f"Row {i+1}, Column {j+1}: No letter detected")
else:
    print(f"Failed to extract wordle board, found {len(rows)} rows instead of 6")

Extracted 6 rows from the Wordle board
Row 1, Column 1: H, Confidence: 1.00
Row 1, Column 2: V, Confidence: 1.00
Row 1, Column 3: I, Confidence: 0.62
Row 1, Column 4: D, Confidence: 1.00
Row 1, Column 5: A, Confidence: 1.00
Row 2, Column 1: No letter detected
Row 2, Column 2: No letter detected
Row 2, Column 3: No letter detected
Row 2, Column 4: No letter detected
Row 2, Column 5: No letter detected
Row 3, Column 1: No letter detected
Row 3, Column 2: No letter detected
Row 3, Column 3: No letter detected
Row 3, Column 4: No letter detected
Row 3, Column 5: No letter detected
Row 4, Column 1: No letter detected
Row 4, Column 2: No letter detected
Row 4, Column 3: No letter detected
Row 4, Column 4: No letter detected
Row 4, Column 5: No letter detected
Row 5, Column 1: No letter detected
Row 5, Column 2: No letter detected
Row 5, Column 3: No letter detected
Row 5, Column 4: No letter detected
Row 5, Column 5: No letter detected
Row 6, Column 1: No letter detected
Row 6, Column 2: No 

In [13]:
# Absolute match
for i in range(len(logits)):
    logits[i] = logits[i].cpu()
    jitted_logits[i] = jitted_logits[i].cpu()
    if not torch.equal(logits[i], jitted_logits[i]):
        diff = (logits[i] - jitted_logits[i]).abs().max().item()
        raise AssertionError(f"logits differ! max absolute difference = {diff}")
    print("Exact match!")

# optimize_for_mobile will fail this check

AssertionError: logits differ! max absolute difference = 1.52587890625e-05

In [14]:
# tiny floating‐point discrepancies
for i in range(len(logits)):
    if not torch.allclose(logits[i], jitted_logits[i], rtol=1e-5, atol=1e-6):
        diff = (logits[i] - jitted_logits[i]).abs().max().item()
        raise AssertionError(f"logits not allclose! max absolute difference = {diff}")
    print("Allclose match!")

Allclose match!
Allclose match!
Allclose match!
Allclose match!
Allclose match!
