# LocalINN Implementation: Training and Testing Pipeline

## 1.1. Install and Import Dependencies

We will be requiring a bunch of different python libraries for video processing and model creation purposes. Its better to check beforehand if all the libraries are properly installed and imported to avoid unnecessary errors.

In [None]:
'''
  Check if the libraries are installed properly

'''

!pip install torch   # Allow us to build our neural network. Chosen PyTorch for this task.
!pip install gdown
!pip install numpy





In [None]:
'''
  Install ROS2 support packages. (Not Required: Colab not supported)

'''

# !sudo apt install ros-humble-rosbag2* ros-humble-tf2* python3-pandas python3-numpy

'\n  Install ROS2 support packages. (Not Required: Colab not supported)\n\n'

In [None]:
'''
  Import the required dependencies.

'''

import os
import torch
import numpy as np
from typing import List, Tuple
import gdown

In [None]:
'''
  Check if the libraries are properly imported.

'''


print("NumPy version: ", np.__version__)

NumPy version:  2.0.2


In [None]:
'''
  Check if GPU is available in the system. We will shift the model to GPU for faster training.

'''

gpu_aval = torch.cuda.is_available()
print(f"GPU available: {gpu_aval}")

if gpu_aval:
    print(f"GPU Name: {torch.cuda.get_device_name(0)}")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

GPU available: False
Using device: cpu


# 1. Build Data Loading Functions

In [None]:
'''
  Download the data for the lidar sensors and pose values to feed in the local INN.

'''


lidar_url = 'https://drive.google.com/uc?export=download&id=1BXJkxE3FpkupMBrWtIYoKkI4x2-d4i9z'
lidar_output_file = 'lidar.npy'

gdown.download(lidar_url, lidar_output_file, quiet=False)

pose_url = 'https://drive.google.com/uc?export=download&id=1ipTKsUAQ-YY7s-TKOZ6KYfjw28Mv_oyv'
pose_output_file = 'pose.npy'

gdown.download(pose_url, pose_output_file, quiet=False)


Downloading...
From: https://drive.google.com/uc?export=download&id=1BXJkxE3FpkupMBrWtIYoKkI4x2-d4i9z
To: /content/lidar.npy
100%|██████████| 15.5M/15.5M [00:00<00:00, 41.8MB/s]
Downloading...
From: https://drive.google.com/uc?export=download&id=1ipTKsUAQ-YY7s-TKOZ6KYfjw28Mv_oyv
To: /content/pose.npy
100%|██████████| 43.1k/43.1k [00:00<00:00, 68.8MB/s]


'pose.npy'

In [None]:
lidar_data = np.load('lidar.npy', allow_pickle=True)
poses_data = np.load('pose.npy', allow_pickle=True)

print("LIDAR shape:", lidar_data.shape)
print("POSES shape:", poses_data.shape)

LIDAR shape: (3579, 1081)
POSES shape: (3579, 3)


In [None]:
torch.set_printoptions(edgeitems=50)
torch.set_printoptions(threshold=10000)
torch.set_printoptions(precision=8)

In [None]:
lidar_data[0]   # Lidar Input data

array([0.9538596, 0.9498828, 0.9459569, ..., 0.8936209, 0.8968134,
       0.9000462], dtype=float32)

In [None]:
poses_data[0]   # Pose data x, y, z

array([0.7522, 7.1221, 0.0592], dtype=float32)

In [None]:
lidar_data.shape

(3579, 1081)

In [None]:
import random
from torch.utils.data import Dataset, DataLoader
import torch

'''
  PyTorch Dataset class to fetch items and other operation requirements.

'''

class LocalINNDataset(Dataset):
    def __init__(self, lidar, poses):
        self.f_in = lidar
        self.f_out = poses
        self.r_in = poses
        self.r_out = lidar

    def __len__(self):
        return self.f_in.shape[0]

    def __getitem__(self, idx):
        lidar = self.f_in[idx]
        pose = self.f_out[idx]
        prev_pose = self.f_out[max(0, idx-1)]

        return lidar, pose, prev_pose

In [None]:
full_dataset = LocalINNDataset(lidar_data, poses_data)

train_size = int(0.8 * len(full_dataset))
test_size = len(full_dataset) - train_size

In [None]:
train_size

2863

In [None]:
from torch.utils.data import random_split

'''
  Creating train and test dataloader to facililate the training and testing process

'''

full_dataset = LocalINNDataset(lidar_data, poses_data)

