In [86]:
# OCR Model from Scratch - Using SROIE Dataset

# 1. Importy i konfiguracja
!pip install py7zr
import os
import torch
import torchvision.transforms as T
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import torch.optim as optim
from PIL import Image
from datasets import Dataset as HFDataset
import json
from tqdm import tqdm
from huggingface_hub import hf_hub_download
import py7zr



In [87]:
!nvidia-smi

/bin/bash: line 1: nvidia-smi: command not found


In [88]:
# 2. Parametry
BATCH_SIZE = 32
IMAGE_SIZE = (128, 512)
#IMAGE_SIZE = (64, 256)
MAX_TEXT_LENGTH = 64
EPOCHS = 10
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

In [89]:
# 3. Pobieranie i rozpakowanie danych
from google.colab import files

uploaded = files.upload()  # wybierz plik sroie_data.zip z dysku


KeyboardInterrupt: 

In [None]:


with py7zr.SevenZipFile("sroie_data.7z", mode='r') as archive:
    archive.extractall(path="sroie_data")


In [None]:
# 4. Wczytanie metadanych

def load_split(split):
    data = []
    folder = f"sroie_data/sroie_data/{split}"
    meta_path = os.path.join(folder, "metadata.jsonl")
    with open(meta_path, "r", encoding="utf-8") as f:
        for line in f:
            rec = json.loads(line)
            rec["image"] = os.path.join(folder, rec["file_name"])
            data.append(rec)
    return HFDataset.from_list(data)

raw_train = load_split("train")
raw_test = load_split("test")

In [None]:
# 5. Dataset i przetwarzanie danych
class OCRDataset(Dataset):
    def __init__(self, hf_dataset_split, transform=None):
        self.dataset = hf_dataset_split
        self.transform = transform

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        item = self.dataset[idx]
        image = Image.open(item["image"]).convert("L")
        if self.transform:
            image = self.transform(image)
        text = item['text'].lower()
        return image, text

transform = T.Compose([
    T.Resize(IMAGE_SIZE),
    T.ToTensor(),
    T.Normalize((0.5,), (0.5,))
])

train_ds = OCRDataset(raw_train, transform=transform)
test_ds = OCRDataset(raw_test, transform=transform)

train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True)
test_loader = DataLoader(test_ds, batch_size=BATCH_SIZE)

In [None]:
# 6. Tokenizacja
import string
VOCAB = list(string.ascii_lowercase + string.digits + " .,;:/-'")
VOCAB_DICT = {c: i+1 for i, c in enumerate(VOCAB)}
VOCAB_DICT['<eos>'] = len(VOCAB_DICT) + 1
INV_VOCAB_DICT = {v: k for k, v in VOCAB_DICT.items()}

def text_to_seq(text):
    seq = [VOCAB_DICT.get(c, 0) for c in text[:MAX_TEXT_LENGTH]]
    #seq.append(VOCAB_DICT['<eos>'])
    #seq += [0] * (MAX_TEXT_LENGTH + 1 - len(seq))
    return torch.tensor(seq)

def seq_to_text(seq):
    chars = [INV_VOCAB_DICT.get(i.item(), '') for i in seq if i.item() > 0]
    return ''.join(chars).replace('<eos>', '')



