## Testing Localozation net

In [None]:
import torch
from localization.localization import MappingVAE, MappingUnet
from helpers import log, Arguments
from localization.localization_dataset import KittiLocalizationDataset
from torch.utils.data import DataLoader

args = Arguments.get_arguments()
dataset = KittiLocalizationDataset(data_path=args.data_path, sequence="00", simplify=True, simplification_rate=10)

rgb_mean = torch.load("normalization_cache/rgb_mean.pth").unsqueeze(0).unsqueeze(-1).unsqueeze(-1).to(args.device)
rgb_sigma = torch.load("normalization_cache/rgb_std.pth").unsqueeze(0).unsqueeze(-1).unsqueeze(-1).to(args.device)

log("Mean: ", rgb_mean.squeeze())
log("Sigma: ", rgb_sigma.squeeze())

#model = MappingVAE()
model = MappingUnet()
#model.load_state_dict(torch.load("localization/MappingVAE_last_0_0.pth"))
#model = model.to(DEVICE)

trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
log("Trainable parameters:", trainable_params)

In [None]:
from helpers import BetaScheduler
from torchvision import transforms
from localization.localization_losses import VAE_loss
from torchvision.transforms import Resize
import torch.nn.functional as F

num_epochs = 10
batch_size = 8

model = model.train().to(args.device)
#resizer = Resize((128, 384))
dataloader = DataLoader(dataset=dataset, batch_size=batch_size, shuffle=True, drop_last=True)

#optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-5)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, num_epochs*len(dataloader), eta_min=1e-5)
beta_schedule = BetaScheduler(len(dataloader))

aug = transforms.Compose([
    transforms.ColorJitter(brightness=0.1, saturation=0.1, hue=1e-4),
])

betas = []
for i in range(num_epochs):
    print("-------------------- Epoch", i+1, "/", num_epochs, " --------------------")
    
    beta_schedule.reset()
    for batch, (im, _, _) in enumerate(dataloader):
        
        optimizer.zero_grad()
        im = im.to(args.device)
        im_normalized = (im-rgb_mean)/rgb_sigma
        #im_normalized = im/255
        aug_normalized = (aug(im)-rgb_mean)/rgb_sigma

        #mu, logvar, latent, im_pred = model(im_normalized, VAE=True)
        mu, logvar, latent, im_pred = model(im_normalized)

        #loss, kl_loss, reconstruction_loss = VAE_loss(im_normalized, im_pred, mu.flatten(1), logvar.flatten(1), beta=beta_schedule.step())
        loss1 = F.mse_loss(im_pred, im_normalized)
        loss2 = torch.abs(torch.norm(torch.exp(0.5*logvar), p=2)-1.0)        
        loss = loss1 + loss2
        #reconstruction_loss = None

        loss.backward()
        optimizer.step()
        scheduler.step()

        #print("Loss: ", loss.item(), "\t LR: ", scheduler.get_last_lr())
        print("Iteration: ", batch, "/", len(dataloader), "\t\t Loss: ", loss.item(), "\t LR: ", scheduler.get_last_lr())
        #print("Iteration: ", batch, "/", len(dataloader), "\t\t Loss: ", loss.item(), "\tKLD loss: ", kl_loss.item(), "\tReconstruction loss: ", reconstruction_loss.item(), "\t LR: ", scheduler.get_last_lr())
    
    model_name = "localization/MappingVAE_last_0_0.pth"
    log("Saving model as ", model_name)
    torch.save(model.state_dict(), model_name)

### Reconstruction check

In [None]:
import matplotlib.pyplot as plt
import matplotlib as mpl

mpl.rcParams['figure.dpi'] = 120

im, _, _ = dataset[100]
plt.imshow(im.permute(1, 2, 0).byte().cpu().detach().numpy())
plt.show()

im_normalized = (im.to(args.device)-rgb_mean)/rgb_sigma
#mu, logvar, latent, im_pred = model(im_normalized, VAE=True)
mu, logvar, latent, im_pred = model(im_normalized)

pred_back = (im_pred*rgb_sigma)+rgb_mean
pred_back = torch.minimum(pred_back, torch.tensor(255.0))
pred_back = torch.maximum(pred_back, torch.tensor(0.0))
pred_back = pred_back.cpu().squeeze().byte().permute(1, 2, 0).numpy()
plt.imshow(pred_back)
plt.show()
plt.imsave("Unet_sigmaloss.png", pred_back)

## Relocalization capability testing code

In [None]:
import matplotlib.pyplot as plt
import matplotlib as mpl
from helpers import log
import numpy as np

latent_space_vectors = []
model.to(args.device).eval()

simplification_rate_1=7
simplification_rate_2=9

dataset = KittiLocalizationDataset(data_path=args.data_path, sequence="00", simplify=True, simplification_rate=simplification_rate_1)
#log("Dataset shape: ", len(dataset))

random_index = 74 #int(torch.randint(low=0, high=len(dataset), size=(1, 1)).squeeze())
log("True index: ", random_index*simplification_rate_1)

with torch.no_grad():
    
    true_im, _, _ = dataset[random_index]
    im_normalized = (true_im.to(args.device)-rgb_mean)/rgb_sigma
    #true_mu, logvar, true_latent, prediction = model(im_normalized, VAE=True)
    true_mu, logvar, true_latent, prediction = model(im_normalized)
    #test_mu = test_latent
    #log("Test mu shape: ", test_mu.shape)
    
    dataset = KittiLocalizationDataset(data_path=args.data_path, sequence="00", simplify=True, simplification_rate=simplification_rate_2)
    #log("Dataset shape: ", len(dataset))

    for i in range(len(dataset)):
        test_im, _, _ = dataset[i]
        im_normalized = (test_im.to(args.device)-rgb_mean)/rgb_sigma
        #mu, logvar, latent, im_pred = model(im_normalized, VAE=True)
        mu, logvar, latent, im_pred = model(im_normalized)
        latent_space_vectors.append(mu)
        #latent_space_vectors.append(mu)

    distances = []
    for i in range(len(latent_space_vectors)):
        dist = (latent_space_vectors[i]-true_mu)
        dist = torch.norm(dist, p=2).detach().to('cpu')
        #print(dist)
        distances.append(dist)

distances = torch.stack(distances, dim=0)
#log("Distances shape", distances.shape)

bins = 1000
[hist, bin_edges] = np.histogram(distances, bins=bins)

# ---------
# Histogram
# ---------

plt.bar(bin_edges[:-1], hist)
plt.xlabel("Distance from sample")
plt.ylabel("Count of elements")
plt.show()

# Predicted index
distances_mean = distances.mean()
pred_index = torch.argmin(distances)

log("Pred index", pred_index*simplification_rate_2)

plt.plot(distances.numpy())
plt.xlabel("Index of latent space vector")
plt.ylabel("Distance from sample")
plt.show()

mean_distance1 = (distances[pred_index]-distances_mean).abs()
log("Mean: ", distances_mean)
log("Pred-mean:",  mean_distance1)

def prepare_im(im):
    return im.detach().byte().squeeze().permute(1, 2, 0).numpy()


pred_im, _, _ = dataset[int(pred_index.squeeze())]

plt.imshow(prepare_im(true_im))
plt.show()
plt.imshow(prepare_im(pred_im))
plt.show()