In [62]:
import os
import pandas as pd
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import torchvision.transforms as T
from tqdm import tqdm


print("Torch version:", torch.__version__)
print("CUDA available:", torch.cuda.is_available())
if torch.cuda.is_available():
    print("Device:", torch.cuda.get_device_name(0))


Torch version: 2.5.1+cu121
CUDA available: True
Device: NVIDIA GeForce GTX 1650


In [63]:
BASE_DIR = "./data/word_nglegena_20260102_155715"

TRAIN_CSV_PATH = f"{BASE_DIR}/train_aug.csv"
VAL_CSV_PATH   = f"{BASE_DIR}/val_aug.csv"
TEST_CSV_PATH  = f"{BASE_DIR}/test_aug.csv"
IMAGE_DIR      = f"{BASE_DIR}/image_aug/"


NGLEGENA = [
    ("ꦲ","ha"), ("ꦤ","na"), ("ꦕ","ca"), ("ꦫ","ra"), ("ꦏ","ka"),
    ("ꦢ","da"), ("ꦠ","ta"), ("ꦱ","sa"), ("ꦮ","wa"), ("ꦭ","la"),
    ("ꦥ","pa"), ("ꦝ","dha"), ("ꦗ","ja"), ("ꦪ","ya"), ("ꦚ","nya"),
    ("ꦩ","ma"), ("ꦒ","ga"), ("ꦧ","ba"), ("ꦛ","tha"), ("ꦔ","nga"),
]


char_list = [c[0] for c in NGLEGENA]
char2idx = {c: i+1 for i, c in enumerate(char_list)}
idx2char = {i+1: c for i, c in enumerate(char_list)}
NUM_CLASSES = len(char_list) + 1


In [64]:
class JavaneseOCRDataset(Dataset):
    def __init__(self, csv_path, img_dir, img_height=32, img_width=128):
        self.df = pd.read_csv(csv_path)
        self.img_dir = img_dir

        self.transform = T.Compose([
            T.Grayscale(1),
            T.Resize((img_height, img_width)),
            T.ToTensor(),
            T.Normalize(mean=[0.5], std=[0.5])
        ])

    def encode(self, text):
        return torch.tensor(
            [char2idx[c] for c in text if c in char2idx],
            dtype=torch.long
        )

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]

        img_path = os.path.join(self.img_dir, row["image"])
        image = Image.open(img_path).convert("L")
        image = self.transform(image)

        label = self.encode(row["transcription"])
        label_len = len(label)

        return image, label, label_len


def ctc_collate_fn(batch):
    images, labels, label_lens = zip(*batch)

    images = torch.stack(images)
    labels = torch.cat(labels)
    label_lens = torch.tensor(label_lens, dtype=torch.long)

    return images, labels, label_lens


In [65]:
class CRNN(nn.Module):
    def __init__(self, num_classes):
        super().__init__()

        self.cnn = nn.Sequential(
            nn.Conv2d(1, 64, 3, 1, 1),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),

            nn.Conv2d(64, 128, 3, 1, 1),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),

            nn.Conv2d(128, 256, 3, 1, 1),
            nn.BatchNorm2d(256),
            nn.ReLU(),

            nn.Conv2d(256, 256, 3, 1, 1),
            nn.ReLU(),
            nn.MaxPool2d((2, 1)),

            nn.Conv2d(256, 512, 3, 1, 1),
            nn.BatchNorm2d(512),
            nn.ReLU(),

            nn.MaxPool2d((2, 1)),
            nn.AdaptiveAvgPool2d((1, None))
        )

        # FIX: Define multi-layer LSTM here instead of Sequential
        self.rnn = nn.LSTM(
            input_size=512, 
            hidden_size=256, 
            num_layers=2, 
            bidirectional=True, 
            batch_first=True
        )

        # Bidirectional LSTM output size is hidden_size * 2
        self.fc = nn.Linear(512, num_classes)

    def forward(self, x):
        # x shape: [batch, 1, 32, 128]
        x = self.cnn(x) 
        
        # After CNN & Adaptive Pool: [batch, 512, 1, width_reduced]
        x = x.squeeze(2)          # [batch, 512, width_reduced]
        x = x.permute(0, 2, 1)    # [batch, width_reduced, 512]
        
        # FIX: LSTM returns (output, hidden_state), we only need output
        x, _ = self.rnn(x)        # [batch, width_reduced, 512]
        
        x = self.fc(x)            # [batch, width_reduced, num_classes]
        return x

