In [1]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np
import random
import math
import time
import cv2
from tqdm.notebook import tqdm
from collections import Counter

In [1]:
import zipfile
import os

zip_path = "/content/vintext.zip"
extract_path = "/content/vintext"

with zipfile.ZipFile(zip_path, 'r') as zip_ref:
    zip_ref.extractall(extract_path)

print("Đã giải nén VinText.")

Đã giải nén VinText.


In [3]:
import os
import glob

LABEL_DIR = '/content/vintext/labels'

def collect_characters(label_dir):
    charset = set()
    for label_file in glob.glob(os.path.join(label_dir, 'gt_*.txt')):
        with open(label_file, encoding='utf-8') as f:
            for line in f:
                parts = line.strip().split(',')
                if len(parts) < 9:
                    continue
                text = ','.join(parts[8:]).strip().lower()
                if text == '###' or len(text) < 2:
                    continue
                for ch in text:
                    charset.add(ch)
    return charset

def write_vocab(charset, save_path='/content/vintext/vocab.txt'):
    # Thêm các token đặc biệt
    tokens = ['<PAD>', '<BOS>', '<EOS>', '<UNK>']
    # Sắp xếp charset theo unicode để cố định thứ tự
    charset = sorted(list(charset))
    with open(save_path, 'w', encoding='utf-8') as f:
        for token in tokens:
            f.write(token + '\n')
        for ch in charset:
            f.write(ch + '\n')
    print(f'Vocab saved to {save_path}. Total chars (not count special tokens): {len(charset)}')

if __name__ == "__main__":
    charset = collect_characters(LABEL_DIR)
    write_vocab(charset)

Vocab saved to /content/vintext/vocab.txt. Total chars (not count special tokens): 126


In [4]:
import os
import glob
from PIL import Image
import torch
import torchvision.transforms as transforms
from torch.utils.data import Dataset

class Vocab:
    def __init__(self, vocab_path):
        self.idx2char = []
        with open(vocab_path, encoding='utf-8') as f:
            for line in f:
                self.idx2char.append(line.strip('\n'))
        self.char2idx = {ch: idx for idx, ch in enumerate(self.idx2char)}
        self.pad_idx = self.char2idx['<PAD>']
        self.bos_idx = self.char2idx['<BOS>']
        self.eos_idx = self.char2idx['<EOS>']
        self.unk_idx = self.char2idx['<UNK>']

    def encode(self, text, max_length):
        ids = [self.bos_idx]
        for ch in text:
            ids.append(self.char2idx.get(ch, self.unk_idx))
        ids.append(self.eos_idx)
        # Pad to max_length
        ids = ids[:max_length]
        ids += [self.pad_idx] * (max_length - len(ids))
        return ids

    def decode(self, ids):
        chars = []
        for idx in ids:
            ch = self.idx2char[idx]
            if ch in ['<PAD>', '<BOS>', '<EOS>']:
                continue
            chars.append(ch)
        return ''.join(chars)

def crop_poly(img, poly):
    # poly: [x1, y1, x2, y2, x3, y3, x4, y4]
    poly_xy = [(poly[i], poly[i+1]) for i in range(0, 8, 2)]
    min_x = max(min([p[0] for p in poly_xy]), 0)
    max_x = max([p[0] for p in poly_xy])
    min_y = max(min([p[1] for p in poly_xy]), 0)
    max_y = max([p[1] for p in poly_xy])
    img_crop = img.crop((min_x, min_y, max_x, max_y))
    return img_crop

