In this tutorial we will learn to:
- Instantiate a DeepPrintExtractor
- Prepare a training dataset
- Train a DeepPrintExtractor

## Instantiate a DeepPrintExtractor

This package implements a number of variants of the DeepPrint architecture. The wrapper class for all these variants is called `DeepPrintExtractor`.
It has a `fit` method to train (and save) the model as well as an `extract` method to extract the DeepPrint features for fingerprint images. 

You can also try to implement your own models, but currently this is not directly supported by the package.

In [12]:
from flx.data.dataset import IdentifierSet, Identifier
from flx.extractor.fixed_length_extractor import get_DeepPrint_TexMinu,get_DeepPrint_Minu,get_DeepPrint_LocTex, DeepPrintExtractor

# We will use the example dataset with 10 subjects and 10 impression per subject
training_ids: IdentifierSet = IdentifierSet([Identifier(i, j) for i in range(100,110) for j in range(8)])

# We choose a dimension of 512 for the fixed-length representation (TexMinu has two outputs num_dims)
extractor: DeepPrintExtractor = get_DeepPrint_LocTex(num_training_subjects=training_ids.num_subjects, num_texture_dims=256)

Created IdentifierSet with 10 subjects and a total of 80 samples.


In [5]:
training_ids

<flx.data.dataset.IdentifierSet at 0x7f0cafe1f190>

## Training the model

Instantiating the model was easy. To train it, first we will load the training data (see the [data tutorial](./dataset_tutorial.ipynb) for how to implement your own dataset).

Besides the fingerprint images, we also need a mapping from subjects to integer labels (for pytorch). For some variants we also need minutiae data. To see how a more complex dataset can be loaded, have a look at `flx/setup/datasets.py`.

Finally, we call the `fit` method, which trains the model and saves it to the specified path.

There is also the option to add a validation set, which will be used to evaluate the embeddings during training. This is useful to monitor the training progress and to avoid overfitting.
In this example we will not use a validation set for simplicity.

In [11]:
import os

import torch 

from flx.data.dataset import *
from flx.data.image_loader import SFingeLoader
from flx.data.minutia_map_loader import SFingeMinutiaMapLoader
from flx.data.label_index import LabelIndex
from flx.data.transformed_image_loader import TransformedImageLoader
from flx.image_processing.binarization import LazilyAllocatedBinarizer
from flx.data.image_helpers import pad_and_resize_to_deepprint_input_size

# NOTE: If this does not work, enter the absolute paths manually here! 
DATASET_DIR: str = os.path.abspath("fvc_dataset/fvc2002_DB3_B capacitive/converted")
MODEL_OUTDIR: str = os.path.abspath("trial_model")

# We will use the SFingeLoader to load the images from the dataset
image_loader = TransformedImageLoader(
        images=SFingeLoader(DATASET_DIR),
        poses=None,
        transforms=[
            LazilyAllocatedBinarizer(5.0),
            pad_and_resize_to_deepprint_input_size,
        ],
    )
# print(training_ids.num_subjects)

image_dataset = Dataset(image_loader, training_ids)

# For pytorch, we need to map the subjects to integer labels from [0 ... num_subjects-1]
label_dataset = Dataset(LabelIndex(training_ids), training_ids)

minutia_maps_dataset = Dataset(SFingeMinutiaMapLoader(DATASET_DIR), training_ids)
print("alka")
extractor.fit(
    fingerprints=image_dataset,
    minutia_maps=minutia_maps_dataset,
    labels=label_dataset,
    validation_fingerprints=None,
    validation_benchmark=None,
    num_epochs=5,
    out_dir=MODEL_OUTDIR
)

Created IdentifierSet with 10 subjects and a total of 80 samples.
Created IdentifierSet with 10 subjects and a total of 80 samples.
alka
Using device cpu
No model file found at /home/rs/21CS91R01/research/fixed-length-fingerprint-extractors/notebooks/trial_model/model.pyt


 --- Starting Epoch 1 of 5 ---

