In [None]:
import yaml

hyperparameters = yaml.load(open('hyperparameters.yaml'), Loader=yaml.FullLoader)
hyperparameters

In [None]:
from model import Image2LatexModel

model = Image2LatexModel.load_from_checkpoint('logs/image2latex/version_0/checkpoints/epoch=3-step=10000.ckpt')
model.eval()

In [None]:
from torchvision.transforms.v2 import Resize, Compose

transform = Compose([
    Resize(hyperparameters['model']['image_size']),
])

In [None]:
from torchvision.io import read_image

input_image = 'original.png'
tensor_image = read_image(input_image)
tensor_image = tensor_image.float() / 255.0
transformed_image = transform(tensor_image)
transformed_image = transformed_image.unsqueeze(0)
transformed_image.shape

In [None]:
import torch

with torch.no_grad():
    encoder_output = model.image_embeddings(transformed_image)
    encoder_output = model.encoder(encoder_output)

encoder_output = encoder_output.transpose(0, 1)
encoder_output.shape

In [None]:
decoded_tokens = torch.full((1, hyperparameters['model']['max_length']), 92)
decoded_tokens

In [None]:
output_sequence = torch.full((1, 1), 92)
output_sequence

In [None]:
from tqdm import tqdm

for i in tqdm(range(hyperparameters['model']['max_length']-1), desc='Generating LaTeX tokens'):
    with torch.no_grad():
        target_embeddings = model.target_embeddings(output_sequence).transpose(0, 1)
        target_mask = model.generate_square_subsequent_mask(output_sequence.size(1))
        decoder_output = model.decoder(target_embeddings, encoder_output, tgt_mask=target_mask, tgt_is_causal=True)
        
        logits = model.output_projection(decoder_output[-1])
        
        next_token = logits.argmax(dim=-1, keepdim=True)
        output_sequence = torch.cat([output_sequence, next_token], dim=1)
        
        if torch.all(next_token == 93):
            break