<a href="https://colab.research.google.com/github/FireMight/point-cloud-retrieval-from-image/blob/master/main.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [0]:
%cd /content
!git clone --recurse-submodules https://github.com/FireMight/point-cloud-retrieval-from-image.git
%cd /content/point-cloud-retrieval-from-image/

In [0]:
%cd /content/point-cloud-retrieval-from-image/
!git pull

In [0]:
# load data
from google.colab import drive

drive.mount('/content/drive')
%cd /content/point-cloud-retrieval-from-image/
!mkdir -p data/oxford/data/reference
!tar -C data/oxford/data/reference -xf '/content/drive/My Drive/ADL4CV/downloads/oxford_dataset/reference/submaps_20m/submaps_20m_processed.tar.xz'
!cp -a '/content/drive/My Drive/ADL4CV/downloads/oxford_dataset/reference/submaps_20m/metadata.csv' data/oxford/data/reference/submaps_20m_processed/
!mkdir -p data/oxford/data/reference/stereo/centre
!tar -C data/oxford/data/reference/stereo/centre -xf '/content/drive/My Drive/ADL4CV/downloads/oxford_dataset/reference/img_center_5/img_20_0-1921.tar.gz'

In [0]:
# load pretrained model
!unzip '/content/drive/My Drive/ADL4CV/models/vgg16_netvlad_checkpoint.zip' -d 'models/'

In [0]:
%cd /content/point-cloud-retrieval-from-image/
from itertools import chain
import numpy as np
from sklearn.neighbors import KDTree
import matplotlib.pyplot as plt
import time

import torch
import torch.nn as nn
import torchvision.models as models
import torchvision

import netvlad.netvlad as netvlad
import pointnet.pointnet.model as pointnet
from data.OxfordRobotcarDataset import OxfordRobotcarDataset

net_vlad_path = 'models/vgg16_netvlad_checkpoint/checkpoints/checkpoint.pth.tar'
img_data_path = 'data/oxford/data/reference/stereo/centre/'
pcl_data_path = 'data/oxford/data/reference/submaps_20m_processed/'

In [0]:
class TripletLoss(nn.Module):
    """
    Triplet loss
    Takes embeddings of an anchor sample, a positive sample and a negative sample
    (https://github.com/adambielski/siamese-triplet)
    """

    def __init__(self, margin = 1.0):
        super(TripletLoss, self).__init__()
        self.margin = margin

    def forward(self, anchor, positive, negative, size_average=True):
        distance_positive = (anchor - positive).pow(2).sum(1)  # .pow(.5)
        distance_negative = (anchor - negative).pow(2).sum(1)  # .pow(.5)
        losses = nn.functional.relu(distance_positive - distance_negative + self.margin)
        return losses.mean() if size_average else losses.sum()

#appends a FC linear to transform output descriptor to appropriate dimenstion
#TODO: make a nice wrapper for NetVLAD
class ModifiedNetVLAD(nn.Module):
    def __init__(self, model,out_features):
        super(ModifiedNetVLAD, self).__init__()
        self.vlad = model
        self.fc = nn.Linear(32768, out_features)

        
    def forward(self, x):
        x = self.vlad.pool(self.vlad.encoder(x))
        x = x.view((x.shape[0],32768))
        x = self.fc(x)
        return x
    
class CustomScheduler(object):
    def __init__(self,optimizer,factor=0.2,patience=4,verbose=True,threshold=1e-4,min_lr=0,eps=1e-6):
        if not isinstance(optimizer,torch.optim.Optimizer):
            raise TypeError('{} is not an Optimizer'.format(
                type(optimizer).__name__))
        self.optimizer = optimizer
        
        if isinstance(min_lr, list) or isinstance(min_lr, tuple):
            if len(min_lr) != len(optimizer.param_groups):
                raise ValueError("expected {} min_lrs, got {}".format(
                    len(optimizer.param_groups), len(min_lr)))
            self.min_lrs = list(min_lr)
        else:
            self.min_lrs = [min_lr] * len(optimizer.param_groups)
            
        self.factor = factor
        self.patience = patience
        self.verbose = verbose,
        self.threshold = threshold
        self.eps = eps
        self.num_bad_epochs = 0
        self.history = [float("inf")]*patience
    
    def step(self,metric):
        #return true if convergence is detected, otherwise - false
        best = min(self.history)
        self.history.pop(0)
        self.history.append(metric)
        if metric<best*(1-self.threshold):
            self.num_bad_epochs = 0
        else:
            self.num_bad_epochs += 1
        
        if self.num_bad_epochs > self.patience:
            self.num_bad_epochs = 0
            return self._reduce_lr()
        return False
    
    def _reduce_lr(self):
        for i, param_group in enumerate(self.optimizer.param_groups):
            old_lr = float(param_group['lr'])
            new_lr = max(old_lr * self.factor, self.min_lrs[i])
            if old_lr - new_lr > self.eps:
                param_group['lr'] = new_lr
                if self.verbose:
                    print('Reducing learning rate'
                          ' of group {} to {:.4e}.'.format(i, new_lr))
                return False
            else:
                return True

