# Inked fingerprint image Style-Transfer

- This adapts ink style transfer to few contact based images.
- Can retrieve `decoder.pth` from https://github.com/naoto0804/pytorch-AdaIN/releases/tag/v0.0.0 

In [7]:
import os
import math
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.models import vgg19
from torchvision import transforms
from PIL import Image

In [None]:
# AdaIN operation
def adaptive_instance_normalization(content_feat, style_feat, eps=1e-5):
    c_mean = content_feat.mean(dim=[2, 3], keepdim=True)
    c_std  = content_feat.std(dim=[2, 3], keepdim=True) + eps

    s_mean = style_feat.mean(dim=[2, 3], keepdim=True)
    s_std  = style_feat.std(dim=[2, 3], keepdim=True) + eps

    normalized = (content_feat - c_mean) / c_std
    return normalized * s_std + s_mean

# AdaIN Style Transfer Clas
class AdaINStyleTransfer:
    def __init__(self, decoder_path, device='mps'):
        self.device = device

        # Encoder: VGG19 up to relu4_1
        vgg = vgg19(pretrained=True).features
        self.encoder = nn.Sequential(*list(vgg.children())[:21]).to(device).eval()

        for p in self.encoder.parameters():
            p.requires_grad = False

        # Decoder
        self.decoder = self._load_decoder(decoder_path).to(device).eval()

    def _load_decoder(self, path):
        decoder = nn.Sequential(
            nn.ReflectionPad2d(1),
            nn.Conv2d(512, 256, 3),
            nn.ReLU(),
            nn.Upsample(scale_factor=2, mode='nearest'),

            nn.ReflectionPad2d(1),
            nn.Conv2d(256, 256, 3),
            nn.ReLU(),

            nn.ReflectionPad2d(1),
            nn.Conv2d(256, 256, 3),
            nn.ReLU(),

            nn.ReflectionPad2d(1),
            nn.Conv2d(256, 256, 3),
            nn.ReLU(),

            nn.ReflectionPad2d(1),
            nn.Conv2d(256, 128, 3),
            nn.ReLU(),
            nn.Upsample(scale_factor=2, mode='nearest'),

            nn.ReflectionPad2d(1),
            nn.Conv2d(128, 128, 3),
            nn.ReLU(),

            nn.ReflectionPad2d(1),
            nn.Conv2d(128, 64, 3),
            nn.ReLU(),
            nn.Upsample(scale_factor=2, mode='nearest'),

            nn.ReflectionPad2d(1),
            nn.Conv2d(64, 64, 3),
            nn.ReLU(),

            nn.ReflectionPad2d(1),
            nn.Conv2d(64, 3, 3),
        )

        state = torch.load(path, map_location='cpu')
        if isinstance(state, dict) and 'state_dict' in state:
            state = state['state_dict']

        decoder.load_state_dict(state)
        return decoder

    # Image loader (keeps original size)
    def load_image(self, image_path):
        image = Image.open(image_path).convert('RGB')
        w, h = image.size

        pad_h = math.ceil(h / 8) * 8 - h
        pad_w = math.ceil(w / 8) * 8 - w

        transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
        ])

        image = transform(image).unsqueeze(0)
        image = F.pad(image, (0, pad_w, 0, pad_h), mode='reflect')

        return image.to(self.device), (h, w)

    # AdaIN forward pass
    @torch.no_grad()
    def transfer_style(self, content_path, style_path, alpha=0.6):
        content, content_hw = self.load_image(content_path)
        style, _ = self.load_image(style_path)

        content_feat = self.encoder(content)
        style_feat   = self.encoder(style)

        t = adaptive_instance_normalization(content_feat, style_feat)
        t = alpha * t + (1 - alpha) * content_feat

        output = self.decoder(t)

        h, w = content_hw
        output = output[:, :, :h, :w]

        return output

    # Save image
    def save_image(self, tensor, path):
        image = tensor.cpu().squeeze(0)

        image = image * torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
        image = image + torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
        image = image.clamp(0, 1)

        transforms.ToPILImage()(image).save(path)
        print(f"Saved: {path}")

## Run

In [None]:
content_dir = "/path/to/your/optical_fingerprints"
output_dir  = "/path/to/output_inked"

style_images = [
    "src/triplet-distil-net/ink/sample1.png",
    "src/triplet-distil-net/ink/sample2.png",
]

content_images = [
    os.path.join(content_dir, f)
    for f in os.listdir(content_dir)
    if f.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp'))
]

print(f"Found {len(content_images)} content images")

In [None]:
device = "you gpu device here"  # e.g., 'cuda:0' or 'mps'

decoder_path = "/path/to/decoder.pth"
content_dir  = "/path/to/optical_fingerprints"
output_dir   = "/path/to/output_inked"

style_images = [
    "src/triplet-distil-net/ink/sample1.png",
    "src/triplet-distil-net/ink/sample2.png",
]

# Initialize model
model = AdaINStyleTransfer(decoder_path, device=device)

content_images = [
    os.path.join(content_dir, f)
    for f in os.listdir(content_dir)
    if f.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp'))
]

print(f"Found {len(content_images)} images")

In [None]:
# Process images
for idx, content_path in enumerate(content_images):
    style_path = random.choice(style_images)

    filename = os.path.basename(content_path)
    output_path = os.path.join(output_dir, filename)

    print(f"[{idx+1}/{len(content_images)}] {filename} ‚Üê {os.path.basename(style_path)}")

    result = model.transfer_style(
        content_path=content_path,
        style_path=style_path,
        alpha=0.6
    )

    model.save_image(result, output_path)