In [None]:
import torch
from AdaBins.models.unet_adaptive_bins import UnetAdaptiveBins
from AdaBins import model_io
import matplotlib as mpl
mpl.rcParams['figure.dpi'] = 150

MIN_DEPTH = 1e-3
MAX_DEPTH_KITTI = 80

N_BINS = 256

model = UnetAdaptiveBins.build(n_bins=N_BINS, min_val=MIN_DEPTH, max_val=MAX_DEPTH_KITTI)


pretrained_path = "AdaBins/AdaBins_kitti.pt"

model, _, _ = model_io.load_checkpoint(pretrained_path, model)

In [None]:
import os
import matplotlib.pyplot as plt
import numpy as np

base_path = None # TODO Change to config file data read
im_files = os.listdir(base_path+"/image_2")
i = 20

test_im = plt.imread(os.path.join(base_path,"image_2", im_files[i]))
plt.imshow(test_im)
plt.show()



test_im = plt.imread(os.path.join(base_path,"semantic", im_files[i]))

test_im = (test_im*200).astype(int)

hist, x = np.histogram(test_im)


print("Test im max: ", test_im.max())
print("Hist: ", hist.shape)
print("X shape: ", x.shape)
plt.bar(x[:-1], hist)
plt.show()


test_im = np.where(test_im==18, 1, 0)

print(test_im.sum())
plt.imshow(test_im)
plt.show()

In [None]:
test_im = plt.imread(os.path.join(base_path,"image_2", im_files[i]))
mean = torch.load("normalization_cache/rgb_mean.pth").unsqueeze(-1).unsqueeze(-1)
sigma = torch.load("normalization_cache/rgb_std.pth").unsqueeze(-1).unsqueeze(-1)

rgb_tensor = torch.tensor(test_im).permute(2, 0, 1)
#rgb_tensor = (rgb_tensor-mean)/sigma
rgb_tensor = rgb_tensor.to("cuda").unsqueeze(0)
model = model.to("cuda").eval()

with torch.no_grad():
    bin_edges, predicted_depth = model(rgb_tensor)

In [None]:
from torchvision.transforms import Resize
resize = Resize((375, 1242))
predicted_depth = resize(predicted_depth)
print(predicted_depth.shape)
plt.imshow(predicted_depth.squeeze().cpu().detach().numpy())
plt.show()

In [None]:
depth = resize(predicted_depth).squeeze().cpu().numpy()
print(depth.shape)
test_im = (200*plt.imread(os.path.join(base_path,"semantic", im_files[i]))).astype(int)

fused_depth = np.where(test_im==18, MAX_DEPTH_KITTI, depth)
plt.imshow(fused_depth)
plt.show()

In [None]:
fused_tensor = torch.tensor(fused_depth).unsqueeze(0).unsqueeze(0)
print(fused_tensor.shape)

In [None]:
from torch import nn

class SILogLoss(nn.Module):  # Main loss function used in AdaBins paper
    def __init__(self):
        super(SILogLoss, self).__init__()
        self.name = 'SILog'

    def forward(self, input, target, mask=None, interpolate=True):
        if interpolate:
            input = nn.functional.interpolate(input, target.shape[-2:], mode='bilinear', align_corners=True)

        if mask is not None:
            input = input[mask]
            target = target[mask]
        g = torch.log(input) - torch.log(target)
        # n, c, h, w = g.shape
        # norm = 1/(h*w)
        # Dg = norm * torch.sum(g**2) - (0.85/(norm**2)) * (torch.sum(g))**2

        Dg = torch.var(g) + 0.15 * torch.pow(torch.mean(g), 2)
        return 10 * torch.sqrt(Dg)


optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5)
model = model.train().to("cuda")

criterion_ueff = SILogLoss()

img = torch.tensor(plt.imread(os.path.join(base_path,"image_2", im_files[i]))).permute(2, 0, 1).unsqueeze(0).to("cuda")
print(img.shape)
fused_tensor = fused_tensor.to("cuda")

