In [1]:
# OCR Model from Scratch - Using SROIE Dataset

# 1. Importy i konfiguracja
!pip install py7zr
import os
import torch
import torchvision.transforms as T
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import torch.optim as optim
from PIL import Image
from datasets import Dataset as HFDataset
import json
from tqdm import tqdm
from huggingface_hub import hf_hub_download
import py7zr

Collecting py7zr
  Downloading py7zr-1.0.0-py3-none-any.whl.metadata (17 kB)
Collecting texttable (from py7zr)
  Downloading texttable-1.7.0-py2.py3-none-any.whl.metadata (9.8 kB)
Collecting brotli>=1.1.0 (from py7zr)
  Downloading Brotli-1.1.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (5.5 kB)
Collecting pyzstd>=0.16.1 (from py7zr)
  Downloading pyzstd-0.17.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (2.5 kB)
Collecting pyppmd<1.3.0,>=1.1.0 (from py7zr)
  Downloading pyppmd-1.2.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (5.4 kB)
Collecting pybcj<1.1.0,>=1.0.0 (from py7zr)
  Downloading pybcj-1.0.6-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (3.7 kB)
Collecting multivolumefile>=0.2.3 (from py7zr)
  Downloading multivolumefile-0.2.3-py3-none-any.whl.metadata (6.3 kB)
Collecting inflate64<1.1.0,>=1.0.0 (from py7zr)
  Downloading inflate64-1.0.3-cp311-cp311-manylinux_2_17_x86_64.manylinu

In [2]:
!nvidia-smi

Tue Jun 10 19:34:04 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.54.15              Driver Version: 550.54.15      CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  Tesla T4                       Off |   00000000:00:04.0 Off |                    0 |
| N/A   34C    P8              9W /   70W |       0MiB /  15360MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
                                                

In [15]:
# 2. Parametry
BATCH_SIZE = 32
IMAGE_SIZE = (128, 512)
#IMAGE_SIZE = (64, 256)
MAX_TEXT_LENGTH = 64
EPOCHS = 10
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

In [4]:
# 3. Pobieranie i rozpakowanie danych
from google.colab import files

uploaded = files.upload()  # wybierz plik sroie_data.zip z dysku


Saving sroie_data.7z to sroie_data.7z


In [5]:


with py7zr.SevenZipFile("sroie_data.7z", mode='r') as archive:
    archive.extractall(path="sroie_data")


In [16]:
# 4. Wczytanie metadanych

def load_split(split):
    data = []
    folder = f"sroie_data/sroie_data/{split}"
    meta_path = os.path.join(folder, "metadata.jsonl")
    with open(meta_path, "r", encoding="utf-8") as f:
        for line in f:
            rec = json.loads(line)
            rec["image"] = os.path.join(folder, rec["file_name"])
            data.append(rec)
    return HFDataset.from_list(data)

raw_train = load_split("train")
raw_test = load_split("test")

In [17]:
# 5. Dataset i przetwarzanie danych
class OCRDataset(Dataset):
    def __init__(self, hf_dataset_split, transform=None):
        self.dataset = hf_dataset_split
        self.transform = transform

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        item = self.dataset[idx]
        image = Image.open(item["image"]).convert("L")
        if self.transform:
            image = self.transform(image)
        text = item['text'].lower()
        return image, text

transform = T.Compose([
    T.Resize(IMAGE_SIZE),
    T.ToTensor(),
    T.Normalize((0.5,), (0.5,))
])

train_ds = OCRDataset(raw_train, transform=transform)
test_ds = OCRDataset(raw_test, transform=transform)

train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True)
test_loader = DataLoader(test_ds, batch_size=BATCH_SIZE)

In [18]:
# 6. Tokenizacja
import string
VOCAB = list(string.ascii_lowercase + string.digits + " .,;:/-'")
VOCAB_DICT = {c: i+1 for i, c in enumerate(VOCAB)}
VOCAB_DICT['<eos>'] = len(VOCAB_DICT) + 1
INV_VOCAB_DICT = {v: k for k, v in VOCAB_DICT.items()}

def text_to_seq(text):
    seq = [VOCAB_DICT.get(c, 0) for c in text[:MAX_TEXT_LENGTH]]
    #seq.append(VOCAB_DICT['<eos>'])
    #seq += [0] * (MAX_TEXT_LENGTH + 1 - len(seq))
    return torch.tensor(seq)

def seq_to_text(seq):
    chars = [INV_VOCAB_DICT.get(i.item(), '') for i in seq if i.item() > 0]
    return ''.join(chars).replace('<eos>', '')



