In [17]:
import os
import sys
import random
import math
from pathlib import Path
from typing import List

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
from PIL import Image
from tqdm import tqdm

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
TABLEBANK_ROOT = Path(
    r"C:\Users\ahmed\Dropbox\PC\Desktop\Ahmed Sajid\Office - NCV\NCV - HTR\TableBank\Recognition"
)
IMAGES_DIR = TABLEBANK_ROOT / "Images"
ANNOT_DIR = TABLEBANK_ROOT / "Annotations"
SRC_FILE = ANNOT_DIR / "src-all_train.txt"
TGT_FILE = ANNOT_DIR / "tgt-all_train.txt"
CASCADE_ROOT = Path(
    r"C:\Users\ahmed\Dropbox\PC\Desktop\Ahmed Sajid\Office - NCV\NCV - HTR\CascadeTabNet"
)
SUBSET_SIZE = 256
BATCH_SIZE = 8
IMG_H = 64
IMG_W = 512
EMBED_DIM = 512
NUM_EPOCHS = 8
LEARNING_RATE = 1e-4
MAX_SEQ_LEN = 512
MODEL_SAVE_PATH = Path.cwd() / "tsr_castabnet_tablebank.pth"
SEED = 42

random.seed(SEED)
torch.manual_seed(SEED)


def resolve_image_path(entry: str, images_dir: Path) -> Path:
    p = Path(entry.strip())
    if p.is_absolute() and p.exists():
        return p
    candidate = images_dir / entry.strip()
    if candidate.exists():
        return candidate
    stripped = entry.strip().lstrip("./")
    candidate = images_dir / stripped
    if candidate.exists():
        return candidate
    parts = entry.strip().split("/")
    candidate = images_dir / parts[-1]
    if candidate.exists():
        return candidate
    raise FileNotFoundError(entry)


class TableBankDataset(Dataset):
    def __init__(self, src_path: Path, tgt_path: Path, images_dir: Path, max_samples: int = None):
        with open(src_path, "r", encoding="utf-8") as f:
            srcs = [l.rstrip("\n") for l in f if l.strip()]
        with open(tgt_path, "r", encoding="utf-8") as f:
            tgts = [l.rstrip("\n") for l in f if l.strip()]
        pairs = []
        for s, t in zip(srcs, tgts):
            try:
                img_path = resolve_image_path(s, images_dir)
            except FileNotFoundError:
                continue
            pairs.append((str(img_path), t))
        if max_samples:
            pairs = pairs[: max_samples]
        self.pairs = pairs
        self.transform = transforms.Compose(
            [
                transforms.Resize((IMG_H, IMG_W)),
                transforms.Grayscale(num_output_channels=3),
                transforms.ToTensor(),
                transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
            ]
        )
        chars = set()
        for _, tgt in pairs:
            chars.update(list(tgt))
        self.vocab = ["<pad>", "<bos>", "<eos>", "<unk>"] + sorted(chars)
        self.stoi = {c: i for i, c in enumerate(self.vocab)}
        self.itos = self.vocab

    def encode_target(self, text: str) -> List[int]:
        seq = [self.stoi["<bos>"]]
        for ch in text:
            seq.append(self.stoi.get(ch, self.stoi["<unk>"]))
            if len(seq) >= MAX_SEQ_LEN - 1:
                break
        seq.append(self.stoi["<eos>"])
        return seq

    def __len__(self):
        return len(self.pairs)

    def __getitem__(self, idx: int):
        img_path, tgt = self.pairs[idx]
        img = Image.open(img_path).convert("RGB")
        img = self.transform(img)
        tgt_ids = self.encode_target(tgt)
        tgt_tensor = torch.tensor(tgt_ids, dtype=torch.long)
        return img, tgt_tensor, tgt, os.path.basename(img_path)


def collate_fn(batch):
    imgs, tgts, raw, names = zip(*batch)
    imgs = torch.stack(imgs)
    lengths = [t.shape[0] for t in tgts]
    max_len = max(lengths)
    padded = torch.full((len(tgts), max_len), 0, dtype=torch.long)
    for i, t in enumerate(tgts):
        padded[i, : t.shape[0]] = t
    return imgs, padded, lengths, raw, names


def try_add_cascade_path(root: Path):
    sys.path.insert(0, str(root))
    sys.path.insert(0, str(root / "Table Structure Recognition"))


