In [1]:
import sys
import os
import torch 
import torch.nn as nn 
import torch.optim as optim 
from tqdm import tqdm

# Go up one directory to reach the root
project_root = os.path.abspath(os.path.join(os.getcwd(), '..'))
if project_root not in sys.path:
    sys.path.append(project_root)

from torch.utils.data import DataLoader
from data import collate_fn
from data import HWLD
from model import CRNN
from data import create_label


In [2]:
create_label.create_dataset_json(
    image_dir="data/cropped",
    label_dir="data/labels",
    output_file='dataset.json'
)


create_label.convert_json_to_csv("dataset.json","dataset.csv")

✅ Saved 88 items to dataset.json
✅ Converted to dataset.csv with 88 entries.


In [3]:
dataset = HWLD.HandwritingLineDataset("dataset.csv")
dataloader = DataLoader(dataset, batch_size=4, shuffle=True, collate_fn=collate_fn.ctc_collate_fn)

charset = dataset.get_charset()
num_classes = len(charset) + 1 


model = CRNN.CRNN(img_height=32, num_classes=num_classes)

sample_batch = next(iter(dataloader))
logits = model(sample_batch["images"])

In [4]:
def cer(preds, targets):
    import editdistance
    total_dist, total_chars = 0, 0
    for p, t in zip(preds, targets):
        dist = editdistance.eval(p, t)
        total_dist += dist
        total_chars += len(t)
    return total_dist / total_chars if total_chars > 0 else 1.0

In [5]:
def decode_predictions(logits, charset, blank=0):
    # Greedy decode, collapse repeats and remove blanks
    probs = logits.softmax(2)
    pred_indices = probs.argmax(2)  # [T, B]
    pred_indices = pred_indices.permute(1, 0)  # [B, T]

    results = []
    for seq in pred_indices:
        prev = blank
        text = ""
        for idx in seq:
            idx = idx.item()
            if idx != blank and idx != prev:
                text += charset[idx]
            prev = idx
        results.append(text)
    return results

In [None]:
def train(model, train_loader, val_loader, charset,num_epochs=20,lr=1e-3,device="cuda" if torch.cuda.is_available() else "cpu",save_path="crnn_best.pth"):
    model = model.to(device)
    optimizer = optim.Adam(model.parameters(), lr=lr)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=3, verbose=True)
    criterion = nn.CTCLoss(blank=0, zero_infinity=True)

    best_val_cer = float("inf")

    for epoch in range(num_epochs):
        model.train()
        epoch_loss = 0.0

        loop = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}", leave=False)
        for batch in loop:
            images = batch["images"].to(device)  # [B, 1, 32, W]
            targets = batch["label"].to(device)
            label_lengths = batch["label_lengths"].to(device)
            input_lengths = batch["input_lengths"].to(device)
            #label_lengths = torch.tensor([len(l) for l in labels], dtype=torch.long)

            # Flatten labels into a 1D tensor for CTC
            #targets = torch.cat(labels).to(device)

            logits = model(images)  # [T, B, C]
            #input_lengths = torch.full(size=(logits.size(1),), fill_value=logits.size(0), dtype=torch.long)

            loss = criterion(logits.log_softmax(2), targets, input_lengths, label_lengths)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            epoch_loss += loss.item()
            loop.set_postfix(loss=loss.item())

        # Validation step
        model.eval()
        all_preds, all_targets = [], []
        with torch.no_grad():
            for batch in val_loader:
                images = batch["images"].to(device)
                labels = batch["label_strs"]
                logits = model(images)
                decoded = decode_predictions(logits.cpu(), charset)
                all_preds.extend(decoded)
                all_targets.extend(labels)

        val_cer = cer(all_preds, all_targets)
        scheduler.step(val_cer)

        print(f"Epoch {epoch+1}: Train Loss = {epoch_loss / len(train_loader):.4f}, Val CER = {val_cer:.4f}")

        # Save best model
        if val_cer < best_val_cer:
            best_val_cer = val_cer
            torch.save(model.state_dict(), save_path)
            print(f"✅ Saved best model with CER = {val_cer:.4f}")

In [7]:
dataset = HWLD.HandwritingLineDataset("dataset.csv")
from sklearn.model_selection import train_test_split
from torch.utils.data import Subset

indices = list(range(len(dataset)))
train_idx, val_idx = train_test_split(indices, test_size=0.3, random_state=42)

train_data = Subset(dataset, train_idx)
val_data = Subset(dataset, val_idx)

#train_data, val_data = random_split(dataset, [62, 26])

train_loader = DataLoader(train_data, batch_size=4, shuffle=True, collate_fn=collate_fn.ctc_collate_fn)
val_loader = DataLoader(val_data, batch_size=4, shuffle=True, collate_fn=collate_fn.ctc_collate_fn)

#charset = train_loader.dataset.charset
charset = dataset.get_charset()
num_classes = len(charset) + 1
model = CRNN.CRNN(img_height=32, num_classes=num_classes)

#sample_batch = next(iter(train_loader))

for batch in train_loader:
    print("Image batch:", batch['images'].shape)
    print("Label concat:", batch['label'].shape)
    print("Label lengths:", batch['label_lengths'])
    
    offset = 0
    for l in batch['label_lengths']:
        print(batch['label'][offset:offset + l.item()])
        offset += l.item()
    break

labels = batch['label']
lengths = batch['label_lengths']

offset = 0
for i, length in enumerate(lengths):
    label = labels[offset:offset + length]
    print(f"Label {i}: {label}, shape: {label.shape}")
    offset += length

train(model=model,train_loader=train_loader,val_loader=val_loader,charset=charset,num_epochs=5)



Image batch: torch.Size([4, 1, 32, 838])
Label concat: torch.Size([203])
Label lengths: tensor([38, 54, 56, 55])
tensor([33, 25,  1, 38, 27, 24,  1, 25, 24, 20, 38, 39, 36, 24, 37,  1, 38, 27,
        20, 38,  1, 27, 20, 40, 24,  1, 21, 24, 24, 32,  1, 24, 32, 22, 33, 23,
        24, 23])
tensor([38, 27, 24,  1,  5, 27, 39, 36, 22, 27,  1, 20, 32, 23,  1, 27, 20, 37,
         1, 38, 27, 24,  1, 34, 33, 41, 24, 36,  1, 38, 33,  1, 31, 20, 29, 24,
         1, 21, 28, 32, 23, 28, 32, 26,  1, 23, 24, 22, 28, 37, 28, 33, 32, 37])
tensor([22, 33, 31, 34, 36, 24, 37, 37, 28, 33, 32,  1, 20, 32, 23,  1, 36, 24,
        23, 39, 22, 38, 28, 33, 32,  1, 38, 27, 24, 36, 24, 21, 43,  1, 24, 32,
        27, 20, 32, 22, 28, 32, 26,  1, 38, 36, 20, 32, 37, 31, 28, 37, 37, 28,
        33, 32])
tensor([20, 22, 22, 39, 36, 20, 22, 43,  1, 33, 25,  1, 38, 36, 20, 32, 37, 31,
        28, 38, 38, 24, 23,  1, 21, 28, 38, 37,  1, 17, 27, 28, 37,  1, 20, 34,
        34, 36, 33, 20, 22, 27,  1, 28, 32, 40, 33, 

                                                                     

KeyError: 'label_str'