Training:


100%|██████████| 5/5 [02:20<00:00, 28.18s/it]


Average Loss: 1604164.1298899048
Multiclass accuracy: 0.125
TrainingLogEntry(
    epoch=1,
    training_loss=1604164.1298899048,
    loss_statistics={'minutia_loss': {'crossent_loss_sum': 0.23300681114196778, 'center_loss_sum': 0.22791186571121216}, 'minutia_map_loss': 100259.79719943623},
    training_accuracy=0.125,
    validation_equal_error_rate=0.125,
}


 --- Starting Epoch 2 of 5 ---

Training:


100%|██████████| 5/5 [02:42<00:00, 32.47s/it]


Average Loss: 190913.8847218323
Multiclass accuracy: 0.03750000149011612
TrainingLogEntry(
    epoch=2,
    training_loss=190913.8847218323,
    loss_statistics={'minutia_loss': {'crossent_loss_sum': 0.08071020245552063, 'center_loss_sum': 0.09443754255771637}, 'minutia_map_loss': 5965.883749809265},
    training_accuracy=0.03750000149011612,
    validation_equal_error_rate=0.03750000149011612,
}


 --- Starting Epoch 3 of 5 ---

Training:


100%|██████████| 5/5 [02:48<00:00, 33.68s/it]


Average Loss: 53685.13687746048
Multiclass accuracy: 0.07500000298023224
TrainingLogEntry(
    epoch=3,
    training_loss=53685.13687746048,
    loss_statistics={'minutia_loss': {'crossent_loss_sum': 0.055977290868759154, 'center_loss_sum': 0.05175041953722636}, 'minutia_map_loss': 1118.332623901367},
    training_accuracy=0.07500000298023224,
    validation_equal_error_rate=0.07500000298023224,
}


 --- Starting Epoch 4 of 5 ---

Training:


100%|██████████| 5/5 [01:57<00:00, 23.55s/it]


Average Loss: 32334.814621772763
Multiclass accuracy: 0.11249999701976776
TrainingLogEntry(
    epoch=4,
    training_loss=32334.814621772763,
    loss_statistics={'minutia_loss': {'crossent_loss_sum': 0.041099635511636735, 'center_loss_sum': 0.032095634192228314}, 'minutia_map_loss': 505.1582831954955},
    training_accuracy=0.11249999701976776,
    validation_equal_error_rate=0.11249999701976776,
}


 --- Starting Epoch 5 of 5 ---

Training:


100%|██████████| 5/5 [02:41<00:00, 32.32s/it]


Average Loss: 32432.93615842819
Multiclass accuracy: 0.05000000074505806
TrainingLogEntry(
    epoch=5,
    training_loss=32432.93615842819,
    loss_statistics={'minutia_loss': {'crossent_loss_sum': 0.031037132143974303, 'center_loss_sum': 0.021133530139923095}, 'minutia_map_loss': 405.3595313186645},
    training_accuracy=0.05000000074505806,
    validation_equal_error_rate=0.05000000074505806,
}


In [14]:
import os

import torch 

from flx.data.dataset import *
from flx.data.image_loader import SFingeLoader
from flx.data.minutia_map_loader import SFingeMinutiaMapLoader
from flx.data.label_index import LabelIndex
from flx.data.transformed_image_loader import TransformedImageLoader
from flx.image_processing.binarization import LazilyAllocatedBinarizer
from flx.data.image_helpers import pad_and_resize_to_deepprint_input_size

# NOTE: If this does not work, enter the absolute paths manually here! 
DATASET_DIR: str = os.path.abspath("fvc_dataset/fvc2002_DB3_B capacitive/converted")
MODEL_OUTDIR: str = os.path.abspath("loc_tex")

