<a href="https://colab.research.google.com/github/FireMight/point-cloud-retrieval-from-image/blob/side_images/main.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [0]:
%cd /content
!git clone --recurse-submodules https://github.com/FireMight/point-cloud-retrieval-from-image.git
%cd /content/point-cloud-retrieval-from-image/

In [0]:
!git checkout side_images

In [0]:
%cd /content/point-cloud-retrieval-from-image/
!git pull

In [0]:
# load data
from google.colab import drive

drive.mount('/content/drive')
%cd /content/point-cloud-retrieval-from-image/
!mkdir -p data/oxford/data/reference
!tar -C data/oxford/data/reference -xf '/content/drive/My Drive/ADL4CV/downloads/oxford_dataset/reference/submaps_20m/submaps_20m_processed.tar.xz'
!cp -a '/content/drive/My Drive/ADL4CV/downloads/oxford_dataset/reference/submaps_20m/metadata.csv' data/oxford/data/reference/submaps_20m_processed/
!mkdir -p data/oxford/data/reference/stereo/centre
!tar -C data/oxford/data/reference/stereo/centre -xf '/content/drive/My Drive/ADL4CV/downloads/oxford_dataset/reference/img_center_5/img_20_0-1921.tar.gz'
!mkdir -p data/oxford/data/reference/mono_left
!tar -C data/oxford/data/reference/mono_left -xf '/content/drive/My Drive/ADL4CV/downloads/oxford_dataset/reference/img_left/img_20_0-1528.tar.gz'
!mkdir -p data/oxford/data/reference/mono_right
!tar -C data/oxford/data/reference/mono_right -xf '/content/drive/My Drive/ADL4CV/downloads/oxford_dataset/reference/img_right/img_20_0-1528.tar.gz'

In [0]:
# load pretrained model
!unzip '/content/drive/My Drive/ADL4CV/models/vgg16_netvlad_checkpoint.zip' -d 'models/'

In [0]:
%cd /content/point-cloud-retrieval-from-image/

from importlib import reload
from itertools import chain
import numpy as np
from sklearn.neighbors import KDTree
import matplotlib.pyplot as plt
import time
import sys

import torch
import torch.nn as nn
import torchvision.models as models
import torchvision

import netvlad.netvlad as netvlad
import pointnet.pointnet.model as pointnet
import data.OxfordRobotcarDataset as OxfordDatasetPackage
reload(OxfordDatasetPackage)

net_vlad_path = 'models/vgg16_netvlad_checkpoint/checkpoints/checkpoint.pth.tar'
data_path = 'data/oxford/data/reference/'
results_path = '/content/drive/My Drive/ADL4CV/results/'

In [0]:
class ModifiedNetVLAD(nn.Module):
    def __init__(self, model,out_features):
        super(ModifiedNetVLAD, self).__init__()
        self.vlad = model
        self.fc = nn.Linear(32768, out_features)

        
    def forward(self, x):
        with torch.no_grad():
            x = self.vlad.pool(self.vlad.encoder(x))
            x = x.view((x.shape[0],32768))
        x = self.fc(x)
        return x
    
class CustomScheduler(object):
    def __init__(self,optimizer,factor=0.2,patience=4,verbose=True,threshold=1e-4,min_lr=0,eps=1e-6):
        if not isinstance(optimizer,torch.optim.Optimizer):
            raise TypeError('{} is not an Optimizer'.format(
                type(optimizer).__name__))
        self.optimizer = optimizer
        
        if isinstance(min_lr, list) or isinstance(min_lr, tuple):
            if len(min_lr) != len(optimizer.param_groups):
                raise ValueError("expected {} min_lrs, got {}".format(
                    len(optimizer.param_groups), len(min_lr)))
            self.min_lrs = list(min_lr)
        else:
            self.min_lrs = [min_lr] * len(optimizer.param_groups)
            
        self.factor = factor
        self.patience = patience
        self.verbose = verbose,
        self.threshold = threshold
        self.eps = eps
        self.reset()
    
    def reset(self):
        self.num_bad_epochs = 0
        self.history = [float("inf")]*self.patience
        
    
    def step(self,metric):
        #return true if convergence is detected, otherwise - false
        best = min(self.history)
        self.history.pop(0)
        self.history.append(metric)
        if metric<best*(1-self.threshold):
            self.num_bad_epochs = 0
        else:
            self.num_bad_epochs += 1
        
        if self.num_bad_epochs > self.patience:
            self.num_bad_epochs = 0
            return self._reduce_lr()
        return False
    
    def _reduce_lr(self):
        for i, param_group in enumerate(self.optimizer.param_groups):
            old_lr = float(param_group['lr'])
            new_lr = max(old_lr * self.factor, self.min_lrs[i])
            if old_lr - new_lr > self.eps:
                param_group['lr'] = new_lr
                if self.verbose:
                    print('Reducing learning rate'
                          ' of group {} to {:.4e}.'.format(i, new_lr))
                return False
            else:
                return True