In [66]:
def ctc_decode(logits):
    preds = logits.argmax(2)
    texts = []

    for pred in preds:
        prev = 0
        text = []
        for p in pred:
            p = p.item()
            if p != prev and p != 0:
                text.append(idx2char[p])
            prev = p
        texts.append("".join(text))
    return texts


def decode_targets(labels, label_lens):
    texts = []
    idx = 0
    for l in label_lens:
        seq = labels[idx:idx+l].tolist()
        texts.append("".join(idx2char[i] for i in seq))
        idx += l
    return texts


In [67]:
def levenshtein(a, b):
    n, m = len(a), len(b)
    if n == 0: return m
    if m == 0: return n

    dp = [[0] * (m + 1) for _ in range(n + 1)]
    for i in range(n + 1):
        dp[i][0] = i
    for j in range(m + 1):
        dp[0][j] = j

    for i in range(1, n + 1):
        for j in range(1, m + 1):
            cost = 0 if a[i-1] == b[j-1] else 1
            dp[i][j] = min(
                dp[i-1][j] + 1,
                dp[i][j-1] + 1,
                dp[i-1][j-1] + cost
            )
    return dp[n][m]


def cer(preds, refs):
    dist, total = 0, 0
    for p, r in zip(preds, refs):
        dist += levenshtein(p, r)
        total += len(r)
    return dist / max(total, 1)


def wer(preds, refs):
    dist, total = 0, 0
    for p, r in zip(preds, refs):
        dist += levenshtein(list(p), list(r))
        total += len(r)
    return dist / max(total, 1)


In [68]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Device: {device}")

model = CRNN(NUM_CLASSES).to(device)

criterion = nn.CTCLoss(blank=0, zero_infinity=True)

optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode="min", factor=0.5, patience=3
)


Device: cuda


In [69]:
def train_one_epoch(model, loader):
    model.train()
    total_loss = 0

    loop = tqdm(loader, total=len(loader), desc="Training", leave=False)

    for images, labels, label_lens in loop:
        images = images.to(device)
        labels = labels.to(device)
        label_lens = label_lens.to(device)

        logits = model(images)
        log_probs = logits.log_softmax(2)

        T = log_probs.size(1)
        input_lens = torch.full(
            (images.size(0),), T, dtype=torch.long, device=device
        )

        assert T >= label_lens.max(), "CTC input length < target length"

        loss = criterion(
            log_probs.permute(1, 0, 2),
            labels,
            input_lens,
            label_lens
        )

        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 5.0)
        optimizer.step()

        total_loss += loss.item()

        loop.set_postfix(loss=total_loss / (loop.n + 1))

    return total_loss / len(loader)


In [70]:
def validate(model, loader):
    model.eval()
    total_loss = 0
    all_preds, all_refs = [], []

    with torch.no_grad():
        for images, labels, label_lens in loader:
            images = images.to(device)
            labels = labels.to(device)
            label_lens = label_lens.to(device)

            logits = model(images)
            log_probs = logits.log_softmax(2)

            T = log_probs.size(1)
            input_lens = torch.full(
                (images.size(0),), T, dtype=torch.long, device=device
            )

            loss = criterion(
                log_probs.permute(1, 0, 2),
                labels,
                input_lens,
                label_lens
            )

            preds = ctc_decode(logits)
            refs = decode_targets(labels, label_lens)

            all_preds.extend(preds)
            all_refs.extend(refs)
            total_loss += loss.item()

    val_loss = total_loss / len(loader)
    return val_loss, cer(all_preds, all_refs), wer(all_preds, all_refs)


In [72]:
EPOCHS = 60

patience = 8
trigger_times = 0

best_val_loss = float("inf")
best_cer = float("inf")

train_ds = JavaneseOCRDataset(csv_path=TRAIN_CSV_PATH, img_dir=IMAGE_DIR)
val_ds   = JavaneseOCRDataset(csv_path=VAL_CSV_PATH, img_dir=IMAGE_DIR)

train_loader = DataLoader(
    train_ds,
    batch_size=32,
    shuffle=True,
    collate_fn=ctc_collate_fn,
    pin_memory=True,
    num_workers=0 # Trouble if set 2 or more
)

val_loader = DataLoader(
    val_ds,
    batch_size=32,
    shuffle=False,
    collate_fn=ctc_collate_fn,
    pin_memory=True,
    num_workers=0 # Trouble if set 2 or more
)

for epoch in range(EPOCHS):
    train_loss = train_one_epoch(model, train_loader)
    val_loss, val_cer, val_wer = validate(model, val_loader)

    scheduler.step(val_loss)

    print(
        f"Epoch {epoch+1:02d} | "
        f"Train Loss: {train_loss:.4f} | "
        f"Val Loss: {val_loss:.4f} | "
        f"CER: {val_cer:.4f} | "
        f"WER: {val_wer:.4f}"
    )

    if val_cer < best_cer:
        best_cer = val_cer
        torch.save(model.state_dict(), "best_crnn_nglegena.pt")
        print("Saved best model!")