class VinTextOCRDataset(Dataset):
    def __init__(self, img_dir, label_dir, vocab_path, max_label_length=48, img_height=32, img_width=128, transform=None):
        self.img_dir = img_dir
        self.label_dir = label_dir
        self.vocab = Vocab(vocab_path)
        self.max_label_length = max_label_length
        self.img_height = img_height
        self.img_width = img_width
        self.transform = transform or transforms.Compose([
            transforms.Resize((img_height, img_width)),
            transforms.ToTensor(),
        ])
        self.samples = []
        label_files = sorted(glob.glob(os.path.join(label_dir, "gt_*.txt")))
        for label_file in label_files:
            img_id = os.path.splitext(os.path.basename(label_file))[0][3:]  # gt_1.txt -> '1'
            img_name = f"im{int(img_id):04d}.jpg"
            img_path = os.path.join(img_dir, img_name)
            if not os.path.exists(img_path):
                continue
            with open(label_file, encoding='utf-8') as f:
                for line in f:
                    parts = line.strip().split(',')
                    if len(parts) < 9:
                        continue
                    poly = [int(float(x)) for x in parts[:8]]
                    text = ','.join(parts[8:]).strip().lower()
                    if text == '###' or len(text) < 2:
                        continue
                    self.samples.append((img_path, poly, text))

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        img_path, poly, text = self.samples[idx]
        img = Image.open(img_path).convert("RGB")
        img_crop = crop_poly(img, poly)
        img_crop = self.transform(img_crop)
        label_ids = self.vocab.encode(text, self.max_label_length)
        label_ids = torch.tensor(label_ids, dtype=torch.long)
        return img_crop, label_ids

# Example usage:
if __name__ == "__main__":
    img_dir = "/content/vintext/train_images"
    label_dir = "/content/vintext/labels"
    vocab_path = "/content/vintext/vocab.txt"
    dataset = VinTextOCRDataset(img_dir, label_dir, vocab_path)
    print("Samples:", len(dataset))
    img, label_ids = dataset[0]
    print("Image shape:", img.shape)
    print("Label ids:", label_ids)
    print("Decoded:", dataset.vocab.decode(label_ids.tolist()))

Samples: 24074
Image shape: torch.Size([3, 32, 128])
Label ids: tensor([ 1, 39, 44, 88, 56,  2,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0])
Decoded: chất


In [10]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class CNNBackbone(nn.Module):
    def __init__(self, out_channels=256):
        super().__init__()
        self.conv_layers = nn.Sequential(
            nn.Conv2d(3, 64, 3, padding=1), nn.BatchNorm2d(64), nn.ReLU(), nn.MaxPool2d(2, 2),
            nn.Conv2d(64, 128, 3, padding=1), nn.BatchNorm2d(128), nn.ReLU(), nn.MaxPool2d(2, 2),
            nn.Conv2d(128, out_channels, 3, padding=1), nn.BatchNorm2d(out_channels), nn.ReLU(),
        )
    def forward(self, x):
        feat = self.conv_layers(x)
        return feat

In [11]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=512):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-torch.log(torch.tensor(10000.0))/d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe.unsqueeze(0)) # (1, max_len, d_model)
    def forward(self, x):
        # x: (B, seq_len, d_model)
        return x + self.pe[:, :x.size(1), :]

In [12]:
class OCRTransformerDecoder(nn.Module):
    def __init__(self, vocab_size, d_model=256, nhead=4, num_layers=3, dropout=0.1, max_len=48):
        super().__init__()
        self.d_model = d_model
        self.token_embed = nn.Embedding(vocab_size, d_model)
        self.pos_embed = PositionalEncoding(d_model, max_len)
        decoder_layer = nn.TransformerDecoderLayer(d_model, nhead, d_model*2, dropout)
        self.transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers)
        self.output_proj = nn.Linear(d_model, vocab_size)
    def forward(self, memory, tgt, tgt_mask=None):
        tgt_emb = self.token_embed(tgt) * (self.d_model ** 0.5)
        tgt_emb = self.pos_embed(tgt_emb)
        tgt_emb = tgt_emb.transpose(0,1) # (tgt_len,B,E)
        memory = memory.transpose(0,1)
        output = self.transformer_decoder(tgt_emb, memory, tgt_mask=tgt_mask)
        output = output.transpose(0,1)   # (B, tgt_len, E)
        logits = self.output_proj(output)
        return logits

In [13]:
class OCRNet(nn.Module):
    def __init__(self, vocab_size, max_seq_len=48, d_model=256):
        super().__init__()
        self.cnn = CNNBackbone(out_channels=d_model)
        self.feat2seq = nn.Linear(d_model, d_model)
        self.decoder = OCRTransformerDecoder(vocab_size, d_model, nhead=4, num_layers=3, max_len=max_seq_len)
    def forward(self, images, tgt_input, tgt_mask=None):
        feat = self.cnn(images)    # (B, C, H', W')
        B, C, H, W = feat.shape
        feat = feat.permute(0,2,3,1).contiguous().view(B, H*W, C) # (B, src_len, C)
        feat = self.feat2seq(feat)
        logits = self.decoder(feat, tgt_input, tgt_mask)
        return logits

