In [1]:
import os
import argparse
from glob import glob
import random, shutil, json
from math import log10, ceil

import torch.nn as nn
from torch.utils.data import DataLoader, SubsetRandomSampler
import torchvision.models as models

import faiss

In [2]:
parser = argparse.ArgumentParser(description='pytorch-NetVlad')
parser.add_argument('--mode', type=str, default='train', help='Mode', choices=['train', 'test', 'cluster'])
parser.add_argument('--batchSize', type=int, default=4, help='Number of triplets (query, pos, negs). Each triplet consists of 12 images.')
parser.add_argument('--cacheBatchSize', type=int, default=1, help='Batch size for caching and testing')
parser.add_argument('--cacheRefreshRate', type=int, default=1000, help='How often to refresh cache, in number of queries. 0 for off')
parser.add_argument('--nEpochs', type=int, default=30, help='number of epochs to train for')
parser.add_argument('--start-epoch', default=0, type=int, metavar='N', help='manual epoch number (useful on restarts)')
parser.add_argument('--nGPU', type=int, default=1, help='number of GPU to use.')
parser.add_argument('--optim', type=str, default='SGD', help='optimizer to use', choices=['SGD', 'ADAM'])
parser.add_argument('--lr', type=float, default=0.0001, help='Learning Rate.')
parser.add_argument('--lrStep', type=float, default=5, help='Decay LR ever N steps.')
parser.add_argument('--lrGamma', type=float, default=0.5, help='Multiply LR by Gamma for decaying.')
parser.add_argument('--weightDecay', type=float, default=0.001, help='Weight decay for SGD.')
parser.add_argument('--momentum', type=float, default=0.9, help='Momentum for SGD.')
parser.add_argument('--nocuda', action='store_true', help='Dont use cuda')
parser.add_argument('--threads', type=int, default=8, help='Number of threads for each data loader to use')
parser.add_argument('--seed', type=int, default=123, help='Random seed to use.')
parser.add_argument('--dataPath', type=str, default='/home/ubuntu/Desktop/pytorch-NetVlad/data/', help='Path for centroid data.')
parser.add_argument('--runsPath', type=str, default='/home/ubuntu/Desktop/pytorch-NetVlad/runs/', help='Path to save runs to.')
parser.add_argument('--savePath', type=str, default='checkpoints', help='Path to save checkpoints to in logdir. Default=checkpoints/')
parser.add_argument('--cachePath', type=str, default='/tmp', help='Path to save cache to.')
parser.add_argument('--resume', type=str, default='', help='Path to load checkpoint from, for resuming training or testing.')
parser.add_argument('--ckpt', type=str, default='latest', help='Resume from latest or best checkpoint.', choices=['latest', 'best'])
parser.add_argument('--evalEvery', type=int, default=1, help='Do a validation set run, and save, every N epochs.')
parser.add_argument('--patience', type=int, default=10, help='Patience for early stopping. 0 is off.')
parser.add_argument('--dataset', type=str, default='pittsburgh', help='Dataset to use', choices=['pittsburgh'])
parser.add_argument('--arch', type=str, default='vgg16', help='basenetwork to use', choices=['vgg16', 'alexnet'])
parser.add_argument('--vladv2', action='store_true', help='Use VLAD v2')
parser.add_argument('--pooling', type=str, default='netvlad', help='type of pooling to use', choices=['netvlad', 'max', 'avg'])
parser.add_argument('--num_clusters', type=int, default=64, help='Number of NetVlad clusters. Default=64')
parser.add_argument('--margin', type=float, default=0.1, help='Margin for triplet loss. Default=0.1')
parser.add_argument('--split', type=str, default='val', help='Data split to use for testing. Default is val', choices=['test', 'test250k', 'train', 'val'])
parser.add_argument('--fromscratch', action='store_true', help='Train from scratch rather than using pretrained models')

_StoreTrueAction(option_strings=['--fromscratch'], dest='fromscratch', nargs=0, const=True, default=False, type=None, choices=None, help='Train from scratch rather than using pretrained models', metavar=None)

In [3]:
import torch
import torchvision.transforms as transforms
import torch.utils.data as data

from os.path import join, exists
from scipy.io import loadmat
import numpy as np
from collections import namedtuple
from PIL import Image

from sklearn.neighbors import NearestNeighbors
import h5py

root_dir = '/home/ubuntu/Desktop/visual-localization-challenge-2020'

def input_transform():
    return transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                               std=[0.229, 0.224, 0.225]),
    ])

def get_whole_training_set(onlyDB=False):
    return WholeDatasetFromStruct(input_transform=input_transform())

dbStruct = namedtuple('dbStruct', ['whichSet', 'dataset', 
    'dbImage', 'utmDb', 'qImage', 'utmQ', 'numDb', 'numQ',
    'posDistThr', 'posDistSqThr', 'nonTrivPosDistSqThr'])

class WholeDatasetFromStruct(data.Dataset):
    def __init__(self, input_transform=None):
        super().__init__()

        self.input_transform = input_transform
        
        self.dataset = 'train'
        self.whichSet = 'vlc'
        
        dbImage_path = '/home/ubuntu/Desktop/visual-localization-challenge-2020/indoor_dataset/1f/train/2019-04-16_14-35-00/images'
        self.dbImage = sorted(glob(os.path.join(dbImage_path, '*.jpg')))[:1000]
        self.qImage = sorted(glob(os.path.join(dbImage_path, '*.jpg')))[:1000]
        
        self.images = self.dbImage
        
        filename = '/home/ubuntu/Desktop/visual-localization-challenge-2020/indoor_dataset/1f/train/2019-04-16_14-35-00/groundtruth.hdf5'
        with h5py.File(filename, "r") as f:
            self.dbCameraPose = np.array(f['22970285_pose'])[:1000,:2]
            
        self.nonTrivPosDistSqThr = 625
        self.posDistThr = 25
        
        self.numDb = len(self.dbCameraPose)
        self.numQ = len(self.dbCameraPose)

    def __getitem__(self, index):
        img = Image.open(self.images[index])

        if self.input_transform:
            img = self.input_transform(img)

        return img, index

    def __len__(self):
        return len(self.images)

