In [54]:
from PIL import Image
import numpy as np
import torchvision.transforms.functional as TF

import torch
import torch.nn as nn
import torch.nn.functional as F

from torchvision import models
from torchvision.models import inception_v3, Inception_V3_Weights

In [None]:
def conv_bn_relu(in_channels, out_channels, kernel_size, stride=1, padding=0):
    return nn.Sequential(
        nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding),
        nn.BatchNorm2d(out_channels),
        nn.ReLU(inplace=True)
    )

class InceptionV4Stem(nn.Module):
    def __init__(self, in_channels):
        super(InceptionV4Stem, self).__init__()
        self.conv2d_1a_3x3 = nn.Conv2d(in_channels, 32, 3, stride=2, padding=0, bias=False)

        self.conv2d_2a_3x3 = nn.Conv2d(32, 32, 3, stride=1, padding=0, bias=False)
        self.conv2d_2b_3x3 = nn.Conv2d(32, 64, 3, stride=1, padding=1, bias=False)

        self.mixed_3a_branch_0 = nn.MaxPool2d(3, stride=2, padding=0)
        self.mixed_3a_branch_1 = nn.Conv2d(64, 96, 3, stride=2, padding=0, bias=False)

        self.mixed_4a_branch_0 = nn.Sequential(
            nn.Conv2d(160, 64, 1, stride=1, padding=0, bias=False),
            nn.Conv2d(64, 96, 3, stride=1, padding=0, bias=False),
        )
        self.mixed_4a_branch_1 = nn.Sequential(
            nn.Conv2d(160, 64, 1, stride=1, padding=0, bias=False),
            nn.Conv2d(64, 64, (1, 7), stride=1, padding=(0, 3), bias=False),
            nn.Conv2d(64, 64, (7, 1), stride=1, padding=(3, 0), bias=False),
            nn.Conv2d(64, 96, 3, stride=1, padding=0, bias=False)
        )

        self.mixed_5a_branch_0 = nn.Conv2d(192, 192, 3, stride=2, padding=0, bias=False)
        self.mixed_5a_branch_1 = nn.MaxPool2d(3, stride=2, padding=0)

    def forward(self, x):
        x = self.conv2d_1a_3x3(x) # 149 x 149 x 32
        x = self.conv2d_2a_3x3(x) # 147 x 147 x 32
        x = self.conv2d_2b_3x3(x) # 147 x 147 x 64
        x0 = self.mixed_3a_branch_0(x)
        x1 = self.mixed_3a_branch_1(x)
        x = torch.cat((x0, x1), dim=1) # 73 x 73 x 160
        x0 = self.mixed_4a_branch_0(x)
        x1 = self.mixed_4a_branch_1(x)
        x = torch.cat((x0, x1), dim=1) # 71 x 71 x 192
        x0 = self.mixed_5a_branch_0(x)
        x1 = self.mixed_5a_branch_1(x)
        x = torch.cat((x0, x1), dim=1) # 35 x 35 x 384
        return x



class LocalizationNetwork(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(1, 24, 5), nn.ReLU(), nn.MaxPool2d(2),
            nn.Conv2d(24, 32, 3), nn.ReLU(), nn.MaxPool2d(2),
            nn.Conv2d(32, 48, 3), nn.ReLU(), nn.MaxPool2d(2),
            nn.Conv2d(48, 64, 3), nn.ReLU(), nn.MaxPool2d(2)
        )
        self.fc = nn.Sequential(
            nn.Flatten(),
            nn.Linear(64*6*6, 64),
            nn.ReLU(),
            nn.Linear(64, 3)
        )

    def forward(self, x):
        x = F.interpolate(x, size=(128, 128), mode='bilinear')
        return self.fc(self.conv(x))


class GridSampler(nn.Module):
    def forward(self, img, params):
        B, _, H, W = img.shape
        tx = params[:, 0] / W
        ty = params[:, 1] / H
        theta = params[:, 2]

        cos_theta = torch.cos(theta)
        sin_theta = torch.sin(theta)

        affine = torch.zeros(B, 2, 3, device=img.device)
        affine[:, 0, 0] = cos_theta
        affine[:, 0, 1] = -sin_theta
        affine[:, 1, 0] = sin_theta
        affine[:, 1, 1] = cos_theta
        affine[:, :, 2] = torch.stack([tx, ty], dim=1)

        grid = F.affine_grid(affine, img.size(), align_corners=False)
        return F.grid_sample(img, grid, align_corners=False)