for i in range(100):
    optimizer.zero_grad()

    bin_edges, pred = model(img)

    mask = fused_tensor > MIN_DEPTH
    l_dense = criterion_ueff(pred, fused_tensor, mask=mask.to(torch.bool), interpolate=True)

    l_dense.backward()
    optimizer.step()

    print(l_dense.item())

In [None]:
test_im = plt.imread(os.path.join(base_path,"image_2", im_files[i]))
rgb_tensor = torch.tensor(test_im).permute(2, 0, 1)

rgb_tensor = rgb_tensor.to("cuda").unsqueeze(0)
model = model.to("cuda").eval()

with torch.no_grad():
    bin_edges, predicted_depth = model(rgb_tensor)


resize = Resize((375, 1242))

print(resize(predicted_depth).shape)
plt.imshow(resize(predicted_depth).squeeze().cpu().detach().numpy())

In [None]:
from localization.localization_dataset import KittiUnetDataset
import matplotlib.pyplot as plt
import matplotlib as mpl
mpl.rcParams['figure.dpi'] = 150

DATA_PATH = None # TODO Change to config file data read
dataset = KittiUnetDataset(data_path=DATA_PATH, sequence="00")
im_true, im_rand, _, _ = dataset[10]

def convert(im):
    return im.squeeze().permute(1, 2, 0).byte().cpu().numpy()

def imshow_torch(im):
    im = convert(im)
    plt.imshow(im)
    plt.show()


imshow_torch(im_true)
imshow_torch(im_rand)

In [None]:
import torch
from localization.localization import MappingUnet
from helpers import log
from localization.localization_dataset import KittiLocalizationDataset, KittiUnetDataset
from torch.utils.data import DataLoader
from torchvision import transforms
from localization.localization_losses import VAE_loss
from helpers import BetaScheduler
from torch.nn import functional as F


DEVICE = "cuda"
DATA_PATH = None # TODO Change to config file data read
dataset = KittiLocalizationDataset(data_path=DATA_PATH, sequence="00", simplify=True, simplification_rate=5)
#dataset = KittiUnetDataset(data_path=DATA_PATH, sequence="00")

rgb_mean = torch.load("normalization_cache/rgb_mean.pth").unsqueeze(0).unsqueeze(-1).unsqueeze(-1).to(DEVICE)
rgb_sigma = torch.load("normalization_cache/rgb_std.pth").unsqueeze(0).unsqueeze(-1).unsqueeze(-1).to(DEVICE)

log("Mean shape: ", rgb_mean)
log("Sigma shape? ", rgb_sigma)

model = MappingUnet()
#print(model)

trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
log("Trainable parameters:", trainable_params)

num_epochs = 10
batch_size = 8

model = model.train().to(DEVICE)

dataloader = DataLoader(dataset=dataset, batch_size=batch_size, shuffle=True, drop_last=True)
beta_schedule = BetaScheduler(len(dataloader))

optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, num_epochs*len(dataloader), eta_min=1e-6)


aug = transforms.Compose([
    transforms.ColorJitter(brightness=0.1, saturation=0.1, hue=1e-6),
])

betas = []
for i in range(num_epochs):
    print("-------------------- Epoch", i+1, "/", num_epochs, " --------------------")
    beta_schedule.reset()
    for batch, (im_true, _, _) in enumerate(dataloader):
        
        optimizer.zero_grad()

        im_true = im_true.to(DEVICE)
        im_rand = aug(im_true.to(DEVICE).byte()).float()
        im_rand = (im_rand-rgb_mean)/rgb_sigma
        im_true = (im_true-rgb_mean)/rgb_sigma


        #im_rand = im_rand.to(DEVICE)
        #im_rand = (im_rand-rgb_mean)/rgb_sigma

        im_pred, latent, latent_mu, latent_logvar = model(im_true)

        #loss = F.mse_loss(im_pred, im_rand)
        loss, kl_loss, recon_loss = VAE_loss(im_true, im_pred, latent_mu, latent_logvar, beta=beta_schedule.step())

        loss.backward()
        optimizer.step()
        scheduler.step()

        #print("Iteration: ", batch, "/", len(dataloader), "\t\t Loss: ", loss.item(), "\t LR: ", scheduler.get_last_lr())
        print("Iteration: ", batch, "/", len(dataloader), "\t\t Loss: ", loss.item(), "\tKLD loss: ", kl_loss.item(), "\tReconstruction loss: ", recon_loss.item(), "\t LR: ", scheduler.get_last_lr())
    
    model_name = "localization/MappingUnet_last_0_0.pth"
    log("Saving model as ", model_name)
    torch.save(model.state_dict(), model_name)

