In [24]:
import torch
from PIL import Image
from torchvision.transforms.v2 import Compose, ToTensor, CenterCrop, ToDtype
from model import EquationRecognitionModel
from tokenizer import EquationTokenizer

model = EquationRecognitionModel.load_from_checkpoint('logs/equation_cnn/version_0/checkpoints/epoch=3-step=2400.ckpt')
model.eval()

EquationRecognitionModel(
  (cnn): Sequential(
    (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU()
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (4): ReLU()
    (5): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (6): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (7): ReLU()
    (8): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (cnn_to_rnn): Linear(in_features=56832, out_features=256, bias=True)
  (embedding): Embedding(45, 256)
  (lstm): LSTM(256, 512, num_layers=2, batch_first=True)
  (output_layer): Linear(in_features=512, out_features=45, bias=True)
  (loss): CrossEntropyLoss()
  (accuracy): MulticlassAccuracy()
)

In [25]:
image_transform = Compose([
    CenterCrop((100, 300)),
    ToDtype(torch.float32, scale=True),
    ToTensor(),
])



In [26]:
tokenizer = EquationTokenizer()
tokenizer.build_vocab()

In [27]:
@torch.no_grad()
def predict_equation(image_path, model, tokenizer):
    image = Image.open(image_path).convert('RGB')
    image_tensor = image_transform(image).unsqueeze(0)
    
    start_token = torch.tensor([tokenizer.vocab['<SOS>']], dtype=torch.long).unsqueeze(0)
    
    length_tensor = torch.tensor([1], dtype=torch.long)
    
    for _ in range(100):
        outputs = model(image_tensor, start_token, length_tensor)
        preds = torch.argmax(outputs, dim=2)
        next_token = preds[:, -1:]
        
        if next_token.item() == tokenizer.vocab['<EOS>']:
            break
        
        start_token = torch.cat((start_token, next_token), dim=1)
        length_tensor += 1
    
    predicted_sequence = start_token.squeeze().tolist()
    equation = tokenizer(predicted_sequence)
    
    return equation

In [28]:
image_path = 'New Project.png'
predicted_equation = predict_equation(image_path, model, tokenizer)
print(f'Predicted Equation: {predicted_equation}')

Predicted Equation: 4x + 12y = 234
