In [31]:
import torch
import numpy
def pr_curve(query_code, retrieval_code, query_targets, retrieval_targets, device):
    """
    P-R curve.

    Args
        query_code(torch.Tensor): Query hash code.
        retrieval_code(torch.Tensor): Retrieval hash code.
        query_targets(torch.Tensor): Query targets.
        retrieval_targets(torch.Tensor): Retrieval targets.
        device (torch.device): Using CPU or GPU.

    Returns
        P(torch.Tensor): Precision.
        R(torch.Tensor): Recall.
    """
    num_query = query_code.shape[0]
    num_bit = query_code.shape[1]
    P = torch.zeros(num_query, num_bit + 1).to(device)
    R = torch.zeros(num_query, num_bit + 1).to(device)
    for i in range(num_query):
        gnd = (query_targets[i].unsqueeze(0).mm(retrieval_targets.t()) > 0).float().squeeze()
        tsum = torch.sum(gnd)
        if tsum == 0:
            continue
        hamm = 0.5 * (retrieval_code.shape[1] - query_code[i, :] @ retrieval_code.t())
        tmp = (hamm <= torch.arange(0, num_bit + 1).reshape(-1, 1).float().to(device)).float()
        total = tmp.sum(dim=-1)
        total = total + (total == 0).float() * 0.1
        t = gnd * tmp
        count = t.sum(dim=-1)
        p = count / total
        r = count / tsum
        P[i] = p
        R[i] = r
    mask = (P > 0).float().sum(dim=0)
    mask = mask + (mask == 0).float() * 0.1
    P = P.sum(dim=0) / mask
    R = R.sum(dim=0) / mask

    return P, R

In [8]:
import os
import torch
if torch.cuda.is_available():
    device = torch.device('cuda:0')
query_code = torch.load("C:\\Users\\annad\\Documents\\Python Scripts\\LPTA_ADSH\\checkpoints\\query_code.t")
database_code  = torch.load("C:\\Users\\annad\\Documents\\Python Scripts\\LPTA_ADSH\\checkpoints\\database_code.t")
query_target = torch.load("C:\\Users\\annad\\Documents\\Python Scripts\\LPTA_ADSH\\checkpoints\\query_targets.t")
database_target = torch.load("C:\\Users\\annad\\Documents\\Python Scripts\\LPTA_ADSH\\checkpoints\\database_targets.t")
precision = torch.load("C:\\Users\\annad\\Documents\\Python Scripts\\LPTA_ADSH\\checkpoints\\Precision.t")
recall = torch.load("C:\\Users\\annad\\Documents\\Python Scripts\\LPTA_ADSH\\checkpoints\\Recall.t")
items = list(recall)
for i in range(len(items)):
    #print(items[i][0], items[i][1])
    print(items[i])
# num_query = query_code.shape[0]
# num_bit = query_code.shape[1]
# P = torch.zeros(num_query, num_bit + 1).to(device)
# P.shape
# pr_curve(query_code, database_code, query_target, database_target, device = device)

tensor(0.6856)
tensor(0.8457)
tensor(0.9920)
tensor(0.9928)
tensor(0.9946)
tensor(0.9996)
tensor(0.9994)
tensor(0.9973)
tensor(0.9978)
tensor(1.0000)
tensor(0.9971)
tensor(0.9964)
tensor(0.9997)
tensor(0.9989)
tensor(0.9987)
tensor(0.9965)
tensor(0.9982)
tensor(0.9994)
tensor(0.9988)
tensor(0.9987)
tensor(0.9979)
tensor(0.9977)
tensor(0.9988)
tensor(0.9981)
tensor(0.9945)
tensor(0.9845)
tensor(0.9860)
tensor(0.9992)
tensor(1.)
tensor(1.)
tensor(1.)
tensor(1.)
tensor(1.)
tensor(1.)
tensor(1.)
tensor(1.)
tensor(1.)
tensor(1.)
tensor(1.)
tensor(1.)
tensor(1.)
tensor(1.)
tensor(1.)
tensor(1.)
tensor(1.)
tensor(1.)
tensor(1.)
tensor(1.)
tensor(1.)