In [None]:
# 7. Model (CNN + RNN + CTC)
class CRNN(nn.Module):
    def __init__(self, vocab_size):
        super(CRNN, self).__init__()
        self.cnn = nn.Sequential(
            nn.Conv2d(1, 64, 3, 1, 1), nn.ReLU(), nn.MaxPool2d(2, 2),
            nn.Conv2d(64, 128, 3, 1, 1), nn.ReLU(), nn.MaxPool2d(2, 2)
        )
        self.rnn = nn.LSTM(128 * (IMAGE_SIZE[0] // 4), 256, bidirectional=True, batch_first=True)
        self.fc = nn.Linear(512, vocab_size + 1)  # +1 for blank in CTC

    def forward(self, x):
        x = self.cnn(x)
        b, c, h, w = x.size()
        x = x.permute(0, 3, 1, 2)
        x = x.reshape(b, w, -1)
        x, _ = self.rnn(x)
        x = self.fc(x)
        return x

model = CRNN(len(VOCAB_DICT)).to(DEVICE)
criterion = nn.CTCLoss(blank=0, zero_infinity=True)
optimizer = optim.Adam(model.parameters(), lr=0.001)


In [None]:
# 8. Trening
for epoch in range(EPOCHS):
    model.train()
    epoch_loss = 0
    for imgs, texts in tqdm(train_loader, desc=f"Epoch {epoch+1}"):
        imgs = imgs.to(DEVICE)

        targets_raw = [text_to_seq(t) for t in texts]
        targets = torch.cat(targets_raw).to(DEVICE)
        target_lengths = torch.tensor([len(t) for t in targets_raw], dtype=torch.long).to(DEVICE)

        preds = model(imgs)
        preds = preds.log_softmax(2).permute(1, 0, 2)

        #input_lengths = torch.full((preds.size(1),), preds.size(0), dtype=torch.long).to(DEVICE)
        input_lengths = torch.full((imgs.size(0),), preds.size(0), dtype=torch.long).to(DEVICE)

        loss = criterion(preds, targets, input_lengths, target_lengths)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        epoch_loss += loss.item()

    print(f"Epoch {epoch+1} Loss: {epoch_loss / len(train_loader):.4f}")


In [None]:

# 9. Inference i testowanie
def ctc_decode(seq):
    # seq: tensor([timestep preds])
    tokens = []
    prev = -1
    for i in seq:
        i = i.item()
        if i != prev and i != 0:
            tokens.append(i)
        prev = i
    return ''.join([INV_VOCAB_DICT.get(t, '') for t in tokens])

model.eval()
with torch.no_grad():
    for imgs, texts in test_loader:
        imgs = imgs.to(DEVICE)
        preds = model(imgs)
        pred_seq = preds.softmax(2).argmax(2)
        for i in range(min(5, len(pred_seq))):
            #print("Pred:", seq_to_text(pred_seq[i]))
            print("Pred:", ctc_decode(pred_seq[i]))
            print("True:", texts[i])
            print("---")
        break

In [None]:
# 10. Zapis wytrenowanego modelu
torch.save(model.state_dict(), "crnn_sroie.pth")
print("Model zapisany jako crnn_sroie.pth")

In [186]:
# 11. Wczytanie modelu z pliku
model = CRNN(len(VOCAB_DICT)).to(DEVICE)
model.load_state_dict(torch.load("modelv1_10_epoch.pth", map_location=DEVICE))
model.eval()

CRNN(
  (cnn): Sequential(
    (0): Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU()
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (4): ReLU()
    (5): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (rnn): LSTM(4096, 256, batch_first=True, bidirectional=True)
  (fc): Linear(in_features=512, out_features=46, bias=True)
)

In [197]:
#dodaj zdjecie
from google.colab import files
uploaded = files.upload()  # wybierz plik obrazu

Saving Zrzut ekranu 2025-06-11 101312.png to Zrzut ekranu 2025-06-11 101312 (2).png


In [198]:
from PIL import Image
import torchvision.transforms as T

# Transformacja zgodna z treningiem
transform = T.Compose([
    T.Resize(IMAGE_SIZE),
    T.ToTensor(),
    T.Normalize((0.5,), (0.5,))
])

# Ścieżka do Twojego pliku, np. "my_image.png"
image_path = list(uploaded.keys())[0]  # zakładamy że jeden plik

# Wczytanie i przygotowanie obrazu
image = Image.open(image_path).convert("L")
image = transform(image).unsqueeze(0).to(DEVICE)  # batch size 1

In [162]:
def ctc_decode(seq):
    # seq: tensor([timestep preds])
    tokens = []
    prev = -1
    for i in seq:
        i = i.item()
        if i != prev and i != 0:
            tokens.append(i)
        prev = i
    return ''.join([INV_VOCAB_DICT.get(t, '') for t in tokens])

In [199]:
model.eval()
with torch.no_grad():
    preds = model(image)
    pred_seq = preds.softmax(2).argmax(2)[0]
    result = ctc_decode(pred_seq)
    print("Rozpoznany tekst:", result)


Rozpoznany tekst: parhgon flskuny
