In [3]:
import os
os.chdir("..")  # move up from notebooks to repo root
print("Now in:", os.getcwd())
# Add project root to sys.path so 'src' can be imported
import sys
project_root = os.path.abspath(os.getcwd())
if project_root not in sys.path:
    sys.path.append(project_root)

Now in: c:\Users\faizan\Desktop\EAST


In [4]:
print(os.getcwd())
from src.dataset import EASTDataset
ds = EASTDataset("data/icdar2015/train_images", "data/icdar2015/train_maps", size=512)

print("Dataset length:", len(ds))
print("First 3 image paths:", ds.img_paths[:3])

c:\Users\faizan\Desktop\EAST
Dataset length: 1000
First 3 image paths: ['data/icdar2015/train_images\\img_1.jpg', 'data/icdar2015/train_images\\img_10.jpg', 'data/icdar2015/train_images\\img_100.jpg']


In [5]:
import torch
from torch.utils.data import DataLoader
from src.dataset import EASTDataset

# create a small dataset instance
train_dataset = EASTDataset(
    img_dir="data/icdar2015/train_images",
    map_dir="data/icdar2015/train_maps",
    size=512,              # keep same as paper
    training=True          # enable random cropping/augmentation
)

# try a small dataloader
train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True, num_workers=2)

# fetch one batch
imgs, scores, geos = next(iter(train_loader))
print("Images :", imgs.shape)    # expect [B, 3, 512, 512]
print("Scores :", scores.shape)  # expect [B, 1, 512, 512]
print("Geos   :", geos.shape)    # expect [B, 8, 512, 512] 

# check value ranges
print("Image range:", imgs.min().item(), "to", imgs.max().item())
print("Score unique values:", torch.unique(scores))
print("Geo stats: mean", geos.mean().item(), "std", geos.std().item())

Images : torch.Size([4, 3, 512, 512])
Scores : torch.Size([4, 1, 512, 512])
Geos   : torch.Size([4, 8, 512, 512])
Image range: -2.1179039478302 to 2.640000104904175
Score unique values: tensor([0., 1.])
Geo stats: mean 0.006690621376037598 std 2.7807297706604004


In [7]:
import torch
from src.model import EAST

model = EAST(pretrained=False)
dummy = torch.randn(1, 3, 512, 512)  # batch of one
score, geo = model(dummy)

print("Score map:", score.shape)  # expect [1, 1, 512, 512]
print("Geo map:", geo.shape)      # expect [1, 8, 512, 512]

Score map: torch.Size([1, 1, 512, 512])
Geo map: torch.Size([1, 8, 512, 512])
