In this tutorial we will learn to:
- Instantiate a DeepPrintExtractor
- Prepare a training dataset
- Train a DeepPrintExtractor

## Instantiate a DeepPrintExtractor

This package implements a number of variants of the DeepPrint architecture. The wrapper class for all these variants is called `DeepPrintExtractor`.
It has a `fit` method to train (and save) the model as well as an `extract` method to extract the DeepPrint features for fingerprint images. 

You can also try to implement your own models, but currently this is not directly supported by the package.

In [None]:
from flx.data.dataset import IdentifierSet, Identifier
from flx.extractor.fixed_length_extractor import get_DeepPrint_TexMinu, DeepPrintExtractor

# We will use the example dataset with 10 subjects and 10 impression per subject
training_ids: IdentifierSet = IdentifierSet([Identifier(i, j) for i in range(10) for j in range(10)])

# We choose a dimension of 512 for the fixed-length representation (TexMinu has two outputs num_dims)
extractor: DeepPrintExtractor = get_DeepPrint_TexMinu(num_training_subjects=training_ids.num_subjects, num_dims=256)

## Training the model

Instantiating the model was easy. To train it, first we will load the training data (see the [data tutorial](./dataset_tutorial.ipynb) for how to implement your own dataset).

Besides the fingerprint images, we also need a mapping from subjects to integer labels (for pytorch). For some variants we also need minutiae data. To see how a more complex dataset can be loaded, have a look at `flx/setup/datasets.py`.

Finally, we call the `fit` method, which trains the model and saves it to the specified path.

There is also the option to add a validation set, which will be used to evaluate the embeddings during training. This is useful to monitor the training progress and to avoid overfitting.
In this example we will not use a validation set for simplicity.

In [None]:
import os

import torch 

from flx.data.dataset import *
from flx.data.image_loader import SFingeLoader
from flx.data.minutia_map_loader import SFingeMinutiaMapLoader
from flx.data.label_index import LabelIndex
from flx.data.transformed_image_loader import TransformedImageLoader
from flx.image_processing.binarization import LazilyAllocatedBinarizer
from flx.data.image_helpers import pad_and_resize_to_deepprint_input_size

# NOTE: If this does not work, enter the absolute paths manually here! 
DATASET_DIR: str = os.path.abspath("example-dataset")
MODEL_OUTDIR: str = os.path.abspath("example-model")

# We will use the SFingeLoader to load the images from the dataset
image_loader = TransformedImageLoader(
        images=SFingeLoader(DATASET_DIR),
        poses=None,
        transforms=[
            LazilyAllocatedBinarizer(5.0),
            pad_and_resize_to_deepprint_input_size,
        ],
    )

image_dataset = Dataset(image_loader, training_ids)

# For pytorch, we need to map the subjects to integer labels from [0 ... num_subjects-1]
label_dataset = Dataset(LabelIndex(training_ids), training_ids)

minutia_maps_dataset = Dataset(SFingeMinutiaMapLoader(DATASET_DIR), training_ids)

extractor.fit(
    fingerprints=image_dataset,
    minutia_maps=minutia_maps_dataset,
    labels=label_dataset,
    validation_fingerprints=None,
    validation_benchmark=None,
    num_epochs=20,
    out_dir=MODEL_OUTDIR
)