torch.save(torch.tensor(betas), "betas.pth")

In [None]:
from localization.localization import MappingUnet
import torch

model = MappingUnet()
rgb_mean = torch.load("normalization_cache/rgb_mean.pth").unsqueeze(0).unsqueeze(-1).unsqueeze(-1).to(DEVICE)
rgb_sigma = torch.load("normalization_cache/rgb_std.pth").unsqueeze(0).unsqueeze(-1).unsqueeze(-1).to(DEVICE)

model.load_state_dict(torch.load("localization/MappingUnet_last_0_0.pth"))

In [None]:
IndexRegressor = nn.Sequential(
    nn.Linear(in_features=1024, out_features=256),
    nn.PReLU(),
    nn.Dropout(0.2),
    nn.Linear(in_features=256, out_features=64),
    nn.PReLU(),
    nn.Dropout(0.2),
    nn.Linear(in_features=64, out_features=1)
)

In [None]:
from localization.localization_dataset import KittiLocalizationDataset

import matplotlib.pyplot as plt
import matplotlib as mpl
from helpers import log
import numpy as np

DEVICE = "cuda"

mpl.rcParams['figure.dpi'] = 150

latent_space_vectors = []
model.to(DEVICE).eval()

DATA_PATH = None # TODO Change to config file data read
dataset = KittiLocalizationDataset(data_path=DATA_PATH, sequence="00", simplify=True, simplification_rate=8)
log("Dataset shape: ", len(dataset))

random_index = int(torch.randint(low=0, high=len(dataset), size=(1, 1)).squeeze())
log("Index: ", random_index)

with torch.no_grad():
    
    true_im, _, _ = dataset[random_index]
    im_normalized = (true_im.to("cuda")-rgb_mean)/rgb_sigma
    prediction, test_latent, test_mu, _ = model(im_normalized)
    
    DATA_PATH = None # TODO Change to config file data read
    dataset = KittiLocalizationDataset(data_path=DATA_PATH, sequence="00", simplify=True, simplification_rate=7)
    log("Dataset shape: ", len(dataset))

    for i in range(len(dataset)):
        test_im, _, _ = dataset[i]
        im_normalized = (test_im.to("cuda")-rgb_mean)/rgb_sigma
        im_pred, latent, mu, _ = model(im_normalized)
        #latent_space_vectors.append(latent)
        latent_space_vectors.append(mu)

        

    distances = []
    for i in range(len(latent_space_vectors)):
        dist = (latent_space_vectors[i]-test_mu)
        dist = torch.norm(dist, p=2).detach().to('cpu')
        distances.append(dist)

distances = torch.stack(distances, dim=0)
log("Distances shape", distances.shape)

plt.plot(distances.numpy())
plt.xlabel("Index of latent space vector")
plt.ylabel("Distance from sample")
plt.show()

max_dist = torch.max(distances)
bins = 1000
[hist, bins] = np.histogram(distances.numpy(), bins=bins)

# ---------
# Histogram
# ---------
plt.bar(bins[:-1], hist)
plt.xlabel("Distance from sample")
plt.ylabel("Count of elements")
plt.show()

# Predicted index
pred_index = torch.argmin(distances)
log("Pred index", pred_index)

def prepare_im(im):
    return im.detach().byte().squeeze().permute(1, 2, 0).numpy()

pred_im, _, _ = dataset[int(pred_index.squeeze())]

plt.imshow(prepare_im(true_im))
plt.show()
plt.imshow(prepare_im(pred_im))
plt.show()