In [1]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np
import random
import math
import time
import cv2
from tqdm.notebook import tqdm
from collections import Counter
import zipfile
import glob

In [2]:

zip_path = "/content/vintext.zip"
extract_path = "/content/vintext"

with zipfile.ZipFile(zip_path, 'r') as zip_ref:
    zip_ref.extractall(extract_path)

print("Đã giải nén VinText.")

Đã giải nén VinText.


In [3]:

LABEL_DIR = '/content/vintext/labels'

def collect_characters(label_dir):
    charset = set()
    for label_file in glob.glob(os.path.join(label_dir, 'gt_*.txt')):
        with open(label_file, encoding='utf-8') as f:
            for line in f:
                parts = line.strip().split(',')
                if len(parts) < 9:
                    continue
                text = ','.join(parts[8:]).strip().lower()
                if text == '###' or len(text) < 2:
                    continue
                for ch in text:
                    charset.add(ch)
    return charset

def write_vocab(charset, save_path='/content/vintext/vocab.txt'):
    # Thêm các token đặc biệt
    tokens = ['<PAD>', '<BOS>', '<EOS>', '<UNK>']
    # Sắp xếp charset theo unicode để cố định thứ tự
    charset = sorted(list(charset))
    with open(save_path, 'w', encoding='utf-8') as f:
        for token in tokens:
            f.write(token + '\n')
        for ch in charset:
            f.write(ch + '\n')
    print(f'Vocab saved to {save_path}. Total chars (not count special tokens): {len(charset)}')

if __name__ == "__main__":
    charset = collect_characters(LABEL_DIR)
    write_vocab(charset)

Vocab saved to /content/vintext/vocab.txt. Total chars (not count special tokens): 126


In [19]:
import torchvision.transforms as transforms
from torch.utils.data import Dataset
import os
import glob
from PIL import Image
import torch

class Vocab:
    def __init__(self, vocab_path):
        self.idx2char = []
        with open(vocab_path, encoding='utf-8') as f:
            for line in f:
                self.idx2char.append(line.strip('\n'))
        self.char2idx = {ch: idx for idx, ch in enumerate(self.idx2char)}
        self.pad_idx = self.char2idx['<PAD>']
        self.bos_idx = self.char2idx['<BOS>']
        self.eos_idx = self.char2idx['<EOS>']
        self.unk_idx = self.char2idx['<UNK>']

    def encode(self, text, max_length):
        ids = [self.bos_idx]
        for ch in text:
            ids.append(self.char2idx.get(ch, self.unk_idx))
        ids.append(self.eos_idx)
        # Pad to max_length
        ids = ids[:max_length]
        ids += [self.pad_idx] * (max_length - len(ids))
        return ids

    def decode(self, ids):
        chars = []
        for idx in ids:
            ch = self.idx2char[idx]
            if ch in ['<PAD>', '<BOS>', '<EOS>']:
                continue
            chars.append(ch)
        return ''.join(chars)

class VinTextOCREnd2EndDataset(Dataset):
    def __init__(self, img_dir, label_dir, vocab_path, max_label_length=128, img_height=64, img_width=256, transform=None):
        self.img_dir = img_dir
        self.label_dir = label_dir
        self.vocab = Vocab(vocab_path)
        self.max_label_length = max_label_length
        self.img_height = img_height
        self.img_width = img_width
        self.transform = transform or transforms.Compose([
            transforms.Resize((img_height, img_width)),
            transforms.ToTensor(),
        ])
        self.samples = []
        label_files = sorted(glob.glob(os.path.join(label_dir, "gt_*.txt")))
        for label_file in label_files:
            img_id = os.path.splitext(os.path.basename(label_file))[0][3:]  # gt_1.txt -> '1'
            img_name = f"im{int(img_id):04d}.jpg"
            img_path = os.path.join(img_dir, img_name)
            if not os.path.exists(img_path):
                continue
            with open(label_file, encoding='utf-8') as f:
                texts = []
                for line in f:
                    parts = line.strip().split(',')
                    if len(parts) < 9:
                        continue
                    text = ','.join(parts[8:]).strip().lower()
                    if text == '###' or len(text) < 2:
                        continue
                    texts.append(text)
                if texts:
                    merged_text = ' '.join(texts)
                    self.samples.append((img_path, merged_text))

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        img_path, text = self.samples[idx]
        img = Image.open(img_path).convert("RGB")
        img = self.transform(img)
        label_ids = self.vocab.encode(text, self.max_label_length)
        label_ids = torch.tensor(label_ids, dtype=torch.long)
        return img, label_ids, text  # trả về cả text gốc để debug dễ dàng

