In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
from torchvision import transforms
from PIL import Image
import torch.optim as optim
import string
import numpy as np
import pandas as pd
import json

In [2]:
class OCRModel(nn.Module):
    def __init__(self, vocab_size, img_channels=1, hidden_size=256, num_lstm_layers=2):
        super(OCRModel, self).__init__()
        
        # CNN Feature Extractor
        self.cnn = nn.Sequential(
            nn.Conv2d(img_channels, 64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),

            nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )
        
        # LSTM Sequence Model
        self.lstm = nn.LSTM(input_size=3200, hidden_size=hidden_size, num_layers=num_lstm_layers, batch_first=True, bidirectional=True)
        
        # Final fully connected layer
        self.fc = nn.Linear(hidden_size * 2, vocab_size)  # *2 because bidirectional
    
    def forward(self, x):
        batch_size = x.size(0)
        
        # CNN forward
        features = self.cnn(x)  # Shape: (B, C, H, W)
        
        # Reshape for LSTM (treat width as time-steps)
        b, c, h, w = features.size()
        features = features.permute(0, 3, 1, 2).contiguous().view(b, w, -1)  # Shape: (B, W, C*H)
        
        # LSTM forward
        lstm_out, _ = self.lstm(features)  # Shape: (B, W, Hidden*2)
        
        # Fully connected
        output = self.fc(lstm_out)  # Shape: (B, W, vocab_size)
        
        return output



In [None]:


# class OCRDataset(Dataset):
#     def __init__(self, image_paths, labels, transform=None):
#         self.image_paths = image_paths  # List of image file paths
#         self.labels = labels  # List of label sequences
#         self.transform = transform

#     def __len__(self):
#         return len(self.image_paths)

#     def __getitem__(self, idx):
#         # Load image
#         image = Image.open(self.image_paths[idx]).convert('L')  # Convert to grayscale
#         if self.transform:
#             image = self.transform(image)
        
#         # Get label sequence
#         label = self.labels[idx]

#         return image, label

# # Updated Transformations with Fixed Size (500, 100)
# transform = transforms.Compose([
#     transforms.Resize((100, 500)),  # Resize to 500x100 (H x W)
#     transforms.ToTensor(),
#     transforms.Normalize((0.5,), (0.5,))  # Normalize (mean=0.5, std=0.5)
# ])


In [3]:
df = pd.read_csv("../dataset/mini_qa_images/mini_qa.csv")

In [None]:
# import string
# import numpy as np



# # Step 1: Extract unique characters from your labels
# unique_chars = set(''.join(labels))  # Join all the labels and get unique characters

# # Step 2: Create char-to-id and id-to-char mappings
# char_to_id = {char: idx + 1 for idx, char in enumerate(sorted(unique_chars))}  # Start ids from 1 to avoid 0 for padding
# id_to_char = {idx: char for char, idx in char_to_id.items()}

# # Step 3: Convert labels from text to IDs
# def text_to_ids(text, char_to_id):
#     return [char_to_id.get(char, 0) for char in text]  # Defaulting to 0 for unknown characters

# # Convert all labels to IDs
# label_ids = [text_to_ids(label, char_to_id) for label in labels]



# # Print out the mappings and converted labels
# print("Character to ID mapping:", char_to_id)
# print("ID to Character mapping:", id_to_char)
# print("Converted labels to IDs:", label_ids)

# # Convert back a label's ID list to text
# def ids_to_text(ids, id_to_char):
#     return ''.join([id_to_char.get(id, '?') for id in ids])  # Use '?' for unknown IDs

# # Example: Decode the first label back from IDs to text
# decoded_text = ids_to_text(label_ids[0], id_to_char)
# print("Decoded text from IDs:", decoded_text)
# import json

# # Save char_to_id and id_to_char mappings to a JSON file
# mapping = {"char_to_id": char_to_id, "id_to_char": id_to_char}

# with open("../dataset/mini_qa_images/char_mappings.json", "w", encoding="utf-8") as f:
#     json.dump(mapping, f, ensure_ascii=False, indent=4)

# print("Character mappings saved to 'char_mappings.json'.")


In [4]:
with open("../dataset/mini_qa_images/char_mappings.json", "r", encoding="utf-8") as f:
    loaded_mapping = json.load(f)

# Convert keys back to int for id_to_char (JSON keys are saved as strings)
char_to_id = loaded_mapping["char_to_id"]
id_to_char = {int(k): v for k, v in loaded_mapping["id_to_char"].items()}

print("Character mappings loaded successfully!")
print("Loaded char_to_id:", char_to_id)
print("Loaded id_to_char:", id_to_char)

def text_to_ids(text):
    return [char_to_id.get(char, 0) for char in text]  # Defaulting to 0 for unknown characters
def ids_to_text(ids):
    return ''.join([id_to_char.get(id, '?') for id in ids])  # Use '?' for unknown IDs