#     def getPositives(self):
#         # positives for evaluation are those within trivial threshold range
#         #fit NN to find them, search by radius
#         if  self.positives is None:
#             knn = NearestNeighbors(n_jobs=-1)
#             knn.fit(self.dbStruct.utmDb)

#             self.distances, self.positives = knn.radius_neighbors(self.dbStruct.utmQ,
#                     radius=self.dbStruct.posDistThr)

#         return self.positives


In [4]:
opt = parser.parse_args('')
print(opt)

Namespace(arch='vgg16', batchSize=4, cacheBatchSize=1, cachePath='/tmp', cacheRefreshRate=1000, ckpt='latest', dataPath='/home/ubuntu/Desktop/pytorch-NetVlad/data/', dataset='pittsburgh', evalEvery=1, fromscratch=False, lr=0.0001, lrGamma=0.5, lrStep=5, margin=0.1, mode='train', momentum=0.9, nEpochs=30, nGPU=1, nocuda=False, num_clusters=64, optim='SGD', patience=10, pooling='netvlad', resume='', runsPath='/home/ubuntu/Desktop/pytorch-NetVlad/runs/', savePath='checkpoints', seed=123, split='val', start_epoch=0, threads=8, vladv2=False, weightDecay=0.001)


In [5]:
cuda = not opt.nocuda
device = torch.device("cuda" if cuda else "cpu")

random.seed(opt.seed)
np.random.seed(opt.seed)
torch.manual_seed(opt.seed)
if cuda:
    torch.cuda.manual_seed(opt.seed)

In [6]:
whole_train_set = get_whole_training_set(onlyDB=True)

In [7]:
pretrained = not opt.fromscratch

encoder_dim = 512
encoder = models.vgg16(pretrained=pretrained)
# capture only feature part and remove last relu and maxpool
layers = list(encoder.features.children())[:-2]

if pretrained:
    # if using pretrained then only train conv5_1, conv5_2, and conv5_3
    for l in layers[:-5]: 
        for p in l.parameters():
            p.requires_grad = False
    
if opt.mode.lower() == 'cluster' and not opt.vladv2:
        layers.append(L2Norm())

In [8]:
encoder = nn.Sequential(*layers)
model = nn.Module() 
model.add_module('encoder', encoder)
model = model.to(device)

In [9]:
cluster_set = whole_train_set

In [10]:
nDescriptors = 50000
nPerImage = 100
nIm = ceil(nDescriptors/nPerImage)

sampler = SubsetRandomSampler(np.random.choice(len(cluster_set), nIm, replace=False))
data_loader = DataLoader(dataset=cluster_set, 
            num_workers=opt.threads, batch_size=opt.cacheBatchSize, shuffle=False, 
            pin_memory=cuda,
            sampler=sampler)

if not exists(join(opt.dataPath, 'centroids')):
    os.makedirs(join(opt.dataPath, 'centroids'))

In [11]:
initcache = join(opt.dataPath, 'centroids', opt.arch + '_' + cluster_set.dataset + '_' + str(opt.num_clusters) + '_desc_cen.hdf5')
with h5py.File(initcache, mode='w') as h5: 
    with torch.no_grad():
        model.eval()
        print('====> Extracting Descriptors')
        dbFeat = h5.create_dataset("descriptors", 
                    [nDescriptors, encoder_dim], 
                    dtype=np.float32)

        for iteration, (input, indices) in enumerate(data_loader, 1):
            input = input.to(device)
            image_descriptors = model.encoder(input).view(input.size(0), encoder_dim, -1).permute(0, 2, 1)
            
            batchix = (iteration-1)*opt.cacheBatchSize*nPerImage
            for ix in range(image_descriptors.size(0)):
                # sample different location for each image in batch
                sample = np.random.choice(image_descriptors.size(1), nPerImage, replace=False)
                startix = batchix + ix*nPerImage
                dbFeat[startix:startix+nPerImage, :] = image_descriptors[ix, sample, :].detach().cpu().numpy()

            if iteration % 50 == 0 or len(data_loader) <= 10:
                print("==> Batch ({}/{})".format(iteration, 
                    ceil(nIm/opt.cacheBatchSize)), flush=True)
            
            del input, image_descriptors

    print('====> Clustering..')
    niter = 100
    kmeans = faiss.Kmeans(encoder_dim, opt.num_clusters, niter=niter, verbose=False)
    kmeans.train(dbFeat[...])

    print('====> Storing centroids', kmeans.centroids.shape)
    h5.create_dataset('centroids', data=kmeans.centroids)
    print('====> Done!')

====> Extracting Descriptors
==> Batch (50/500)
==> Batch (100/500)
==> Batch (150/500)
==> Batch (200/500)
==> Batch (250/500)
==> Batch (300/500)
==> Batch (350/500)
==> Batch (400/500)
==> Batch (450/500)
==> Batch (500/500)
====> Clustering..
====> Storing centroids (64, 512)
====> Done!