# Example usage:
if __name__ == "__main__":
    img_dir = "/content/vintext/train_images"
    label_dir = "/content/vintext/labels"
    vocab_path = "/content/vintext/vocab.txt"
    dataset = VinTextOCREnd2EndDataset(img_dir, label_dir, vocab_path)
    print("Samples:", len(dataset))
    img, label_ids, gt_text = dataset[0]
    print("Image shape:", img.shape)
    print("Label ids:", label_ids)
    print("Decoded:", dataset.vocab.decode(label_ids.tolist()))
    print("Groundtruth text:", gt_text)

Samples: 1198
Image shape: torch.Size([3, 64, 256])
Label ids: tensor([  1,  39,  44,  88,  56,   4,  48,  85, 119,  50,  43,   4,  56, 110,
         56,   4,  81, 103,   4,  39,  74,   4,  58,  45, 105,  39,   4,  48,
         64,  49,  15,   4,  50,  80,  50,  43,   4,  55,  57,  88,  56,   4,
         39,  37,  51,   4,  81, 103,   4,  56,  80,  50,  43,   4,  56,  44,
         57,   4,  50,  44,  92,  52,   2,   0,   0,   0,   0,   0,   0,   0,
          0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0])
Decoded: chất lượng tốt để có việc làm, năng suất cao để tăng thu nhập
Groundtruth text: chất lượng tốt để có việc làm, năng suất cao để tăng thu nhập


In [20]:

class CNNBackbone(nn.Module):
    def __init__(self, out_channels=256):
        super().__init__()
        resnet = models.resnet18(weights="IMAGENET1K_V1")
        self.conv_layers = nn.Sequential(*list(resnet.children())[:-2])
        self.proj = nn.Conv2d(resnet.fc.in_features, out_channels, 1)
    def forward(self, x):
        feat = self.conv_layers(x)
        feat = self.proj(feat)
        return feat

In [21]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=512):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-torch.log(torch.tensor(10000.0))/d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe.unsqueeze(0))
    def forward(self, x):
        return x + self.pe[:, :x.size(1), :]

In [22]:
class OCRTransformerDecoder(nn.Module):
    def __init__(self, vocab_size, d_model=256, nhead=4, num_layers=3, dropout=0.1, max_len=48):
        super().__init__()
        self.d_model = d_model
        self.token_embed = nn.Embedding(vocab_size, d_model)
        self.pos_embed = PositionalEncoding(d_model, max_len)
        decoder_layer = nn.TransformerDecoderLayer(d_model, nhead, d_model*2, dropout)
        self.transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers)
        self.output_proj = nn.Linear(d_model, vocab_size)
    def forward(self, memory, tgt, tgt_mask=None):
        tgt_emb = self.token_embed(tgt) * (self.d_model ** 0.5)
        tgt_emb = self.pos_embed(tgt_emb)
        tgt_emb = tgt_emb.transpose(0,1) # (tgt_len,B,E)
        memory = memory.transpose(0,1)
        output = self.transformer_decoder(tgt_emb, memory, tgt_mask=tgt_mask)
        output = output.transpose(0,1)   # (B, tgt_len, E)
        logits = self.output_proj(output)
        return logits