In [15]:
device = "cuda" if torch.cuda.is_available() else "cpu"
vocab_size = 100
model = OCRNet(vocab_size=vocab_size, max_seq_len=48).to(device)

images = torch.randn(2, 3, 32, 128).to(device)       # Đưa lên device
labels_in = torch.randint(0, vocab_size, (2, 48)).to(device)  # Đưa lên device

out = model(images, labels_in)
print(out.shape)  # kỳ vọng (2, 48, vocab_size)

torch.Size([2, 48, 100])


In [16]:
import torch.nn as nn

def ocr_loss_fn(logits, labels, pad_idx):
    """
    logits: (B, tgt_len, vocab_size)
    labels: (B, tgt_len)
    """
    logits = logits.view(-1, logits.size(-1))
    labels = labels.view(-1)
    loss = nn.CrossEntropyLoss(ignore_index=pad_idx)(logits, labels)
    return loss

In [17]:
def create_tgt_mask(tgt_input, pad_idx):
    # tgt_input: (B, tgt_len)
    B, tgt_len = tgt_input.shape
    # Mask phía phải (không cho nhìn trước)
    tgt_mask = torch.triu(torch.ones((tgt_len, tgt_len)), diagonal=1).bool().to(tgt_input.device)
    return tgt_mask  # (tgt_len, tgt_len)

In [18]:
def train_one_epoch(model, dataloader, optimizer, device, pad_idx):
    model.train()
    total_loss = 0
    for images, labels in dataloader:
        images = images.to(device)
        labels = labels.to(device)
        # Input cho decoder là labels[:, :-1], target là labels[:, 1:]
        tgt_input = labels[:, :-1]
        tgt_out = labels[:, 1:]
        tgt_mask = create_tgt_mask(tgt_input, pad_idx)
        logits = model(images, tgt_input, tgt_mask)
        loss = ocr_loss_fn(logits, tgt_out, pad_idx)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    return total_loss / len(dataloader)

In [19]:
@torch.no_grad()
def evaluate(model, dataloader, device, pad_idx):
    model.eval()
    total_loss = 0
    total_acc = 0
    total_count = 0
    for images, labels in dataloader:
        images = images.to(device)
        labels = labels.to(device)
        tgt_input = labels[:, :-1]
        tgt_out = labels[:, 1:]
        tgt_mask = create_tgt_mask(tgt_input, pad_idx)
        logits = model(images, tgt_input, tgt_mask)
        loss = ocr_loss_fn(logits, tgt_out, pad_idx)
        total_loss += loss.item()
        # Đo char-level accuracy
        preds = logits.argmax(dim=-1)
        mask = (tgt_out != pad_idx)
        acc = ((preds == tgt_out) & mask).sum().item() / mask.sum().item()
        total_acc += acc
        total_count += 1
    return total_loss / total_count, total_acc / total_count

In [21]:
import torch.optim as optim
from torch.utils.data import DataLoader

# Giả sử bạn đã có dataset, vocab, model, pad_idx sẵn
pad_idx = dataset.vocab.pad_idx
batch_size = 64
num_epochs = 5
device = "cuda" if torch.cuda.is_available() else "cpu"

# Tạo DataLoader (giả sử class VinTextOCRDataset đã được import)
train_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=2)
# Nếu có tập validation, bạn cũng tạo val_loader tương tự

optimizer = optim.Adam(model.parameters(), lr=1e-3)

for epoch in range(num_epochs):
    train_loss = train_one_epoch(model, train_loader, optimizer, device, pad_idx)
    print(f"Epoch {epoch+1} | Train loss: {train_loss:.4f}")
    # Nếu có val_loader thì thêm đánh giá:
    # val_loss, val_acc = evaluate(model, val_loader, device, pad_idx)
    # print(f"Val loss: {val_loss:.4f} | Val acc: {val_acc:.4f}")

# Lưu model
torch.save(model.state_dict(), "/content/ocr_cnn_transformer.pth")

RuntimeError: CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.
