# EAST Scene Text Detection (Scratch Training) â€“ ICDAR2015

**Objective**  
To build an end-to-end scene text detection pipeline using EAST, train it from scratch,
and establish a baseline for further lightweight detector research.

**Key Focus**
- End-to-end pipeline correctness
- Proper evaluation using ICDAR2015 protocol
- Recording accuracy + efficiency metrics


## Project setup

In [1]:
import os
import sys
import time
import random
import numpy as np
import torch

# Move to project root
PROJECT_ROOT = os.path.abspath("..")
os.chdir(PROJECT_ROOT)

if PROJECT_ROOT not in sys.path:
    sys.path.insert(0, PROJECT_ROOT)

print("Working directory:", os.getcwd())


Working directory: /DATA/akash/akash_cnn/lightweight-text-detector


  from .autonotebook import tqdm as notebook_tqdm


## Experiment CONFIG

In [2]:
# ===============================
# EXPERIMENT CONTROL
# ===============================

EXPERIMENT_NAME = "exp3_imagenet_vgg16_long"

USE_PRETRAINED = True      # True for Exp-2
PRETRAINED_TYPE = "imagenet"     # "imagenet" later

INPUT_SIZE = 512
EPOCHS = 60
BATCH_SIZE = 20
LEARNING_RATE = 1e-4

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"


## Experiment Folder Setup

In [3]:
import os

EXP_ROOT = f"experiments/{EXPERIMENT_NAME}"

WEIGHTS_DIR = f"{EXP_ROOT}/weights"
LOG_DIR = f"{EXP_ROOT}/logs"
RESULTS_DIR = f"{EXP_ROOT}/results"
PRED_DIR = f"{RESULTS_DIR}/predictions"

for d in [WEIGHTS_DIR, LOG_DIR, PRED_DIR]:
    os.makedirs(d, exist_ok=True)

print("Experiment directories ready")


Experiment directories ready


## Dataset Paths

In [4]:
TRAIN_IMG_DIR = "data/icdar2015/ch4_train_images"
TEST_IMG_DIR  = "data/icdar2015/ch4_test_images"


## Model Initialization

In [5]:
from src.models.east import EAST

model = EAST(
    cfg="D",
    weights="imagenet" if USE_PRETRAINED else None
).to(DEVICE)

print(model.__class__.__name__, "initialized")


EAST initialized


## Loss & Optimizer

In [6]:
from src.losses.loss import Loss

criterion = Loss()

optimizer = torch.optim.Adam(
    model.parameters(),
    lr=LEARNING_RATE
)


## DATASET + DATALOADER CELL

In [7]:
from torch.utils.data import DataLoader
from src.data.dataset import Dataset

# -------------------------
# Dataset
# -------------------------
train_dataset = Dataset(
    data_path=TRAIN_IMG_DIR,
    scale=0.25,
    length=INPUT_SIZE
)

# -------------------------
# DataLoader
# -------------------------
train_loader = DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=4,
    pin_memory=True,
    drop_last=True
)

print("Train loader ready")
print("Total training samples:", len(train_dataset))
print("Batches per epoch:", len(train_loader))


Train loader ready
Total training samples: 1000
Batches per epoch: 50


## Training Loop

In [10]:
from tqdm import tqdm
import torch

loss_log = []

