In [1]:
import matplotlib.pyplot as plt
import numpy as np
import torch
import torchvision
from torchvision import datasets, transforms, models
from torchvision.models import VGG16_Weights
import dataloader

from pytorch_metric_learning.distances import CosineSimilarity
from pytorch_metric_learning.utils import common_functions as c_f
from pytorch_metric_learning.utils.inference import InferenceModel, MatchFinder

import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from pytorch_metric_learning import distances, losses, miners, reducers, testers
from pytorch_metric_learning.utils.accuracy_calculator import AccuracyCalculator

from data_augmenter import enrich_and_shuffle_dataset

In [2]:
LMDB_PATH_HOST="/home/jovyan/data/HWR.2021-11-08/lmdb.hwr_40-1.0"
TRN_DATA="/home/jovyan/data/HWR.2021-11-08/dataset_gt/pero.ceske_dopisy.ceske_kroniky.embed/final.2021-11-18/lines.filtered_max_width.trn.550.shuf"
TST_DATA="/home/jovyan/data/HWR.2021-11-08/dataset_gt/pero.ceske_dopisy.ceske_kroniky.embed/final.2021-11-18/lines.filtered_max_width.tst.550"
VAL_DATA="/home/jovyan/data/HWR.2021-11-08/dataset_gt/pero.ceske_dopisy.ceske_kroniky.embed/final.2021-11-18/lines.filtered_max_width.val.550.shuf"

In [3]:
def train(model, loss_func, mining_func, device, train_loader, optimizer, epoch):
    model.train()
    for batch_idx, (data, labels) in enumerate(train_loader):
        data, labels = data.to(device), labels.to(device)
        optimizer.zero_grad()
        embeddings = model(data)
        indices_tuple = mining_func(embeddings, labels)
        loss = loss_func(embeddings, labels, indices_tuple)
        loss.backward()
        optimizer.step()
        if batch_idx % 20 == 0:
            print(
                "Epoch {} Iteration {}: Loss = {}, Number of mined triplets = {}".format(
                    epoch, batch_idx, loss, mining_func.num_triplets
                )
            )

In [4]:
### convenient function from pytorch-metric-learning ###
def get_all_embeddings(dataset, model):
    tester = testers.BaseTester()
    return tester.get_all_embeddings(dataset, model)


### compute accuracy using AccuracyCalculator from pytorch-metric-learning ###
def test(train_set, test_set, model, accuracy_calculator):
    train_embeddings, train_labels = get_all_embeddings(train_set, model)
    test_embeddings, test_labels = get_all_embeddings(test_set, model)
    train_labels = train_labels.squeeze(1)
    test_labels = test_labels.squeeze(1)
    print("Computing accuracy")
    accuracies = accuracy_calculator.get_accuracy(
        test_embeddings, test_labels, train_embeddings, train_labels, False
    )
    print("Test set accuracy (Precision@1) = {}".format(accuracies["precision_at_1"]))


In [6]:
# dataset = dataloader.DatasetFromLMDB(lmdb_path=LMDB_PATH_HOST, labels_path=TRN_DATA)

# labels_to_indices = c_f.get_labels_to_indices([row[1] for row in dataset])


In [7]:
from models import MyVGG, MyTransformer

device = torch.device("cuda")

batch_size = 256

dataset1 = dataloader.DatasetFromLMDB(lmdb_path=LMDB_PATH_HOST, labels_path=TRN_DATA)
dataset2 = dataloader.DatasetFromLMDB(lmdb_path=LMDB_PATH_HOST, labels_path=TST_DATA)
dataset3 = dataloader.DatasetFromLMDB(lmdb_path=LMDB_PATH_HOST, labels_path=VAL_DATA)



train_loader = torch.utils.data.DataLoader(
    dataset1, batch_size=batch_size, shuffle=True
)
test_loader = torch.utils.data.DataLoader(dataset2, batch_size=batch_size)

# model = MyVGG().to(device)
model = MyTransformer().to(device)
print(model)


optimizer = optim.Adam(model.parameters(), lr=0.01)
num_epochs = 5