for epoch in range(EPOCHS):
    train_loss = train_one_epoch(model, train_loader)
    val_loss, val_cer, val_wer = validate(model, val_loader)

    scheduler.step(val_loss)

    print(
        f"Epoch {epoch+1:02d} | "
        f"Train Loss: {train_loss:.4f} | "
        f"Val Loss: {val_loss:.4f} | "
        f"CER: {val_cer:.4f} | "
        f"WER: {val_wer:.4f}"
    )

    if val_cer < best_cer:
        best_cer = val_cer
        torch.save(model.state_dict(), "best_crnn_nglegena.pt")
        print("Saved best model (CER improved)!")

    if val_loss < best_val_loss:
        best_val_loss = val_loss
        trigger_times = 0
    else:
        trigger_times += 1

        if trigger_times >= patience:
            print("Early stopping triggered!")
            break 


Training:   0%|          | 0/63 [00:00<?, ?it/s]

                                                                       

Epoch 01 | Train Loss: 0.0037 | Val Loss: 0.2340 | CER: 0.0509 | WER: 0.0509
Saved best model!


                                                                       

Epoch 02 | Train Loss: 0.0039 | Val Loss: 0.2343 | CER: 0.0509 | WER: 0.0509


                                                                       

Epoch 03 | Train Loss: 0.0038 | Val Loss: 0.2384 | CER: 0.0509 | WER: 0.0509


                                                                       

Epoch 04 | Train Loss: 0.0039 | Val Loss: 0.2388 | CER: 0.0515 | WER: 0.0515


                                                                       

Epoch 05 | Train Loss: 0.0038 | Val Loss: 0.2349 | CER: 0.0509 | WER: 0.0509


                                                                       

Epoch 06 | Train Loss: 0.0036 | Val Loss: 0.2333 | CER: 0.0503 | WER: 0.0503
Saved best model!


                                                                       

Epoch 07 | Train Loss: 0.0037 | Val Loss: 0.2403 | CER: 0.0521 | WER: 0.0521


                                                                       

Epoch 08 | Train Loss: 0.0036 | Val Loss: 0.2308 | CER: 0.0491 | WER: 0.0491
Saved best model!


                                                                       

Epoch 09 | Train Loss: 0.0037 | Val Loss: 0.2344 | CER: 0.0515 | WER: 0.0515


                                                                       

Epoch 10 | Train Loss: 0.0036 | Val Loss: 0.2354 | CER: 0.0521 | WER: 0.0521


                                                                       

Epoch 11 | Train Loss: 0.0037 | Val Loss: 0.2353 | CER: 0.0509 | WER: 0.0509


                                                                       

Epoch 12 | Train Loss: 0.0036 | Val Loss: 0.2328 | CER: 0.0503 | WER: 0.0503


                                                                       

Epoch 13 | Train Loss: 0.0036 | Val Loss: 0.2369 | CER: 0.0527 | WER: 0.0527


                                                                       

Epoch 14 | Train Loss: 0.0037 | Val Loss: 0.2343 | CER: 0.0509 | WER: 0.0509


                                                                       

Epoch 15 | Train Loss: 0.0036 | Val Loss: 0.2378 | CER: 0.0509 | WER: 0.0509


                                                                       

Epoch 16 | Train Loss: 0.0036 | Val Loss: 0.2329 | CER: 0.0503 | WER: 0.0503


                                                                       

Epoch 17 | Train Loss: 0.0036 | Val Loss: 0.2326 | CER: 0.0503 | WER: 0.0503


                                                                       

Epoch 18 | Train Loss: 0.0036 | Val Loss: 0.2331 | CER: 0.0509 | WER: 0.0509


                                                                       

Epoch 19 | Train Loss: 0.0036 | Val Loss: 0.2299 | CER: 0.0491 | WER: 0.0491


                                                                       

Epoch 20 | Train Loss: 0.0036 | Val Loss: 0.2356 | CER: 0.0515 | WER: 0.0515


                                                                       

Epoch 21 | Train Loss: 0.0036 | Val Loss: 0.2330 | CER: 0.0509 | WER: 0.0509


                                                                       

Epoch 22 | Train Loss: 0.0036 | Val Loss: 0.2369 | CER: 0.0521 | WER: 0.0521


                                                                       

Epoch 23 | Train Loss: 0.0036 | Val Loss: 0.2370 | CER: 0.0503 | WER: 0.0503


                                                                       

