In [None]:
import torch
import numpy as np
import pandas as pd
from torch.utils.data import DataLoader
from dataset import TestDataset
import config
from model import SiameseNet
from torch.utils.data import DataLoader
from sklearn.metrics import roc_curve, accuracy_score


In [None]:
PATH = config.PATH
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = SiameseNet().to(device)
csv_path='csv/test_imgs_labels.csv'

model.load_state_dict(torch.load(PATH, map_location=device))


In [None]:
dataset=TestDataset(
    csv_path,
    transform=config.transform,
)

sampleloader = DataLoader(dataset=dataset, batch_size=1, shuffle=True)

probs = []
target = []
with torch.no_grad():
    for batch_idx, (anchor, sample, label) in enumerate(sampleloader):
        anchor, sample = anchor.to(device), sample.to(device)

        anchor_out, sample_out = model.forward_prediction(anchor, sample)
        ASD = torch.sigmoid((anchor_out - sample_out).pow(2).sum(1))
        probs.append(ASD.cpu().detach().numpy())
        target.append(label.detach().numpy())

In [None]:
probs=[item.tolist()[0] for item in probs]
target= [items.tolist()[0] for items in target]

fpr, tpr, thresholds = roc_curve(target, probs)

gmean = np.sqrt(tpr * (1 - fpr))
index = np.argmax(gmean)
thresholdOpt = round(thresholds[index], ndigits=4)


In [None]:
pred = [1 if x >= thresholdOpt else 0 for x in probs]

In [None]:
print('Accuracy:', accuracy_score(target,pred))