distance = distances.CosineSimilarity()
reducer = reducers.ThresholdReducer(low=0)
loss_func = losses.TripletMarginLoss(margin=0.2, distance=distance, reducer=reducer)
mining_func = miners.TripletMarginMiner(
    margin=0.2, distance=distance, type_of_triplets="semihard"
)
accuracy_calculator = AccuracyCalculator(include=("precision_at_1",), k=1)
### pytorch-metric-learning stuff ###


for epoch in range(1, num_epochs + 1):
    train(model, loss_func, mining_func, device, train_loader, optimizer, epoch)
    test(dataset1, dataset2, model, accuracy_calculator)

MyTransformer(
  (vit_model): VisionTransformer(
    (conv_proj): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
    (encoder): Encoder(
      (dropout): Dropout(p=0.0, inplace=False)
      (layers): Sequential(
        (encoder_layer_0): EncoderBlock(
          (ln_1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
          (self_attention): MultiheadAttention(
            (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
          )
          (dropout): Dropout(p=0.0, inplace=False)
          (ln_2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
          (mlp): MLPBlock(
            (0): Linear(in_features=768, out_features=3072, bias=True)
            (1): GELU(approximate='none')
            (2): Dropout(p=0.0, inplace=False)
            (3): Linear(in_features=3072, out_features=768, bias=True)
            (4): Dropout(p=0.0, inplace=False)
          )
        )
        (encoder_layer_1): EncoderBlock(
          (ln_1

100%|██████████| 434/434 [03:42<00:00,  1.95it/s]
100%|██████████| 33/33 [00:17<00:00,  1.93it/s]


Computing accuracy


  x.storage().data_ptr() + x.storage_offset() * 4)


Test set accuracy (Precision@1) = 0.7904761904761904
Epoch 2 Iteration 0: Loss = 0.09823919087648392, Number of mined triplets = 46421
Epoch 2 Iteration 20: Loss = 0.09242875874042511, Number of mined triplets = 39337
Epoch 2 Iteration 40: Loss = 0.09880223870277405, Number of mined triplets = 37060


100%|██████████| 434/434 [03:42<00:00,  1.95it/s]
100%|██████████| 33/33 [00:17<00:00,  1.91it/s]


Computing accuracy
Test set accuracy (Precision@1) = 0.8647619047619047
Epoch 3 Iteration 0: Loss = 0.09305235743522644, Number of mined triplets = 35003
Epoch 3 Iteration 20: Loss = 0.09560642391443253, Number of mined triplets = 27741
Epoch 3 Iteration 40: Loss = 0.09013506770133972, Number of mined triplets = 25368


100%|██████████| 434/434 [03:42<00:00,  1.95it/s]
100%|██████████| 33/33 [00:17<00:00,  1.93it/s]


Computing accuracy
Test set accuracy (Precision@1) = 0.8961904761904762
Epoch 4 Iteration 0: Loss = 0.0898827314376831, Number of mined triplets = 30065
Epoch 4 Iteration 20: Loss = 0.08946233987808228, Number of mined triplets = 18143
Epoch 4 Iteration 40: Loss = 0.08846427500247955, Number of mined triplets = 22765


100%|██████████| 434/434 [03:42<00:00,  1.95it/s]
100%|██████████| 33/33 [00:17<00:00,  1.92it/s]


Computing accuracy
Test set accuracy (Precision@1) = 0.9152380952380952
Epoch 5 Iteration 0: Loss = 0.08383698761463165, Number of mined triplets = 13078
Epoch 5 Iteration 20: Loss = 0.09222208708524704, Number of mined triplets = 15974
Epoch 5 Iteration 40: Loss = 0.08342385292053223, Number of mined triplets = 13296


100%|██████████| 434/434 [03:42<00:00,  1.95it/s]
100%|██████████| 33/33 [00:17<00:00,  1.93it/s]


Computing accuracy
Test set accuracy (Precision@1) = 0.9304761904761905


In [9]:
torch.save(model.state_dict(), "model/with_transformer_grayscale.pth")