In [None]:
import torch
from PIL import Image
from torchvision.transforms.v2 import Compose, ToTensor, CenterCrop, ToDtype
from model import EquationRecognitionModel
from tokenizer import LaTeXTokenizer

model = EquationRecognitionModel.load_from_checkpoint('logs/equation_cnn/version_2/checkpoints/epoch=10-step=2673.ckpt')
model.eval()

In [None]:
image_transform = Compose([
    CenterCrop((100, 300)),
    ToDtype(torch.float32, scale=True),
    ToTensor(),
])

In [None]:
tokenizer = LaTeXTokenizer()
tokenizer.load_vocab('data/vocab.json')

In [None]:
@torch.no_grad()
def predict_equation(image_path, model, tokenizer):
    image = Image.open(image_path).convert('RGB')
    image_tensor = image_transform(image).unsqueeze(0)
    
    start_token = torch.tensor([tokenizer.vocab['<SOS>']], dtype=torch.long).unsqueeze(0)
    
    length_tensor = torch.tensor([1], dtype=torch.long)
    
    for _ in range(100):
        outputs = model(image_tensor, start_token, length_tensor)
        preds = torch.argmax(outputs, dim=2)
        next_token = preds[:, -1:]
        
        if next_token.item() == tokenizer.vocab['<EOS>']:
            break
        
        start_token = torch.cat((start_token, next_token), dim=1)
        length_tensor += 1
    
    predicted_sequence = start_token.squeeze().tolist()
    equation = tokenizer(predicted_sequence)
    
    return equation

In [None]:
image_path = input('Enter Path to Image: ')
predicted_equation = predict_equation(image_path, model, tokenizer)
print(f'Predicted Equation: {predicted_equation}')