In [19]:
# 7. Model (CNN + RNN + CTC)
class CRNN(nn.Module):
    def __init__(self, vocab_size):
        super(CRNN, self).__init__()
        self.cnn = nn.Sequential(
            nn.Conv2d(1, 64, 3, 1, 1), nn.ReLU(), nn.MaxPool2d(2, 2),
            nn.Conv2d(64, 128, 3, 1, 1), nn.ReLU(), nn.MaxPool2d(2, 2)
        )
        self.rnn = nn.LSTM(128 * (IMAGE_SIZE[0] // 4), 256, bidirectional=True, batch_first=True)
        self.fc = nn.Linear(512, vocab_size + 1)  # +1 for blank in CTC

    def forward(self, x):
        x = self.cnn(x)
        b, c, h, w = x.size()
        x = x.permute(0, 3, 1, 2)
        x = x.reshape(b, w, -1)
        x, _ = self.rnn(x)
        x = self.fc(x)
        return x

model = CRNN(len(VOCAB_DICT)).to(DEVICE)
criterion = nn.CTCLoss(blank=0, zero_infinity=True)
optimizer = optim.Adam(model.parameters(), lr=0.001)


In [24]:
# 8. Trening
for epoch in range(EPOCHS):
    model.train()
    epoch_loss = 0
    for imgs, texts in tqdm(train_loader, desc=f"Epoch {epoch+1}"):
        imgs = imgs.to(DEVICE)

        targets_raw = [text_to_seq(t) for t in texts]
        targets = torch.cat(targets_raw).to(DEVICE)
        target_lengths = torch.tensor([len(t) for t in targets_raw], dtype=torch.long).to(DEVICE)

        preds = model(imgs)
        preds = preds.log_softmax(2).permute(1, 0, 2)

        #input_lengths = torch.full((preds.size(1),), preds.size(0), dtype=torch.long).to(DEVICE)
        input_lengths = torch.full((imgs.size(0),), preds.size(0), dtype=torch.long).to(DEVICE)

        loss = criterion(preds, targets, input_lengths, target_lengths)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        epoch_loss += loss.item()

    print(f"Epoch {epoch+1} Loss: {epoch_loss / len(train_loader):.4f}")


Epoch 1: 100%|██████████| 1051/1051 [03:56<00:00,  4.44it/s]


Epoch 1 Loss: -0.1330


Epoch 2: 100%|██████████| 1051/1051 [03:57<00:00,  4.43it/s]


Epoch 2 Loss: -0.1378


Epoch 3: 100%|██████████| 1051/1051 [03:56<00:00,  4.45it/s]


Epoch 3 Loss: -0.1480


Epoch 4: 100%|██████████| 1051/1051 [03:55<00:00,  4.46it/s]


Epoch 4 Loss: -0.1630


Epoch 5: 100%|██████████| 1051/1051 [03:56<00:00,  4.44it/s]


Epoch 5 Loss: -0.1619


Epoch 6: 100%|██████████| 1051/1051 [03:56<00:00,  4.45it/s]


Epoch 6 Loss: -0.1606


Epoch 7: 100%|██████████| 1051/1051 [03:55<00:00,  4.45it/s]


Epoch 7 Loss: -0.1448


Epoch 8: 100%|██████████| 1051/1051 [03:56<00:00,  4.45it/s]


Epoch 8 Loss: -0.1664


Epoch 9: 100%|██████████| 1051/1051 [03:56<00:00,  4.44it/s]


Epoch 9 Loss: -0.1788


Epoch 10: 100%|██████████| 1051/1051 [03:56<00:00,  4.44it/s]

Epoch 10 Loss: -0.1706





In [25]:

# 9. Inference i testowanie
def ctc_decode(seq):
    # seq: tensor([timestep preds])
    tokens = []
    prev = -1
    for i in seq:
        i = i.item()
        if i != prev and i != 0:
            tokens.append(i)
        prev = i
    return ''.join([INV_VOCAB_DICT.get(t, '') for t in tokens])

model.eval()
with torch.no_grad():
    for imgs, texts in test_loader:
        imgs = imgs.to(DEVICE)
        preds = model(imgs)
        pred_seq = preds.softmax(2).argmax(2)
        for i in range(min(5, len(pred_seq))):
            #print("Pred:", seq_to_text(pred_seq[i]))
            print("Pred:", ctc_decode(pred_seq[i]))
            print("True:", texts[i])
            print("---")
        break

Pred: tan chay yee
True: tan chay yee
---
Pred:  copy 
True: *** copy ***
---
Pred: ojc marketing sdn bhd
True: ojc marketing sdn bhd
---
Pred: roc no: 53835-h
True: roc no: 538358-h
---
Pred: no z 84, jalan bayu 4,
True: no 2 & 4, jalan bayu 4,
---


In [26]:
# 10. Zapis wytrenowanego modelu
torch.save(model.state_dict(), "crnn_sroie.pth")
print("Model zapisany jako crnn_sroie.pth")

Model zapisany jako crnn_sroie.pth


In [None]:
# 11. Wczytanie modelu z pliku
model = CRNN(len(VOCAB_DICT)).to(DEVICE)
model.load_state_dict(torch.load("crnn_sroie.pth", map_location=DEVICE))
model.eval()