class MinutiaeEmbeddingHead(nn.Module):
    def __init__(self, in_channels):
        super().__init__()
        self.net = nn.Sequential(
            conv_bn_relu(in_channels, 384, 3, 1, 1),
            conv_bn_relu(384, 768, 3, 2, 1),
            conv_bn_relu(768, 768, 3, 1, 1),
            conv_bn_relu(768, 896, 3, 2, 1),
            conv_bn_relu(896, 1024, 3, 1, 1),
            nn.AdaptiveAvgPool2d(1),
            nn.Flatten(),
            nn.Linear(1024, 96)
        )

    def forward(self, x):
        return self.net(x)


class MinutiaeMapHead(nn.Module):
    def __init__(self, in_channels):
        super().__init__()
        self.net = nn.Sequential(
            nn.ConvTranspose2d(in_channels, 384, 3, stride=2, padding=1, output_padding=1),   # 35 → 70
            conv_bn_relu(384, 128, 3, 1, 1),  # 70×70
            nn.ConvTranspose2d(128, 128, 3, stride=2, padding=1, output_padding=1),            # 70 → 140
            conv_bn_relu(128, 64, 3, 1, 1),   # 140×140
            nn.ConvTranspose2d(64, 32, 3, stride=2, padding=1, output_padding=0),              # 140 → 281 (crop later)
            conv_bn_relu(32, 32, 3, 1, 1),
            nn.Conv2d(32, 6, 1)               # Final 6 channels
        )

    def forward(self, x):
        x = self.net(x)
        # center crop to 192x192
        H, W = x.shape[-2:]
        top = (H - 192) // 2
        left = (W - 192) // 2
        return x[..., top:top+192, left:left+192]



class DeepPrintNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.localization = LocalizationNetwork()
        self.sampler = GridSampler()
        self.stem = InceptionV4Stem(in_channels=1)

        self.minutiae_stack = nn.Sequential(*[conv_bn_relu(384, 384, 3, 1, 1) for _ in range(6)])
        self.minutiae_embed = MinutiaeEmbeddingHead(384)
        self.minutiae_map = MinutiaeMapHead(384)

        self.texture_stack = nn.Sequential(
            conv_bn_relu(384, 384, 3, 1, 1),
            conv_bn_relu(384, 384, 3, 1, 1),
            conv_bn_relu(384, 384, 3, 1, 1)
        )
        self.texture_fc = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Flatten(),
            nn.Linear(384, 96)
        )

    def forward(self, x):
        align_params = self.localization(x)
        aligned = self.sampler(x, align_params)

        x = self.stem(aligned)

        # TODO: E(x) inception A 6x, before going into D(x) and M(x)

        m_feat = self.minutiae_stack(x)
        m_embed = self.minutiae_embed(m_feat)
        m_map = self.minutiae_map(m_feat)

        t_feat = self.texture_stack(x)
        t_embed = self.texture_fc(t_feat)

        embedding = F.normalize(torch.cat([m_embed, t_embed], dim=1), dim=1)

        return {
            'embedding': embedding,
            'minutiae_map': m_map,
            'alignment': align_params,
            'aligned': aligned
        }

In [119]:
model = DeepPrintNet()

In [120]:
sample_input = torch.randn(1, 1, 448, 448)
output = model(sample_input)

In [122]:
embedding, map, aligment, aligned = output.values()
print('embedding shape:', embedding.shape)
print('map shape:', map.shape)
print('aligment shape:', aligment.shape)
print('aligned shape:', aligned.shape)

embedding shape: torch.Size([1, 192])
map shape: torch.Size([1, 6, 192, 192])
aligment shape: torch.Size([1, 3])
aligned shape: torch.Size([1, 1, 448, 448])


In [None]:
create_model_image = False

if create_model_image:
    from torchviz import make_dot

    sample_input = torch.randn(1, 1, 448, 448)
    output = model(sample_input)
    make_dot(output['embedding'], params=dict(model.named_parameters())).render("deepprint_graph", format="png")


torch.Size([1, 6, 110, 110])