# Importing The Trained Model

In [5]:
import torch
import pickle

from Modules.Encoder import CNNEncoder
from Modules.Decoder import RNNDecoder
from Modules.Sequence import Seq2Seq

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load tokenizer
with open('tokenizer.pkl', 'rb') as f:
    tokenizer = pickle.load(f)

# Model setup (must match training setup)
encoder = CNNEncoder(output_dim=256).to(device)
decoder = RNNDecoder(hidden_dim=256, vocab_size=tokenizer.vocab_size()).to(device)

# Load checkpoint
checkpoint = torch.load('model_checkpoint.pth', map_location=device)

sos_token_id = checkpoint['sos_token_id']
eos_token_id = checkpoint['eos_token_id']

model = Seq2Seq(encoder, decoder, sos_token_id, eos_token_id, device)
model.load_state_dict(checkpoint['model_state_dict'])
model.to(device)
model.eval()


[nltk_data] Downloading package punkt to C:\Users\Sumit Washimkar
[nltk_data]     SRW\AppData\Roaming\nltk_data...
[nltk_data]   Package punkt is already up-to-date!


Seq2Seq(
  (encoder): CNNEncoder(
    (feature_extractor): Sequential(
      (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
      (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
      (4): Sequential(
        (0): BasicBlock(
          (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
          (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
        (1): BasicBlock(
          (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn1): BatchNorm2d(

# Demonstraction of Image and Output

In [6]:
from PIL import Image
import torch
from torchvision import transforms

def predict_single_image(model, tokenizer, image_path, device, transform, max_len=50):
    model.eval()

    # Load and preprocess image
    image = Image.open(image_path).convert("RGB")
    image = transform(image).unsqueeze(0).to(device)  # [1, C, H, W]

    with torch.no_grad():
        encoder_out = model.encoder(image)

        # Get token IDs
        sos_token_id = tokenizer.token_to_id.get("<SOS>", 1)
        eos_token_id = tokenizer.token_to_id.get("<EOS>", 2)

        # Init decoder
        inputs = torch.tensor([sos_token_id]).to(device)

        # Init hidden from encoder output
        encoder_mean = encoder_out.mean(dim=1)  # [1, H]
        h_0 = encoder_mean.unsqueeze(0)         # [1, 1, H]
        c_0 = torch.zeros_like(h_0)             # [1, 1, H]
        hidden = (h_0, c_0)

        decoded_tokens = []

        for _ in range(max_len):
            output, hidden, _ = model.decoder(inputs, hidden, encoder_out)
            top1 = output.argmax(1)

            if top1.item() == eos_token_id:
                break

            decoded_tokens.append(top1.item())
            inputs = top1

        return tokenizer.decode(decoded_tokens) if decoded_tokens else ""


In [8]:
transform = transforms.Compose([
    transforms.Resize((224, 224)),               # Slightly larger for random crop
    transforms.RandomHorizontalFlip(p=0.5),      # Flip with 50% chance
    transforms.RandomRotation(15),                # Smaller rotation range (more realistic)
    transforms.ToTensor(),
    transforms.Normalize(                         # Normalize with ImageNet stats (if using ResNet pretrained)
        mean=[0.485, 0.456, 0.406], 
        std=[0.229, 0.224, 0.225]
    ),
])

In [12]:
image_path1 = "18_em_6.bmp"  # Path to your image
predicted_latex = predict_single_image(model, tokenizer, image_path1, device, transform)
print("Predicted LaTeX:", predicted_latex)

Predicted LaTeX: ( f ( a ) ) = = a a b b


In [10]:
image_path2 = "Demo1.png"  # Path to your image
predicted_latex = predict_single_image(model, tokenizer, image_path2, device, transform)
print("Predicted LaTeX:", predicted_latex)

Predicted LaTeX: [ [ n n ] ] ] ] ] ] ] ] ]


In [11]:
image_path3 = "Demo2.png"  # Path to your image
predicted_latex = predict_single_image(model, tokenizer, image_path3, device, transform)
print("Predicted LaTeX:", predicted_latex)

Predicted LaTeX: [ [ [ [ m m m m m m m