def load_netvlad(checkpoint_path):
    encoder_dim = 512
    encoder = models.vgg16(pretrained=False)
    layers = list(encoder.features.children())[:-2]
    encoder = nn.Sequential(*layers)    
    model = nn.Module()
    model.add_module('encoder', encoder)
    vlad_layer = netvlad.NetVLAD(num_clusters=64, dim=encoder_dim, vladv2=False)
    model.add_module('pool',vlad_layer)
    
    checkpoint = torch.load(checkpoint_path,map_location=lambda storage, loc: storage)
    model.load_state_dict(checkpoint['state_dict'])
    return model

def index_nn(img_desc, pcl_desc_list):
    min_dist = (img_desc - pcl_desc_list[0]).norm()
    min_idx = 0
    for idx,pcl_desc in enumerate(pcl_desc_list):
        if (img_desc - pcl_desc).norm() < min_dist:
            min_dist = (img_desc - pcl_desc).norm()
            min_idx = idx
    return min_idx

def eval_descriptors(img_net, pcl_net, data_loader, dataset):
    img_descs = np.empty((len(data_loader.dataset),1024),dtype=np.float32)
    pcl_descs = np.empty((len(data_loader.dataset),1024),dtype=np.float32)
    indices = []
    
    # not nice but ok
    use_triplet_cache = dataset.use_triplet
    dataset.use_triplet = False
    
    with torch.no_grad():
        img_net.eval()
        pcl_net.eval()

        for batch_num, data in enumerate(data_loader):
            batch_indices, img, pcl, _ = data
            batch_size = img.size()[0]
            
            img_desc = img_net(img)
            pcl_desc,_,_ = pcl_net(pcl)
            for j in range(0,batch_size):
                img_descs[batch_num*batch_size + j,:] = img_desc[j,:].cpu().detach().numpy().reshape(1,1024)
                pcl_descs[batch_num*batch_size + j,:] = pcl_desc[j,:].cpu().detach().numpy().reshape(1,1024)
                indices.append(batch_indices[j])
    
        dataset.use_triplet = use_triplet_cache
        
        return indices, img_descs, pcl_descs
    
def triplet_loss(anchor, positive, negative, margin=1.0, reduction='mean'):
    distance_positive = (anchor - positive).pow(2).sum(1)  # .pow(.5)
    distance_negative = (anchor - negative).pow(2).sum(1)  # .pow(.5)
    losses = nn.functional.relu(distance_positive - distance_negative + margin)
    
    if reduction == 'sum':
        return losses.sum()
    else:
        return losses.mean()   
            
def calc_eval_metrics(indices, img_descs, pcl_descs, dataset, n_max=25, d_retr=25.0):
    # idx of img descriptors may not correspong to dataset idx, use the indices_map to 
    # specify the mapping
    N = img_descs.shape[0]
        
    # Build KD Tree of computed pcl descriptors
    leaf_size = int(N / 10)
    kd_tree = KDTree(pcl_descs, leaf_size=leaf_size, metric='euclidean')
    
    # Initialize metrics
    top1_errors = np.empty(N)
    top5_errors = np.empty(N)
    recall_over_n = np.zeros((N,n_max))
    
    
    # Get closest pcl descriptors for every query image
    for i in range(N):
        img_desc = img_descs[i]
        idx_query = indices[i]
        
        indices_retr = kd_tree.query(img_desc.reshape(1, -1), k=n_max , 
                                     sort_results=True, return_distance=False)
        indices_retr = [indices[idx_retr] for idx_retr in indices_retr[0]]
        
        # Ground truth position: center of j-th submap
        pos_query = dataset.get_center_pos(idx_query)
        
        pos_errors = np.empty(n_max)
        #for n, i_retr in enumerate(indices[0]):
        for n, idx_retr in enumerate(indices_retr):
            pos_retr = dataset.get_center_pos(idx_retr)
            pos_errors[n] = np.linalg.norm(pos_query - pos_retr)
        
        
        # Top-1 position error
        top1_errors[i] = pos_errors[0]
        
        # Top-5 avg. position error
        top5_errors[i] = np.sum(pos_errors[:5]) / 5.0
        
        # Recall over n
        for n in range(1, n_max):
            n_tp = np.where(pos_errors[:n] <= d_retr)[0].shape[0]
            recall_over_n[i,n] = float(n_tp) / n
                           
    return top1_errors, top5_errors, recall_over_n
                
            


