In [1]:
import os
import zipfile

# Nếu bạn tải file zip, giải nén như sau:
zip_path = "/content/IIIT5K.zip"
extract_path = "/content/IIIT5K"
with zipfile.ZipFile(zip_path, 'r') as zip_ref:
    zip_ref.extractall(extract_path)
print("Đã giải nén IIIT 5K-word.")

Đã giải nén IIIT 5K-word.


In [10]:
import scipy.io
import os

def mat_to_labeltxt(mat_path, save_txt):
    mat = scipy.io.loadmat(mat_path)
    # Tên biến struct trong file là traindata hoặc testdata
    # Lấy đúng key, bỏ các key __header__, __version__, __globals__
    keys = [k for k in mat.keys() if not k.startswith("__")]
    struct_key = keys[0]
    struct = mat[struct_key][0]  # là mảng các sample struct

    with open(save_txt, "w", encoding="utf-8") as f:
        for entry in struct:
            img_name = str(entry["ImgName"][0])
            gt = str(entry["GroundTruth"][0])
            # img_name đã chứa "train/xxx.png" hoặc "test/xxx.png"
            img_path = os.path.join("/content/IIIT5K/IIIT5K", img_name)
            # Nếu ảnh tồn tại thì ghi ra label
            if os.path.exists(img_path):
                f.write(f"{img_name} {gt}\n")
    print(f"Saved {save_txt}")

train_mat = "/content/IIIT5K/IIIT5K/traindata.mat"
test_mat = "/content/IIIT5K/IIIT5K/testdata.mat"

train_labeltxt = "/content/IIIT5K/IIIT5K/train_label.txt"
test_labeltxt = "/content/IIIT5K/IIIT5K/test_label.txt"

mat_to_labeltxt(train_mat, train_labeltxt)
mat_to_labeltxt(test_mat, test_labeltxt)

Saved /content/IIIT5K/IIIT5K/train_label.txt
Saved /content/IIIT5K/IIIT5K/test_label.txt


In [11]:
def collect_characters_iiit5k(label_file):
    charset = set()
    with open(label_file, 'r', encoding='utf-8') as f:
        for line in f:
            parts = line.strip().split(' ', 1)
            if len(parts) < 2:
                continue
            word = parts[1]
            for ch in word:
                charset.add(ch)
    return charset

def write_vocab_iiit5k(charset, save_path='/content/IIIT5K/IIIT5K/vocab.txt'):
    tokens = ['<PAD>', '<BOS>', '<EOS>', '<UNK>']
    charset = sorted(list(charset))
    with open(save_path, 'w', encoding='utf-8') as f:
        for token in tokens:
            f.write(token + '\n')
        for ch in charset:
            f.write(ch + '\n')
    print(f'Vocab saved to {save_path}. Total chars (not count special tokens): {len(charset)}')

label_file = "/content/IIIT5K/IIIT5K/train_label.txt"
charset = collect_characters_iiit5k(label_file)
write_vocab_iiit5k(charset)

Vocab saved to /content/IIIT5K/IIIT5K/vocab.txt. Total chars (not count special tokens): 36


In [12]:
class Vocab:
    def __init__(self, vocab_path):
        self.idx2char = []
        with open(vocab_path, encoding='utf-8') as f:
            for line in f:
                self.idx2char.append(line.strip('\n'))
        self.char2idx = {ch: idx for idx, ch in enumerate(self.idx2char)}
        self.pad_idx = self.char2idx['<PAD>']
        self.bos_idx = self.char2idx['<BOS>']
        self.eos_idx = self.char2idx['<EOS>']
        self.unk_idx = self.char2idx['<UNK>']

    def encode(self, text, max_length):
        ids = [self.bos_idx]
        for ch in text:
            ids.append(self.char2idx.get(ch, self.unk_idx))
        ids.append(self.eos_idx)
        ids = ids[:max_length]
        ids += [self.pad_idx] * (max_length - len(ids))
        return ids

    def decode(self, ids):
        chars = []
        for idx in ids:
            ch = self.idx2char[idx]
            if ch in ['<PAD>', '<BOS>', '<EOS>']:
                continue
            chars.append(ch)
        return ''.join(chars)

In [15]:
import torchvision.transforms as transforms
from torch.utils.data import Dataset
from PIL import Image
import torch

