<a href="https://colab.research.google.com/github/FireMight/point-cloud-retrieval-from-image/blob/master/main.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [0]:
%cd /content
!git clone --recurse-submodules https://github.com/FireMight/point-cloud-retrieval-from-image.git
%cd /content/point-cloud-retrieval-from-image/

In [0]:
%cd /content/point-cloud-retrieval-from-image/
!git pull

In [0]:
# load data
from google.colab import drive

drive.mount('/content/drive')
%cd /content/point-cloud-retrieval-from-image/
!mkdir -p data/oxford/data/reference
!tar -C data/oxford/data/reference -xf '/content/drive/My Drive/ADL4CV/downloads/oxford_dataset/reference/submaps_20m/submaps_20m_processed.tar.xz'
!cp -a '/content/drive/My Drive/ADL4CV/downloads/oxford_dataset/reference/submaps_20m/metadata.csv' data/oxford/data/reference/submaps_20m_processed/
!mkdir -p data/oxford/data/reference/stereo/centre
!tar -C data/oxford/data/reference/stereo/centre -xf '/content/drive/My Drive/ADL4CV/downloads/oxford_dataset/reference/img_center_5/img_20_0-1921.tar.gz'

In [0]:
# load pretrained model
!unzip '/content/drive/My Drive/ADL4CV/models/vgg16_netvlad_checkpoint.zip' -d 'models/'

In [0]:
%cd /content/point-cloud-retrieval-from-image/
from itertools import chain
import numpy as np

import torch
import torch.nn as nn
import torchvision.models as models
import torchvision

import netvlad.netvlad as netvlad
import pointnet.pointnet.model as pointnet
from data.OxfordRobotcarDataset import OxfordRobotcarDataset

net_vlad_path = 'models/vgg16_netvlad_checkpoint/checkpoints/checkpoint.pth.tar'
img_data_path = 'data/oxford/data/reference/stereo/centre/'
pcl_data_path = 'data/oxford/data/reference/submaps_20m_processed/'

In [0]:
class TripletLoss(nn.Module):
    """
    Triplet loss
    Takes embeddings of an anchor sample, a positive sample and a negative sample
    """

    def __init__(self, margin = 1.0):
        super(TripletLoss, self).__init__()
        self.margin = margin

    def forward(self, anchor, positive, negative, size_average=True):
        distance_positive = (anchor - positive).pow(2).sum(1)  # .pow(.5)
        distance_negative = (anchor - negative).pow(2).sum(1)  # .pow(.5)
        losses = nn.functional.relu(distance_positive - distance_negative + self.margin)
        return losses.mean() if size_average else losses.sum()

#appends a FC linear to transform output descriptor to appropriate dimenstion
#TODO: make a nice wrapper for NetVLAD
class ModifiedNetVLAD(nn.Module):
    def __init__(self, model,out_features):
        super(ModifiedNetVLAD, self).__init__()
        self.vlad = model
        self.fc = nn.Linear(32768, out_features)

        
    def forward(self, x):
        x = self.vlad.pool(self.vlad.encoder(x))
        x = x.view((x.shape[0],32768))
        x = self.fc(x)
        return x

def load_netvlad(checkpoint_path):
    encoder_dim = 512
    encoder = models.vgg16(pretrained=False)
    layers = list(encoder.features.children())[:-2]
    encoder = nn.Sequential(*layers)    
    model = nn.Module()
    model.add_module('encoder', encoder)
    vlad_layer = netvlad.NetVLAD(num_clusters=64, dim=encoder_dim, vladv2=False)
    model.add_module('pool',vlad_layer)
    
    checkpoint = torch.load(checkpoint_path,map_location=lambda storage, loc: storage)
    model.load_state_dict(checkpoint['state_dict'])
    return model

def index_nn(img_desc, pcl_desc_list):
    min_dist = (img_desc - pcl_desc_list[0]).norm()
    min_idx = 0
    for idx,pcl_desc in enumerate(pcl_desc_list):
        if (img_desc - pcl_desc).norm() < min_dist:
            min_dist = (img_desc - pcl_desc).norm()
            min_idx = idx


