In [1]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from utils import MakeDataset
from model import SiameseNetwork, ScatteringNetwork

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
from model import SiameseClassifier

In [2]:
# Load model
scatnet = ScatteringNetwork(theta_div = 5, ds = 4)
model = SiameseNetwork(scatnet)

In [3]:
state_dict = torch.load("siamese-models/siam_v2_final.pth")

In [4]:
model.load_state_dict(state_dict)

<All keys matched successfully>

In [5]:
dataset = MakeDataset(dataset_type="ieee",mode = "train")

In [19]:
pos_dist_arr = []
neg_dist_arr = []

In [20]:
for i in range(dataset.__len__()):
    (anchor_img, pos_img, neg_img) = dataset.__getitem__(i)
    (H,W) = anchor_img.shape
    anchor_img = torch.reshape(torch.from_numpy(anchor_img), (1,1,H,W))
    pos_img = torch.reshape(torch.from_numpy(pos_img), (1,1,H,W))
    neg_img = torch.reshape(torch.from_numpy(neg_img), (1,1,H,W))

    anchor = model(anchor_img)
    pos = model(pos_img)
    neg = model(neg_img)

    pos_dist = torch.norm(anchor - pos)
    neg_dist = torch.norm(anchor - neg)
    pos_dist_arr.append(pos_dist.item())
    neg_dist_arr.append(neg_dist.item())

In [21]:
pos_dist_arr = np.array(pos_dist_arr)
neg_dist_arr = np.array(neg_dist_arr)

In [23]:
print("mean pos dist: ", np.mean(pos_dist_arr))
print("std pos dist: ", np.std(pos_dist_arr))
print("mean neg dist", np.mean(neg_dist_arr))
print("std neg dist", np.std(neg_dist_arr))

mean pos dist:  4.272645527408237
std pos dist:  1.23637133218063
mean neg dist 5.865790309414031
std neg dist 1.297391340279686


In [6]:
from utils import LabeledPairDataset

In [7]:
true_positives = 0
true_negatives = 0
predicted_positives = 0
predicted_negatives = 0
correct_positives = 0
correct_negatives = 0

In [8]:
paired_dataset = LabeledPairDataset(mode = "train", dataset_type = "ieee")

In [9]:
dataloader = torch.utils.data.DataLoader(paired_dataset,batch_size=10)

In [11]:
for (i, (label, img1, img2)) in enumerate(dataloader):
    (N,H,W) = img1.shape
    img1 = torch.reshape(img1, (N,1,H,W))
    img2 = torch.reshape(img2, (N,1,H,W))
    f1 = model(img1)
    f2 = model(img2)
    dist = torch.flatten(torch.norm(f1 - f2, dim = 1))
    preds = torch.lt(dist, 5.53*torch.ones(N)).to(torch.float)
    
    true_positives += torch.sum(label)
    predicted_positives += torch.sum(preds)
    correct_positives += torch.sum(torch.mul(label, preds))

    true_negatives += N - torch.sum(label)
    predicted_negatives += N - torch.sum(preds)
    correct_negatives += torch.sum(torch.mul(1 - label, 1 - preds))

In [13]:
print("Accuracy: ", (correct_negatives+correct_positives)/(true_negatives+true_positives))
print("Positives accuracy: ", correct_positives/true_positives)
print("Negatives accuracy: ", correct_negatives/true_negatives)

Accuracy:  tensor(0.7012, dtype=torch.float64)
Positives accuracy:  tensor(0.8510, dtype=torch.float64)
Negatives accuracy:  tensor(0.6359, dtype=torch.float64)
