In [None]:
import sys

sys.path.append("..")

from math import ceil

import torch
import torch.nn.functional as F
from IPython.display import display
from torch import optim, nn
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.models import efficientnet_v2_s
from torchvision.transforms import InterpolationMode

from tri_dataset import TriDataset
from tri_net import TriNet

from comparator import Comparator

In [None]:
train_path = "dataset/train/"
test_path = "dataset/test"
batch_size = 32
lr = 0.0001
momentum = 0.9
epochs = 100
input_size = 224

aug_rotation = 10
aug_translate = (0.05, 0.05)
aug_scale = (1, 1)

augmentation = transforms.RandomAffine(
	degrees=aug_rotation,
	translate=aug_translate,
	scale=aug_scale,
	interpolation=InterpolationMode.BILINEAR,
	fill=255
)

transform = transforms.Compose([
	augmentation,
	lambda image: Comparator.preprocess_image(image, input_size),
	# lambda image: display(image) or image,
	Comparator.image_to_tensor()
])

In [None]:
def train():
	# architecture = efficientnet_v2_s()
	# architecture.features[0][0] = nn.Conv2d(1, 24, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
	# model = TriNet(architecture)
	model = torch.load("weights/best.pt")

	dataset = TriDataset(train_path, transform=transform)
	data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
	optimizer = optim.Adam(model.parameters(), lr=lr)

	loss_fn = nn.TripletMarginWithDistanceLoss(
		distance_function=lambda x, y: (1.0 - F.cosine_similarity(x, y))
	)

	model.to(device)
	display(model)

	iter_count = ceil(len(dataset) / batch_size)

	for epoch in range(epochs):
		model.train()
		i = 0
		running_loss = 0

		for batch_idx, (anchor, positive, negative) in enumerate(data_loader):
			optimizer.zero_grad()

			anchor = anchor.to(device)
			positive = positive.to(device)
			negative = negative.to(device)

			(anchor, positive, negative) = model.forward_tri(anchor, positive, negative)
			loss = loss_fn(anchor, positive, negative)

			loss.backward()
			optimizer.step()

			if i < iter_count:
				print('\r', end='')

			i += 1
			loss_item = loss.item()

			print(f"Epoch: {epoch + 1}/{epochs}, Iter: {i}/{iter_count}, Loss: {loss_item}", end='')

			running_loss += loss_item
			if i == iter_count:
				print(f"\nAvg: {running_loss / iter_count}")
				running_loss = 0

		if (epoch + 1) % 100 == 0:
			torch.save(model, f"model_{epoch + 1}.pt")
	return model

In [None]:
def test():
	dataset = TriDataset(test_path, transform=transform)
	data_loader = DataLoader(dataset, batch_size=1, shuffle=False)
	model.eval()
	with torch.no_grad():
		for batch_idx, (anchor, positive, negative) in enumerate(data_loader):
			anchor = anchor.to(device)
			positive = positive.to(device)
			negative = negative.to(device)
			anchor, positive, negative = model.forward_tri(anchor, positive, negative)

			pos = (1 - nn.functional.cosine_similarity(anchor, positive)) / 2
			neg = (1 - nn.functional.cosine_similarity(anchor, negative)) / 2

			print(
				f"Pos: {pos.item()}, Neg: {neg.item()}")


In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
model = train()
torch.save(model, "model.pt")

We utilize the cosine similarity criterion to compare two signatures. The closer the value is to 0, the more alike the signatures appear to be.
In the next cell `Pos` is the similarity between `anchor` and `positive` images (should be close to 0), `Neg` represents the similarity between `anchor` and `negative` images.

In [None]:
model = torch.load("weights/best.pt")
test()

In [None]:
print("Hello, World!")

In [None]:
model = torch.load("weights/best.pt")
test()

In [None]:
print("Hello, World!")