# Experiment

Here, I want to try to re-implement the whole WordDetectorNN in a single Jupyter Notebook to keep things simple. Let's see if I get that done :-D

In [None]:
from pathlib import Path

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.models.resnet import BasicBlock, ResNet

## First experiments

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

In [None]:
t = torch.tensor([3], device=device)
t

## Helper functions

In [None]:
def count_parameters(net):
    total_params = sum(p.numel() for p in net.parameters())
    trainable_params = sum(p.numel() for p in net.parameters() if p.requires_grad)
    return {
        "total_params": total_params,
        "trainable_params": trainable_params,
    }

## Neural network

In [None]:
# If you were using Bottleneck for other ResNet versions:
# from torchvision.models.resnet import ResNet, BasicBlock, Bottleneck


class ModifiedResNet18(ResNet):
    def __init__(self, **kwargs):
        # Initialize with BasicBlock and standard ResNet-18 layers
        # num_classes is irrelevant here as we won't use the fc layer
        super().__init__(BasicBlock, [2, 2, 2, 2], num_classes=1000, **kwargs)

        # 1. Modify the first convolutional layer for 1-channel (grayscale) input
        # Original resnet.conv1 is Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
        # We need Conv2d(1, 64, ...)
        original_conv1 = self.conv1
        self.conv1 = nn.Conv2d(
            1,
            original_conv1.out_channels,
            kernel_size=original_conv1.kernel_size,
            stride=original_conv1.stride,
            padding=original_conv1.padding,
            bias=False,
        )  # bias is False in original ResNet conv1

        # Optional: If you wanted to initialize weights similarly to torchvision:
        # nn.init.kaiming_normal_(self.conv1.weight, mode='fan_out', nonlinearity='relu')
        # However, if you load custom pretrained weights for the whole model later,
        # this specific initialization might be overwritten.

        # We don't need the final fully connected layer for feature extraction
        del self.fc
        # self.avgpool is also not strictly needed for the U-Net style features,
        # but it doesn't hurt to keep it if not used. You could 'del self.avgpool' too.

    def _forward_impl(self, x: torch.Tensor):
        # This is largely copied from torchvision.models.resnet.ResNet._forward_impl
        # but modified to return intermediate features.

        # See note [TorchScript super()]
        x = self.conv1(x)
        x = self.bn1(x)
        out1 = self.relu(x)  # Corresponds to bb1 in WordDetectorNet (before maxpool)
        x = self.maxpool(out1)

        out2 = self.layer1(x)  # Corresponds to bb2
        out3 = self.layer2(out2)  # Corresponds to bb3
        out4 = self.layer3(out3)  # Corresponds to bb4
        out5 = self.layer4(out4)  # Corresponds to bb5

        # WordDetectorNet expects (bb5, bb4, bb3, bb2, bb1)
        return out5, out4, out3, out2, out1

    def forward(self, x: torch.Tensor):
        return self._forward_impl(x)

Try it out:

In [None]:
backbone = ModifiedResNet18()

H, W = 400, 500
test_input = torch.randn((1, 1, H, W))

output = backbone(test_input)
out5, out4, out3, out2, out1 = output

print("Print output sizes:")
for o in output:
    print("\t", o.shape)

nr_params = count_parameters(backbone)
print(f"Total params: {nr_params['total_params']}")
print(f"Trainable params: {nr_params['trainable_params']}")

Now off to the `WordDetectorNN` (for now just copied from external repo):

In [None]:
class MapOrdering:
    """order of the maps encoding the aabbs around the words"""

    SEG_WORD = 0
    SEG_SURROUNDING = 1
    SEG_BACKGROUND = 2
    GEO_TOP = 3
    GEO_BOTTOM = 4
    GEO_LEFT = 5
    GEO_RIGHT = 6
    NUM_MAPS = 7


def compute_scale_down(input_size, output_size):
    """compute scale down factor of neural network, given input and output size"""
    return output_size[0] / input_size[0]


class UpscaleAndConcatLayer(torch.nn.Module):
    """
    take small map with cx channels
    upscale to size of large map (s*s)
    concat large map with cy channels and upscaled small map
    apply conv and output map with cz channels
    """

    def __init__(self, cx, cy, cz):
        super(UpscaleAndConcatLayer, self).__init__()
        self.conv = torch.nn.Conv2d(cx + cy, cz, 3, padding=1)

    def forward(self, x, y, s):
        x = F.interpolate(x, s)
        z = torch.cat((x, y), 1)
        z = F.relu(self.conv(z))
        return z


