In [None]:
from pathlib import Path
from logging import INFO

from src.data.core.pairs import PairsDataModule
from src.data.mnist import MNIST
from src.models.mlp import MLP, PretrainedMLPBackbone
from src.core.utils import MyTrainer, ROOT_PATH
from src.models.similarity.threshold_siamese import ThresholdSiamese

In [None]:
pretrained_classifier_path = ROOT_PATH / 'pretrained' / 'mnist_3_features' / 'classifier.ckpt'
mnist = MNIST()

# Classifier

In [None]:
model = MLP(features_dim=3, logits=10)
trainer = MyTrainer()
trainer.fit(model, mnist)
trainer.test(model, mnist)
pretrained_classifier_path.parent.mkdir(exist_ok=True, parents=True)
trainer.save_checkpoint(pretrained_classifier_path)

# Siamese Threshold

In [None]:
backbone = PretrainedMLPBackbone(pretrained_classifier_path).eval()
model = ThresholdSiamese(backbone=backbone)
mnist_pairs = PairsDataModule.load('fit', singles=mnist)
trainer = MyTrainer()
trainer.fit(model, mnist_pairs)
trainer.test(model, mnist_pairs)