In [None]:
from logging import WARNING, INFO

from torch.utils.data import default_collate

from src.core.utils import MyTrainer, filter_lightning_logs
from src.data.lfw_pairs import LFWPairsDev, LFWPairsTest
from src.models.facenet import FacenetBackbone
from src.data.sllfw_pairs import SLLFWPairsTest
from src.models.similarity.threshold_siamese import ThresholdSiamese

In [None]:
def test(data, pretrained, log_level=WARNING):
    with filter_lightning_logs(log_level):
        model = ThresholdSiamese(FacenetBackbone(pretrained=pretrained))
        trainer = MyTrainer()
        trainer.fit(model, data.fit_dataloader())
        print(f'Chosen threshold: {model.threshold}')

        return trainer.test(model, data)

def evaludate(datamodule, pretrained, log_level=WARNING):
    def iter():
        for fold in datamodule.folds():
            yield test(fold, pretrained, log_level)
    
    return default_collate(list(iter()))


# Testing on Dev View

In [None]:
test(LFWPairsDev.load('fit'), 'vggface2', INFO)

# Testing on Test View

## VGGFace2 - LFW

In [None]:
results = evaludate(LFWPairsTest, 'vggface2')

In [None]:
print(f'mean benign accuracy: {results[0]["test/0/Accuracy"].mean()}')

## CASIA-WebFace - LFW

In [None]:
results = evaludate(LFWPairsTest, 'casia-webface')

In [None]:
print(f'mean benign accuracy: {results[0]["test/0/Accuracy"].mean()}')

## VGGFace2 - SLLFW

In [None]:
results = evaludate(SLLFWPairsTest, 'vggface2')

In [None]:
print(f'mean benign accuracy: {results[0]["test/0/Accuracy"].mean()}')

## CASIA-WebFace - SLLFW

In [None]:
results = evaludate(SLLFWPairsTest, 'casia-webface')

In [None]:
print(f'mean benign accuracy: {results[0]["test/0/Accuracy"].mean()}')