train_size = int(0.8 * len(full_dataset))
test_size = len(full_dataset) - train_size
train_dataset, test_dataset = random_split(full_dataset, [train_size, test_size])

train_loader = DataLoader(
    dataset=train_dataset,
    batch_size=2,
    shuffle=True,
    num_workers=2
)

test_loader = DataLoader(
    dataset=test_dataset,
    batch_size=4,
    shuffle=False,
    num_workers=2
)

print(f"Total samples: {len(full_dataset)}")
print(f"Training samples: {len(train_dataset)}")
print(f"Testing samples: {len(test_dataset)}")

Total samples: 3579
Training samples: 2863
Testing samples: 716


In [None]:
len(test_loader)

179

In [None]:
'''
  Creating an iterator from train_loader dataloader

'''

data_iterator = iter(train_loader)

i_lidar, c_pose, p_pose = next(data_iterator)

print(f"Lidar tensor shape: {i_lidar.shape}")
print(f"Lidar tensor dtype: {i_lidar.dtype}")
print(f"Current Poses tensor shape: {c_pose.shape}")
print(f"Current Poses tensor dtype: {c_pose.dtype}")
print(f"Previous Poses tensor shape: {p_pose.shape}")
print(f"Previous Poses tensor dtype: {p_pose.dtype}")

Lidar tensor shape: torch.Size([2, 1081])
Lidar tensor dtype: torch.float32
Current Poses tensor shape: torch.Size([2, 3])
Current Poses tensor dtype: torch.float32
Previous Poses tensor shape: torch.Size([2, 3])
Previous Poses tensor dtype: torch.float32


In [None]:
sample_iterator = iter(train_loader)

In [None]:
l, cp, pp = next(sample_iterator)

In [None]:
l