Epoch 24 | Train Loss: 0.0036 | Val Loss: 0.2374 | CER: 0.0521 | WER: 0.0521


                                                                       

Epoch 25 | Train Loss: 0.0036 | Val Loss: 0.2370 | CER: 0.0509 | WER: 0.0509


                                                                       

Epoch 26 | Train Loss: 0.0036 | Val Loss: 0.2331 | CER: 0.0515 | WER: 0.0515


                                                                       

Epoch 27 | Train Loss: 0.0036 | Val Loss: 0.2326 | CER: 0.0527 | WER: 0.0527


                                                                       

Epoch 28 | Train Loss: 0.0036 | Val Loss: 0.2414 | CER: 0.0521 | WER: 0.0521


                                                                       

Epoch 29 | Train Loss: 0.0036 | Val Loss: 0.2411 | CER: 0.0533 | WER: 0.0533


                                                                       

Epoch 30 | Train Loss: 0.0036 | Val Loss: 0.2329 | CER: 0.0509 | WER: 0.0509


                                                                       

Epoch 31 | Train Loss: 0.0036 | Val Loss: 0.2323 | CER: 0.0509 | WER: 0.0509


                                                                       

Epoch 32 | Train Loss: 0.0036 | Val Loss: 0.2358 | CER: 0.0521 | WER: 0.0521


                                                                       

Epoch 33 | Train Loss: 0.0038 | Val Loss: 0.2328 | CER: 0.0521 | WER: 0.0521


                                                                       

Epoch 34 | Train Loss: 0.0036 | Val Loss: 0.2377 | CER: 0.0521 | WER: 0.0521


                                                                       

Epoch 35 | Train Loss: 0.0036 | Val Loss: 0.2352 | CER: 0.0503 | WER: 0.0503


                                                                       

Epoch 36 | Train Loss: 0.0036 | Val Loss: 0.2357 | CER: 0.0521 | WER: 0.0521


                                                                       

Epoch 37 | Train Loss: 0.0036 | Val Loss: 0.2362 | CER: 0.0515 | WER: 0.0515


                                                                       

Epoch 38 | Train Loss: 0.0036 | Val Loss: 0.2394 | CER: 0.0521 | WER: 0.0521


                                                                       

Epoch 39 | Train Loss: 0.0036 | Val Loss: 0.2376 | CER: 0.0527 | WER: 0.0527


                                                                       

Epoch 40 | Train Loss: 0.0036 | Val Loss: 0.2356 | CER: 0.0509 | WER: 0.0509


                                                                       

Epoch 41 | Train Loss: 0.0035 | Val Loss: 0.2356 | CER: 0.0527 | WER: 0.0527


                                                                       

Epoch 42 | Train Loss: 0.0035 | Val Loss: 0.2365 | CER: 0.0527 | WER: 0.0527


                                                                       

Epoch 43 | Train Loss: 0.0036 | Val Loss: 0.2306 | CER: 0.0503 | WER: 0.0503


                                                                       

Epoch 44 | Train Loss: 0.0038 | Val Loss: 0.2362 | CER: 0.0527 | WER: 0.0527


                                                                       

Epoch 45 | Train Loss: 0.0036 | Val Loss: 0.2353 | CER: 0.0515 | WER: 0.0515


                                                                       

Epoch 46 | Train Loss: 0.0035 | Val Loss: 0.2367 | CER: 0.0521 | WER: 0.0521


                                                                       

Epoch 47 | Train Loss: 0.0036 | Val Loss: 0.2400 | CER: 0.0533 | WER: 0.0533


                                                                       

Epoch 48 | Train Loss: 0.0036 | Val Loss: 0.2354 | CER: 0.0527 | WER: 0.0527


                                                                       

Epoch 49 | Train Loss: 0.0037 | Val Loss: 0.2358 | CER: 0.0521 | WER: 0.0521


                                                                       

Epoch 50 | Train Loss: 0.0036 | Val Loss: 0.2422 | CER: 0.0515 | WER: 0.0515


                                                                       

Epoch 51 | Train Loss: 0.0036 | Val Loss: 0.2400 | CER: 0.0521 | WER: 0.0521


                                                                       

Epoch 52 | Train Loss: 0.0036 | Val Loss: 0.2398 | CER: 0.0527 | WER: 0.0527


                                                                       

Epoch 53 | Train Loss: 0.0037 | Val Loss: 0.2339 | CER: 0.0509 | WER: 0.0509


                                                                       

Epoch 54 | Train Loss: 0.0036 | Val Loss: 0.2330 | CER: 0.0503 | WER: 0.0503


                                                                       

