In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image

In [None]:
class OCRModel(nn.Module):
    def __init__(self, vocab_size, img_channels=1, hidden_size=256, num_lstm_layers=2):
        super(OCRModel, self).__init__()
        
        # CNN Feature Extractor
        self.cnn = nn.Sequential(
            nn.Conv2d(img_channels, 64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),

            nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )
        
        # LSTM Sequence Model
        self.lstm = nn.LSTM(input_size=128, hidden_size=hidden_size, num_layers=num_lstm_layers, batch_first=True, bidirectional=True)
        
        # Final fully connected layer
        self.fc = nn.Linear(hidden_size * 2, vocab_size)  # *2 because bidirectional
    
    def forward(self, x):
        batch_size = x.size(0)
        
        # CNN forward
        features = self.cnn(x)  # Shape: (B, C, H, W)
        
        # Reshape for LSTM (treat width as time-steps)
        b, c, h, w = features.size()
        features = features.permute(0, 3, 1, 2).contiguous().view(b, w, -1)  # Shape: (B, W, C*H)
        
        # LSTM forward
        lstm_out, _ = self.lstm(features)  # Shape: (B, W, Hidden*2)
        
        # Fully connected
        output = self.fc(lstm_out)  # Shape: (B, W, vocab_size)
        
        return output



In [3]:
criterion = nn.CTCLoss(blank=0)

In [1]:
class OCRDataset(Dataset):
    def __init__(self, image_paths, labels, transform=None):
        self.image_paths = image_paths  # List of image file paths
        self.labels = labels  # List of label sequences
        self.transform = transform

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        # Load image
        image = Image.open(self.image_paths[idx]).convert('L')  # Convert to grayscale
        if self.transform:
            image = self.transform(image)
        
        # Get label sequence
        label = self.labels[idx]

        return image, label

# Define transformations (e.g., resize, normalize)
transform = transforms.Compose([
    transforms.Resize((32, 128)),  # Resize to fixed size for the model
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))  # Example normalization
])

# Create dataset
dataset = OCRDataset(image_paths=['path/to/image1.png', 'path/to/image2.png'], 
                     labels=[[1, 2, 3], [4, 5, 6]],
                     transform=transform)

# DataLoader for batching
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)

NameError: name 'Dataset' is not defined

In [None]:
vocab_size = 51  # Adjust based on your vocab
model = OCRModel(vocab_size=vocab_size)
print(model)