In [1]:
import sys

sys.path.append("..")

from math import ceil

import torch
import torch.nn.functional as F
from IPython.display import display
from torch import optim, nn
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.models import efficientnet_v2_s
from torchvision.transforms import InterpolationMode

from tri_dataset import TriDataset
from tri_net import TriNet

from comparator import Comparator

In [2]:
train_path = "dataset/train/"
test_path = "dataset/test"
batch_size = 32
lr = 0.0001
momentum = 0.9
epochs = 100
input_size = 224

aug_rotation = 10
aug_translate = (0.05, 0.05)
aug_scale = (1, 1)

augmentation = transforms.RandomAffine(
	degrees=aug_rotation,
	translate=aug_translate,
	scale=aug_scale,
	interpolation=InterpolationMode.BILINEAR,
	fill=255
)

transform = transforms.Compose([
	augmentation,
	lambda image: Comparator.preprocess_image(image, input_size),
	# lambda image: display(image) or image,
	Comparator.image_to_tensor()
])

In [3]:
def train():
	# architecture = efficientnet_v2_s()
	# architecture.features[0][0] = nn.Conv2d(1, 24, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
	# model = TriNet(architecture)
	model = torch.load("weights/best.pt")

	dataset = TriDataset(train_path, transform=transform)
	data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
	optimizer = optim.Adam(model.parameters(), lr=lr)

	loss_fn = nn.TripletMarginWithDistanceLoss(
		distance_function=lambda x, y: (1.0 - F.cosine_similarity(x, y))
	)

	model.to(device)
	display(model)

	iter_count = ceil(len(dataset) / batch_size)

	for epoch in range(epochs):
		model.train()
		i = 0
		running_loss = 0

		for batch_idx, (anchor, positive, negative) in enumerate(data_loader):
			optimizer.zero_grad()

			anchor = anchor.to(device)
			positive = positive.to(device)
			negative = negative.to(device)

			(anchor, positive, negative) = model.forward_tri(anchor, positive, negative)
			loss = loss_fn(anchor, positive, negative)

			loss.backward()
			optimizer.step()

			if i < iter_count:
				print('\r', end='')

			i += 1
			loss_item = loss.item()

			print(f"Epoch: {epoch + 1}/{epochs}, Iter: {i}/{iter_count}, Loss: {loss_item}", end='')

			running_loss += loss_item
			if i == iter_count:
				print(f"\nAvg: {running_loss / iter_count}")
				running_loss = 0

		if (epoch + 1) % 100 == 0:
			torch.save(model, f"model_{epoch + 1}.pt")
	return model

In [4]:
def test():
	dataset = TriDataset(test_path, transform=transform, display_info=True)
	data_loader = DataLoader(dataset, batch_size=1, shuffle=False)
	model.eval()
	with torch.no_grad():
		for batch_idx, (anchor, positive, negative) in enumerate(data_loader):
			anchor = anchor.to(device)
			positive = positive.to(device)
			negative = negative.to(device)
			anchor, positive, negative = model.forward_tri(anchor, positive, negative)

			pos = (1 - nn.functional.cosine_similarity(anchor, positive)) / 2
			neg = (1 - nn.functional.cosine_similarity(anchor, negative)) / 2

			print(
				f"Pos: {pos.item()}, Neg: {neg.item()}")


In [5]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
model = train()
torch.save(model, "model.pt")

TriNet(
  (model): EfficientNet(
    (features): Sequential(
      (0): Conv2dNormActivation(
        (0): Conv2d(1, 24, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
        (1): BatchNorm2d(24, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
        (2): SiLU(inplace=True)
      )
      (1): Sequential(
        (0): FusedMBConv(
          (block): Sequential(
            (0): Conv2dNormActivation(
              (0): Conv2d(24, 24, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
              (1): BatchNorm2d(24, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
              (2): SiLU(inplace=True)
            )
          )
          (stochastic_depth): StochasticDepth(p=0.0, mode=row)
        )
        (1): FusedMBConv(
          (block): Sequential(
            (0): Conv2dNormActivation(
              (0): Conv2d(24, 24, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
              (1): BatchNorm2d(24, eps

Epoch: 1/100, Iter: 21/21, Loss: 0.34575825929641724
Avg: 0.22868129540057408
Epoch: 2/100, Iter: 21/21, Loss: 0.16789892315864563
Avg: 0.193085368900072
Epoch: 3/100, Iter: 21/21, Loss: 0.38388526439666752
Avg: 0.19160312414169312
Epoch: 4/100, Iter: 21/21, Loss: 0.24938370287418365
Avg: 0.21872200497559138
Epoch: 5/100, Iter: 21/21, Loss: 0.057269152253866196
Avg: 0.18192345844138236
Epoch: 6/100, Iter: 21/21, Loss: 0.15463419258594513
Avg: 0.17569203355482646
Epoch: 7/100, Iter: 21/21, Loss: 0.30138558149337778
Avg: 0.19815607581819808
Epoch: 8/100, Iter: 21/21, Loss: 0.23643171787261963
Avg: 0.19819107722668422
Epoch: 9/100, Iter: 21/21, Loss: 0.28539717197418213
Avg: 0.1873427735907691
Epoch: 10/100, Iter: 21/21, Loss: 0.14154356718063354
Avg: 0.18748505555448078
Epoch: 11/100, Iter: 21/21, Loss: 0.28087109327316284
Avg: 0.19498857642923081
Epoch: 12/100, Iter: 21/21, Loss: 0.22590678930282593
Avg: 0.17278869840360822
Epoch: 13/100, Iter: 21/21, Loss: 0.21535073220729828
Avg: 0.18

We utilize the cosine similarity criterion to compare two signatures. The closer the value is to 0, the more alike the signatures appear to be.
In the next cell `Pos` is the similarity between `anchor` and `positive` images (should be close to 0), `Neg` represents the similarity between `anchor` and `negative` images.

In [None]:
model = torch.load("weights/best.pt")
test()

In [None]:
print("Hello, World!")