class IIIT5KWordDataset(Dataset):
    def __init__(self, label_file, vocab_path, max_label_length=32, img_height=32, img_width=100, transform=None):
        self.vocab = Vocab(vocab_path)
        self.max_label_length = max_label_length
        self.img_height = img_height
        self.img_width = img_width
        self.transform = transform or transforms.Compose([
            transforms.Resize((img_height, img_width)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ])
        self.samples = []
        with open(label_file, encoding='utf-8') as f:
            for line in f:
                parts = line.strip().split(' ', 1)
                if len(parts) < 2:
                    continue
                img_name, text = parts[0], parts[1]
                # Fix: Use os.path.join to correctly construct the path
                img_path = os.path.join("/content/IIIT5K/IIIT5K", img_name)
                if os.path.exists(img_path):
                    self.samples.append((img_path, text))

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        img_path, text = self.samples[idx]
        img = Image.open(img_path).convert("RGB")
        img = self.transform(img)
        label_ids = self.vocab.encode(text, self.max_label_length)
        label_ids = torch.tensor(label_ids, dtype=torch.long)
        return img, label_ids, text

In [16]:
label_file = "/content/IIIT5K/IIIT5K/train_label.txt"
vocab_path = "/content/IIIT5K/IIIT5K/vocab.txt"
max_label_length = 32
img_height, img_width = 32, 100

dataset = IIIT5KWordDataset(label_file, vocab_path, max_label_length, img_height, img_width)
print("Samples:", len(dataset))
img, label_ids, text = dataset[0]
print("Image shape:", img.shape)
print("Label ids:", label_ids)
print("Decoded:", dataset.vocab.decode(label_ids.tolist()))
print("Groundtruth text:", text)

Samples: 2000
Image shape: torch.Size([3, 32, 100])
Label ids: tensor([ 1, 38, 28, 34,  2,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0])
Decoded: YOU
Groundtruth text: YOU


In [17]:
import torch.nn as nn
import torchvision.models as models

class CNNBackbone(nn.Module):
    def __init__(self, out_channels=256):
        super().__init__()
        resnet = models.resnet18(weights="IMAGENET1K_V1")
        self.conv_layers = nn.Sequential(*list(resnet.children())[:-2])
        self.proj = nn.Conv2d(resnet.fc.in_features, out_channels, 1)
    def forward(self, x):
        feat = self.conv_layers(x)
        feat = self.proj(feat)
        return feat

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=256):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-torch.log(torch.tensor(10000.0))/d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe.unsqueeze(0))
    def forward(self, x):
        return x + self.pe[:, :x.size(1), :]

class OCRTransformerDecoder(nn.Module):
    def __init__(self, vocab_size, d_model=256, nhead=4, num_layers=3, dropout=0.1, max_len=32):
        super().__init__()
        self.d_model = d_model
        self.token_embed = nn.Embedding(vocab_size, d_model)
        self.pos_embed = PositionalEncoding(d_model, max_len)
        decoder_layer = nn.TransformerDecoderLayer(d_model, nhead, d_model*2, dropout)
        self.transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers)
        self.output_proj = nn.Linear(d_model, vocab_size)
    def forward(self, memory, tgt, tgt_mask=None):
        tgt_emb = self.token_embed(tgt) * (self.d_model ** 0.5)
        tgt_emb = self.pos_embed(tgt_emb)
        tgt_emb = tgt_emb.transpose(0,1)
        memory = memory.transpose(0,1)
        output = self.transformer_decoder(tgt_emb, memory, tgt_mask=tgt_mask)
        output = output.transpose(0,1)
        logits = self.output_proj(output)
        return logits

class OCRNet(nn.Module):
    def __init__(self, vocab_size, max_seq_len=32, d_model=256):
        super().__init__()
        self.cnn = CNNBackbone(out_channels=d_model)
        self.feat2seq = nn.Linear(d_model, d_model)
        self.decoder = OCRTransformerDecoder(vocab_size, d_model, nhead=4, num_layers=3, max_len=max_seq_len)
    def forward(self, images, tgt_input, tgt_mask=None):
        feat = self.cnn(images)
        B, C, H, W = feat.shape
        feat = feat.permute(0,2,3,1).contiguous().view(B, H*W, C)
        feat = self.feat2seq(feat)
        logits = self.decoder(feat, tgt_input, tgt_mask)
        return logits

In [18]:
def ocr_loss_fn(logits, labels, pad_idx):
    logits = logits.reshape(-1, logits.size(-1))
    labels = labels.reshape(-1)
    loss = nn.CrossEntropyLoss(ignore_index=pad_idx)(logits, labels)
    return loss

def create_tgt_mask(tgt_input, pad_idx):
    B, tgt_len = tgt_input.shape
    tgt_mask = torch.triu(torch.ones((tgt_len, tgt_len)), diagonal=1).bool().to(tgt_input.device)
    return tgt_mask

