# Experiment

Here, I want to try to re-implement the whole WordDetectorNN in a single Jupyter Notebook to keep things simple. Let's see if I get that done :-D

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.models.resnet import BasicBlock, ResNet

## First experiments

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

In [None]:
t = torch.tensor([3], device=device)
t

## Helper functions

In [None]:
def count_parameters(net):
    total_params = sum(p.numel() for p in net.parameters())
    trainable_params = sum(p.numel() for p in net.parameters() if p.requires_grad)
    return {
        "total_params": total_params,
        "trainable_params": trainable_params,
    }

## Neural network

In [None]:
# If you were using Bottleneck for other ResNet versions:
# from torchvision.models.resnet import ResNet, BasicBlock, Bottleneck


class ModifiedResNet18(ResNet):
    def __init__(self, **kwargs):
        # Initialize with BasicBlock and standard ResNet-18 layers
        # num_classes is irrelevant here as we won't use the fc layer
        super().__init__(BasicBlock, [2, 2, 2, 2], num_classes=1000, **kwargs)

        # 1. Modify the first convolutional layer for 1-channel (grayscale) input
        # Original resnet.conv1 is Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
        # We need Conv2d(1, 64, ...)
        original_conv1 = self.conv1
        self.conv1 = nn.Conv2d(
            1,
            original_conv1.out_channels,
            kernel_size=original_conv1.kernel_size,
            stride=original_conv1.stride,
            padding=original_conv1.padding,
            bias=False,
        )  # bias is False in original ResNet conv1

        # Optional: If you wanted to initialize weights similarly to torchvision:
        # nn.init.kaiming_normal_(self.conv1.weight, mode='fan_out', nonlinearity='relu')
        # However, if you load custom pretrained weights for the whole model later,
        # this specific initialization might be overwritten.

        # We don't need the final fully connected layer for feature extraction
        del self.fc
        # self.avgpool is also not strictly needed for the U-Net style features,
        # but it doesn't hurt to keep it if not used. You could 'del self.avgpool' too.

    def _forward_impl(self, x: torch.Tensor):
        # This is largely copied from torchvision.models.resnet.ResNet._forward_impl
        # but modified to return intermediate features.

        # See note [TorchScript super()]
        x = self.conv1(x)
        x = self.bn1(x)
        out1 = self.relu(x)  # Corresponds to bb1 in WordDetectorNet (before maxpool)
        x = self.maxpool(out1)

        out2 = self.layer1(x)  # Corresponds to bb2
        out3 = self.layer2(out2)  # Corresponds to bb3
        out4 = self.layer3(out3)  # Corresponds to bb4
        out5 = self.layer4(out4)  # Corresponds to bb5

        # WordDetectorNet expects (bb5, bb4, bb3, bb2, bb1)
        return out5, out4, out3, out2, out1

    def forward(self, x: torch.Tensor):
        return self._forward_impl(x)

Try it out:

In [None]:
backbone = ModifiedResNet18()

H, W = 400, 500
test_input = torch.randn((1, 1, H, W))

output = backbone(test_input)
out5, out4, out3, out2, out1 = output

print("Print output sizes:")
for o in output:
    print("\t", o.shape)

nr_params = count_parameters(backbone)
print(f"Total params: {nr_params['total_params']}")
print(f"Trainable params: {nr_params['trainable_params']}")

Now off to the `WordDetectorNN` (for now just copied from external repo):

In [None]:
class MapOrdering:
    """order of the maps encoding the aabbs around the words"""

    SEG_WORD = 0
    SEG_SURROUNDING = 1
    SEG_BACKGROUND = 2
    GEO_TOP = 3
    GEO_BOTTOM = 4
    GEO_LEFT = 5
    GEO_RIGHT = 6
    NUM_MAPS = 7


def compute_scale_down(input_size, output_size):
    """compute scale down factor of neural network, given input and output size"""
    return output_size[0] / input_size[0]


class UpscaleAndConcatLayer(torch.nn.Module):
    """
    take small map with cx channels
    upscale to size of large map (s*s)
    concat large map with cy channels and upscaled small map
    apply conv and output map with cz channels
    """

    def __init__(self, cx, cy, cz):
        super(UpscaleAndConcatLayer, self).__init__()
        self.conv = torch.nn.Conv2d(cx + cy, cz, 3, padding=1)

    def forward(self, x, y, s):
        x = F.interpolate(x, s)
        z = torch.cat((x, y), 1)
        z = F.relu(self.conv(z))
        return z


class WordDetectorNet(torch.nn.Module):
    input_size = (448, 448)
    output_size = (224, 224)
    scale_down = compute_scale_down(input_size, output_size)

    def __init__(self):
        super(WordDetectorNet, self).__init__()

        # Use the modified ResNet18 for feature extraction
        self.backbone = ModifiedResNet18()
        # All weights in the backbone will be randomly initialized.

        self.up1 = UpscaleAndConcatLayer(512, 256, 256)  # input//16
        self.up2 = UpscaleAndConcatLayer(256, 128, 128)  # input//8
        self.up3 = UpscaleAndConcatLayer(128, 64, 64)  # input//4
        self.up4 = UpscaleAndConcatLayer(64, 64, 32)  # input//2

        self.conv1 = torch.nn.Conv2d(32, MapOrdering.NUM_MAPS, 3, 1, padding=1)

    @staticmethod
    def scale_shape(s, f):
        assert s[0] % f == 0 and s[1] % f == 0
        return s[0] // f, s[1] // f

    def output_activation(self, x, apply_softmax):
        if apply_softmax:
            seg = torch.softmax(
                x[:, MapOrdering.SEG_WORD : MapOrdering.SEG_BACKGROUND + 1], dim=1
            )
        else:
            seg = x[:, MapOrdering.SEG_WORD : MapOrdering.SEG_BACKGROUND + 1]
        geo = torch.sigmoid(x[:, MapOrdering.GEO_TOP :]) * self.input_size[0]
        y = torch.cat([seg, geo], dim=1)
        return y

    def forward(self, x, apply_softmax=False):
        s = x.shape[2:]  # Original image shape HxW
        bb5, bb4, bb3, bb2, bb1 = self.backbone(x)

        y = self.up1(bb5, bb4, self.scale_shape(s, 16))
        # up2 takes y (H/16, 256ch) and bb3 (H/8, 128ch). Upscales y to H/8. Output: H/8, 128ch.
        y = self.up2(y, bb3, self.scale_shape(s, 8))
        # up3 takes y (H/8, 128ch) and bb2 (H/4, 64ch). Upscales y to H/4. Output: H/4, 64ch.
        y = self.up3(y, bb2, self.scale_shape(s, 4))
        # up4 takes y (H/4, 64ch) and bb1 (H/2, 64ch). Upscales y to H/2. Output: H/2, 32ch.
        y = self.up4(y, bb1, self.scale_shape(s, 2))

        y = self.conv1(
            y
        )  # Final convolution to get NUM_MAPS channels. Output: H/2, NUM_MAPS ch.

        return self.output_activation(y, apply_softmax)

Now test it:

In [None]:
net = WordDetectorNet()

H, W = net.input_size
test_input = torch.randn((1, 1, H, W))

output = net(test_input)

print("Print output sizes:", output.shape)

nr_params = count_parameters(net)
print(f"Total params: {nr_params['total_params']}")
print(f"Trainable params: {nr_params['trainable_params']}")

## Data

## Training