try_add_cascade_path(CASCADE_ROOT)
CASCADE_AVAILABLE = False
try:
    from CascadeTabNet.model import CascadeTabNet  # type: ignore
    CASCADE_AVAILABLE = True
except Exception:
    CASCADE_AVAILABLE = False


class EncoderWrapper(nn.Module):
    def __init__(self, out_dim: int):
        super().__init__()
        r = models.resnet18(pretrained=True)
        self.encoder = nn.Sequential(*list(r.children())[:-2])
        self.proj = nn.Linear(512, out_dim)

    def forward(self, x):
        f = self.encoder(x)
        b, c, h, w = f.size()
        f = f.view(b, c, -1).permute(0, 2, 1)
        f = self.proj(f)
        f = f.permute(1, 0, 2)
        return f


class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=2000):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        pos = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(pos * div)
        if d_model % 2 == 1:
            pe[:, 1::2] = torch.cos(pos * div[:-1])
        else:
            pe[:, 1::2] = torch.cos(pos * div)
        self.register_buffer("pe", pe.unsqueeze(1))

    def forward(self, x):
        return x + self.pe[: x.size(0)]


def generate_mask(sz: int):
    mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
    mask = mask.float().masked_fill(mask == 0, float("-inf")).masked_fill(mask == 1, float(0.0))
    return mask


class Seq2SeqModel(nn.Module):
    def __init__(self, vocab_size: int, embed_dim: int):
        super().__init__()
        self.encoder = EncoderWrapper(embed_dim)
        self.token_emb = nn.Embedding(vocab_size, embed_dim)
        self.pos_enc = PositionalEncoding(embed_dim)
        decoder_layer = nn.TransformerDecoderLayer(d_model=embed_dim, nhead=8)
        self.decoder = nn.TransformerDecoder(decoder_layer, num_layers=3)
        self.output_fc = nn.Linear(embed_dim, vocab_size)
        self.embed_dim = embed_dim

    def forward(self, imgs, tgt_seq):
        memory = self.encoder(imgs)
        tgt_emb = self.token_emb(tgt_seq)
        tgt_emb = tgt_emb * math.sqrt(self.embed_dim)
        tgt_emb = tgt_emb.permute(1, 0, 2)
        tgt_emb = self.pos_enc(tgt_emb)
        tgt_mask = generate_mask(tgt_emb.size(0)).to(tgt_emb.device)
        out = self.decoder(tgt_emb, memory, tgt_mask=tgt_mask)
        logits = self.output_fc(out)
        return logits.permute(1, 0, 2)


def train_model():
    dataset = TableBankDataset(SRC_FILE, TGT_FILE, IMAGES_DIR, max_samples=SUBSET_SIZE)
    loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_fn)
    model = Seq2SeqModel(len(dataset.vocab), EMBED_DIM).to(DEVICE)
    optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE)
    criterion = nn.CrossEntropyLoss(ignore_index=0)
    for epoch in range(1, NUM_EPOCHS + 1):
        model.train()
        loop = tqdm(loader, desc=f"Epoch {epoch}", leave=True)
        total_loss = 0.0
        for imgs, tgt_padded, _, _, _ in loop:
            imgs = imgs.to(DEVICE)
            tgt_in = tgt_padded[:, :-1].to(DEVICE)
            tgt_out = tgt_padded[:, 1:].to(DEVICE)
            optimizer.zero_grad()
            logits = model(imgs, tgt_in)
            loss = criterion(logits.reshape(-1, logits.size(-1)), tgt_out.reshape(-1))
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 5.0)
            optimizer.step()
            total_loss += loss.item()
            loop.set_postfix(loss=total_loss / (loop.n + 1))
    torch.save({"model_state_dict": model.state_dict(), "vocab": dataset.vocab}, MODEL_SAVE_PATH)
    return model, dataset