class WordDetectorNet(torch.nn.Module):
    input_size = (448, 448)
    output_size = (224, 224)
    scale_down = compute_scale_down(input_size, output_size)

    def __init__(self):
        super(WordDetectorNet, self).__init__()

        # Use the modified ResNet18 for feature extraction
        self.backbone = ModifiedResNet18()
        # All weights in the backbone will be randomly initialized.

        self.up1 = UpscaleAndConcatLayer(512, 256, 256)  # input//16
        self.up2 = UpscaleAndConcatLayer(256, 128, 128)  # input//8
        self.up3 = UpscaleAndConcatLayer(128, 64, 64)  # input//4
        self.up4 = UpscaleAndConcatLayer(64, 64, 32)  # input//2

        self.conv1 = torch.nn.Conv2d(32, MapOrdering.NUM_MAPS, 3, 1, padding=1)

    @staticmethod
    def scale_shape(s, f):
        assert s[0] % f == 0 and s[1] % f == 0
        return s[0] // f, s[1] // f

    def output_activation(self, x, apply_softmax):
        if apply_softmax:
            seg = torch.softmax(
                x[:, MapOrdering.SEG_WORD : MapOrdering.SEG_BACKGROUND + 1], dim=1
            )
        else:
            seg = x[:, MapOrdering.SEG_WORD : MapOrdering.SEG_BACKGROUND + 1]
        geo = torch.sigmoid(x[:, MapOrdering.GEO_TOP :]) * self.input_size[0]
        y = torch.cat([seg, geo], dim=1)
        return y

    def forward(self, x, apply_softmax=False):
        s = x.shape[2:]  # Original image shape HxW
        bb5, bb4, bb3, bb2, bb1 = self.backbone(x)

        y = self.up1(bb5, bb4, self.scale_shape(s, 16))
        # up2 takes y (H/16, 256ch) and bb3 (H/8, 128ch). Upscales y to H/8. Output: H/8, 128ch.
        y = self.up2(y, bb3, self.scale_shape(s, 8))
        # up3 takes y (H/8, 128ch) and bb2 (H/4, 64ch). Upscales y to H/4. Output: H/4, 64ch.
        y = self.up3(y, bb2, self.scale_shape(s, 4))
        # up4 takes y (H/4, 64ch) and bb1 (H/2, 64ch). Upscales y to H/2. Output: H/2, 32ch.
        y = self.up4(y, bb1, self.scale_shape(s, 2))

        y = self.conv1(
            y
        )  # Final convolution to get NUM_MAPS channels. Output: H/2, NUM_MAPS ch.

        return self.output_activation(y, apply_softmax)

Now test it:

In [None]:
net = WordDetectorNet()

H, W = net.input_size
test_input = torch.randn((1, 1, H, W))

output = net(test_input)

print("Print output sizes:", output.shape)

nr_params = count_parameters(net)
print(f"Total params: {nr_params['total_params']}")
print(f"Trainable params: {nr_params['trainable_params']}")

## Dataset

In [None]:
class BoundingBox:
    def __init__(self, x_min, y_min, x_max, y_max):
        """
        Initialize a bounding box.
        (x_min, y_min): top-left corner
        (x_max, y_max): bottom-right corner
        label: optional class label
        """
        self.x_min = float(x_min)
        self.y_min = float(y_min)
        self.x_max = float(x_max)
        self.y_max = float(y_max)
        
    def translate(self, dx, dy) -> "BoundingBox":
        """Translate the bounding box by (dx, dy)."""
        bbox_translated = BoundingBox(
            self.x_min + dx,
            self.y_min + dy,
            self.x_max + dx,
            self.y_max + dy
        )
        return bbox_translated

    def scale(self, sx, sy) -> "BoundingBox":
        """Scale the bounding box by sx and sy."""
        bbox_scaled = BoundingBox(
            self.x_min * sx,
            self.x_max * sx,
            self.y_min * sy,
            self.y_max * sy,
        )
        return bbox_scaled

    def area(self):
        """Return the area of the bounding box."""
        return max(0.0, self.x_max - self.x_min) * max(0.0, self.y_max - self.y_min)

    def intersect(self, other):
        """Return the intersection area with another bounding box."""
        x_min = max(self.x_min, other.x_min)
        y_min = max(self.y_min, other.y_min)
        x_max = min(self.x_max, other.x_max)
        y_max = min(self.y_max, other.y_max)
        if x_min < x_max and y_min < y_max:
            return (x_max - x_min) * (y_max - y_min)
        return 0.0

    def iou(self, other):
        """Return the Intersection over Union (IoU) with another bounding box."""
        inter = self.intersect(other)
        union = self.area() + other.area() - inter
        if union == 0:
            return 0.0
        return inter / union

    def __repr__(self):
        return f"BoundingBox({self.x_min}, {self.y_min}, {self.x_max}, {self.y_max})"

