In [None]:
# Cell 1: Setup
!wget -O data.tar.gz "https://aistages-api-public-prod.s3.amazonaws.com/app/Competitions/000377/data/data.tar.gz"
!tar -xzf data.tar.gz
!pip install -q segmentation-models-pytorch albumentations opencv-python-headless pyclipper shapely

In [None]:
# Cell 2: Imports + Config
import os, json, cv2, gc
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.amp import autocast, GradScaler
from torchvision.models import resnet50, ResNet50_Weights
import albumentations as A
from albumentations.pytorch import ToTensorV2
from tqdm import tqdm
import pyclipper
from shapely.geometry import Polygon as ShapelyPolygon

os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
torch.backends.cudnn.benchmark = True

BASE = './data/datasets'
TRAIN_IMG = os.path.join(BASE, 'images/train')
VAL_IMG = os.path.join(BASE, 'images/val')
TEST_IMG = os.path.join(BASE, 'images/test')
TRAIN_JSON = os.path.join(BASE, 'jsons/train.json')
VAL_JSON = os.path.join(BASE, 'jsons/val.json')
TEST_JSON = os.path.join(BASE, 'jsons/test.json')
SAMPLE_SUB = os.path.join(BASE, 'sample_submission.csv')

DEVICE = torch.device('cuda')
print(f"GPU: {torch.cuda.get_device_name(0)}")
print(f"VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

SZ = 1024
BS = 4
ACCUM = 8
EPOCHS = 25
LR = 1e-3
SHRINK = 0.4
K_AMP = 50

In [None]:
# Cell 3: DBNet Model (ResNet50 + FPN + Differentiable Binarization)
class DBNet(nn.Module):
    def __init__(self):
        super().__init__()
        bb = resnet50(weights=ResNet50_Weights.IMAGENET1K_V2)
        self.stem = nn.Sequential(bb.conv1, bb.bn1, bb.relu, bb.maxpool)
        self.layer1 = bb.layer1   # 256,  /4
        self.layer2 = bb.layer2   # 512,  /8
        self.layer3 = bb.layer3   # 1024, /16
        self.layer4 = bb.layer4   # 2048, /32

        C = 256
        self.lat4 = nn.Conv2d(2048, C, 1)
        self.lat3 = nn.Conv2d(1024, C, 1)
        self.lat2 = nn.Conv2d(512, C, 1)
        self.lat1 = nn.Conv2d(256, C, 1)

        self.sm4 = nn.Conv2d(C, 64, 3, padding=1)
        self.sm3 = nn.Conv2d(C, 64, 3, padding=1)
        self.sm2 = nn.Conv2d(C, 64, 3, padding=1)
        self.sm1 = nn.Conv2d(C, 64, 3, padding=1)

        self.prob_head = nn.Sequential(
            nn.Conv2d(C, 64, 3, padding=1, bias=False), nn.BatchNorm2d(64), nn.ReLU(),
            nn.Conv2d(64, 1, 1))
        self.thresh_head = nn.Sequential(
            nn.Conv2d(C, 64, 3, padding=1, bias=False), nn.BatchNorm2d(64), nn.ReLU(),
            nn.Conv2d(64, 1, 1))

    def forward(self, x):
        H, W = x.shape[2:]
        x = self.stem(x)
        c1 = self.layer1(x)
        c2 = self.layer2(c1)
        c3 = self.layer3(c2)
        c4 = self.layer4(c3)

        p4 = self.lat4(c4)
        p3 = self.lat3(c3) + F.interpolate(p4, size=c3.shape[2:], mode='nearest')
        p2 = self.lat2(c2) + F.interpolate(p3, size=c2.shape[2:], mode='nearest')
        p1 = self.lat1(c1) + F.interpolate(p2, size=c1.shape[2:], mode='nearest')

        s = c1.shape[2:]
        fused = torch.cat([
            self.sm1(p1),
            self.sm2(F.interpolate(p2, size=s, mode='nearest')),
            self.sm3(F.interpolate(p3, size=s, mode='nearest')),
            self.sm4(F.interpolate(p4, size=s, mode='nearest')),
        ], dim=1)

        prob_logits = F.interpolate(self.prob_head(fused), (H, W), mode='bilinear', align_corners=False)
        thresh_logits = F.interpolate(self.thresh_head(fused), (H, W), mode='bilinear', align_corners=False)

        prob = torch.sigmoid(prob_logits)
        thresh = torch.sigmoid(thresh_logits)
        binary = torch.sigmoid(K_AMP * (prob - thresh))

        return prob_logits, prob, thresh, binary

model = DBNet().to(DEVICE)
print(f"Params: {sum(p.numel() for p in model.parameters())/1e6:.1f}M")

In [None]:
# Cell 4: GT Generation + Dataset

def shrink_poly(polygon, ratio=0.4):
    try:
        poly = ShapelyPolygon(polygon)
        if not poly.is_valid: poly = poly.buffer(0)
        if poly.area < 1: return None
        D = poly.area * (1 - ratio**2) / (poly.length + 1e-6)
        pco = pyclipper.PyclipperOffset()
        pco.AddPath([(int(p[0]), int(p[1])) for p in polygon],
                    pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON)
        shrunk = pco.Execute(int(-D))
        if not shrunk: return None
        return np.array(shrunk[0], dtype=np.int32)
    except:
        return None

def expand_poly(polygon, ratio=1.5):
    try:
        poly = ShapelyPolygon(polygon)
        if not poly.is_valid: poly = poly.buffer(0)
        if poly.area < 1: return polygon
        D = poly.area * (1 - 1/ratio**2) / (poly.length + 1e-6)
        pco = pyclipper.PyclipperOffset()
        pco.AddPath([(int(p[0]), int(p[1])) for p in polygon],
                    pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON)
        expanded = pco.Execute(int(D))
        if not expanded: return polygon
        return np.array(expanded[0], dtype=np.int32)
    except:
        return polygon

def make_gt(h, w, polygons):
    """Fast GT: prob(shrunk), binary(original), thresh(distance), thresh_mask"""
    gt_binary = np.zeros((h, w), dtype=np.float32)
    gt_prob = np.zeros((h, w), dtype=np.float32)

    for poly in polygons:
        pts = np.array(poly, dtype=np.int32)
        if len(pts) < 3: continue
        cv2.fillPoly(gt_binary, [pts], 1.0)
        shrunk = shrink_poly(pts, SHRINK)
        if shrunk is not None and len(shrunk) >= 3:
            cv2.fillPoly(gt_prob, [shrunk], 1.0)

    bu = gt_binary.astype(np.uint8)
    d_out = cv2.distanceTransform(1 - bu, cv2.DIST_L2, 5)
    d_in  = cv2.distanceTransform(bu, cv2.DIST_L2, 5)
    D = 8.0
    combined = np.where(bu > 0, d_in, d_out)
    gt_thresh = (1.0 - np.clip(combined / D, 0, 1)).astype(np.float32)
    gt_thresh_mask = (combined < D).astype(np.float32)

    return gt_prob, gt_binary, gt_thresh, gt_thresh_mask

class OCRDataset(Dataset):
    def __init__(self, img_dir, json_path, transform=None, is_test=False):
        self.img_dir, self.transform, self.is_test = img_dir, transform, is_test
        with open(json_path, 'r') as f:
            self.data = json.load(f)['images']
        self.names = list(self.data.keys())

    def __len__(self): return len(self.names)

    def __getitem__(self, idx):
        name = self.names[idx]
        img = cv2.imread(os.path.join(self.img_dir, name))
        if img is None: return self.__getitem__((idx+1) % len(self))
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        h, w = img.shape[:2]

        if self.is_test:
            if self.transform:
                img = self.transform(image=img)['image']
            return img, name, (h, w)

        words = self.data[name].get('words', {})
        polys = [words[k]['points'] for k in words if len(words[k].get('points', [])) >= 3]
        gt_p, gt_b, gt_t, gt_tm = make_gt(h, w, polys)

        if self.transform:
            aug = self.transform(image=img, masks=[gt_p, gt_b, gt_t, gt_tm])
            img = aug['image']
            gt_p, gt_b, gt_t, gt_tm = aug['masks']

        return img, gt_p, gt_b, gt_t, gt_tm

train_tf = A.Compose([
    A.Resize(SZ, SZ),
    A.HorizontalFlip(p=0.5),
    A.ShiftScaleRotate(shift_limit=0.05, scale_limit=0.1, rotate_limit=10,
                       border_mode=cv2.BORDER_CONSTANT, p=0.4),
    A.RandomBrightnessContrast(0.2, 0.2, p=0.3),
    A.Normalize(), ToTensorV2()
])
val_tf = A.Compose([A.Resize(SZ, SZ), A.Normalize(), ToTensorV2()])

print("Dataset ready")

In [None]:
# Cell 5: Loss + Training

def dice_loss(pred, target, smooth=1.0):
    p = pred.reshape(-1)
    t = target.reshape(-1)
    inter = (p * t).sum()
    return 1 - (2*inter + smooth) / (p.sum() + t.sum() + smooth)

def db_loss(prob_logits, prob, thresh, binary, gt_p, gt_b, gt_t, gt_tm):
    # Prob loss: BCE(logits) + Dice
    gt_p = gt_p.unsqueeze(1)
    gt_b = gt_b.unsqueeze(1)
    gt_t = gt_t.unsqueeze(1)
    gt_tm = gt_tm.unsqueeze(1)

    l_prob_bce = F.binary_cross_entropy_with_logits(prob_logits, gt_p)
    l_prob_dice = dice_loss(prob, gt_p)
    l_prob = l_prob_bce + l_prob_dice

    # Binary loss: Dice (autocast-safe)
    l_binary = dice_loss(binary, gt_b)

    # Thresh loss: L1 within mask
    if gt_tm.sum() > 0:
        l_thresh = (torch.abs(thresh - gt_t) * gt_tm).sum() / (gt_tm.sum() + 1e-6)
    else:
        l_thresh = torch.tensor(0.0, device=prob.device)

    return l_prob + l_binary + 10.0 * l_thresh

def train():
    torch.cuda.empty_cache(); gc.collect()

    train_ds = OCRDataset(TRAIN_IMG, TRAIN_JSON, train_tf)
    val_ds = OCRDataset(VAL_IMG, VAL_JSON, val_tf)
    train_dl = DataLoader(train_ds, BS, shuffle=True, num_workers=4,
                          pin_memory=True, drop_last=True, persistent_workers=True)
    val_dl = DataLoader(val_ds, BS, shuffle=False, num_workers=4,
                        pin_memory=True, persistent_workers=True)

    optimizer = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=1e-2)
    scheduler = torch.optim.lr_scheduler.OneCycleLR(
        optimizer, max_lr=LR, epochs=EPOCHS,
        steps_per_epoch=len(train_dl)//ACCUM, pct_start=0.1)
    scaler = GradScaler('cuda')
    best = float('inf')

    print(f"=== DBNet Training ===")
    print(f"    ResNet50 + FPN + DB Head, {SZ}x{SZ}")
    print(f"    Batch {BS} x Accum {ACCUM} = {BS*ACCUM}")

    for ep in range(1, EPOCHS+1):
        model.train()
        t_loss = 0
        optimizer.zero_grad(set_to_none=True)

        for i, (imgs, gp, gb, gt, gtm) in enumerate(tqdm(train_dl, desc=f"E{ep}")):
            imgs = imgs.to(DEVICE, non_blocking=True)
            gp = gp.to(DEVICE, non_blocking=True)
            gb = gb.to(DEVICE, non_blocking=True)
            gt = gt.to(DEVICE, non_blocking=True)
            gtm = gtm.to(DEVICE, non_blocking=True)

            with autocast('cuda', dtype=torch.bfloat16):
                pl, p, th, bn = model(imgs)
                loss = db_loss(pl, p, th, bn, gp, gb, gt, gtm) / ACCUM

            scaler.scale(loss).backward()
            if (i+1) % ACCUM == 0:
                scaler.unscale_(optimizer)
                torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                scaler.step(optimizer); scaler.update()
                optimizer.zero_grad(set_to_none=True)
                scheduler.step()
            t_loss += loss.item() * ACCUM

        # Validation
        model.eval()
        v_loss = 0
        with torch.no_grad():
            for imgs, gp, gb, gt, gtm in val_dl:
                imgs = imgs.to(DEVICE)
                gp, gb, gt, gtm = gp.to(DEVICE), gb.to(DEVICE), gt.to(DEVICE), gtm.to(DEVICE)
                with autocast('cuda', dtype=torch.bfloat16):
                    pl, p, th, bn = model(imgs)
                    v_loss += db_loss(pl, p, th, bn, gp, gb, gt, gtm).item()

        avg_t = t_loss / len(train_dl)
        avg_v = v_loss / len(val_dl)
        print(f"E{ep}: train={avg_t:.4f} val={avg_v:.4f} lr={optimizer.param_groups[0]['lr']:.6f}")

        if avg_v < best:
            best = avg_v
            torch.save(model.state_dict(), 'dbnet_best.pth')
            print(f"  -> Saved (best={best:.4f})")

        mem = torch.cuda.max_memory_allocated()/1e9
        print(f"  -> Peak VRAM: {mem:.1f}GB")
        torch.cuda.reset_peak_memory_stats()
        torch.cuda.empty_cache()

    print(f"Done. Best val={best:.4f}")

train()

In [None]:
# Cell 6: Inference + Submission

def inference():
    print("=== DBNet Inference (Multi-Scale TTA) ===")
    model.load_state_dict(torch.load('dbnet_best.pth', map_location=DEVICE))
    model.eval()

    with open(TEST_JSON, 'r') as f:
        test_data = json.load(f)['images']
    names = list(test_data.keys())
    normalize = A.Normalize()

    scales = [896, 1024, 1152]
    predictions = {}

    with torch.no_grad():
        for name in tqdm(names, desc="Inference"):
            img = cv2.imread(os.path.join(TEST_IMG, name))
            if img is None:
                predictions[name] = ""; continue
            img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
            oh, ow = img.shape[:2]

            acc = np.zeros((oh, ow), dtype=np.float32)
            for sc in scales:
                resized = cv2.resize(img, (sc, sc))
                normed = normalize(image=resized)['image']
                t = torch.from_numpy(normed.transpose(2,0,1)).float().unsqueeze(0).to(DEVICE)

                with autocast('cuda', dtype=torch.bfloat16):
                    _, p1, _, _ = model(t)
                    _, p2, _, _ = model(torch.flip(t, [3]))
                    p2 = torch.flip(p2, [3])

                avg = ((p1 + p2) / 2).float().cpu().numpy()[0, 0]
                acc += cv2.resize(avg, (ow, oh))

            final = acc / len(scales)

            # Post-processing
            binary = (final > 0.6).astype(np.uint8)
            kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (3, 3))
            binary = cv2.morphologyEx(binary, cv2.MORPH_OPEN, kernel, iterations=1)

            contours, _ = cv2.findContours(binary, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
            polys = []
            for cnt in contours:
                if cv2.contourArea(cnt) < 80: continue
                rect = cv2.minAreaRect(cnt)
                if min(rect[1]) < 4: continue

                eps = 0.01 * cv2.arcLength(cnt, True)
                approx = cv2.approxPolyDP(cnt, eps, True)
                if len(approx) >= 4:
                    pts = approx.reshape(-1, 2)
                else:
                    pts = np.int32(cv2.boxPoints(rect))

                # Expand polygon (recover from shrunk training)
                expanded = expand_poly(pts, ratio=1.5)
                if len(expanded) >= 4:
                    polys.append(expanded.tolist())

            if polys:
                parts = [" ".join(f"{int(p[0])} {int(p[1])}" for p in poly) for poly in polys]
                predictions[name] = "|".join(parts)
            else:
                predictions[name] = ""

    df = pd.read_csv(SAMPLE_SUB)
    df['polygons'] = df['filename'].map(predictions).fillna("")
    df.to_csv('submission_dbnet.csv', index=False)

    counts = df['polygons'].apply(lambda x: len(x.split('|')) if x else 0)
    print(f"Saved: submission_dbnet.csv")
    print(f"Avg polygons/img: {counts.mean():.1f}, Min: {counts.min()}, Max: {counts.max()}")

inference()