In [10]:
from PIL import Image
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

from DeepPrint.model import DeepPrintNet

In [11]:
# Params for training
num_classes = 2500
device = 'cpu'

model = DeepPrintNet(num_classes)

In [12]:
sample_input = torch.randn(1, 1, 448, 448)
output = model(sample_input)
embedding, map, aligment, aligned, R1, R2, logits_r1, logits_r2 = output.values()

In [13]:
# Variance Scaling in torch
# TODO: check this lol
def weights_init(m):
    if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
        nn.init.kaiming_uniform_(m.weight, nonlinearity='relu')
        if m.bias is not None:
            nn.init.zeros_(m.bias)

model.apply(weights_init)

DeepPrintNet(
  (localization): LocalizationNetwork(
    (conv): Sequential(
      (0): Conv2d(1, 24, kernel_size=(5, 5), stride=(1, 1))
      (1): ReLU()
      (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (3): Conv2d(24, 32, kernel_size=(3, 3), stride=(1, 1))
      (4): ReLU()
      (5): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (6): Conv2d(32, 48, kernel_size=(3, 3), stride=(1, 1))
      (7): ReLU()
      (8): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (9): Conv2d(48, 64, kernel_size=(3, 3), stride=(1, 1))
      (10): ReLU()
      (11): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    )
    (fc): Sequential(
      (0): Flatten(start_dim=1, end_dim=-1)
      (1): Linear(in_features=2304, out_features=64, bias=True)
      (2): ReLU()
      (3): Linear(in_features=64, out_features=3, bias=True)
    )
  )
  (sampler): GridSampler()
  (stem): 

In [14]:
class CenterLoss(nn.Module):
    def __init__(self, num_classes, feat_dim, device='cpu'):
        super().__init__()
        self.num_classes = num_classes
        self.feat_dim = feat_dim
        self.centers = nn.Parameter(torch.randn(num_classes, feat_dim).to(device))

    def forward(self, features, labels):
        # features: [B, feat_dim], labels: [B]
        batch_size = features.size(0)
        centers_batch = self.centers[labels] # [B, feat_dim]
        loss = ((features - centers_batch) ** 2).sum() / batch_size
        return loss
    
def deepprint_loss(
    output,
    labels,
    minutiae_map_gt,
    center_loss_r1,
    center_loss_r2,
    λ1=1.0,
    λ2=0.00125,
    λ3=0.095,
):
    # L1: Cross-entropy on both branches
    loss_ce1 = F.cross_entropy(output['logits1'], labels)
    loss_ce2 = F.cross_entropy(output['logits2'], labels)
    L1 = loss_ce1 + loss_ce2

    # L2: Center loss on both branches
    L2_1 = center_loss_r1(output['R1'], labels)
    L2_2 = center_loss_r2(output['R2'], labels)
    L2 = L2_1 + L2_2

    # L3: Minutiae map MSE
    # output['minutiae_map'] shape: [B, 6, 192, 192]
    # minutiae_map_gt shape:      [B, 6, 192, 192]
    L3 = F.mse_loss(output['minutiae_map'], minutiae_map_gt)

    # Weighted sum
    total_loss = λ1 * L1 + λ2 * L2 + λ3 * L3

    return total_loss, {'L1': L1.item(), 'L2': L2.item(), 'L3': L3.item()}

In [15]:
base_lr = 0.001  # You can tune this as needed

center_loss_r1 = CenterLoss(num_classes, 96, device=device)
center_loss_r2 = CenterLoss(num_classes, 96, device=device)

optimizer = torch.optim.RMSprop(
    [
        {'params': model.parameters()},
        {'params': center_loss_r1.parameters(), 'lr': 0.5 * base_lr},
        {'params': center_loss_r2.parameters(), 'lr': 0.5 * base_lr},
    ],
    lr=base_lr,
    weight_decay=0.00004
)