In [None]:
import pickle
import xml.etree.ElementTree as ET
from typing import List, Tuple
from typing import TypedDict

import cv2
import numpy as np
from torch.utils.data import Dataset
from tqdm import tqdm

# TODO later: Replace print with logging.

class IAM_Dataset_Element(TypedDict):
    image: np.ndarray
    bounding_boxes: List[BoundingBox]
    filename: str

class IAM_Dataset(Dataset):
    """
    Loads, pre-processes, and caches the IAM Handwriting Database.

    This class handles the entire data preparation pipeline. On the first run, it
    processes all images and ground truth files, resizes them, and saves them
    to a cache file for extremely fast loading on subsequent runs.

    Inherits from `torch.utils.data.Dataset`, making it fully compatible with
    PyTorch's DataLoader.
    """

    _GT_DIR_NAME = 'gt'
    _IMG_DIR_NAME = 'img'
    _CACHE_FILENAME = 'dataset_cache.pickle'
    _IMG_EXT = '.png'
    _GT_EXT = '*.xml'

    def __init__(
        self,
        root_dir: Path,
        input_width: int,
        input_height: int,
        force_rebuild_cache: bool = False,
    ):
        """
        Initializes the dataset. Checks for a cache file first. If it doesn't
        exist, it builds one.

        Args:
            root_dir (Path): The root directory of the dataset, containing 'gt' and 'img' subdirectories.
            input_size (Tuple[int, int]): The target (height, width) for the network input images.
            loaded_img_scale (float): A factor to initially scale down images to reduce memory
                                      usage during pre-processing. Default is 0.25.
        """
        super().__init__()
        self.root_dir = root_dir
        self.input_width = input_width
        self.input_height = input_height

        self.img_cache: List[np.ndarray] = []
        self.gt_cache: List[List[BoundingBox]] = []
        self.filename_cache: List[str] = []

        cache_path = self.root_dir / self._CACHE_FILENAME
        if cache_path.exists() and not force_rebuild_cache:
            print(f"Loading cached data from {cache_path}...")
            self._load_from_cache(cache_path)
        else:
            print(f"Cache not found. Building and caching data from {self.root_dir}...")
            self._preprocess_and_cache(cache_path)

    def _load_from_cache(self, cache_path: Path):
        """Loads pre-processed data from a pickle file."""
        with open(cache_path, 'rb') as f:
            self.img_cache, self.gt_cache, self.filename_cache = pickle.load(f)

    def _preprocess_and_cache(self, cache_path: Path):
        """Finds, processes, and caches all data samples."""
        gt_dir = self.root_dir / self._GT_DIR_NAME
        img_dir = self.root_dir / self._IMG_DIR_NAME

        fn_gts = sorted(gt_dir.glob(self._GT_EXT))
        print(f"Found {len(fn_gts)} ground truth files. Processing...")

        for fn_gt in tqdm(fn_gts, desc="Preprocessing IAM Dataset"):
            fn_img = img_dir / (fn_gt.stem + self._IMG_EXT)
            if not fn_img.exists():
                continue

            # Load image and GT
            img = cv2.imread(fn_img, cv2.IMREAD_GRAYSCALE)
            gt = self._parse_gt(fn_gt)

            # Pre-processing pipeline
            raise NotImplementedError
            # img, gt = self._crop_page_to_content(img, gt)
            # img, gt = self._adjust_to_size(img, gt)

            self.img_cache.append(img)
            self.gt_cache.append(gt)
            self.filename_cache.append(fn_gt.stem)

        print(f"Preprocessing complete. Saving cache to {cache_path}...")
        with open(cache_path, 'wb') as f:
            pickle.dump([self.img_cache, self.gt_cache, self.filename_cache], f)
        print("Cache saved successfully.")


    def _parse_gt(self, fn_gt: Path) -> List[BoundingBox]:
        """Parses an XML ground truth file to get word bounding boxes."""
        tree = ET.parse(fn_gt)
        root = tree.getroot()
        aabbs = []

        for line in root.findall("./handwritten-part/line"):
            for word in line.findall('./word'):
                x_min, x_max, y_min, y_max = float('inf'), 0, float('inf'), 0
                components = word.findall('./cmp')
                if not components:
                    continue

                for cmp in components:
                    x = float(cmp.attrib['x'])
                    y = float(cmp.attrib['y'])
                    w = float(cmp.attrib['width'])
                    h = float(cmp.attrib['height'])
                    x_min = min(x_min, x)
                    x_max = max(x_max, x + w)
                    y_min = min(y_min, y)
                    y_max = max(y_max, y + h)
                
                # Scale coordinates to match the initially scaled image
                aabb = BoundingBox(x_min, x_max, y_min, y_max)
                aabbs.append(aabb)
        return aabbs

    def _crop_page_to_content(self, img: np.ndarray, gt: List[BoundingBox]) -> Tuple[np.ndarray, List[BoundingBox]]:
        """Crops the image to the bounding box containing all words."""
        x_min = min(aabb.x_min for aabb in gt)
        x_max = max(aabb.x_max for aabb in gt)
        y_min = min(aabb.y_min for aabb in gt)
        y_max = max(aabb.y_max for aabb in gt)

        gt_crop = [aabb.translate(-x_min, -y_min) for aabb in gt]
        img_crop = img[int(y_min):int(y_max), int(x_min):int(x_max)]
        return img_crop, gt_crop

    def _adjust_to_size(self, img: np.ndarray, gt: List[BoundingBox]) -> Tuple[np.ndarray, List[BoundingBox]]:
        """Resizes the image and AABBs to the final network input size."""
        h, w = img.shape
        fx = self.input_width / w
        fy = self.input_height / h
        
        gt_resized = [aabb.scale(fx, fy) for aabb in gt]
        img_resized = cv2.resize(img, dsize=(self.input_width, self.input_height)) # cv2 uses (w, h)
        return img_resized, gt_resized

    def __len__(self) -> int:
        """Returns the total number of samples in the dataset."""
        return len(self.img_cache)

    def __getitem__(self, idx: int) -> IAM_Dataset_Element:
        """
        Retrieves a sample from the dataset.

        Args:
            idx (int): The index of the sample to retrieve.

        Returns:
            A tuple containing:
            - The pre-processed image as a NumPy array.
            - A list of AABB objects for the ground truth words.
        """
        return {
            'image': self.img_cache[idx],
            'bounding_boxes': self.gt_cache[idx],
            'filename': self.filename_cache[idx],
        }
    
    def store_element_as_image(self, idx: int, output_path: Path, draw_bboxes: bool = False) -> None:
        """
        Saves a dataset element as an image with bounding boxes drawn on it.
        
        Args:
            idx (int): The index of the dataset element to save.
            output_path (Path): The path where the image should be saved.
        """
        # Get the element
        element = self[idx]
        img = element['image'].copy()  # Copy to avoid modifying the cached image
        bboxes = element['bounding_boxes']
        
        # Convert grayscale to BGR for colored bounding boxes
        img_color = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
        
        # Draw bounding boxes
        if draw_bboxes:
            for bbox in bboxes:
                # Convert float coordinates to integers
                x_min = int(bbox.x_min)
                y_min = int(bbox.y_min)
                x_max = int(bbox.x_max)
                y_max = int(bbox.y_max)
                
                # Draw rectangle (green color, thickness=2)
                cv2.rectangle(img_color, (x_min, y_min), (x_max, y_max), (0, 255, 0), 2)
        
        # Save the image
        cv2.imwrite(str(output_path), img_color)


# TODO: Add plotting functions to visualize the dataset samples and their ground truth annotations. In particular useful given many transformations.

In [None]:
# Experiment w/ dataset class
data_path = Path.home() / 'Development/WordDetectorNN/data/train'
dataset = IAM_Dataset(root_dir=data_path, input_width=448, input_height=448, force_rebuild_cache=True)

In [None]:
# TODO: Fix dataset!
dataset[43][1][0] # -> the 0.0 is suspicous
dataset[349][1][0] # -> the 0.0 is suspicous
# -> it seems like the cropping AND bounding box extraction process is the problem

In [None]:
dataset.store_element_as_image(578, 'test.png', draw_bboxes=True)

In [None]:
dataset[578]['filename']

## Dataloader

## Loss

## Training