In [2]:
import torch
from torch.utils.data import DataLoader
from torchvision import transforms, datasets
from transformers import TrOCRProcessor, VisionEncoderDecoderModel
from tqdm import tqdm

# Define data transforms
data_transform = transforms.Compose([
    transforms.Resize((256, 256)),  # Resize images to 256x256
    transforms.ToTensor(),           # Convert images to PyTorch tensors
])

# Path to the root directory of your dataset
data_dir = 'data'

# Create a dataset instance using ImageFolder
dataset = datasets.ImageFolder(root=data_dir, transform=data_transform)

# Create a DataLoader to efficiently load and batch data
data_loader = DataLoader(dataset, batch_size=32, shuffle=True)

# Initialize the OCR processor and model
processor = TrOCRProcessor.from_pretrained('microsoft/trocr-small-stage1')
model = VisionEncoderDecoderModel.from_pretrained('microsoft/trocr-small-stage1')

# Define optimizer and loss function
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = torch.nn.CrossEntropyLoss()

# Training loop
num_epochs = 10
for epoch in range(num_epochs):
    total_loss = 0.0
    with tqdm(total=len(data_loader), desc=f"Epoch {epoch+1}/{num_epochs}", unit="batch") as pbar:
        for images, labels in data_loader:
            # Forward pass
            pixel_values = processor(images, return_tensors="pt").pixel_values
            decoder_input_ids = torch.tensor([[model.config.decoder.decoder_start_token_id]] * pixel_values.shape[0])
            outputs = model(pixel_values=pixel_values, decoder_input_ids=decoder_input_ids)
            
            # Compute loss
            loss = criterion(outputs.logits.view(-1, outputs.logits.shape[-1]), labels.view(-1))
            
            # Backward pass and optimization
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            total_loss += loss.item()
            pbar.update(1)
            pbar.set_postfix({'Loss': total_loss / (pbar.n * data_loader.batch_size)})
    
    # Print average loss for the epoch
    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {total_loss/len(data_loader)}")


Some weights of VisionEncoderDecoderModel were not initialized from the model checkpoint at microsoft/trocr-small-stage1 and are newly initialized: ['encoder.pooler.dense.bias', 'encoder.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Epoch 1/10: 100%|██████████| 313/313 [3:16:59<00:00, 37.76s/batch, Loss=0.104]   


Epoch 1/10, Loss: 3.323179184819182


Epoch 2/10: 100%|██████████| 313/313 [2:18:01<00:00, 26.46s/batch, Loss=0.0967]  


Epoch 2/10, Loss: 3.093468418517433


Epoch 3/10: 100%|██████████| 313/313 [2:32:47<00:00, 29.29s/batch, Loss=0.0959]  


Epoch 3/10, Loss: 3.06980657958375


Epoch 4/10: 100%|██████████| 313/313 [2:28:53<00:00, 28.54s/batch, Loss=0.0955]  


Epoch 4/10, Loss: 3.055168318672302


Epoch 5/10: 100%|██████████| 313/313 [2:27:15<00:00, 28.23s/batch, Loss=0.0952]  


Epoch 5/10, Loss: 3.0450121449967162


Epoch 6/10: 100%|██████████| 313/313 [2:27:07<00:00, 28.20s/batch, Loss=0.0949]  


Epoch 6/10, Loss: 3.0364519894694366


Epoch 7/10: 100%|██████████| 313/313 [2:26:13<00:00, 28.03s/batch, Loss=0.0947]  


Epoch 7/10, Loss: 3.029373561993194


Epoch 8/10: 100%|██████████| 313/313 [2:28:17<00:00, 28.43s/batch, Loss=0.0946]  


Epoch 8/10, Loss: 3.0264371241243504


Epoch 9/10:  94%|█████████▍| 294/313 [2:30:15<14:38, 46.25s/batch, Loss=0.0946]  