tensor([[1.09908402, 1.09175801, 1.08455002, 1.07745600, 1.07047606, 1.06360400,
         1.05684102, 1.05018306, 1.04362798, 1.03717399, 1.03081799, 1.02455997,
         1.01839602, 1.01232600, 1.00634599, 1.00045598, 0.99465239, 0.98893511,
         0.98330128, 0.97775048, 0.97227979, 0.96688819, 0.96157491, 0.95633709,
         0.95117468, 0.94608498, 0.94106787, 0.93612111, 0.93124408, 0.92643458,
         0.92169207, 0.91701579, 0.91240329, 0.90785491, 0.90336812, 0.89894289,
         0.89457750, 0.89027143, 0.88602293, 0.88183171, 0.87769681, 0.87361681,
         0.86959130, 0.86561888, 0.86169922, 0.85783082, 0.85401362, 0.85024571,
         0.84652722, 0.84285730, 0.83923471, 0.83565933, 0.83212972, 0.82864583,
         0.82520610, 0.82181102, 0.81845880, 0.81514949, 0.81188220, 0.80865622,
         0.80547118, 0.80232620, 0.79922122, 0.79615468, 0.79312718, 0.79013729,
         0.78718501, 0.78426939, 0.78139031, 0.77854681, 0.77573848, 0.77296549,
         0.77022672, 0.76752

In [None]:
cp

tensor([[ 0.09060000, -7.44469976,  0.05800000],
        [-1.09850001,  4.03490019,  0.05730000]])

In [None]:
pp

tensor([[ 0.10220000, -7.42059994,  0.05890000],
        [-1.10450006,  3.94330001,  0.05790000]])

## 1.3. Neural Network Design

In [None]:
'''
  Importing required models from PyTorch for the neural network creation.

'''

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import optim
import math

In [None]:
"""
    Implements Eq (4) from the PDF that contains positional encoders.
    Expanding input dimension D to D * 2L, where L is a tunable parameter

"""

class PositionalEncoding(nn.Module):
    def __init__(self, L, device=None, dtype=torch.float32):
        super().__init__()
        self.L = L
        vals = torch.tensor([2.0 ** l for l in range(L)], dtype=dtype, device=device)
        self.register_buffer("val_list", vals)
        self.register_buffer("pi", torch.tensor(torch.pi, dtype=dtype, device=device))

    def encode(self, x):
        vals = self.val_list.view(*(1 for _ in range(x.dim())), self.L)
        arg = x.unsqueeze(-1) * vals * self.pi
        return torch.sin(arg), torch.cos(arg)

    def encode_even(self, x):
        vals = self.val_list.view(*(1 for _ in range(x.dim())), self.L)
        arg = x.unsqueeze(-1) * vals * self.pi * 2
        return torch.sin(arg), torch.cos(arg)

    # def decode(self, sin_value, cos_value):
    #     v = torch.atan2(sin_value, cos_value) / self.pi
    #     m = v < 0
    #     return v + m.float()

    # def decode_even(self, sin_value, cos_value):
    #     v = torch.atan2(sin_value, cos_value) / (self.pi / 2)
    #     m = v < 0
    #     v = v + m.float()
    #     z = torch.abs(v - 1) < 1e-3
    #     return v * (~z)

    def decode_from_encoded_list(self, encoded_list, even_dims=(2,), max_shift=12):
        device = encoded_list[0].device
        dtype = encoded_list[0].dtype
        L = encoded_list[0].shape[-1]
        sin = torch.stack([encoded_list[0], encoded_list[2], encoded_list[4]], dim=0)  # (3, L)
        cos = torch.stack([encoded_list[1], encoded_list[3], encoded_list[5]], dim=0)  # (3, L)
        vals = (2.0 ** torch.arange(L, device=device, dtype=dtype)).view(1, L)  # (1,L)
        base_freqs = vals * math.pi  # (1,L)
        mult = torch.tensor([2.0 if i in even_dims else 1.0 for i in range(3)], device=device, dtype=dtype).view(3, 1)  # (3,1)
        freqs = base_freqs * mult  # (3, L)
        angles = torch.atan2(sin, cos)  # (3, L)
        x_pf = angles / freqs  # (3, L)
        estimate = x_pf[:, -1].clone()  # (3,)

        for i in range(L - 2, -1, -1):
            xi = x_pf[:, i]  # (3,)
            period_i = (2.0 ** (1 - i)) / mult.squeeze(1)  # (3,)
            k = torch.round((estimate - xi) / period_i)
            estimate = xi + k * period_i

        final = torch.zeros_like(estimate)
        for d in range(3):
            base_unit = 2.0 / (2.0 if d in even_dims else 1.0)
            shifts = (torch.arange(-max_shift, max_shift + 1, device=device, dtype=dtype) * base_unit)  # (S,)
            est_exp = (estimate[d].unsqueeze(0) + shifts).unsqueeze(1)  # (S,1)
            freqs_d = freqs[d].unsqueeze(0)  # (1,L)
            cand = est_exp * freqs_d  # (S,L)
            cand_sin = torch.sin(cand)
            cand_cos = torch.cos(cand)
            orig_sin = sin[d].unsqueeze(0)  # (1,L)
            orig_cos = cos[d].unsqueeze(0)
            err = ((cand_sin - orig_sin) ** 2 + (cand_cos - orig_cos) ** 2).sum(dim=1)  # (S,)
            best = torch.argmin(err)
            final[d] = estimate[d] + shifts[best]
        return final

In [None]:
iterator = iter(test_loader)

In [None]:
## encode the prev_state
p_encoding = PositionalEncoding(L=10)

for i in range(8):
    lidar, pose, prev_pose = next(iterator)
    print(pose[0])
    pose = pose[0] * 0.01

    encoded = []
    for k in range(3):
        if k == 2:
            sine_part, cosine_part = p_encoding.encode_even(pose[k])
        else:
            sine_part, cosine_part = p_encoding.encode(pose[k])
        encoded.append(sine_part)
        encoded.append(cosine_part)

    # print(f"Pose {i}: ",pose)

    decoded = p_encoding.decode_from_encoded_list(encoded, even_dims=(2,), max_shift=20)
    print(f"Decoded {i}: ",(decoded)*100)
    signed_remainder = (decoded - torch.trunc(decoded)) * 100
    print(f"Signed Remainder {i}: ", signed_remainder)
    print()

tensor([-3.82909989, -7.54769993,  0.06030000])
Decoded 0:  tensor([-3.82909989, -7.54769945,  0.06030000])
Signed Remainder 0:  tensor([-3.82909989, -7.54769945,  0.06030000])

tensor([-0.41990000,  9.24339962,  0.05920000])
Decoded 1:  tensor([-0.41990000,  9.24340057,  0.05919999])
Signed Remainder 1:  tensor([-0.41990000,  9.24340057,  0.05919999])

tensor([-1.12380004,  3.31349993,  0.06130000])
Decoded 2:  tensor([-1.12380004,  3.31349969,  0.06129999])
Signed Remainder 2:  tensor([-1.12380004,  3.31349969,  0.06129999])

tensor([ 0.23590000, -6.92700005,  0.05890000])
Decoded 3:  tensor([ 0.23590000, -6.92699909,  0.05889999])
Signed Remainder 3:  tensor([ 0.23590000, -6.92699909,  0.05889999])

tensor([-3.52480006, -7.37570000,  0.05680000])
Decoded 4:  tensor([-3.52480006, -7.37570000,  0.05680000])
Signed Remainder 4:  tensor([-3.52480006, -7.37570000,  0.05680000])

tensor([-3.38989997, -8.88650036,  0.05690000])
Decoded 5:  tensor([-3.38989973, -8.88650131,  0.05690000])
Si

In [None]:
iterator = iter(test_loader)

In [None]:
poses_data

array([[ 0.7522,  7.1221,  0.0592],
       [ 0.7522,  7.1221,  0.0592],
       [ 0.7522,  7.1221,  0.0592],
       ...,
       [-0.4198,  9.2434,  0.0592],
       [-0.4198,  9.2434,  0.0592],
       [-0.4198,  9.2434,  0.0592]], dtype=float32)

In [None]:
"""
    Compresses Lidar scan into latent space.
    Encoder using MLP converting scan to a normal distribution with mu, log_var
    Decoder using MLP converting back to Reconstructed Scan

"""

class VAE(nn.Module):
    def __init__(self, input_dim, latent_dim=32, hidden_dim=128):
        super().__init__()
        self.latent_dim = latent_dim

        self.encoder_linear = nn.Linear(input_dim, hidden_dim)
        self.encoder_act = nn.Tanh()

        self.fc_mu = nn.Linear(hidden_dim, latent_dim)
        self.fc_var = nn.Linear(hidden_dim, latent_dim)


        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, hidden_dim),
            # nn.ReLU(),
            nn.Tanh(),
            nn.Linear(hidden_dim, input_dim),
            nn.Sigmoid()
        )

    def encode(self, x):
        # print(x)
        x = torch.clamp(x, min=-20.0, max=20.0)
        h = self.encoder_linear(x)
        # print(h)
        h = torch.clamp(h, min=-10.0, max=10.0)
        h = self.encoder_act(h)


        h = torch.clamp(h, min=-10.0, max=10.0)

        # print(h)

        return self.fc_mu(h), self.fc_var(h)

    def reparameterize(self, mu, log_var):
        std = torch.exp(0.5 * log_var)
        eps = torch.randn_like(std)
        return mu + eps * std

    def forward(self, x):

        # print(x)

        mu, log_var = self.encode(x)

        # print(mu, log_var)

        log_var = torch.clamp(log_var, min=-6.0, max=0.5)
        z = self.reparameterize(mu, log_var)
        x_recon = self.decoder(z)

        # print(x_recon, mu, log_var, z)

        return x_recon, mu, log_var, z


In [None]:
"""
    RealNVP Affine Coupling Layer introduced in Real-NVP paper.
    Splits input u into u1, u2 to avoid detection of common section during global localization.
    Transformed u2 based on u1 and condition c based on the equation 1 of the paper.

"""

class CouplingLayer(nn.Module):
    def __init__(self, input_dim, cond_dim, hidden_dim=128, alpha=2):
        super().__init__()
        self.mid = input_dim // 2
        self.input_dim = input_dim
        self.alpha = alpha
        net_in_dim = (input_dim - self.mid) + cond_dim

        self.s_net = nn.Sequential(
            nn.Linear(net_in_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, self.mid),
            nn.Tanh()
        )

        self.t_net = nn.Sequential(
            nn.Linear(net_in_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, self.mid)
        )


    def forward(self, u, c):
        u1 = u[:, :self.mid]
        u2 = u[:, self.mid:]

        net_input = torch.cat([u2, c], dim=1)

        s = self.s_net(net_input)
        t = self.t_net(net_input)

        s_cl = self.clamp(s)

        v1 = u1 * torch.exp(s_cl) + t
        v2 = u2

        return torch.cat([v1, v2], dim=1)

    def reverse(self, v, c):
        v1 = v[:, :self.mid]
        v2 = v[:, self.mid:]

        net_input = torch.cat([v2, c], dim=1)

        s = self.s_net(net_input)
        t = self.t_net(net_input)

        s_cl = self.clamp(s)

        u1 = (v1 - t) * torch.exp(-s_cl)
        u2 = v2

        return torch.cat([u1, u2], dim=1)


    def clamp(self, val):
        return (2 * self.alpha / torch.pi) * torch.atan(val / self.alpha)

In [None]:
'''
  Model is created for the LocalINN model with fc layers and encoders

'''

class LocalINN(nn.Module):
    def __init__(self, pose_dim=3, scan_dim=1081, vae_latent_dim=32, num_coupling=6):
        super().__init__()

        self.vae = VAE(scan_dim, vae_latent_dim)

        self.L_pose = 7
        self.L_cond = 1
        self.pos_enc_pose = PositionalEncoding(L=self.L_pose)
        self.pos_enc_cond = PositionalEncoding(L=self.L_cond)

        pose_enc_dim = pose_dim * 2 * self.L_pose
        cond_enc_dim = pose_dim * 2 * self.L_cond

        # INN Input dim = pose_enc_dim
        # INN Output dim = vae_latent_dim + z_dim
        # z_dim = pose_enc_dim - vae_latent_dim
        self.z_dim = pose_enc_dim - vae_latent_dim

        if self.z_dim <= 0:
            raise ValueError(f"Pose encoding dim ({pose_enc_dim}) must be > VAE latent dim ({vae_latent_dim})")

        print(f"Dimensions -> Pose Enc: {pose_enc_dim}, VAE Latent: {vae_latent_dim}, Z Noise: {self.z_dim}, Cond Enc: {cond_enc_dim}")

        self.coupling_layers = nn.ModuleList([
            CouplingLayer(pose_enc_dim, cond_enc_dim) for _ in range(num_coupling)
        ])

        self.perms = [torch.randperm(pose_enc_dim) for _ in range(num_coupling)]
        self.inv_perms = [torch.argsort(p) for p in self.perms]

    def process_condition(self, c):
        encoded_batch = []
        for i in range(c.shape[0]):
            encoded = []
            for k in range(c.shape[1]):
                if k == 2:
                    sine_part, cosine_part = self.pos_enc_cond.encode_even(c[i][k])
                else:
                    sine_part, cosine_part = self.pos_enc_cond.encode(c[i][k])
                encoded.append(sine_part)
                encoded.append(cosine_part)
            encoded_batch.append(encoded)
        return encoded_batch

    def process_pose(self, p):
        print(p.shape)
        encoded_batch = []
        for i in range(p.shape[0]):
            encoded = []
            for k in range(p.shape[1]):
                if k == 2:
                    sine_part, cosine_part = self.pos_enc_pose.encode_even(p[i][k])
                else:
                    sine_part, cosine_part = self.pos_enc_pose.encode(p[i][k])
                encoded.append(sine_part)
                encoded.append(cosine_part)
            encoded_batch.append(encoded)
        return encoded_batch

    def forward_inn(self, x_enc, c_enc):
        h = x_enc

        for i, layer in enumerate(self.coupling_layers):
            h = h[:, self.perms[i]]
            h = layer(h, c_enc)

        return h

    def reverse_inn(self, y_latent, z, c_enc):
        if y_latent.dim() > 2:
            y_latent = y_latent.flatten(1)
        h = torch.cat([y_latent, z], dim=1)

        for i in range(len(self.coupling_layers) - 1, -1, -1):
            h = self.coupling_layers[i].reverse(h, c_enc)
            h = h[:, self.inv_perms[i]]

        return h

    def forward(self, pose, scan, prev_pose):
        batch_size = pose.shape[0]

        # print(scan)
        # print(pose)
        # print(prev_pose)

        recon_scan, mu, log_var, y_vae = self.vae(scan)

        # print(recon_scan, mu, log_var, y_vae)

        x_enc = self.process_pose(pose)
        c_enc = self.process_condition(prev_pose)

        # print(c_enc)

        x_enc = torch.stack([torch.cat([p.flatten() for p in sample]) for sample in x_enc])
        c_enc = torch.stack([torch.cat([p.flatten() for p in sample]) for sample in c_enc])



        # print(x_enc.shape)
        # print(c_enc.shape)

        # print(x_enc)
        # print(c_enc)

        out_inn = self.forward_inn(x_enc, c_enc)

        y_inn = out_inn[:, :self.vae.latent_dim]
        z_inn = out_inn[:, self.vae.latent_dim:]

        scan_inn_recon = self.vae.decoder(y_inn)

        z_rand = torch.randn(batch_size, self.z_dim, device=pose.device)
        x_enc_pred = self.reverse_inn(y_vae, z_rand, c_enc)


        return {
            "scan_recon": recon_scan,
            "scan_inn_recon": scan_inn_recon,
            "mu": mu,
            "log_var": log_var,
            "y_vae": y_vae,
            "y_inn": y_inn,
            "z_inn": z_inn,
            "x_enc_gt": x_enc,
            "x_enc_pred": x_enc_pred
        }

## 1.4. Training the Neural Network

We need to define our loss function, optimizers, callbacks and learning rate schedular, all of which requires some type of hyperparameter which need to be tuned for a better model.


In [None]:
model = LocalINN()

Dimensions -> Pose Enc: 42, VAE Latent: 32, Z Noise: 10, Cond Enc: 6


In [None]:
"""
    Initializes the weights of Linear layers using Xavier Uniform.

"""

def init_weights(m):
    if isinstance(m, torch.nn.Linear):
        torch.nn.init.xavier_uniform_(m.weight.data)
        if m.bias is not None:
            m.bias.data.fill_(0.01)

In [None]:
model.apply(init_weights)

model.to(device)

LocalINN(
  (vae): VAE(
    (encoder_linear): Linear(in_features=1081, out_features=128, bias=True)
    (encoder_act): Tanh()
    (fc_mu): Linear(in_features=128, out_features=32, bias=True)
    (fc_var): Linear(in_features=128, out_features=32, bias=True)
    (decoder): Sequential(
      (0): Linear(in_features=32, out_features=128, bias=True)
      (1): Tanh()
      (2): Linear(in_features=128, out_features=1081, bias=True)
      (3): Sigmoid()
    )
  )
  (pos_enc_pose): PositionalEncoding()
  (pos_enc_cond): PositionalEncoding()
  (coupling_layers): ModuleList(
    (0-5): 6 x CouplingLayer(
      (s_net): Sequential(
        (0): Linear(in_features=27, out_features=128, bias=True)
        (1): ReLU()
        (2): Linear(in_features=128, out_features=21, bias=True)
        (3): Tanh()
      )
      (t_net): Sequential(
        (0): Linear(in_features=27, out_features=128, bias=True)
        (1): ReLU()
        (2): Linear(in_features=128, out_features=21, bias=True)
      )
    )
  

In [None]:
from torch import optim
from torch.optim import lr_scheduler
import torch.nn as nn

'''
  Initiating the Adam optimizer and learning rate scheduler (#TODO).

'''

optimizer = optim.Adam(model.parameters(), lr=1e-3)

In [None]:
'''
  Complete training pipeline with callbacks (#TODO), optimizer, learning rate scheduler (#TODO)
  and example prediction (#TODO).

'''

epochs = 5 # Small number for demo

print("Starting Training...")

loss_history = []
# i = 0

for epoch in range(epochs):
    total_loss = 0
    for batch_idx, (lidar, pose, prev_pose) in enumerate(train_loader):
        optimizer.zero_grad()
        lidar = lidar.to(device)
        pose = pose.to(device)
        pose = pose * 0.01
        prev_pose = prev_pose.to(device)
        prev_pose = prev_pose * 0.01
        lidar = torch.where(lidar > 100, torch.tensor(0.06, device=lidar.device), lidar)
        lidar = torch.where(lidar < -100, torch.tensor(-0.06, device=lidar.device), lidar)
        # print(lidar)

        # Debug before forward
        # print("BATCH SHAPES (pose, lidar, prev_pose):", pose.shape, lidar.shape, prev_pose.shape)

        outputs = model(pose, lidar, prev_pose)

        # print("\n===== MODEL OUTPUT STATS (Batch Debug) =====")
        # print(lidar)
        # print(pose)
        # print(prev_pose)
        # for key, value in outputs.items():
        #     print_tensor_stats(key, value)
        # print("============================================\n")

        # i+=1
        # if i > 5: break

        # 1. VAE Loss
        recon_loss = nn.MSELoss()(outputs['scan_recon'], lidar)
        kl_loss = -0.5 * torch.sum(1 + outputs['log_var'] - outputs['mu'].pow(2) - outputs['log_var'].exp())
        vae_loss = recon_loss + 0.001 * kl_loss

        # print(recon_loss, kl_loss, vae_loss)

        # 2. Forward Loss
        inn_recon_loss = nn.MSELoss()(outputs['scan_inn_recon'], lidar)
        latent_match_loss = nn.MSELoss()(outputs['y_inn'], outputs['y_vae'].detach())
        z_loss = torch.mean(outputs['z_inn']**2)

        forward_loss = inn_recon_loss + latent_match_loss + z_loss

        # 3. Reverse Loss
        reverse_loss = nn.MSELoss()(outputs['x_enc_pred'], outputs['x_enc_gt'])

        loss = vae_loss + forward_loss + reverse_loss

        # print(loss)

        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=20.0)
        optimizer.step()

        total_loss += loss.item()

    avg_loss = total_loss / len(train_loader)
    loss_history.append(avg_loss)
    if (epoch+1) % 5 == 0:
        print(f"Epoch {epoch+1}/{epochs}, Loss: {avg_loss:.4f}")


[1;30;43mStreaming output truncated to the last 5000 lines.[0m
[[tensor([0.01576386]), tensor([0.99987572]), tensor([-0.14542827]), tensor([0.98936880]), tensor([0.00382645]), tensor([0.99999267])], [tensor([0.02024618]), tensor([0.99979502]), tensor([0.17416312]), tensor([0.98471683]), tensor([0.00371964]), tensor([0.99999309])]]
torch.Size([2, 42])
torch.Size([2, 6])
torch.Size([2, 3])
[[tensor([-0.09591594]), tensor([0.99538946]), tensor([-0.27020958]), tensor([0.96280152]), tensor([0.00380760]), tensor([0.99999273])], [tensor([0.02024618]), tensor([0.99979502]), tensor([0.17416312]), tensor([0.98471683]), tensor([0.00371964]), tensor([0.99999309])]]
torch.Size([2, 42])
torch.Size([2, 6])
torch.Size([2, 3])
[[tensor([-0.06264002]), tensor([0.99803621]), tensor([-0.28514269]), tensor([0.95848507]), tensor([0.00356256]), tensor([0.99999368])], [tensor([0.02024618]), tensor([0.99979502]), tensor([0.17416312]), tensor([0.98471683]), tensor([0.00371964]), tensor([0.99999309])]]
torch.S

## 1.5. Testing the trained Model

In [None]:
test_iterator = iter(test_loader)

In [None]:
def split_encoded_vector(x, L):
    B, D = x.shape
    assert D % L == 0, "D must be divisible by L"
    return x.reshape(B, D // L, L)


In [None]:
model.eval()
L = 7

with torch.no_grad():
      lidar, pose, prev_pose = next(test_iterator)
      lidar = lidar.to(device)
      pose = pose.to(device)
      prev_pose = prev_pose.to(device)
      lidar = torch.where(lidar > 100, torch.tensor(0.06, device=lidar.device), lidar)
      lidar = torch.where(lidar < -100, torch.tensor(-0.06, device=lidar.device), lidar)

      # --- Model Inference and Pose Prediction ---
      outputs = model(pose, lidar, prev_pose)

      # print(pose)
      x_enc_gt = outputs['x_enc_gt']
      x_enc_pred = outputs['x_enc_pred']

      # print(type(x_enc_gt))
      # print(len(x_enc_gt))
      # print(type(x_enc_gt[0]))
      # try:
      #     print(x_enc_gt[0].shape)
      # except Exception as e:
      #     print("shape error:", e)
      # print("example inner:", x_enc_gt[0])

      x_enc_gt = split_encoded_vector(x_enc_gt, L)
      x_enc_pred = split_encoded_vector(x_enc_pred, L)

      # print(type(x_enc_gt))
      # print(len(x_enc_gt))
      # print(type(x_enc_gt[0]))
      # try:
      #     print(x_enc_gt[0].shape)
      # except Exception as e:
      #     print("shape error:", e)
      # print("example inner:", x_enc_gt[0])


      decoded_gt = []
      decoded_pred = []

      decoder = PositionalEncoding(L)

      for i in range(x_enc_gt.shape[0]):
          decoded = decoder.decode_from_encoded_list(x_enc_gt[i], even_dims=(2,), max_shift=20)
          signed_remainder = (decoded - torch.trunc(decoded)) * 100
          decoded_gt.append(signed_remainder)

          decoded = decoder.decode_from_encoded_list(x_enc_pred[i], even_dims=(2,), max_shift=20)
          signed_remainder = (decoded - torch.trunc(decoded)) * 100
          decoded_pred.append(signed_remainder)

      # print(x_enc_gt)
      # print(x_enc_pred)
      print()
      print(pose)
      print(decoded_gt)
      print(decoded_pred)
      print()






torch.Size([4, 3])
[[tensor([-0.17935160]), tensor([-0.98378503]), tensor([0.97012681]), tensor([0.24259840]), tensor([0.35875902]), tensor([0.93343019])], [tensor([-0.96842670]), tensor([0.24929839]), tensor([-0.69229442]), tensor([-0.72161520]), tensor([0.36344635]), tensor([0.93161511])], [tensor([0.62964553]), tensor([0.77688253]), tensor([-0.99887502]), tensor([-0.04742087]), tensor([0.35347489]), tensor([0.93544400])], [tensor([-0.96873921]), tensor([0.24808128]), tensor([-0.69229442]), tensor([-0.72161520]), tensor([0.36344635]), tensor([0.93161511])]]
torch.Size([4, 42])
torch.Size([4, 6])

tensor([[-0.93199998,  8.45330048,  0.05760000],
        [-0.41980001,  9.24339962,  0.05920000],
        [-1.72510004, -4.52850008,  0.05710000],
        [-0.42019999,  9.24339962,  0.05920000]])
[tensor([-93.19999695,  45.33004761,   5.75999975]), tensor([-41.97999954, -75.65998840,   5.92000008]), tensor([ 27.48998642, -52.85000610,   5.71002960]), tensor([-42.02000046, -75.65998840,   5.

In [None]:
pose

tensor([[-1.26929998,  0.86390001,  0.06010000],
        [-0.99980003, -1.07539999,  0.05540000],
        [-0.41980001,  9.24339962,  0.05920000],
        [ 0.75150001,  1.60189998,  0.05720000],
        [-3.19600010, -4.24350023,  0.05830000],
        [-2.95040011, -4.48680019,  0.06010000],
        [-0.99479997, -1.45850003,  0.05560000],
        [-2.68330002, -4.24830008,  0.05840000]], device='cuda:0')

In [None]:
decoded_gt

[tensor([-26.92999840, -13.60999298,   6.00999594], device='cuda:0'),
 tensor([-99.98000336,  -7.53999949,   5.53999996], device='cuda:0'),
 tensor([-41.97999954,  24.33996201,   5.92000008], device='cuda:0'),
 tensor([ 75.15000916, -39.80998993,   5.71999979], device='cuda:0'),
 tensor([-19.60003281, -24.35002327,   5.83000040], device='cuda:0'),
 tensor([-95.04000854, -48.68001938,   6.00999594], device='cuda:0'),
 tensor([-99.47999573,  54.15000916,   5.56000042], device='cuda:0'),
 tensor([-68.33002472, -24.83000755, -94.15998840], device='cuda:0')]

In [None]:
decoded_pred

[tensor([ 9.01660919, -5.59730530, -2.02589035], device='cuda:0'),
 tensor([-88.73081207, -94.21348572,  87.21561432], device='cuda:0'),
 tensor([ -0.85144043,  73.96068573, -10.40821075], device='cuda:0'),
 tensor([ 18.39256287,  73.35071564, -99.01190186], device='cuda:0'),
 tensor([-85.79978943, -21.52252197, -11.90528870], device='cuda:0'),
 tensor([  8.33778381, -30.92842102,  84.12399292], device='cuda:0'),
 tensor([ 21.67816162, -84.40322876,  98.52294922], device='cuda:0'),
 tensor([-1.22814178, 76.60560608, 88.71250153], device='cuda:0')]

In [None]:
## encode the prev_state
p_encoding = PositionalEncoding(L=10)

for i in range(8):
    lidar, pose, prev_pose = next(iterator)
    print(pose[0])
    pose = pose[0] * 0.01

    encoded = []
    for k in range(3):
        if k == 2:
            sine_part, cosine_part = p_encoding.encode_even(pose[k])
        else:
            sine_part, cosine_part = p_encoding.encode(pose[k])
        encoded.append(sine_part)
        encoded.append(cosine_part)

    # print(f"Pose {i}: ",pose)

    decoded = p_encoding.decode_from_encoded_list(encoded, even_dims=(2,), max_shift=20)
    print(f"Decoded {i}: ",(decoded)*100)
    signed_remainder = (decoded - torch.trunc(decoded)) * 100
    print(f"Signed Remainder {i}: ", signed_remainder)
    print()