for epoch in range(EPOCHS):
    model.train()

    epoch_geo_loss = 0.0
    epoch_cls_loss = 0.0

    pbar = tqdm(
        train_loader,
        desc=f"Epoch [{epoch+1}/{EPOCHS}]",
        dynamic_ncols=True
    )

    for imgs, gt_score, gt_geo, ignored_map in pbar:
        # -------------------------
        # Move to device
        # -------------------------
        imgs = imgs.to(DEVICE)
        gt_score = gt_score.to(DEVICE)
        gt_geo = gt_geo.to(DEVICE)
        ignored_map = ignored_map.to(DEVICE)

        # -------------------------
        # Forward
        # -------------------------
        pred_score, pred_geo = model(imgs)

        # -------------------------
        # EAST loss (guaranteed dict)
        # -------------------------
        loss_dict = criterion(
            gt_score, pred_score,
            gt_geo, pred_geo,
            ignored_map
        )

        geo_loss = loss_dict["geo_loss"]
        cls_loss = loss_dict["cls_loss"]

        total_loss = geo_loss + cls_loss

        # -------------------------
        # Backward
        # -------------------------
        optimizer.zero_grad()
        total_loss.backward()
        optimizer.step()

        # -------------------------
        # Logging
        # -------------------------
        epoch_geo_loss += geo_loss.item()
        epoch_cls_loss += cls_loss.item()

        pbar.set_postfix(
            geo=f"{geo_loss.item():.3f}",
            cls=f"{cls_loss.item():.3f}"
        )

    # -------------------------
    # Epoch summary
    # -------------------------
    avg_geo = epoch_geo_loss / len(train_loader)
    avg_cls = epoch_cls_loss / len(train_loader)

    loss_log.append({
        "epoch": epoch + 1,
        "geo_loss": avg_geo,
        "cls_loss": avg_cls
    })

    print(
        f"\nEpoch {epoch+1} Summary | "
        f"Geo Loss: {avg_geo:.4f} | "
        f"Cls Loss: {avg_cls:.4f}"
    )

    # -------------------------
    # Save checkpoint
    # -------------------------
    torch.save(
        model.state_dict(),
        f"{WEIGHTS_DIR}/epoch_{epoch+1}.pth"
    )

print("\nTraining finished successfully")


Epoch [1/60]:   0%|          | 0/50 [00:02<?, ?it/s]


KeyboardInterrupt: 

## Dataset Statistics

In [None]:
print("===== DATASET STATISTICS =====")
print("Train images:", len(os.listdir(TRAIN_IMG_DIR)))
print("Test images :", len(os.listdir(TEST_IMG_DIR)))


## Model Size

In [None]:
def count_params(model):
    total = sum(p.numel() for p in model.parameters())
    trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
    return total, trainable

total_params, trainable_params = count_params(model)

print("===== MODEL SIZE =====")
print(f"Total params     : {total_params/1e6:.2f} M")
print(f"Trainable params : {trainable_params/1e6:.2f} M")


## FLOPs

In [None]:
from thop import profile

dummy = torch.randn(1, 3, INPUT_SIZE, INPUT_SIZE).to(DEVICE)
flops, _ = profile(model, inputs=(dummy,), verbose=False)

print("===== COMPUTE COST =====")
print(f"GFLOPs @ {INPUT_SIZE}x{INPUT_SIZE}: {flops/1e9:.2f}")


## Single Image Inference

In [None]:
from PIL import Image
from src.detect import detect

model.eval()

img_name = sorted(os.listdir(TEST_IMG_DIR))[0]
img_path = os.path.join(TEST_IMG_DIR, img_name)
img = Image.open(img_path).convert("RGB")

boxes = detect(img, model, DEVICE)

print("Image:", img_name)
print("Detected boxes:", 0 if boxes is None else len(boxes))


## Single Image Visualization

In [None]:
from PIL import ImageDraw
import matplotlib.pyplot as plt

draw = ImageDraw.Draw(img)

if boxes is not None:
    for box in boxes:
        pts = [
            (box[0], box[1]),
            (box[2], box[3]),
            (box[4], box[5]),
            (box[6], box[7])
        ]
        draw.polygon(pts, outline="lime", width=2)

plt.figure(figsize=(8,8))
plt.imshow(img)
plt.axis("off")
plt.title("Qualitative Result (Single Image)")
plt.show()


## Inference Time + FPS + GPU Memory

In [None]:
import time
from torchvision import transforms

torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()

w, h = img.size
img_r = img.resize(((w//32)*32, (h//32)*32))

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,)*3, (0.5,)*3)
])

img_tensor = transform(img_r).unsqueeze(0).to(DEVICE)

with torch.no_grad():
    for _ in range(5):
        _ = model(img_tensor)

torch.cuda.synchronize()
start = time.time()

with torch.no_grad():
    _ = model(img_tensor)

torch.cuda.synchronize()
end = time.time()

infer_time = (end - start) * 1000
fps = 1000 / infer_time
peak_mem = torch.cuda.max_memory_allocated() / (1024**3)

print("===== INFERENCE PERFORMANCE =====")
print(f"Inference time : {infer_time:.2f} ms")
print(f"FPS            : {fps:.2f}")
print(f"Peak GPU mem   : {peak_mem:.2f} GB")


## Metrics JSON

In [None]:
import json