In [23]:
class OCRNet(nn.Module):
    def __init__(self, vocab_size, max_seq_len=48, d_model=256):
        super().__init__()
        self.cnn = CNNBackbone(out_channels=d_model)
        self.feat2seq = nn.Linear(d_model, d_model)
        self.decoder = OCRTransformerDecoder(vocab_size, d_model, nhead=4, num_layers=3, max_len=max_seq_len)
    def forward(self, images, tgt_input, tgt_mask=None):
        feat = self.cnn(images)    # (B, C, H', W')
        B, C, H, W = feat.shape
        feat = feat.permute(0,2,3,1).contiguous().view(B, H*W, C) # (B, src_len, C)
        feat = self.feat2seq(feat)
        logits = self.decoder(feat, tgt_input, tgt_mask)
        return logits

In [24]:
device = "cuda" if torch.cuda.is_available() else "cpu"
vocab_size = len(dataset.vocab.idx2char)
max_seq_len = 128  # hoặc đúng với giá trị bạn đặt ở Dataset

model = OCRNet(vocab_size=vocab_size, max_seq_len=max_seq_len).to(device)

images = torch.randn(2, 3, 64, 256).to(device)      # Đúng shape ảnh đã resize
labels_in = torch.randint(0, vocab_size, (2, max_seq_len)).to(device)

out = model(images, labels_in)
print(out.shape)  # kỳ vọng (2, max_seq_len, vocab_size)

torch.Size([2, 128, 130])


In [25]:
def ocr_loss_fn(logits, labels, pad_idx):
    """
    logits: (B, tgt_len, vocab_size)
    labels: (B, tgt_len)
    """
    logits = logits.reshape(-1, logits.size(-1))   # sửa view thành reshape
    labels = labels.reshape(-1)                    # sửa view thành reshape
    loss = nn.CrossEntropyLoss(ignore_index=pad_idx)(logits, labels)
    return loss

In [26]:
def create_tgt_mask(tgt_input, pad_idx):
    # tgt_input: (B, tgt_len)
    B, tgt_len = tgt_input.shape
    # Mask phía phải (không cho nhìn trước)
    tgt_mask = torch.triu(torch.ones((tgt_len, tgt_len)), diagonal=1).bool().to(tgt_input.device)
    return tgt_mask  # (tgt_len, tgt_len)

In [27]:
def train_one_epoch(model, dataloader, optimizer, device, pad_idx):
    model.train()
    total_loss = 0
    for images, labels, _ in dataloader:    # <-- SỬA DÒNG NÀY
        images = images.to(device)
        labels = labels.to(device)
        tgt_input = labels[:, :-1]
        tgt_out = labels[:, 1:]
        tgt_mask = create_tgt_mask(tgt_input, pad_idx)
        logits = model(images, tgt_input, tgt_mask)
        loss = ocr_loss_fn(logits, tgt_out, pad_idx)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    return total_loss / len(dataloader)

In [28]:
@torch.no_grad()
def evaluate(model, dataloader, device, pad_idx):
    model.eval()
    total_loss = 0
    total_acc = 0
    total_count = 0
    for images, labels in dataloader:
        images = images.to(device)
        labels = labels.to(device)
        tgt_input = labels[:, :-1]
        tgt_out = labels[:, 1:]
        tgt_mask = create_tgt_mask(tgt_input, pad_idx)
        logits = model(images, tgt_input, tgt_mask)
        loss = ocr_loss_fn(logits, tgt_out, pad_idx)
        total_loss += loss.item()
        # Đo char-level accuracy
        preds = logits.argmax(dim=-1)
        mask = (tgt_out != pad_idx)
        acc = ((preds == tgt_out) & mask).sum().item() / mask.sum().item()
        total_acc += acc
        total_count += 1
    return total_loss / total_count, total_acc / total_count

In [29]:

pad_idx = dataset.vocab.pad_idx
batch_size = 64
num_epochs = 10
device = "cuda" if torch.cuda.is_available() else "cpu"

# Tạo DataLoader (giả sử class VinTextOCRDataset đã được import)
train_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=2)
# Nếu có tập validation, bạn cũng tạo val_loader tương tự

optimizer = optim.Adam(model.parameters(), lr=1e-3)