Epoch 55 | Train Loss: 0.0036 | Val Loss: 0.2333 | CER: 0.0515 | WER: 0.0515


                                                                       

Epoch 56 | Train Loss: 0.0036 | Val Loss: 0.2347 | CER: 0.0515 | WER: 0.0515


                                                                       

Epoch 57 | Train Loss: 0.0036 | Val Loss: 0.2441 | CER: 0.0515 | WER: 0.0515


                                                                       

Epoch 58 | Train Loss: 0.0037 | Val Loss: 0.2350 | CER: 0.0521 | WER: 0.0521


                                                                       

Epoch 59 | Train Loss: 0.0036 | Val Loss: 0.2386 | CER: 0.0509 | WER: 0.0509


                                                                       

Epoch 60 | Train Loss: 0.0036 | Val Loss: 0.2427 | CER: 0.0527 | WER: 0.0527


                                                                       

Epoch 01 | Train Loss: 0.0036 | Val Loss: 0.2375 | CER: 0.0509 | WER: 0.0509


                                                                       

Epoch 02 | Train Loss: 0.0036 | Val Loss: 0.2335 | CER: 0.0515 | WER: 0.0515


                                                                       

Epoch 03 | Train Loss: 0.0036 | Val Loss: 0.2379 | CER: 0.0533 | WER: 0.0533


                                                                       

Epoch 04 | Train Loss: 0.0036 | Val Loss: 0.2376 | CER: 0.0509 | WER: 0.0509


                                                                       

Epoch 05 | Train Loss: 0.0036 | Val Loss: 0.2385 | CER: 0.0521 | WER: 0.0521


                                                                       

Epoch 06 | Train Loss: 0.0036 | Val Loss: 0.2365 | CER: 0.0521 | WER: 0.0521


                                                                       

Epoch 07 | Train Loss: 0.0036 | Val Loss: 0.2396 | CER: 0.0527 | WER: 0.0527


                                                                       

Epoch 08 | Train Loss: 0.0037 | Val Loss: 0.2356 | CER: 0.0515 | WER: 0.0515


                                                                       

Epoch 09 | Train Loss: 0.0036 | Val Loss: 0.2412 | CER: 0.0515 | WER: 0.0515


                                                                       

Epoch 10 | Train Loss: 0.0037 | Val Loss: 0.2439 | CER: 0.0527 | WER: 0.0527
Early stopping triggered!


In [None]:
def test_model(model, loader, num_samples=5):
    model.eval()
    all_preds, all_refs = [], []
    samples_shown = 0

    print(f"{'PREDICTION':<20} | {'GROUND TRUTH':<20} | {'STATUS'}")
    print("-" * 60)

    with torch.no_grad():
        for images, labels, label_lens in loader:
            images = images.to(device)
            labels = labels.to(device)
            label_lens = label_lens.to(device)

            logits = model(images)
            
            preds = ctc_decode(logits)
            refs = decode_targets(labels, label_lens)

            all_preds.extend(preds)
            all_refs.extend(refs)

            if samples_shown < num_samples:
                for p, r in zip(preds, refs):
                    if samples_shown < num_samples:
                        status = "✅" if p == r else "❌"
                        print(f"{p:<20} | {r:<20} | {status}")
                        samples_shown += 1

    final_cer = cer(all_preds, all_refs)
    final_wer = wer(all_preds, all_refs)

    print("-" * 60)
    print(f"Test CER: {final_cer:.4f}")
    print(f"Test WER: {final_wer:.4f}")
    
    return final_cer, final_wer


In [73]:
test_ds = JavaneseOCRDataset(csv_path=TEST_CSV_PATH, img_dir=IMAGE_DIR)

test_loader = DataLoader(
    test_ds,
    batch_size=16,
    shuffle=False,
    collate_fn=ctc_collate_fn,
    num_workers=0 # Trouble if set 2 or more
)

model.load_state_dict(torch.load("best_crnn_nglegena.pt", weights_only=False))
test_model(model, test_loader)


PREDICTION           | GROUND TRUTH         | STATUS
------------------------------------------------------------
ꦭꦮ                   | ꦭꦮ                   | ✅
ꦭꦮ                   | ꦭꦮ                   | ✅
ꦭꦮ                   | ꦭꦮ                   | ✅
ꦭꦪꦕꦱꦧ                | ꦭꦪꦕꦱꦧ                | ✅
ꦭꦪꦕꦱꦧ                | ꦭꦪꦕꦱꦧ                | ✅
------------------------------------------------------------
Test CER: 0.0673
Test WER: 0.0673


(0.06726726726726727, 0.06726726726726727)