In [2]:
from data_loader import get_train_validation_loader, get_test_loader
from config_maker import get_config
from torchvision import transforms
from trainer import Trainer
from model import SiameseNet
import os
import torch
from glob import glob
from tqdm.notebook import tqdm

In [2]:
#
# train_loader, valid_loader = get_train_validation_loader(config.data_dir, config.batch_size,
#                                                                  config.num_train,
#                                                                  config.augment, config.way,
#                                                                  config.valid_trials,
#                                                                  config.shuffle, config.seed,
#                                                                  config.num_workers, config.pin_memory)
#

In [3]:
config = get_config()
is_best = True
model = SiameseNet()

if is_best:
    model_path = os.path.join(config.logs_dir, './models/best_model.pt')
else:
    model_path = sorted(glob(config.logs_dir + './models/model_ckpt_*.pt'), key=len)[-1]

ckpt = torch.load(model_path)

model.load_state_dict(ckpt['model_state'])

# if config.use_gpu:
#     model.cuda()
#     device = 'cuda'
# else:
device = 'cpu'

test_loader = get_test_loader(config.data_dir, config.way, config.test_trials,
                              config.seed, config.num_workers, config.pin_memory)

[*] use GPU Quadro RTX 4000


In [None]:
correct_sum = 0
num_test = test_loader.dataset.trials
print(f"[*] Test on {num_test} pairs.")

pbar = tqdm(enumerate(test_loader), total=num_test, desc="Test")
with torch.no_grad():
    for i, (x1, x2, _) in pbar:

        print(x1,x2)

        if config.use_gpu:
            x1, x2 = x1.to(device), x2.to(device)

        # compute log probabilities
        out = model(x1, x2)

        y_pred = torch.sigmoid(out)
        y_pred = torch.argmax(y_pred)
        if y_pred == 0:
            correct_sum += 1

        pbar.set_postfix_str(f"accuracy: {correct_sum / num_test}")

test_acc = (100. * correct_sum) / num_test
print(f"Test Acc: {correct_sum}/{num_test} ({test_acc:.2f}%)")