def greedy_decode(model, img_tensor, stoi, itos, max_len=MAX_SEQ_LEN):
    model.eval()
    bos = stoi["<bos>"]
    eos = stoi["<eos>"]
    ys = torch.tensor([[bos]], device=DEVICE, dtype=torch.long)
    img_tensor = img_tensor.unsqueeze(0).to(DEVICE)
    with torch.no_grad():
        memory = model.encoder(img_tensor)
        for _ in range(max_len):
            emb = model.token_emb(ys)
            emb = emb * math.sqrt(model.embed_dim)
            emb_t = emb.permute(1, 0, 2)
            emb_t = model.pos_enc(emb_t)
            tgt_mask = generate_mask(emb_t.size(0)).to(DEVICE)
            out = model.decoder(emb_t, memory, tgt_mask=tgt_mask)
            last = out[-1, 0, :]
            logits = model.output_fc(last)
            next_tok = logits.argmax(dim=-1)
            next_tok_tensor = next_tok.unsqueeze(0).unsqueeze(0)
            ys = torch.cat([ys, next_tok_tensor.to(DEVICE)], dim=1)
            if next_tok.item() == eos:
                break
    tokens = ys.squeeze(0).tolist()
    tokens = [t for t in tokens if t not in (bos, eos, 0)]
    text = "".join([itos[t] if t < len(itos) else "<unk>" for t in tokens])
    return text


def quick_test(model, dataset, n=5):
    tf = transforms.Compose(
        [
            transforms.Resize((IMG_H, IMG_W)),
            transforms.Grayscale(num_output_channels=3),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
        ]
    )
    sample_ids = random.sample(range(len(dataset)), min(n, len(dataset)))
    for i in sample_ids:
        img_path, tgt = dataset.pairs[i]
        img = Image.open(img_path).convert("RGB")
        img_t = tf(img)
        pred = greedy_decode(model, img_t, dataset.stoi, dataset.itos)
        print("IMAGE:", os.path.basename(img_path))
        print("PRED :", pred)
        print("GT   :", tgt)
        print("-" * 80)


if __name__ == "__main__":
    model, ds = train_model()
    quick_test(model, ds, n=8)


Epoch 1: 100%|██████████████████████████████████████████████████████████████| 32/32 [00:03<00:00,  8.78it/s, loss=0.54]
Epoch 2: 100%|█████████████████████████████████████████████████████████████| 32/32 [00:03<00:00,  9.73it/s, loss=0.293]
Epoch 3: 100%|█████████████████████████████████████████████████████████████| 32/32 [00:03<00:00,  9.85it/s, loss=0.265]
Epoch 4: 100%|█████████████████████████████████████████████████████████████| 32/32 [00:03<00:00,  9.73it/s, loss=0.239]
Epoch 5: 100%|█████████████████████████████████████████████████████████████| 32/32 [00:03<00:00,  9.91it/s, loss=0.247]
Epoch 6: 100%|█████████████████████████████████████████████████████████████| 32/32 [00:03<00:00,  9.73it/s, loss=0.222]
Epoch 7: 100%|█████████████████████████████████████████████████████████████| 32/32 [00:03<00:00,  9.72it/s, loss=0.216]
Epoch 8: 100%|█████████████████████████████████████████████████████████████| 32/32 [00:03<00:00,  9.77it/s, loss=0.205]


IMAGE: 1712.00077.table_0.png
PRED : <table> <thead> <tr> <td> <td> <td> <td> <td> <td> <td> <td> <td> <td> <td> <td> <td> <td> <td> </tr> <tr> <td> <td> <td> <td> <td> <td> <td> <td> <td> <td> <td> <td> <td> <td> <td> <td> <td> <td> <td> <td> <td> <td> <td> <td> <td> <td> <td> <td> <td> <td> <td> <td> <td> <td> <td> <td> <td> <td> <td> <td> <td> <td> <td> <td> <td> <td> <td> <td> <td> <td> <td> <td> <td> <td> <td> <td> <td> <td> <td> <td> <td> <td> <td> <td> <td> <td> <td> <td> <td> <td> <td> <td> <td> <td> <td> <td> <td> <td> <td> <td> <td> 
GT   : <table> <thead> <tr> <td> <td> </tr> </thead> <tbody> <tr> <td> <td> </tr> <tr> <td> <td> </tr> <tr> <td> <td> </tr> <tr> <td> <td> </tr> <tr> <td> <td> </tr> </tbody> </table>
--------------------------------------------------------------------------------
IMAGE: 1809.07129.table_1.png
PRED : <table> <thead> <tr> <td> <td> <td> <td> <td> <td> <td> <td> <td> <td> <td> <td> <td> <td> <td> <td> <td> <td> <td> <td> <td> <td> <td> <td> <td> <t