Character mappings loaded successfully!
Loaded char_to_id: {' ': 1, '?': 2, '᠂': 3, '᠃': 4, '᠋': 5, '᠌': 6, '᠍': 7, '\u180e': 8, 'ᠠ': 9, 'ᠡ': 10, 'ᠢ': 11, 'ᠣ': 12, 'ᠤ': 13, 'ᠥ': 14, 'ᠦ': 15, 'ᠧ': 16, 'ᠨ': 17, 'ᠩ': 18, 'ᠪ': 19, 'ᠬ': 20, 'ᠭ': 21, 'ᠮ': 22, 'ᠯ': 23, 'ᠰ': 24, 'ᠱ': 25, 'ᠲ': 26, 'ᠳ': 27, 'ᠴ': 28, 'ᠵ': 29, 'ᠶ': 30, 'ᠷ': 31, 'ᠹ': 32, '\u202f': 33, '︖': 34, '？': 35}
Loaded id_to_char: {1: ' ', 2: '?', 3: '᠂', 4: '᠃', 5: '᠋', 6: '᠌', 7: '᠍', 8: '\u180e', 9: 'ᠠ', 10: 'ᠡ', 11: 'ᠢ', 12: 'ᠣ', 13: 'ᠤ', 14: 'ᠥ', 15: 'ᠦ', 16: 'ᠧ', 17: 'ᠨ', 18: 'ᠩ', 19: 'ᠪ', 20: 'ᠬ', 21: 'ᠭ', 22: 'ᠮ', 23: 'ᠯ', 24: 'ᠰ', 25: 'ᠱ', 26: 'ᠲ', 27: 'ᠳ', 28: 'ᠴ', 29: 'ᠵ', 30: 'ᠶ', 31: 'ᠷ', 32: 'ᠹ', 33: '\u202f', 34: '︖', 35: '？'}


In [25]:
image_paths = []
image_labels = []
for index, row in df.iterrows():
    # if(index%20==0):
    #     print(f"{index} out of 100")
    text = row['question']
    path = f"../dataset/mini_qa_images/question/{index}.png"
    image_paths.append(path)
    image_labels.append(text_to_ids(text))

In [None]:
class OCRDataset(Dataset):
    def __init__(self, image_paths, labels, transform=None):
        self.image_paths = image_paths  # List of image file paths
        self.labels = labels  # List of label sequences
        self.transform = transform

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        # Load image
        image = Image.open(self.image_paths[idx]).convert('L')  # Convert to grayscale
        if self.transform:
            image = self.transform(image)
        
        # Get label sequence
        label = self.labels[idx]

        return image, torch.tensor(label, dtype=torch.long)

# Define transformations (e.g., resize, normalize)
transform = transforms.Compose([
    transforms.Resize((100, 500)),  # Resize to fixed size for the model
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))  # Example normalization
])



In [29]:
def collate_fn(batch):
    images, labels = zip(*batch)

    # Pad labels to the maximum length in the batch
    max_len = max(len(label) for label in labels)
    print(max_len)
    # Pad each label sequence to max_len
    padded_labels = [torch.cat([label, torch.full((max_len - len(label),), -1)]) for label in labels]

    return (torch.stack(images), torch.stack(padded_labels))

# Create dataset
dataset = OCRDataset(
    image_paths=image_paths,
    labels=image_labels,
    transform=transform
)

dataloader = DataLoader(dataset, batch_size=2, shuffle=True, collate_fn=collate_fn)



In [30]:


# Define device (use GPU if available)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Move model to device
model = OCRModel(vocab_size=len(char_to_id)).to(device)

# Loss function (CTC Loss)
criterion = nn.CTCLoss(blank=0)  # Set blank label as 0 (assumes char_to_id does not use 0 for real characters)

# Optimizer
optimizer = optim.Adam(model.parameters(), lr=0.001)


In [32]:
device

device(type='cpu')

In [31]:
num_epochs = 10  # Number of epochs
for epoch in range(num_epochs):
    model.train()
    total_loss = 0

    for images, labels in dataloader:
        images = images.to(device)  # Move images to GPU if available
        
        # Convert labels to tensor and move to device
        label_lengths = torch.tensor([len(label) for label in labels], dtype=torch.long).to(device)
        labels = torch.cat([torch.tensor(label, dtype=torch.long) for label in labels]).to(device)

        optimizer.zero_grad()
        
        # Forward pass
        outputs = model(images)  # Shape: (B, W, vocab_size)
        log_probs = outputs.log_softmax(2)  # CTC loss expects log probabilities
        
        # Input length for CTC loss (W from output shape)
        input_lengths = torch.full((log_probs.size(0),), log_probs.size(1), dtype=torch.long).to(device)

        # Compute loss
        loss = criterion(log_probs.permute(1, 0, 2), labels, input_lengths, label_lengths)  # CTC expects (T, N, C)
        loss.backward()
        
        optimizer.step()
        
        total_loss += loss.item()
    
    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {total_loss / len(dataloader)}")

43


  labels = torch.cat([torch.tensor(label, dtype=torch.long) for label in labels]).to(device)


39
42
31
59
41
79
39
40
34
51
51
24
45
36
102
111
68
51
35
65
37
104
67
39
65
23
74
42
39
42
31
58
55
27
51
35
37
35
64
56
47
33
36
45
42
24
58
78
48
Epoch [1/10], Loss: nan
78
102
59
79
37
104
39
48
51
55
111
51
56
51
46
64
27
39
42
49
42
42
58
42
31
62
29
34
45
24
40
58
42
28
30
51
45
39
64
65
39
33
42
35


KeyboardInterrupt: 