#overfit to single sample (single image, single pcl); descriptors should be exactly equal
if __name__ == '__main__':
    if not torch.cuda.is_available():
        print('Failed to connect to a GPU. Are you sure you are using the correct runtime type?')
        
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    
    #set up models
    
    torch.cuda.empty_cache()
    
    #input: image, output 32K desc
    img_net = load_netvlad(net_vlad_path)
    #append FC layer to reduce to 1K desc
    img_net = ModifiedNetVLAD(img_net,1024)
    
    #input: pcl. output 1K desc
    pcl_net = pointnet.PointNetfeat(True,True)
    
    img_net.to(device)
    pcl_net.to(device)
    
    
    dataset = OxfordRobotcarDataset(img_dir=img_data_path,\
                                    img_net=img_net,\
                                    pcl_dir=pcl_data_path,\
                                    pcl_net=pcl_net,\
                                    tuple_type='simple',\
                                    device=device)
    
    test_size = int(0.1*len(dataset))
    train_size = len(dataset) - 2*test_size 
    train_set, val_set, test_set = torch.utils.data.random_split(dataset,[train_size,test_size,test_size])
    
    small_set,_ = torch.utils.data.random_split(dataset,[64,len(dataset)-64])
    small_set_loader = torch.utils.data.DataLoader(small_set,batch_size=64)
    
    train_loader = torch.utils.data.DataLoader(train_set, batch_size=64, shuffle=True)
    val_loader = torch.utils.data.DataLoader(val_set, batch_size=64, shuffle=False)
    
    optim = torch.optim.Adam(chain(img_net.parameters(),pcl_net.parameters()),lr=5e-4)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optim,verbose=True)
    optim.zero_grad()
    
    tl=TripletLoss(5);
    train_losses_history = []
    val_losses_history = []
    num_epochs = 500
    #train
    for i in range(num_epochs):
        train_loss_sum = 0
        img_net.train()
        pcl_net.train()
        for img, pos in small_set_loader:
            # we need a batch size of at least 2 to run the sample trough PointNet
            if(img.size()[0]==1):
                continue
            with torch.no_grad():
                tmp_desc = img_net.vlad.pool(img_net.vlad.encoder(img))
                tmp_desc = tmp_desc.view((tmp_desc.shape[0],32768))
                
            img_desc = img_net.fc(tmp_desc)
            pos_desc,_,_ = pcl_net(pos)
            #neg_desc,_,_ = pcl_net(neg)
            loss = nn.functional.mse_loss(img_desc,pos_desc,reduction='sum')
            train_loss_sum += loss.detach()
            #print("Loss {}".format(loss))
            loss.backward()
            optim.step()
            
        train_loss_sum/=len(small_set)
        train_losses_history.append(train_loss_sum)
        print("Loss: {}".format(train_loss_sum))
        scheduler.step(train_loss_sum)
        
            
with torch.no_grad():
    img_net.eval()
    pcl_net.eval()
    pcl_descs = []
    img_descs = []

    for img,pos in small_set_loader:
        img_desc = img_net(img)
        pos_desc,_,_ = pcl_net(pos)
        for j in range(0,img_desc.size()[0]):
            img_descs.append(img_desc[j,:].detach())
            pcl_descs.append(pos_desc[j,:].detach())

    num_correct = 0
    for j,img_desc in enumerate(img_descs):
        idx_nn = index_nn(img_desc,pcl_descs)
        if idx_nn == j:
            num_correct += 1
    print ("recall@1 train:  {}\n".format(num_correct/len(val_set)))

In [0]:
        with torch.no_grad():
            val_loss_sum = 0
            pcl_descs = []
            img_descs = []
            
            img_net.eval()
            pcl_net.eval()

            for img,pos in val_loader:
                img_desc = img_net(img)
                pos_desc,_,_ = pcl_net(pos)
                #neg_desc,_,_ = pcl_net(neg)
                for j in range(0,img_desc.size()[0]):
                    img_descs.append(img_desc[j,:].detach())
                    pcl_descs.append(pos_desc[j,:].detach())
                #loss = tl(img_desc,pos_desc,neg_desc,False)
                loss = nn.functional.mse_loss(img_desc,pos_desc,reduction='sum')
                val_loss_sum += loss.detach()
            
            train_loss_sum /= len(train_set)
            train_losses_history.append(train_loss_sum)
            val_loss_sum /= len(val_set)
            val_losses_history.append(val_loss_sum)
            
            num_correct = 0
            for j,img_desc in enumerate(img_descs):
                idx_nn = index_nn(img_desc,pcl_descs)
                if idx_nn == j:
                    num_correct += 1
            print ("Epoch {}/{}\n".format(i,num_epochs) +\
                   "training loss:   {}\n".format(train_loss_sum) +\
                   "validation loss: {}\n".format(val_loss_sum) +\
                   "recall@1:        {}\n".format(num_correct/len(val_set)))

In [0]:
# check recall for training set, quite time consuming, so don't do it during training
with torch.no_grad():
    img_net.eval()
    pcl_net.eval()
    pcl_descs = []
    img_descs = []

    for img,pos in small_set_loader:
        img_desc = img_net(img)
        pos_desc,_,_ = pcl_net(pos)
        for j in range(0,img_desc.size()[0]):
            img_descs.append(img_desc[j,:].detach())
            pcl_descs.append(pos_desc[j,:].detach())

    num_correct = 0
    for j,img_desc in enumerate(img_descs):
        idx_nn = index_nn(img_desc,pcl_descs)
        if idx_nn == j:
            num_correct += 1
    print ("recall@1 train:  {}\n".format(num_correct/len(val_set)))

In [0]:
import matplotlib.pyplot as plt
from matplotlib.pyplot import figure

fig = figure(figsize=(16,8))
fig.gca().plot(train_losses_history)
plt.xlabel('epochs')
plt.ylabel('loss')
fig.savefig('loss.jpg')