def load_netvlad(checkpoint_path):
    encoder_dim = 512
    encoder = models.vgg16(pretrained=False)
    layers = list(encoder.features.children())[:-2]
    encoder = nn.Sequential(*layers)    
    model = nn.Module()
    model.add_module('encoder', encoder)
    vlad_layer = netvlad.NetVLAD(num_clusters=64, dim=encoder_dim, vladv2=False)
    model.add_module('pool',vlad_layer)
    
    checkpoint = torch.load(checkpoint_path,map_location=lambda storage, loc: storage)
    model.load_state_dict(checkpoint['state_dict'])
    return model

def eval_descriptors(img_net, pcl_net, data_loader, dataset):
    img_descs = np.empty((len(data_loader.dataset),1024),dtype=np.float32)
    pcl_descs = np.empty((len(data_loader.dataset),1024),dtype=np.float32)
    indices = []
    
    # not nice but ok
    use_triplet_cache = dataset.use_triplet
    dataset.use_triplet = False
    
    with torch.no_grad():
        img_net.eval()
        pcl_net.eval()

        for batch_num, data in enumerate(data_loader):
            batch_indices, img, pcl, _ = data
            batch_size = img.size()[0]
            
            img_desc = img_net(img)
            pcl_desc,_,_ = pcl_net(pcl)
            
            img_desc = norm_descriptor(img_desc)
            pcl_desc = norm_descriptor(pcl_desc)
            
            for j in range(0,batch_size):
                img_descs[batch_num*batch_size + j,:] = img_desc[j,:].cpu().detach().numpy().reshape(1,1024)
                pcl_descs[batch_num*batch_size + j,:] = pcl_desc[j,:].cpu().detach().numpy().reshape(1,1024)
                indices.append(batch_indices.data[j])
    
        dataset.use_triplet = use_triplet_cache
        
        return indices, img_descs, pcl_descs
    
def norm_descriptor(desc):
    desc_norms = torch.norm(desc, dim=1, keepdim=True)
    desc = desc / desc_norms
    return desc
    
    
def triplet_loss(anchor, positive, negative, margin=0.5, reduction='mean'):
    distance_positive = (anchor - positive).pow(2).sum(1)  #.pow(.5)
    distance_negative = (anchor - negative).pow(2).sum(1)  #.pow(.5)
    losses = nn.functional.relu(distance_positive - distance_negative + margin)
    
    #print('triplet loss')
    #print(distance_positive)
    #print(distance_negative)
    #print(losses)
    
    if reduction == 'sum':
        return losses.sum()
    else:
        return losses.mean()   
            
def calc_eval_metrics(indices, img_descs, pcl_descs, dataset, n_max=25, d_retr=25.0):
    # idx of img descriptors may not correspong to dataset idx, use the indices_map to 
    # specify the mapping
    N = img_descs.shape[0]
    
    n_max = min(N-1, n_max)
        
    # Build KD Tree of computed pcl descriptors
    leaf_size = max(1, int(N / 10))
    kd_tree = KDTree(pcl_descs, leaf_size=leaf_size, metric='euclidean')
    
    # Initialize metrics
    top1_errors = np.empty(N)
    top5_errors = np.empty(N)
    recall_over_n = np.zeros((N,n_max+1))
    
    
    # Get closest pcl descriptors for every query image
    for i in range(N):
        img_desc = img_descs[i]
        idx_query = indices[i]
        
        indices_retr = kd_tree.query(img_desc.reshape(1, -1), k=n_max+1, 
                                     sort_results=True, return_distance=False)
        indices_retr = [indices[idx_retr] for idx_retr in indices_retr[0]]
        
        #print('Eval: query idx {} retrieved:'.format(idx_query))
        #print(indices_retr)
        
        # Ground truth position: center of j-th submap
        pos_query = dataset.get_center_pos(idx_query)
        
        pos_errors = np.empty(n_max+1)
        #for n, i_retr in enumerate(indices[0]):
        for n, idx_retr in enumerate(indices_retr):
            pos_retr = dataset.get_center_pos(idx_retr)
            pos_errors[n] = np.linalg.norm(pos_query - pos_retr)
        
        
        # Top-1 position error
        top1_errors[i] = pos_errors[0]
        
        # Top-5 avg. position error
        top5_errors[i] = np.sum(pos_errors[:5]) / 5.0
        
        # Recall over n
        for n in range(0, n_max+1):
            if np.any(pos_errors[:n] <= d_retr):
                recall_over_n[i,n] = 1.0
                           
    return top1_errors, top5_errors, recall_over_n
                
            

