# Image retrieval
Based on [image retrieval approach by filip radenovic](https://github.com/filipradenovic/cnnimageretrieval-pytorch)

In [None]:
import os
import numpy as np
import torch
from torchvision import transforms
import matplotlib.pyplot as plt

from models.cirtorch_network import init_network, extract_vectors
from dataset_loaders.txt_to_db import get_images, get_points
from evaluate import get_files
from dataset_loaders.utils import load_image

In [None]:
state = torch.load('data/teacher_models/retrievalSfM120k-resnet101-gem-b80fb85.pth')

In [None]:
net_params = {}
net_params['architecture'] = state['meta']['architecture']
net_params['pooling'] = state['meta']['pooling']
net_params['local_whitening'] = state['meta'].get('local_whitening', False)
net_params['regional'] = state['meta'].get('regional', False)
net_params['whitening'] = state['meta'].get('whitening', False)
net_params['mean'] = state['meta']['mean']
net_params['std'] = state['meta']['std']
net_params['pretrained'] = False
# load network
net = init_network(net_params)
net.load_state_dict(state['state_dict'])
if 'Lw' in state['meta']:
    net.meta['Lw'] = state['meta']['Lw']
print(net.meta_repr())

In [None]:
# setting up the multi-scale parameters
ms = list(eval('[1]'))
if len(ms)>1 and net.meta['pooling'] == 'gem' and not net.meta['regional'] and not net.meta['whitening']:
    msp = net.pool.p.item()
    print(">> Set-up multiscale:")
    print(">>>> ms: {}".format(ms))            
    print(">>>> msp: {}".format(msp))
else:
    msp = 1

In [None]:
if torch.cuda.is_available():
    net.cuda()
net.eval()
# set up the transform
normalize = transforms.Normalize(
    mean=net.meta['mean'],
    std=net.meta['std']
)
transform = transforms.Compose([
    transforms.ToTensor(),
    normalize
])
Lw = None

In [None]:
images = get_images()

In [None]:
dataroot = 'data/AachenDayNight/images_upright'
image_names = [os.path.join(dataroot, img.name) for img in images.values()]
augmented_names = [os.path.join('data/AachenDayNight/AugmentedNightImages_high_res', img.name.replace('.jpg', '.png').replace('db/', '')) for img in images.values()]
query_image_names = get_files('data/AachenDayNight/images_upright/query', '*.jpg')

im_size = 1024

data_desc_path = 'data/cirtorch_data_descs.npy'
augmented_desc_path = 'data/cirtorch_augmented_descs.npy'

In [None]:
if os.path.exists(data_desc_path):
    print('Loading data from path', end='')
    vecs = np.load(data_desc_path)
    print('\rData loaded')
else:
    vecs = extract_vectors(net, image_names, im_size, transform, ms=ms, msp=msp)
    np.save(data_desc_path, vecs.cpu().numpy())

In [None]:
if os.path.exists(augmented_desc_path):
    print('Loading data from path', end='')
    vecs = np.load(augmented_desc_path)
    print('\rData loaded')
else:
    vecs = extract_vectors(net, augmented_names, im_size, transform, ms=ms, msp=msp)
    np.save(augmented_desc_path, vecs.cpu().numpy())

In [None]:
qvecs = extract_vectors(net, query_image_names, im_size, transform, ms=ms, msp=msp)

In [None]:
vecs = vecs.numpy()
qvecs = qvecs.numpy()

In [None]:
print(qvecs.shape)

## Tripletnet

In [None]:
from models.cirtorch_utils.genericdataset import PointCloudImagesFromList, PCDataLoader
import models.pointnet2_classification as ptnet


In [None]:
points3d = get_points()

In [None]:
log_dir = 'logs/triplet_baseline_w_schedule_1/'
#log_dir = 'logs/triplet_baseline_w_schedule_no_normalize/'
epoch = 19
trptnet = ptnet.NetAachen()
trptnet.load_state_dict(torch.load(os.path.join(log_dir, 'ptnet_epoch_{:03d}.pth.tar'.format(epoch)))['model_state_dict'])
trptnet.eval()
trcnnet = init_network({'architecture' : 'resnet34'})
trcnnet.load_state_dict(torch.load(os.path.join(log_dir, 'cnn2d_epoch_{:03d}.pth.tar'.format(epoch)))['model_state_dict'])
trcnnet.eval()
print('Done')

In [None]:
stats = np.loadtxt('data/img_stats.txt')

In [None]:
normalize = transforms.Normalize(
   mean=stats[0],
   std=stats[1]
)
transform = transforms.Compose([
    transforms.CenterCrop(1),
    transforms.ToTensor(),
    normalize
])

In [None]:
dataset = PointCloudImagesFromList('data/AachenDayNight/images_upright', images, points3d, imsize=1024, transform=transform, triplet=False, min_num_points=100)
dataloader = PCDataLoader(dataset, batch_size=1, shuffle=False)

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
trptnet.to(device)
trcnnet.to(device)
print('Device: {}'.format(device))

In [None]:
torch_norm = lambda x: (x.transpose(0, 1) / torch.norm(x, p=2, dim=1)).transpose(0,1)

In [None]:
recalculate = True

if recalculate:
    img_descs = []
    point_cloud_descs = []
    for i, data in enumerate(dataloader):
        #fv1 = trcnnet(data[0].to(device)).detach().squeeze(0).cpu().numpy()
        #fv2 = trptnet(data[1].to(device)).detach().cpu().squeeze(0).numpy()
        fv1 = torch_norm(trcnnet(data[0].to(device)).unsqueeze(0)).detach().squeeze(0).cpu().numpy()
        fv2 = torch_norm(trptnet(data[1].to(device))).detach().cpu().squeeze(0).numpy()
        img_descs.append(fv1)
        point_cloud_descs.append(fv2)
        print('Difference: {:f}\ttotal norm v1: {}, v2: {}'.format(np.dot(fv1, fv2), np.linalg.norm(fv1), np.linalg.norm(fv2)))
        print('\r{}/{}'.format(i+1, len(dataloader)), end='')
        if i > 5:
            break
    print('')
    img_descs = np.vstack(img_descs)
    print(img_descs.shape)
    np.save('data/triplet_img_descriptors.npy', img_descs)
    point_cloud_descs = np.vstack(point_cloud_descs)
    print(point_cloud_descs.shape)
    np.save('data/triplet_pointnet_descriptors.npy', point_cloud_descs)

In [None]:
for data in dataloader:
    y = trcnnet(data[0].to(device)).detach()
    x = trptnet(data[1].to(device)).detach()
    break
normalize = lambda x: (x.transpose(0, 1) / torch.norm(x, p=2, dim=1)).transpose(0,1)
print(x.size())
print(y.size())
print(x)
#print(y)
x = normalize(x)
if len(y.size()) == 1:
    y = y.unsqueeze(0)
y = normalize(y)
x = x.cpu().numpy()
print(x.shape)
print(np.linalg.norm(x))
y = y.cpu().numpy()
print(y.shape)
print(np.linalg.norm(y, axis=1))