# Tutorials
https://github.com/NielsRogge/Transformers-Tutorials/blob/master/VisionTransformer/Quick_demo_of_HuggingFace_version_of_Vision_Transformer_inference.ipynb
https://huggingface.co/docs/transformers/en/model_doc/trocr

In [1]:
import numpy as numpy
import pandas as pd
import matplotlib as plt
import torch
import transformers
from transformers import ViTForImageClassification

  from .autonotebook import tqdm as notebook_tqdm


# Import and the model and processor


In [None]:
from transformers import ViTImageProcessor, BertTokenizer, VisionEncoderDecoderModel
import torch

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

processor =  ViTImageProcessor.from_pretrained("google/vit-base-patch16-224-in21k")
tokenizer = BertTokenizer.from_pretrained("google-bert/bert-base-uncased")
model = VisionEncoderDecoderModel.from_encoder_decoder_pretrained("google/vit-base-patch16-224-in21k", "google-bert/bert-base-uncased")


model.config.decoder_start_token_id = tokenizer.cls_token_id
model.config.pad_token_id = tokenizer.pad_token_id

Config of the encoder: <class 'transformers.models.vit.modeling_vit.ViTModel'> is overwritten by shared encoder config: ViTConfig {
  "attention_probs_dropout_prob": 0.0,
  "encoder_stride": 16,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.0,
  "hidden_size": 768,
  "image_size": 384,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "layer_norm_eps": 1e-12,
  "model_type": "vit",
  "num_attention_heads": 12,
  "num_channels": 3,
  "num_hidden_layers": 12,
  "patch_size": 16,
  "qkv_bias": false,
  "transformers_version": "4.48.3"
}

Config of the decoder: <class 'transformers.models.trocr.modeling_trocr.TrOCRForCausalLM'> is overwritten by shared decoder config: TrOCRConfig {
  "activation_dropout": 0.0,
  "activation_function": "gelu",
  "add_cross_attention": true,
  "attention_dropout": 0.0,
  "bos_token_id": 0,
  "classifier_dropout": 0.0,
  "cross_attention_hidden_size": 768,
  "d_model": 1024,
  "decoder_attention_heads": 16,
  "decoder_ffn_dim": 4096,
  "decoder

VisionEncoderDecoderModel(
  (encoder): ViTModel(
    (embeddings): ViTEmbeddings(
      (patch_embeddings): ViTPatchEmbeddings(
        (projection): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
      )
      (dropout): Dropout(p=0.0, inplace=False)
    )
    (encoder): ViTEncoder(
      (layer): ModuleList(
        (0-11): 12 x ViTLayer(
          (attention): ViTSdpaAttention(
            (attention): ViTSdpaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=False)
              (key): Linear(in_features=768, out_features=768, bias=False)
              (value): Linear(in_features=768, out_features=768, bias=False)
              (dropout): Dropout(p=0.0, inplace=False)
            )
            (output): ViTSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
          )
          (intermediate): ViTIntermediate(
            (dense): Linea

In [None]:
# Assemble dataset
import os
from torch.utils.data import DataLoader, Dataset
from PIL import Image

#dataset = load_dataset("dataset/IM2LATEX-100K-HANDWRITTEN/images")

class IM2LATEXDataset(Dataset):
    def __init__(self, images_folder, list_file, formulas_file, transform=None):
        self.images_folder = images_folder
        self.transform = transform
        with open(list_file, 'r') as f:
            self.image_formula_pairs = [line.strip().split() for line in f]
        with open(formulas_file, 'r') as f:
            self.formulas = f.read().splitlines()

    def __len__(self):
        return len(self.image_formula_pairs)

    def __getitem__(self, idx):
        img_name, formula_idx = self.image_formula_pairs[idx]
        img_name = os.path.join(self.images_folder, img_name)
        image = Image.open(img_name)
        formula = self.formulas[int(formula_idx)]
        if self.transform:
            image = self.transform(image)
        return image, formula

def load_im2latex_dataset(images_folder, train_list, test_list, val_list, formulas_file, batch_size=32, transform=None):
    train_dataset = IM2LATEXDataset(images_folder, train_list, formulas_file, transform)
    test_dataset = IM2LATEXDataset(images_folder, test_list, formulas_file, transform)
    val_dataset = IM2LATEXDataset(images_folder, val_list, formulas_file, transform)
    
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
    
    return train_loader, test_loader, val_loader

# Example usage
images_folder = "dataset/IM2LATEX-100K-HANDWRITTEN/images"
train_list = "dataset/IM2LATEX-100K-HANDWRITTEN/train.lst"
test_list = "dataset/IM2LATEX-100K-HANDWRITTEN/test.lst"
val_list = "dataset/IM2LATEX-100K-HANDWRITTEN/val.lst"
formulas_file = "dataset/IM2LATEX-100K-HANDWRITTEN/formulas.lst"
batch_size = 32

transform = transforms.Compose([
    transforms.Resize((128, 128)),
    transforms.ToTensor(),
])

train_loader, test_loader, val_loader = load_im2latex_dataset(images_folder, train_list, test_list, val_list, formulas_file, batch_size, transform)


Downloading data: 100%|██████████| 99552/99552 [00:01<00:00, 96902.28files/s] 
Generating train split: 99552 examples [00:01, 89025.39 examples/s]


In [None]:
import matplotlib.pyplot as plt

dataset

# Iterate through data
for images, formulas in train_loader:
    # Your training code here
    pass

#image = dataset["train"]['image'][0]

# plt.imshow(image)
# plt.show()


DatasetDict({
    train: Dataset({
        features: ['image'],
        num_rows: 99552
    })
})