for epoch in range(num_epochs):
    train_loss = train_one_epoch(model, train_loader, optimizer, device, pad_idx)
    print(f"Epoch {epoch+1} | Train loss: {train_loss:.4f}")
    # Nếu có val_loader thì thêm đánh giá:
    # val_loss, val_acc = evaluate(model, val_loader, device, pad_idx)
    # print(f"Val loss: {val_loss:.4f} | Val acc: {val_acc:.4f}")

# Lưu model
torch.save(model.state_dict(), "/content/ocr_cnn_transformer.pth")

Epoch 1 | Train loss: 3.3751
Epoch 2 | Train loss: 2.7316
Epoch 3 | Train loss: 2.5947
Epoch 4 | Train loss: 2.5345
Epoch 5 | Train loss: 2.4917
Epoch 6 | Train loss: 2.4565
Epoch 7 | Train loss: 2.4216
Epoch 8 | Train loss: 2.3981
Epoch 9 | Train loss: 2.3683
Epoch 10 | Train loss: 2.3434


In [31]:
from PIL import Image
import torch
from torchvision import transforms

def predict_image(model, image, vocab, device, max_len=48):
    model.eval()
    transform = transforms.Compose([
        transforms.Resize((32, 128)),
        transforms.ToTensor(),
    ])
    img_tensor = transform(image).unsqueeze(0).to(device)
    tgt_input = torch.tensor([[vocab.bos_idx]], dtype=torch.long).to(device)
    decoded = []
    for _ in range(max_len):
        tgt_mask = torch.triu(torch.ones((tgt_input.shape[1], tgt_input.shape[1]), device=device), diagonal=1).bool()
        logits = model(img_tensor, tgt_input, tgt_mask)
        next_token = logits[:, -1, :].argmax(-1).item()
        print(f"Step {_}: Token idx {next_token} - char '{vocab.idx2char[next_token]}'")
        if next_token == vocab.eos_idx or next_token == vocab.pad_idx:
            break
        decoded.append(next_token)
        tgt_input = torch.cat([tgt_input, torch.tensor([[next_token]], dtype=torch.long).to(device)], dim=1)
    print("Full token idx sequence:", decoded)
    return vocab.decode(decoded)

# Đường dẫn tới ảnh bạn muốn test
img_path = "/content/vintext/test_image/im1249.jpg"  # Đổi tên file nếu cần

# Đọc ảnh và dự đoán
image = Image.open(img_path).convert("RGB")
predicted_text = predict_image(model, image, dataset.vocab, device)
print(f"Ảnh: {img_path}")
print(f"Chuỗi văn bản model sinh ra: {predicted_text}")

Step 0: Token idx 56 - char 't'
Step 1: Token idx 44 - char 'h'
Step 2: Token idx 37 - char 'a'
Step 3: Token idx 50 - char 'n'
Step 4: Token idx 4 - char ' '
Step 5: Token idx 56 - char 't'
Step 6: Token idx 44 - char 'h'
Step 7: Token idx 4 - char ' '
Step 8: Token idx 56 - char 't'
Step 9: Token idx 44 - char 'h'
Step 10: Token idx 4 - char ' '
Step 11: Token idx 56 - char 't'
Step 12: Token idx 44 - char 'h'
Step 13: Token idx 64 - char 'à'
Step 14: Token idx 50 - char 'n'
Step 15: Token idx 43 - char 'g'
Step 16: Token idx 4 - char ' '
Step 17: Token idx 56 - char 't'
Step 18: Token idx 44 - char 'h'
Step 19: Token idx 4 - char ' '
Step 20: Token idx 56 - char 't'
Step 21: Token idx 44 - char 'h'
Step 22: Token idx 4 - char ' '
Step 23: Token idx 56 - char 't'
Step 24: Token idx 44 - char 'h'
Step 25: Token idx 4 - char ' '
Step 26: Token idx 56 - char 't'
Step 27: Token idx 44 - char 'h'
Step 28: Token idx 64 - char 'à'
Step 29: Token idx 50 - char 'n'
Step 30: Token idx 44 - cha