# We will use the SFingeLoader to load the images from the dataset
image_loader = TransformedImageLoader(
        images=SFingeLoader(DATASET_DIR),
        poses=None,
        transforms=[
            LazilyAllocatedBinarizer(5.0),
            pad_and_resize_to_deepprint_input_size,
        ],
    )
# print(training_ids.num_subjects)

image_dataset = Dataset(image_loader, training_ids)

# For pytorch, we need to map the subjects to integer labels from [0 ... num_subjects-1]
label_dataset = Dataset(LabelIndex(training_ids), training_ids)

minutia_maps_dataset = Dataset(SFingeMinutiaMapLoader(DATASET_DIR), training_ids)
print("alka")
extractor.fit(
    fingerprints=image_dataset,
    minutia_maps=minutia_maps_dataset,
    labels=label_dataset,
    validation_fingerprints=None,
    validation_benchmark=None,
    num_epochs=5,
    out_dir=MODEL_OUTDIR
)

Created IdentifierSet with 10 subjects and a total of 80 samples.
Created IdentifierSet with 10 subjects and a total of 80 samples.
alka
Using device cpu
No model file found at /home/rs/21CS91R01/research/fixed-length-fingerprint-extractors/notebooks/loc_tex/model.pyt


 --- Starting Epoch 1 of 5 ---

Training:


  0%|          | 0/5 [00:00<?, ?it/s]

100%|██████████| 5/5 [05:09<00:00, 61.89s/it]


Average Loss: 5.457501411437988
Multiclass accuracy: 0.05000000074505806
TrainingLogEntry(
    epoch=1,
    training_loss=5.457501411437988,
    loss_statistics={'crossent_loss_sum': 0.14773398041725158, 'center_loss_sum': 0.19335986077785491},
    training_accuracy=0.05000000074505806,
    validation_equal_error_rate=0.05000000074505806,
}


 --- Starting Epoch 2 of 5 ---

Training:


100%|██████████| 5/5 [05:03<00:00, 60.61s/it]


Average Loss: 4.7597064018249515
Multiclass accuracy: 0.05000000074505806
TrainingLogEntry(
    epoch=2,
    training_loss=4.7597064018249515,
    loss_statistics={'crossent_loss_sum': 0.07405644208192826, 'center_loss_sum': 0.07468437999486924},
    training_accuracy=0.05000000074505806,
    validation_equal_error_rate=0.05000000074505806,
}


 --- Starting Epoch 3 of 5 ---

Training:


100%|██████████| 5/5 [05:22<00:00, 64.57s/it]


Average Loss: 4.375380802154541
Multiclass accuracy: 0.0625
TrainingLogEntry(
    epoch=3,
    training_loss=4.375380802154541,
    loss_statistics={'crossent_loss_sum': 0.04918948113918305, 'center_loss_sum': 0.04196428606907527},
    training_accuracy=0.0625,
    validation_equal_error_rate=0.0625,
}


 --- Starting Epoch 4 of 5 ---

Training:


100%|██████████| 5/5 [04:53<00:00, 58.64s/it]


Average Loss: 4.029953575134277
Multiclass accuracy: 0.11249999701976776
TrainingLogEntry(
    epoch=4,
    training_loss=4.029953575134277,
    loss_statistics={'crossent_loss_sum': 0.03668036088347435, 'center_loss_sum': 0.026287664100527762},
    training_accuracy=0.11249999701976776,
    validation_equal_error_rate=0.11249999701976776,
}


 --- Starting Epoch 5 of 5 ---

Training:


100%|██████████| 5/5 [04:22<00:00, 52.44s/it]


Average Loss: 3.792448043823242
Multiclass accuracy: 0.07500000298023224
TrainingLogEntry(
    epoch=5,
    training_loss=3.792448043823242,
    loss_statistics={'crossent_loss_sum': 0.029407902359962462, 'center_loss_sum': 0.017997698187828065},
    training_accuracy=0.07500000298023224,
    validation_equal_error_rate=0.07500000298023224,
}