In [0]:
# Configure run
camera = 'side' # Choose from 'center', 'side', 'all'
num_epochs = 500
epoch_use_triplet = 0
use_small = False
use_same_dataset_size = True # Equal number of img for comparison
small_size = 8

if use_small:
    run_name = '{}_small{}_{}_epochs'.format(camera, small_size, num_epochs)
else:
    run_name = '{}_full_{}_epochs'.format(camera, num_epochs)

In [0]:
# Setup data and models
if not torch.cuda.is_available():
    print('Failed to connect to a GPU. Are you sure you are using the correct runtime type?')

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

#set up models

torch.cuda.empty_cache()

#input: image, output 32K desc
img_net = load_netvlad(net_vlad_path)
#append FC layer to reduce to 1K desc
img_net = ModifiedNetVLAD(img_net,1024)

#input: pcl. output 1K desc
pcl_net = pointnet.PointNetfeat(True,True)

img_net.to(device)
pcl_net.to(device)


# Collect all required image directories
img_data_paths = []
if camera == 'center' or camera == 'all':
    img_data_paths.append(data_path + 'stereo/centre/')
if camera == 'side' or camera == 'all':
    img_data_paths.append(data_path + 'mono_left/')
    img_data_paths.append(data_path + 'mono_right/')
    
# Currently we only use 20m submaps
pcl_data_path = data_path + 'submaps_20m_processed/'

assert len(img_data_paths) > 0
dataset = OxfordDatasetPackage.OxfordRobotcarDataset(img_dirs=img_data_paths,
                                                     pcl_dir=pcl_data_path,
                                                     device=device)

# Use same amount of data for comparison of different cameras
if use_same_dataset_size and camera in ['side', 'all']:
    dataset_used, _ = torch.utils.data.random_split(dataset,[1921,len(dataset)-1921])
else:
    dataset_used = dataset

# 80-10-10 train-val-test split
test_size = int(0.1*len(dataset_used))
train_size = len(dataset_used) - 2*test_size 
train_set, val_set, test_set = torch.utils.data.random_split(dataset_used,
                                               [train_size,test_size,test_size])

# Use small training set to test overfitting
if use_small:
    small_set,_ = torch.utils.data.random_split(train_set,
                                         [small_size,len(train_set)-small_size])
    train_set = small_set
    
# Training with triplet loss requires smaller batch size
if epoch_use_triplet < num_epochs:
    batchsize_train = min(32, len(train_set))
else:
    batchsize_train = min(64, len(train_set))

train_loader = torch.utils.data.DataLoader(train_set, batch_size=batchsize_train, 
                                           shuffle=True)
val_loader = torch.utils.data.DataLoader(val_set, batch_size=64, shuffle=False)

In [0]:
# Train model
optim = torch.optim.Adam(chain(img_net.parameters(),pcl_net.parameters()),lr=1e-4)
scheduler = CustomScheduler(optim,verbose=True)
optim.zero_grad()

train_losses_history = []
val_losses_history = []
val_loss_min = sys.maxsize