final_metrics = {
    "experiment": EXPERIMENT_NAME,
    "dataset": "ICDAR2015",
    "training": loss_log[-1],
    "model": {
        "params_million": total_params / 1e6,
        "gflops": flops / 1e9
    },
    "inference": {
        "time_ms": infer_time,
        "fps": fps,
        "gpu_mem_gb": peak_mem
    }
}

with open(os.path.join(EXP_ROOT, "metrics.json"), "w") as f:
    json.dump(final_metrics, f, indent=2)

print("metrics.json saved")

## Save Loss Logs

In [None]:
import json

with open(f"{LOG_DIR}/loss_log.json", "w") as f:
    json.dump(loss_log, f, indent=2)

print("Loss log saved")


## Batch Inference

In [None]:
from tqdm import tqdm
from PIL import Image
from src.detect import detect

model.eval()

with torch.no_grad():
    pbar = tqdm(sorted(os.listdir(TEST_IMG_DIR)), desc="Batch inference")

    for name in pbar:
        if not name.lower().endswith((".jpg", ".png")):
            continue

        pbar.set_postfix_str(name)

        img = Image.open(os.path.join(TEST_IMG_DIR, name)).convert("RGB")
        boxes = detect(img, model, DEVICE)

        txt_path = os.path.join(
            PRED_DIR, f"res_{os.path.splitext(name)[0]}.txt"
        )

        with open(txt_path, "w") as f:
            if boxes is not None:
                for box in boxes:
                    f.write(",".join(str(int(v)) for v in box[:8]) + "\n")

print("Batch inference done")


## ZIP Creation

In [None]:
import zipfile

submit_zip = f"{RESULTS_DIR}/submit.zip"

with zipfile.ZipFile(submit_zip, "w", zipfile.ZIP_DEFLATED) as z:
    for file in os.listdir(PRED_DIR):
        if file.endswith(".txt"):
            z.write(os.path.join(PRED_DIR, file), arcname=file)

print("submit.zip created")


In [None]:
!python evaluate/script.py \
-g=evaluate/gt.zip \
-s=experiments/exp1_scratch_vgg16/results/submit.zip


In [None]:
# =====================================================
# FINAL FIXED ICDAR METRIC EXTRACTION (100% SAFE)
# =====================================================

import subprocess
import re
import json
import os

GT_ZIP = "evaluate/gt.zip"
SUBMIT_ZIP = f"{EXP_ROOT}/results/submit.zip"
METRICS_JSON = f"{EXP_ROOT}/metrics.json"

cmd = [
    "python", "evaluate/script.py",
    f"-g={GT_ZIP}",
    f"-s={SUBMIT_ZIP}"
]

result = subprocess.run(
    cmd,
    stdout=subprocess.PIPE,
    stderr=subprocess.STDOUT,
    text=True
)

output = result.stdout
print(output)

# ----------------------------
# Safe metric extractor
# ----------------------------
def find_metric(pattern, text, group_idx):
    m = re.search(pattern, text, re.IGNORECASE)
    return float(m.group(group_idx)) if m else None

# ----------------------------
# Try standard patterns
# ----------------------------
precision = find_metric(
    r"precision[^0-9]*([0-9]+\.[0-9]+)", output, 1
)

recall = find_metric(
    r"recall[^0-9]*([0-9]+\.[0-9]+)", output, 1
)

f1 = find_metric(
    r"(f[- ]?measure|hmean)[^0-9]*([0-9]+\.[0-9]+)", output, 2
)

# ----------------------------
# Fallback: table format
# ----------------------------
if f1 is None:
    rows = re.findall(
        r"\|\s*([0-9]+\.[0-9]+)\s*\|\s*([0-9]+\.[0-9]+)\s*\|\s*([0-9]+\.[0-9]+)\s*\|",
        output
    )
    if rows:
        precision, recall, f1 = map(float, rows[-1])

# ----------------------------
# Final sanity check
# ----------------------------
if f1 is None:
    raise RuntimeError("ICDAR metrics could not be extracted. Check printed output.")

# ----------------------------
# Save to metrics.json
# ----------------------------
with open(METRICS_JSON, "r") as f:
    metrics = json.load(f)

metrics["icdar"] = {
    "precision": precision,
    "recall": recall,
    "f_measure": f1
}

with open(METRICS_JSON, "w") as f:
    json.dump(metrics, f, indent=2)

print("ICDAR metrics saved successfully")
metrics["icdar"]