#overfit to single sample (single image, single pcl); descriptors should be exactly equal
if __name__ == '__main__':
    if not torch.cuda.is_available():
        print('Failed to connect to a GPU. Are you sure you are using the correct runtime type?')
        
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    
    #set up models
    
    torch.cuda.empty_cache()
    
    #input: image, output 32K desc
    img_net = load_netvlad(net_vlad_path)
    #append FC layer to reduce to 1K desc
    img_net = ModifiedNetVLAD(img_net,1024)
    
    #input: pcl. output 1K desc
    pcl_net = pointnet.PointNetfeat(True,True)
    
    img_net.to(device)
    pcl_net.to(device)
    
    
    dataset = OxfordRobotcarDataset(img_dir=img_data_path,
                                    pcl_dir=pcl_data_path,
                                    device=device)
    
    test_size = int(0.1*len(dataset))
    train_size = len(dataset) - 2*test_size 
    train_set, val_set, test_set = torch.utils.data.random_split(dataset,[train_size,test_size,test_size])
    
    small_set,_ = torch.utils.data.random_split(dataset,[32,len(dataset)-32])
    small_set_loader = torch.utils.data.DataLoader(small_set,batch_size=32,shuffle=False)
    
    train_loader = torch.utils.data.DataLoader(train_set, batch_size=64, shuffle=False)
    val_loader = torch.utils.data.DataLoader(val_set, batch_size=32, shuffle=False)
    
    optim = torch.optim.Adam(chain(img_net.parameters(),pcl_net.parameters()),lr=5e-4)
    scheduler = CustomScheduler(optim,verbose=True)
    optim.zero_grad()
    
    tl=TripletLoss(1);
    train_losses_history = []
    val_losses_history = []
    num_epochs = 5
    epoch_use_triplet = 2
    
    #use small set
    train_loader = small_set_loader
    train_set = small_set
    
    
    #train
    for i in range(num_epochs):
        # Check if we use triplet loss
        if i < epoch_use_triplet:
            use_triplet = False
        else:
            use_triplet = True
            indices, img_descs, pcl_descs = eval_descriptors(img_net, pcl_net, 
                                                             train_loader, dataset)
            dataset.update_train_descriptors(indices, img_descs, pcl_descs)
            
        dataset.use_triplet = use_triplet
        
        
        train_loss_sum = 0
        img_net.train()
        pcl_net.train()
        
        for _, img, pos, neg in train_loader:
            # we need a batch size of at least 2 to run the sample trough PointNet
            if(img.size()[0]==1):
                continue
            with torch.no_grad():
                tmp_desc = img_net.vlad.pool(img_net.vlad.encoder(img))
                tmp_desc = tmp_desc.view((tmp_desc.shape[0],32768))

            img_desc = img_net.fc(tmp_desc)
            pos_desc,_,_ = pcl_net(pos)
            
            if use_triplet:
                assert len(neg) > 0
                neg_desc,_,_ = pcl_net(neg)                
                loss = triplet_loss(img_desc,pos_desc,neg_desc,reduction='sum')
            else:
                loss = nn.functional.mse_loss(img_desc,pos_desc,reduction='sum')

            train_loss_sum += loss.detach()
            converged = scheduler.step(loss/img.size()[0])
            loss.backward()
            optim.step()
        
        train_loss = train_loss_sum / len(train_set)
        train_losses_history.append(train_loss)
        
        
        # Validation
        with torch.no_grad():
            val_loss_sum = 0
            img_net.eval()
            pcl_net.eval()
            
            # Calculate descriptors for neg anchor of triplet loss
            if use_triplet:
                indices, img_descs, pcl_descs = eval_descriptors(img_net, pcl_net, 
                                                                 val_loader, dataset)
                dataset.update_train_descriptors(indices, img_descs, pcl_descs)
            
            for _, img, pos, neg in val_loader:
                # we need a batch size of at least 2 to run the sample trough PointNet
                if(img.size()[0]==1):
                    continue
                tmp_desc = img_net.vlad.pool(img_net.vlad.encoder(img))
                tmp_desc = tmp_desc.view((tmp_desc.shape[0],32768))

                img_desc = img_net.fc(tmp_desc)
                pos_desc,_,_ = pcl_net(pos)

                if use_triplet:
                    assert len(neg) > 0
                    neg_desc,_,_ = pcl_net(neg)                
                    loss = triplet_loss(img_desc,pos_desc,neg_desc,reduction='sum')
                else:
                    loss = nn.functional.mse_loss(img_desc,pos_desc,reduction='sum')

                val_loss_sum += loss.detach()
        
            val_loss = val_loss_sum / len(val_set)
            val_losses_history.append(val_loss)

            print ("Epoch {}/{}\n".format(i,num_epochs) +\
                   "training loss:   {:.4}\n".format(train_loss) +\
                   "validation loss: {:.4}\n".format(val_loss))
        
        if converged:
            print("Convergence after {} epochs".format(i))
            break
        
    
    indices, img_descs, pcl_descs = eval_descriptors(img_net, pcl_net, train_loader,
                                                     dataset)
    eval_metrics = calc_eval_metrics(indices, img_descs, pcl_descs, dataset, d_retr=5.0)
    top1_errors, top5_errors, recall_over_n = eval_metrics

    print('Avg pos error: Top1 {} top5 {}'.format(np.average(top1_errors),
                                                  np.average(top5_errors)))

    plt.plot(np.average(recall_over_n*100, axis=0))
    plt.xlabel('N - Number of top database candidates')
    plt.ylabel('Average Recall @N [%]')
    
    