In [None]:
from torch.utils.data import DataLoader
import torch.optim as optim

def train_one_epoch(model, dataloader, optimizer, device, pad_idx):
    model.train()
    total_loss = 0
    for images, labels, _ in dataloader:
        images = images.to(device)
        labels = labels.to(device)
        tgt_input = labels[:, :-1]
        tgt_out = labels[:, 1:]
        tgt_mask = create_tgt_mask(tgt_input, pad_idx)
        logits = model(images, tgt_input, tgt_mask)
        loss = ocr_loss_fn(logits, tgt_out, pad_idx)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    return total_loss / len(dataloader)

@torch.no_grad()
def evaluate(model, dataloader, device, pad_idx):
    model.eval()
    total_loss = 0
    total_acc = 0
    total_count = 0
    for images, labels, _ in dataloader:
        images = images.to(device)
        labels = labels.to(device)
        tgt_input = labels[:, :-1]
        tgt_out = labels[:, 1:]
        tgt_mask = create_tgt_mask(tgt_input, pad_idx)
        logits = model(images, tgt_input, tgt_mask)
        loss = ocr_loss_fn(logits, tgt_out, pad_idx)
        total_loss += loss.item()
        preds = logits.argmax(dim=-1)
        mask = (tgt_out != pad_idx)
        acc = ((preds == tgt_out) & mask).sum().item() / mask.sum().item()
        total_acc += acc
        total_count += 1
    return total_loss / total_count, total_acc / total_count

pad_idx = dataset.vocab.pad_idx
batch_size = 64
num_epochs = 30
device = "cuda" if torch.cuda.is_available() else "cpu"
vocab_size = len(dataset.vocab.idx2char)
max_seq_len = max_label_length

model = OCRNet(vocab_size=vocab_size, max_seq_len=max_seq_len).to(device)
train_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=2)
optimizer = optim.Adam(model.parameters(), lr=1e-3)

for epoch in range(num_epochs):
    train_loss = train_one_epoch(model, train_loader, optimizer, device, pad_idx)
    print(f"Epoch {epoch+1} | Train loss: {train_loss:.4f}")
    # Nếu có val_loader:
    # val_loss, val_acc = evaluate(model, val_loader, device, pad_idx)
    # print(f"Val loss: {val_loss:.4f} | Val acc: {val_acc:.4f}")

torch.save(model.state_dict(), "/content/ocr_cnn_transformer_iiit5k.pth")

Epoch 1 | Train loss: 2.7799
Epoch 2 | Train loss: 2.3283
Epoch 3 | Train loss: 2.1018
Epoch 4 | Train loss: 1.9240
Epoch 5 | Train loss: 1.7874
Epoch 6 | Train loss: 1.7892
Epoch 7 | Train loss: 1.6285
Epoch 8 | Train loss: 1.4662
Epoch 9 | Train loss: 1.2980
Epoch 10 | Train loss: 1.2064
Epoch 11 | Train loss: 1.1545
Epoch 12 | Train loss: 1.0989


In [21]:
from PIL import Image

def predict_image(model, image, vocab, device, max_len=32):
    model.eval()
    transform = transforms.Compose([
        transforms.Resize((32, 100)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])
    img_tensor = transform(image).unsqueeze(0).to(device)
    tgt_input = torch.tensor([[vocab.bos_idx]], dtype=torch.long).to(device)
    decoded = []
    for _ in range(max_len):
        tgt_mask = torch.triu(torch.ones((tgt_input.shape[1], tgt_input.shape[1]), device=device), diagonal=1).bool()
        logits = model(img_tensor, tgt_input, tgt_mask)
        next_token = logits[:, -1, :].argmax(-1).item()
        if next_token == vocab.eos_idx or next_token == vocab.pad_idx:
            break
        decoded.append(next_token)
        tgt_input = torch.cat([tgt_input, torch.tensor([[next_token]], dtype=torch.long).to(device)], dim=1)
    return vocab.decode(decoded)

img_path = "/content/IIIT5K/IIIT5K/test/1002_1.png"  # Đổi đường dẫn ảnh test nếu cần
image = Image.open(img_path).convert("RGB")
predicted_text = predict_image(model, image, dataset.vocab, device)
print(f"Ảnh: {img_path}")
print(f"Chuỗi văn bản model sinh ra: {predicted_text}")

Ảnh: /content/IIIT5K/IIIT5K/test/1002_1.png
Chuỗi văn bản model sinh ra: PRINAT