print('Start training: ' + run_name)
for i in range(num_epochs):
    # Check if we use triplet loss
    if i < epoch_use_triplet:
        use_triplet = False
    else:
        use_triplet = True
        indices, img_descs, pcl_descs = eval_descriptors(img_net, pcl_net, 
                                                         train_loader, dataset)
        dataset.update_train_descriptors(indices, img_descs, pcl_descs)

    if i == epoch_use_triplet:
        scheduler.reset()

    dataset.use_triplet = use_triplet        

    
    train_loss_sum = 0
    img_net.train()
    pcl_net.train()
    for _, img, pos, neg in train_loader:
        # we need a batch size of at least 2 to run the sample trough PointNet
        if(img.size()[0]==1):
            continue

        img_desc = img_net(img)
        pos_desc,_,_ = pcl_net(pos)

        img_desc = norm_descriptor(img_desc)
        pos_desc = norm_descriptor(pos_desc)

        if use_triplet:
            assert len(neg) > 0
            neg_desc,_,_ = pcl_net(neg) # This line causes the training to fail!!
            neg_desc = norm_descriptor(neg_desc)
            loss = triplet_loss(img_desc,pos_desc,neg_desc,reduction='sum')
        else:
            loss = nn.functional.mse_loss(img_desc,pos_desc,reduction='sum')

        train_loss_sum += loss.detach()
        loss.backward()
        optim.step()

    train_loss = train_loss_sum / len(train_set)
    converged = scheduler.step(train_loss_sum / len(train_set))
    train_losses_history.append(train_loss)

    
    # Validation
    with torch.no_grad():
        val_loss_sum = 0
        img_net.eval()
        pcl_net.eval()

        # Calculate descriptors for neg anchor of triplet loss
        if use_triplet:
            indices, img_descs, pcl_descs = eval_descriptors(img_net, pcl_net, 
                                                             val_loader, dataset)
            dataset.update_train_descriptors(indices, img_descs, pcl_descs)

        for _, img, pos, neg in val_loader:
            # we need a batch size of at least 2 to run the sample trough PointNet
            if(img.size()[0]==1):
                continue
            img_desc = img_net(img)
            pos_desc,_,_ = pcl_net(pos)

            img_desc = norm_descriptor(img_desc)
            pos_desc = norm_descriptor(pos_desc)

            if use_triplet:
                assert len(neg) > 0
                neg_desc,_,_ = pcl_net(neg)
                neg_desc = norm_descriptor(neg_desc)
                loss = triplet_loss(img_desc,pos_desc,neg_desc,reduction='sum')
            else:
                loss = nn.functional.mse_loss(img_desc,pos_desc,reduction='sum')

            val_loss_sum += loss.detach()

        val_loss = val_loss_sum / len(val_set)
        val_losses_history.append(val_loss)

        
    print ("Epoch {}/{}\n".format(i+1,num_epochs) +\
           "training loss:   {:.4}\n".format(train_loss) +\
           "validation loss: {:.4}\n".format(val_loss)) 
    
    
    # Save best performning model
    if val_loss < val_loss_min:
        print('Save new model checkpoint\n')
        torch.save({'img_net_state_dict' : img_net.state_dict(),
                    'pcl_net_state_dict' : pcl_net.state_dict()}, 
                   results_path + run_name + '.pt')
        val_loss_min = val_loss

    
    # Save loss history
    np.save(results_path + run_name + '_train_loss.npy', train_losses_history)
    np.save(results_path + run_name + '_val_loss.npy', val_losses_history)

    #if converged:
    #    print("Convergence after {} epochs".format(i))
    #    break

In [0]:
# Load best performing models
print('Load model from run: ' + run_name)
checkpoint = torch.load(results_path + run_name + '.pt')
img_net.load_state_dict(checkpoint['img_net_state_dict'])
pcl_net.load_state_dict(checkpoint['pcl_net_state_dict'])



# Training performance
indices, img_descs, pcl_descs = eval_descriptors(img_net, pcl_net, train_loader,
                                                 dataset)
eval_metrics = calc_eval_metrics(indices, img_descs, pcl_descs, dataset, 
                                 d_retr=25.0, n_max=25)
top1_errors, top5_errors, recall_over_n = eval_metrics
avg_recall_train = np.average(recall_over_n, axis=0)

print('Avg pos error TRAIN: Top1 {} top5 {}'.format(np.average(top1_errors),
                                                    np.average(top5_errors)))



# Test performance
test_loader = torch.utils.data.DataLoader(test_set, batch_size=64, shuffle=False)
indices, img_descs, pcl_descs = eval_descriptors(img_net, pcl_net, test_loader,
                                                 dataset)
eval_metrics = calc_eval_metrics(indices, img_descs, pcl_descs, dataset, 
                                 d_retr=25.0, n_max=25)
top1_errors, top5_errors, recall_over_n = eval_metrics
avg_recall_test = np.average(recall_over_n, axis=0)

print('Avg pos error TEST: Top1 {} top5 {}'.format(np.average(top1_errors),
                                                   np.average(top5_errors)))


plt.plot(np.arange(1,avg_recall_train.shape[0]), avg_recall_train[1:]*100, 
         label='Train', color='blue')
plt.plot(np.arange(1,avg_recall_test.shape[0]), avg_recall_test[1:]*100, 
         label='Test', color='orange')
plt.xlim(0,max(avg_recall_train.shape[0], avg_recall_test.shape[0]))
plt.xlabel('N - Number of top database candidates')
plt.ylabel('Average Recall @N [%]')
plt.legend()

fig = plt.gcf()
fig.set_size_inches(9.0, 6.0)
fig.savefig(results_path + run_name + '_metrics.png', dpi=100)

In [0]:
# Load loss histories 
train_losses_history = np.load(results_path + run_name + '_train_loss.npy', 
                               allow_pickle=True)
val_losses_history = np.load(results_path + run_name + '_val_loss.npy', 
                             allow_pickle=True)

In [0]:
plt.plot(train_losses_history, label='Train', color='blue')
plt.plot(val_losses_history, label='Val', color='orange')
#plt.yscale('log')
plt.ylim(0,2)
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()

fig = plt.gcf()
fig.set_size_inches(9.0, 6.0)
fig.savefig(results_path + run_name + '_loss.png', dpi=100)