In [0]:
eval_loader = torch.utils.data.DataLoader(train_set,batch_size=64,shuffle=False)
print(len(train_set))
with torch.no_grad():
    img_net.eval()
    pcl_net.eval()
    img_descs = np.empty((len(train_set),1024),dtype=np.float32)
    pcl_descs = np.empty((len(train_set),1024),dtype=np.float32)


    for batch_num,img_pcl in enumerate(eval_loader):
        img, pcl = img_pcl
        batch_size = img.size()[0]
        img_desc = img_net(img)
        pcl_desc,_,_ = pcl_net(pcl)
        for j in range(0,batch_size):
            img_descs[batch_num*batch_size + j,:] = img_desc[j,:].cpu().detach().numpy().reshape(1,1024)
            pcl_descs[batch_num*batch_size + j,:] = pcl_desc[j,:].cpu().detach().numpy().reshape(1,1024)

    eval_metrics = calc_eval_metrics(img_descs, pcl_descs, dataset, d_retr=5.0, indices_map=train_set.indices)
    top1_errors, top5_errors, recall_over_n = eval_metrics

    print('Avg pos error: Top1 {} top5 {}'.format(np.average(top1_errors),
                                                  np.average(top5_errors)))

    plt.plot(np.average(recall_over_n*100, axis=0))
    plt.xlabel('N - Number of top database candidates')
    plt.ylabel('Average Recall @N [%]')

In [0]:
        with torch.no_grad():
            val_loss_sum = 0
            pcl_descs = []
            img_descs = []
            
            img_net.eval()
            pcl_net.eval()

            for img,pos in val_loader:
                img_desc = img_net(img)
                pos_desc,_,_ = pcl_net(pos)
                #neg_desc,_,_ = pcl_net(neg)
                for j in range(0,img_desc.size()[0]):
                    img_descs.append(img_desc[j,:].detach())
                    pcl_descs.append(pos_desc[j,:].detach())
                #loss = tl(img_desc,pos_desc,neg_desc,False)
                loss = nn.functional.mse_loss(img_desc,pos_desc,reduction='sum')
                val_loss_sum += loss.detach()
            
            train_loss_sum /= len(train_set)
            train_losses_history.append(train_loss_sum)
            val_loss_sum /= len(val_set)
            val_losses_history.append(val_loss_sum)
            
            num_correct = 0
            for j,img_desc in enumerate(img_descs):
                idx_nn = index_nn(img_desc,pcl_descs)
                if idx_nn == j:
                    num_correct += 1
            print ("Epoch {}/{}\n".format(i,num_epochs) +\
                   "training loss:   {}\n".format(train_loss_sum) +\
                   "validation loss: {}\n".format(val_loss_sum) +\
                   "recall@1:        {}\n".format(num_correct/len(val_set)))

In [0]:
# check recall for training set, quite time consuming, so don't do it during training
with torch.no_grad():
    img_net.eval()
    pcl_net.eval()
    pcl_descs = []
    img_descs = []

    for img,pos in small_set_loader:
        img_desc = img_net(img)
        pos_desc,_,_ = pcl_net(pos)
        for j in range(0,img_desc.size()[0]):
            img_descs.append(img_desc[j,:].detach())
            pcl_descs.append(pos_desc[j,:].detach())

    num_correct = 0
    for j,img_desc in enumerate(img_descs):
        idx_nn = index_nn(img_desc,pcl_descs)
        if idx_nn == j:
            num_correct += 1
    print ("recall@1 train:  {}\n".format(num_correct/len(val_set)))

In [0]:
import matplotlib.pyplot as plt
from matplotlib.pyplot import figure

fig = figure(figsize=(16,8))
fig.gca().plot(train_losses_history)
fig.gca().plot(val_losses_history)
plt.xlabel('epochs')
plt.ylabel('loss')
fig.savefig('loss.jpg')