In [6]:
from pathlib import Path
from scipy.io import loadmat
import sys
import os

# Robust path finding for data.mat
current_path = Path.cwd()
possible_data_paths = [
    current_path / 'data' / 'data.mat',
    current_path.parent / 'data' / 'data.mat',
    current_path.parent.parent / 'data' / 'data.mat',
    # Fallback absolute path
    Path('/home/luky/skola/KalmanNet-for-state-estimation/data/data.mat')
]

dataset_path = None
for p in possible_data_paths:
    if p.exists():
        dataset_path = p
        break

if dataset_path is None or not dataset_path.exists():
    print("Warning: data.mat not found automatically.")
    dataset_path = Path('data/data.mat')

print(f"Dataset path: {dataset_path}")

# Add project root to sys.path (2 levels up from debug/test)
notebook_dir = os.getcwd()
project_root = os.path.abspath(os.path.join(notebook_dir, '..', '..'))
if project_root not in sys.path:
    sys.path.insert(0, project_root)
print(f"Project root added: {project_root}")

mat_data = loadmat(dataset_path)
print(mat_data.keys())


Dataset path: /home/luky/skola/KalmanNet-main/data/data.mat
Project root added: /home/luky/skola/KalmanNet-main
dict_keys(['__header__', '__version__', '__globals__', 'hB', 'souradniceGNSS', 'souradniceX', 'souradniceY', 'souradniceZ'])


In [7]:
import torch
import matplotlib.pyplot as plt
from utils import trainer
from utils import utils
from Systems import DynamicSystem
import Filters
import torch.nn.functional as F
from torch.utils.data import TensorDataset, DataLoader
import numpy as np
from scipy.io import loadmat
from scipy.interpolate import RegularGridInterpolator
import random

torch.manual_seed(42)
np.random.seed(42)
random.seed(42)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(42)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"device: {device}")

device: cuda


In [8]:
mat_data = loadmat(dataset_path)

souradniceX_mapa = mat_data['souradniceX']
souradniceY_mapa = mat_data['souradniceY']
souradniceZ_mapa = mat_data['souradniceZ']
souradniceGNSS = mat_data['souradniceGNSS'] 
x_axis_unique = souradniceX_mapa[0, :]
y_axis_unique = souradniceY_mapa[:, 0]

print(f"Dimensions of 1D X axis: {x_axis_unique.shape}")
print(f"Dimensions of 1D Y axis: {y_axis_unique.shape}")
print(f"Dimensions of 2D elevation data Z: {souradniceZ_mapa.shape}")

terMap_interpolator = RegularGridInterpolator(
    (y_axis_unique, x_axis_unique),
    souradniceZ_mapa,
    bounds_error=False, 
    fill_value=np.nan
)

def terMap(px, py):
    # Query bilinear interpolation over the terrain map
    points_to_query = np.column_stack((py, px))
    return terMap_interpolator(points_to_query)

Dimensions of 1D X axis: (2500,)
Dimensions of 1D Y axis: (2500,)
Dimensions of 2D elevation data Z: (2500, 2500)


In [9]:
import torch
from Systems import DynamicSystemTAN

state_dim = 4
obs_dim = 3
dT = 1
q = 1

F = torch.tensor([[1.0, 0.0, dT, 0.0],
                   [0.0, 1.0, 0.0, dT],
                   [0.0, 0.0, 1.0, 0.0],
                   [0.0, 0.0, 0.0, 1.0]])

Q = q* torch.tensor([[dT**3/3, 0.0, dT**2/2, 0.0],
                   [0.0, dT**3/3, 0.0, dT**2/2],
                   [dT**2/2, 0.0, dT, 0.0],
                   [0.0, dT**2/2, 0.0, dT]])
R = torch.tensor([[3.0**2, 0.0, 0.0],
                   [0.0, 1.0**2, 0.0],
                   [0.0, 0.0, 1.0**2]])

initial_velocity_np = souradniceGNSS[:2, 1] - souradniceGNSS[:2, 0]
# initial_velocity_np = torch.from_numpy()
initial_velocity = torch.from_numpy(np.array([0,0]))

initial_position = torch.from_numpy(souradniceGNSS[:2, 0])
x_0 = torch.cat([
    initial_position,
    initial_velocity
]).float()
print(x_0)

P_0 = torch.tensor([[25.0, 0.0, 0.0, 0.0],
                    [0.0, 25.0, 0.0, 0.0],
                    [0.0, 0.0, 0.5, 0.0],
                    [0.0, 0.0, 0.0, 0.5]])
import torch.nn.functional as func

def h_nl_differentiable(x: torch.Tensor, map_tensor, x_min, x_max, y_min, y_max) -> torch.Tensor:
    batch_size = x.shape[0]

    px = x[:, 0]
    py = x[:, 1]

    px_norm = 2.0 * (px - x_min) / (x_max - x_min) - 1.0
    py_norm = 2.0 * (py - y_min) / (y_max - y_min) - 1.0

    sampling_grid = torch.stack((px_norm, py_norm), dim=1).view(batch_size, 1, 1, 2)

    vyska_terenu_batch = func.grid_sample(
        map_tensor.expand(batch_size, -1, -1, -1),
        sampling_grid, 
        mode='bilinear', 
        padding_mode='border',
        align_corners=True
    )

    vyska_terenu = vyska_terenu_batch.view(batch_size)

    eps = 1e-12
    vx_w, vy_w = x[:, 2], x[:, 3]
    norm_v_w = torch.sqrt(vx_w**2 + vy_w**2).clamp(min=eps)
    cos_psi = vx_w / norm_v_w
    sin_psi = vy_w / norm_v_w

    vx_b = cos_psi * vx_w - sin_psi * vy_w 
    vy_b = sin_psi * vx_w + cos_psi * vy_w

    result = torch.stack([vyska_terenu, vx_b, vy_b], dim=1)

    return result

x_axis_unique = souradniceX_mapa[0, :]
y_axis_unique = souradniceY_mapa[:, 0]
terMap_tensor = torch.from_numpy(souradniceZ_mapa).float().unsqueeze(0).unsqueeze(0).to(device)
x_min, x_max = x_axis_unique.min(), x_axis_unique.max()
y_min, y_max = y_axis_unique.min(), y_axis_unique.max()

h_wrapper = lambda x: h_nl_differentiable(
    x, 
    map_tensor=terMap_tensor, 
    x_min=x_min, 
    x_max=x_max, 
    y_min=y_min, 
    y_max=y_max
)

system_model = DynamicSystemTAN(
    state_dim=state_dim,
    obs_dim=obs_dim,
    Q=Q.float(),
    R=R.float(),
    Ex0=x_0.float(),
    P0=P_0.float(),
    F=F.float(),
    h=h_wrapper,
    x_axis_unique=x_axis_unique, 
    y_axis_unique=y_axis_unique,
    device=device
)

tensor([1487547.1250, 6395520.5000,       0.0000,       0.0000])
INFO: DynamicSystemTAN inicializov√°n s hranicemi mapy:
  X: [1476611.42, 1489541.47]
  Y: [6384032.63, 6400441.34]


In [10]:
import torch
from torch.utils.data import TensorDataset, DataLoader
from utils import utils
import torch
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import TensorDataset, DataLoader
import numpy as np
import os
import random
from copy import deepcopy
from state_NN_models import TAN
from utils import trainer 

torch.manual_seed(42)
np.random.seed(42)
random.seed(42)


In [11]:
import torch
from torch.utils.data import TensorDataset, DataLoader
import os
from utils import trainer # P≈ôedpokl√°d√°m, ≈æe toto m√°≈°

# === 1. ZJEDNODU≈†EN√ù DATA MANAGER (BEZ NORMALIZACE) ===
class NavigationDataManager:
    def __init__(self, data_dir):
        """
        Jen dr≈æ√°k na cestu k dat≈Øm. ≈Ω√°dn√° statistika, ≈æ√°dn√° normalizace.
        """
        self.data_dir = data_dir
        
    def get_dataloader(self, seq_len, split='train', shuffle=True, batch_size=32):
        # Sestaven√≠ cesty: ./generated_data/len_100/train.pt
        path = os.path.join(self.data_dir, f'len_{seq_len}', f'{split}.pt')
        
        if not os.path.exists(path):
            raise FileNotFoundError(f"‚ùå Dataset nenalezen: {path}")
            
        # Naƒçten√≠ tenzor≈Ø
        data = torch.load(path)
        x = data['x'] # Stav [Batch, Seq, DimX]
        y = data['y'] # Mƒõ≈ôen√≠ [Batch, Seq, DimY] - RAW DATA
        
        # Vytvo≈ôen√≠ datasetu
        dataset = TensorDataset(x, y)
        
        return DataLoader(dataset, batch_size=batch_size, shuffle=shuffle)

# === 2. KONFIGURACE CURRICULA ===
DATA_DIR = './generated_data_synthetic_controlled'

# Inicializace mana≈æera (teƒè je to jen wrapper pro naƒç√≠t√°n√≠ soubor≈Ø)
data_manager = NavigationDataManager(DATA_DIR)

# Definice f√°z√≠ (zde ≈ô√≠d√≠≈°, jak se tr√©nink vyv√≠j√≠)
curriculum_schedule = [
    # F√ÅZE 1: Warm-up (Kr√°tk√© sekvence)
    {
        'phase_id': 1,
        'seq_len': 10,          
        'epochs': 500,           
        'lr': 1e-3, 
        'batch_size': 256
    },
    
    # F√ÅZE 2: Stabilizace (St≈ôedn√≠ d√©lka)
    {
        'phase_id': 2,
        'seq_len': 100, 
        'epochs': 200, 
        'lr': 1e-4,             
        'batch_size': 128
    },
       # F√ÅZE 3: Long-term Reality (Pln√° d√©lka)
    {
        'phase_id': 3,
        'seq_len': 200,         
        'epochs': 200, 
        'lr': 5e-5,             
        'batch_size': 64       # Men≈°√≠ batch kv≈Øli pamƒõti GPU u dlouh√Ωch sekvenc√≠
    },
    # F√ÅZE 3: Long-term Reality (Pln√° d√©lka)
    {
        'phase_id': 4,
        'seq_len': 300,         
        'epochs': 200, 
        'lr': 1e-5,             
        'batch_size': 64       # Men≈°√≠ batch kv≈Øli pamƒõti GPU u dlouh√Ωch sekvenc√≠
    }
]

# === 3. NAƒå√çT√ÅN√ç DO PAMƒöTI (CACHING) ===
print("\n=== NAƒå√çT√ÅN√ç RAW DAT Z DISKU (BEZ EXT. NORMALIZACE) ===")
datasets_cache = {} 

for phase in curriculum_schedule:
    seq_len = phase['seq_len']
    bs = phase['batch_size']
    
    print(f"üì• Naƒç√≠t√°m F√°zi {phase['phase_id']}: Seq={seq_len} | Batch={bs} ...")
    
    try:
        # Pou≈æit√≠ DataManageru
        train_loader = data_manager.get_dataloader(seq_len=seq_len, split='train', shuffle=True, batch_size=bs)
        val_loader = data_manager.get_dataloader(seq_len=seq_len, split='val', shuffle=False, batch_size=bs)
        
        # Ulo≈æen√≠ do cache
        datasets_cache[phase['phase_id']] = (train_loader, val_loader)
        
        # Rychl√° kontrola pro jistotu
        x_ex, y_ex = next(iter(train_loader))
        if phase['phase_id'] == 1:
            print(f"   üîé Uk√°zka RAW dat (y): {y_ex[0, 0, :].tolist()}") 
            # Mƒõl bys vidƒõt velk√° ƒç√≠sla (nap≈ô. 250.0) a mal√° (0.2), ne ~0.0
        
    except FileNotFoundError as e:
        print(f"   ‚ö†Ô∏è CHYBA: {e}")
        # raise e # Odkomentuj, pokud chce≈°, aby to spadlo p≈ôi chybƒõ

print("\n‚úÖ Data p≈ôipravena. Normalizaci ≈ôe≈°√≠ model.")


=== NAƒå√çT√ÅN√ç RAW DAT Z DISKU (BEZ EXT. NORMALIZACE) ===
üì• Naƒç√≠t√°m F√°zi 1: Seq=10 | Batch=256 ...
   üîé Uk√°zka RAW dat (y): [323.7707824707031, -13.519903182983398, -29.721908569335938]
üì• Naƒç√≠t√°m F√°zi 2: Seq=100 | Batch=128 ...
üì• Naƒç√≠t√°m F√°zi 3: Seq=200 | Batch=64 ...
üì• Naƒç√≠t√°m F√°zi 4: Seq=300 | Batch=64 ...

‚úÖ Data p≈ôipravena. Normalizaci ≈ôe≈°√≠ model.


In [12]:
import torch
import torch.nn.functional as F
import numpy as np
from copy import deepcopy

def train_BayesianKalmanNet_Hybrid(
    model, train_loader, val_loader, device,
    total_train_iter, learning_rate, clip_grad,
    J_samples, validation_period, logging_period,
    warmup_iterations=0, weight_decay_=1e-5,
    lambda_mse=100.0  # <--- NOV√ù PARAMETR: Kotva pro MSE
):
    # torch.autograd.set_detect_anomaly(True)
    optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay_)
    
    # Scheduler: Pokud se loss zasekne, sn√≠≈æ√≠me LR (pom√°h√° stabilizovat konvergenci)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='min', factor=0.5, patience=50, verbose=True
    )

    best_val_anees = float('inf')
    score_at_best = {"val_nll": 0.0, "val_mse": 0.0}
    best_iter_count = 0
    best_model_state = None
    train_iter_count = 0
    done = False

    print(f"üöÄ START Hybrid Training: Loss = NLL + {lambda_mse} * MSE")

    while not done:
        model.train()
        for x_true_batch, y_meas_batch in train_loader:
            if train_iter_count >= total_train_iter: done = True; break
            if torch.isnan(x_true_batch).any():
                print(f"!!! SKIP BATCH iter {train_iter_count}: NaN found in x_true (Ground Truth) !!!")
                continue
            
            x_true_batch = x_true_batch.to(device)
            y_meas_batch = y_meas_batch.to(device)
            
            # --- Training ---
            optimizer.zero_grad()
            batch_size, seq_len, _ = x_true_batch.shape
            
            all_trajectories_for_ensemble = []
            all_regs_for_ensemble = []

            # 1. Ensemble Forward Pass
            for j in range(J_samples):
                model.reset(batch_size=batch_size, initial_state=x_true_batch[:, 0, :])
                current_trajectory_x_hats = []
                current_trajectory_regs = []
                for t in range(1, seq_len):
                    y_t = y_meas_batch[:, t, :]
                    x_filtered_t, reg_t = model.step(y_t)
                    if torch.isnan(x_filtered_t).any():
                            raise ValueError(f"NaN in x_filtered_t at sample {j}, step {t}")
                    current_trajectory_x_hats.append(x_filtered_t)
                    current_trajectory_regs.append(reg_t)
                all_trajectories_for_ensemble.append(torch.stack(current_trajectory_x_hats, dim=1))
                all_regs_for_ensemble.append(torch.sum(torch.stack(current_trajectory_regs)))

            # 2. Statistiky Ensemble
            ensemble_trajectories = torch.stack(all_trajectories_for_ensemble, dim=0)
            x_hat_sequence = ensemble_trajectories.mean(dim=0)
            
            # Epistemick√° variance (ƒçist√Ω rozptyl s√≠tƒõ)
            # P≈ôiƒç√≠t√°me 1e-9 jen proti dƒõlen√≠ nulou, nen√≠ to "noise floor"
            cov_diag_sequence = ensemble_trajectories.var(dim=0) + 1e-9 
            
            regularization_loss = torch.stack(all_regs_for_ensemble).mean()
            target_sequence = x_true_batch[:, 1:, :]
            
            # --- 3. V√ùPOƒåET HYBRIDN√ç LOSS ---
            
            # A) MSE ƒå√°st (P≈ôesnost)
            mse_loss = F.mse_loss(x_hat_sequence, target_sequence)
            
            # B) NLL ƒå√°st (Konzistence)
            # 0.5 * (log(var) + (target - pred)^2 / var)
            error_sq = (x_hat_sequence - target_sequence) ** 2
            nll_loss = 0.5 * (torch.log(cov_diag_sequence) + error_sq / cov_diag_sequence).mean()
            
            # C) Celkov√° Loss (Hybrid)
            # Zde je ta magie: I kdy≈æ NLL chce ut√©ct s varianc√≠, lambda_mse * mse ho dr≈æ√≠ zp√°tky
            loss = nll_loss + (lambda_mse * mse_loss) + regularization_loss * 1.0
            
            if torch.isnan(loss): 
                print("Collapse detected (NaN loss)"); done = True; break
            
            loss.backward()
            if clip_grad > 0: torch.nn.utils.clip_grad_norm_(model.parameters(), clip_grad)
            optimizer.step()
            train_iter_count += 1
            
            # --- Logging ---
            diff = x_hat_sequence - target_sequence
            mean_error = diff.abs().mean().item()
            min_variance = cov_diag_sequence.min().item()

            if train_iter_count % logging_period == 0:
                with torch.no_grad():
                    # Zjist√≠me dropout pravdƒõpodobnosti (jen pro info)
                    p1 = torch.sigmoid(model.dnn.concrete_dropout1.p_logit).item()
                    p2 = torch.sigmoid(model.dnn.concrete_dropout2.p_logit).item()
                
                print(f"--- Iter [{train_iter_count}/{total_train_iter}] ---")
                print(f"    Total Loss: {loss.item():.4f}")
                print(f"    MSE Component: {mse_loss.item():.4f} (Weighted: {(mse_loss.item() * lambda_mse):.4f})")
                print(f"    NLL Component: {nll_loss.item():.4f}")
                print(f"    Min Variance: {min_variance:.2e}")
                print(f"    p1={p1:.3f}, p2={p2:.3f}")

            # --- Validation step ---
            if train_iter_count > 0 and train_iter_count % validation_period == 0:
                # Step scheduleru podle tr√©novac√≠ loss (nebo validace, pokud bys to p≈ôedƒõlal)
                scheduler.step(loss)
                
                print(f"\n--- Validation at iteration {train_iter_count} ---")
                model.eval()
                val_mse_list = []
                all_val_x_true_cpu, all_val_x_hat_cpu, all_val_P_hat_cpu = [], [], []

                with torch.no_grad():
                    for x_true_val_batch, y_meas_val_batch in val_loader:
                        val_batch_size, val_seq_len, _ = x_true_val_batch.shape
                        x_true_val_batch = x_true_val_batch.to(device)
                        y_meas_val_batch = y_meas_val_batch.to(device)
                        val_ensemble_trajectories = []
                        for j in range(J_samples):
                            model.reset(batch_size=val_batch_size, initial_state=x_true_val_batch[:, 0, :])
                            val_current_x_hats = []
                            for t in range(1, val_seq_len):
                                y_t_val = y_meas_val_batch[:, t, :]
                                x_filtered_t, _ = model.step(y_t_val)
                                val_current_x_hats.append(x_filtered_t)
                            val_ensemble_trajectories.append(torch.stack(val_current_x_hats, dim=1))
                        
                        # Agregace validace
                        val_ensemble = torch.stack(val_ensemble_trajectories, dim=0)
                        val_preds_seq = val_ensemble.mean(dim=0)
                        
                        val_target_seq = x_true_val_batch[:, 1:, :]
                        val_mse_list.append(F.mse_loss(val_preds_seq, val_target_seq).item())
                        
                        # P≈ô√≠prava pro ANEES
                        initial_state_val = x_true_val_batch[:, 0, :].unsqueeze(1)
                        full_x_hat = torch.cat([initial_state_val, val_preds_seq], dim=1)
                        
                        # Epistemick√° variance
                        val_covs_diag = val_ensemble.var(dim=0) + 1e-9
                        
                        # Vytvo≈ôen√≠ diagon√°ln√≠ch matic P
                        # (Zjednodu≈°en√° konstrukce pro ANEES calc)
                        # Pro p≈ôesn√© ANEES bychom mƒõli dƒõlat outer product, 
                        # ale diagon√°la z var() je dobr√° aproximace pro BKN
                        val_covs_full = torch.zeros(val_batch_size, val_seq_len-1, 4, 4, device=device)
                        for b in range(val_batch_size):
                            for t in range(val_seq_len-1):
                                val_covs_full[b, t] = torch.diag(val_covs_diag[b, t])

                        P0 = model.system_model.P0.unsqueeze(0).repeat(val_batch_size, 1, 1).unsqueeze(1)
                        full_P_hat = torch.cat([P0, val_covs_full], dim=1)
                        
                        all_val_x_true_cpu.append(x_true_val_batch.cpu())
                        all_val_x_hat_cpu.append(full_x_hat.cpu())
                        all_val_P_hat_cpu.append(full_P_hat.cpu())

                avg_val_mse = np.mean(val_mse_list)
                final_x_true_list = torch.cat(all_val_x_true_cpu, dim=0)
                final_x_hat_list = torch.cat(all_val_x_hat_cpu, dim=0)
                final_P_hat_list = torch.cat(all_val_P_hat_cpu, dim=0)
                
                # V√Ωpoƒçet ANEES
                avg_val_anees = trainer.calculate_anees_vectorized(final_x_true_list, final_x_hat_list, final_P_hat_list)
                
                print(f"  Average MSE: {avg_val_mse:.4f}, Average ANEES: {avg_val_anees:.4f}")
                
                # Ukl√°d√°n√≠ modelu:
                # ZDE POZOR: Ukl√°d√°me, pokud je ANEES rozumn√© A Z√ÅROVE≈á MSE nen√≠ katastrofick√©.
                # Nebo prostƒõ ukl√°d√°me podle nejlep≈°√≠ho ANEES jako d≈ô√≠v, ale s vƒõdom√≠m hybridn√≠ loss.
                if not np.isnan(avg_val_anees) and avg_val_anees < best_val_anees and avg_val_anees > 0:
                    print("  >>> New best VALIDATION ANEES! Saving model. <<<")
                    best_val_anees = avg_val_anees
                    best_iter_count = train_iter_count
                    score_at_best['val_mse'] = avg_val_mse
                    best_model_state = deepcopy(model.state_dict())
                print("-" * 50)
                model.train()

    print("\nTraining completed.")
    if best_model_state:
        print(f"Loading best model from iteration {best_iter_count} with ANEES {best_val_anees:.4f}")
        model.load_state_dict(best_model_state)
    else:
        print("No best model was saved; returning last state.")

    return {
        "best_val_anees": best_val_anees,
        "best_val_nll": score_at_best['val_nll'],
        "best_val_mse": score_at_best['val_mse'],
        "best_iter": best_iter_count,
        "final_model": model
    }

In [13]:
from state_NN_models import TAN
# --- A) KONFIGURACE S√çTƒö ---
print("=== INICIALIZACE NOV√âHO MODELU ===")
state_knet2 = TAN.StateBayesianKalmanNetTAN(
        system_model=system_model, 
        device=device,
        hidden_size_multiplier=8,       
        output_layer_multiplier=4,
        num_gru_layers=1,
        init_max_dropout=0.4,
        init_min_dropout=0.2      
).to(device)

train_loader, val_loader = datasets_cache[2]
print(len(train_loader))

train_BayesianKalmanNet_Hybrid(
    model=state_knet2,
    train_loader=train_loader,
    val_loader=val_loader,
    device=device,
    total_train_iter=2000,          # M√≠rnƒõ del≈°√≠
    learning_rate=1e-5,             # <--- M√≠rnƒõ zv√Ω≈°eno pro dynamiku
    warmup_iterations=0,          # <--- PRODLOU≈ΩENO (d≈Øle≈æit√© pro MSE)
    clip_grad=0.25,                  # Uvoln√≠me trochu gradienty
    J_samples=5,                    # Pro rychlost ladƒõn√≠ staƒç√≠ 5, na fin√°le dej 10+
    validation_period=10,
    logging_period=1,
    weight_decay_=1e-3,
    lambda_mse=100.0                # <--- NOV√ù PARAMETR: Kotva pro MSE
)

print("üéâ Tr√©nink dokonƒçen.")

=== INICIALIZACE NOV√âHO MODELU ===
INFO: Aplikuji 'Start Zero' inicializaci pro Kalman Gain.
DEBUG: V√Ωstupn√≠ vrstva vynulov√°na (Soft Start).
16




üöÄ START Hybrid Training: Loss = NLL + 100.0 * MSE
--- Iter [1/2000] ---
    Total Loss: 20019800064.0000
    MSE Component: 244559.2656 (Weighted: 24455926.5625)
    NLL Component: 19995344896.0000
    Min Variance: 1.00e-09
    p1=0.371, p2=0.212
--- Iter [2/2000] ---
    Total Loss: 13119876096.0000
    MSE Component: 251342.7812 (Weighted: 25134278.1250)
    NLL Component: 13094742016.0000
    Min Variance: 1.00e-09
    p1=0.371, p2=0.212
--- Iter [3/2000] ---
    Total Loss: 14000433152.0000
    MSE Component: 245195.4688 (Weighted: 24519546.8750)
    NLL Component: 13975913472.0000
    Min Variance: 1.00e-09
    p1=0.371, p2=0.212
--- Iter [4/2000] ---
    Total Loss: 16730799104.0000
    MSE Component: 233098.5000 (Weighted: 23309850.0000)
    NLL Component: 16707488768.0000
    Min Variance: 1.00e-09
    p1=0.371, p2=0.212
--- Iter [5/2000] ---
    Total Loss: 12090051584.0000
    MSE Component: 242148.6406 (Weighted: 24214864.0625)
    NLL Component: 12065837056.0000
    Min

KeyboardInterrupt: 

In [None]:
if False:
    # save model.
    save_path = f'knet_curriculum_model_1GRU_relativne_ok_vysledky.pth'
    torch.save(state_knet2.state_dict(), save_path)
    print(f"Model saved to '{save_path}'.")

# Test na synteticke trajektorii

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import Filters
import os
from tqdm import tqdm

# === KONFIGURACE ===
TEST_DATA_PATH = './generated_data_synthetic_controlled/test_set/test.pt'
PLOT_PER_ITERATION = True  # Vykreslovat graf pro ka≈ædou trajektorii?
MAX_TEST_SAMPLES = 1      # Kolik trajektori√≠ z test setu vyhodnotit (max 10, co jsme vygenerovali)

print(f"=== VYHODNOCEN√ç NA TESTOVAC√ç SADƒö ===")
print(f"Naƒç√≠t√°m data z: {TEST_DATA_PATH}")

# 1. Naƒçten√≠ Testovac√≠ sady
if not os.path.exists(TEST_DATA_PATH):
    raise FileNotFoundError(f"Soubor {TEST_DATA_PATH} neexistuje! Spus≈•te generov√°n√≠ testovac√≠ sady.")

test_data = torch.load(TEST_DATA_PATH, map_location=device)
X_test_all = test_data['x']  # Ground Truth [N, Seq, 4]
Y_test_all = test_data['y']  # Measurements [N, Seq, 3]

n_samples = min(X_test_all.shape[0], MAX_TEST_SAMPLES)
print(f"Poƒçet testovac√≠ch trajektori√≠: {n_samples}")
print(f"D√©lka sekvence: {X_test_all.shape[1]}")
print("Modely: KalmanNet vs. UKF vs. PF")

# 2. Inicializace pro sbƒõr dat
detailed_results = []
agg_mse = {"KNet": [], "UKF": [], "PF": []}
agg_pos = {"KNet": [], "UKF": [], "PF": []}

# Ujist√≠me se, ≈æe KNet je v eval m√≥du
state_knet2.eval()

# --- HLAVN√ç SMYƒåKA (Iterace p≈ôes testovac√≠ trajektorie) ---
for i in tqdm(range(n_samples), desc="Evaluace"):
    
    # A) P≈ô√≠prava dat pro tento bƒõh
    x_gt_tensor = X_test_all[i]      # [Seq, 4]
    y_obs_tensor = Y_test_all[i]     # [Seq, 3]
    
    x_gt = x_gt_tensor.cpu().numpy()
    seq_len = x_gt.shape[0]
    
    # Skuteƒçn√Ω startovn√≠ stav (pro inicializaci filtr≈Ø)
    true_init_state = x_gt_tensor[0] # [4]
    
    # B) Inference: KalmanNet
    # KNet oƒçek√°v√° [Batch, Seq, Dim], tak≈æe mus√≠me p≈ôidat dimenzi
    with torch.no_grad():
        initial_state_batch = true_init_state.unsqueeze(0) # [1, 4]
        
        # Reset stavu s√≠tƒõ
        state_knet2.reset(batch_size=1, initial_state=initial_state_batch)
        
        knet_preds = []
        # KNet zpracov√°v√° sekvenci krok po kroku (nebo bychom mohli upravit forward na celou sekvenci)
        # Zde zachov√°me logiku step-by-step pro konzistenci
        
        # Vstup y_obs_tensor m√° tvar [Seq, 3]. Pot≈ôebujeme [1, 3] pro ka≈æd√Ω krok
        y_input_batch = y_obs_tensor.unsqueeze(0) # [1, Seq, 3]
        
        for t in range(1, seq_len):
            y_t = y_input_batch[:, t, :] # [1, 3]
            x_est = state_knet2.step(y_t)
            knet_preds.append(x_est)
            
        # Slo≈æen√≠ predikce (p≈ôid√°me poƒç√°teƒçn√≠ stav)
        if len(knet_preds) > 0:
            knet_preds_tensor = torch.stack(knet_preds, dim=1) # [1, Seq-1, 4]
            full_knet_est = torch.cat([initial_state_batch.unsqueeze(1), knet_preds_tensor], dim=1)
        else:
            full_knet_est = initial_state_batch.unsqueeze(1)
            
        x_est_knet = full_knet_est.squeeze().cpu().numpy()

    # C) Inference: UKF & PF
    # Filtry oƒçek√°vaj√≠ [Seq, Dim] (bez batch dimenze, pokud tak byly naps√°ny)
    
    # UKF
    ukf_ideal = Filters.UnscentedKalmanFilter(system_model)
    ukf_res = ukf_ideal.process_sequence(
        y_seq=y_obs_tensor,
        Ex0=true_init_state, 
        P0=system_model.P0
    )
    x_est_ukf = ukf_res['x_filtered'].cpu().numpy()

    # PF (Sn√≠≈æil jsem poƒçet ƒç√°stic na 2000 pro rychlost, pro fin√°ln√≠ diplomku dejte v√≠ce)
    pf = Filters.ParticleFilter(system_model, num_particles=10000) 
    pf_res = pf.process_sequence(
        y_seq=y_obs_tensor,
        Ex0=true_init_state, 
        P0=system_model.P0
    )
    x_est_pf = pf_res['x_filtered'].cpu().numpy()

    
    # D) V√Ωpoƒçet chyb
    # O≈ôe≈æeme na d√©lku min(odhad, gt) pro jistotu
    min_len = min(len(x_gt), len(x_est_knet), len(x_est_ukf))
    
    # KNet
    diff_knet = x_est_knet[:min_len] - x_gt[:min_len]
    mse_knet = np.mean(np.sum(diff_knet[:, :2]**2, axis=1)) # Pouze XY chyba
    pos_err_knet = np.mean(np.sqrt(diff_knet[:, 0]**2 + diff_knet[:, 1]**2))
    
    # UKF
    diff_ukf = x_est_ukf[:min_len] - x_gt[:min_len]
    mse_ukf = np.mean(np.sum(diff_ukf[:, :2]**2, axis=1))
    pos_err_ukf = np.mean(np.sqrt(diff_ukf[:, 0]**2 + diff_ukf[:, 1]**2))
    
    # PF
    diff_pf = x_est_pf[:min_len] - x_gt[:min_len]
    mse_pf = np.mean(np.sum(diff_pf[:, :2]**2, axis=1))
    pos_err_pf = np.mean(np.sqrt(diff_pf[:, 0]**2 + diff_pf[:, 1]**2))
    
    # Ulo≈æen√≠
    agg_mse["KNet"].append(mse_knet)
    agg_pos["KNet"].append(pos_err_knet)
    agg_mse["UKF"].append(mse_ukf)
    agg_pos["UKF"].append(pos_err_ukf)
    agg_mse["PF"].append(mse_pf)
    agg_pos["PF"].append(pos_err_pf)

    detailed_results.append({
        "Run_ID": i + 1,
        "KNet_MSE": mse_knet,
        "UKF_MSE": mse_ukf,
        "PF_MSE": mse_pf,
        "KNet_PosErr": pos_err_knet,
        "UKF_PosErr": pos_err_ukf,
        "PF_PosErr": pos_err_pf
    })
    
    # E) Vykreslen√≠
    if PLOT_PER_ITERATION:
        fig = plt.figure(figsize=(12, 6))
        # Vykresl√≠me jen XY trajektorii
        plt.plot(x_gt[:, 0], x_gt[:, 1], 'k-', linewidth=3, alpha=0.3, label='Ground Truth')
        plt.plot(x_est_knet[:, 0], x_est_knet[:, 1], 'g-', linewidth=2, label=f'KalmanNet (Err: {pos_err_knet:.1f}m)')
        plt.plot(x_est_ukf[:, 0], x_est_ukf[:, 1], 'b--', linewidth=1, label=f'UKF (Err: {pos_err_ukf:.1f}m)')
        plt.plot(x_est_pf[:, 0], x_est_pf[:, 1], 'r:', linewidth=1, alpha=0.6, label=f'PF (Err: {pos_err_pf:.1f}m)')
        
        plt.title(f"Test Trajectory {i+1} (Length {seq_len})")
        plt.xlabel("X [m]")
        plt.ylabel("Y [m]")
        plt.legend()
        plt.axis('equal')
        plt.grid(True)
        plt.show()

# --- V√ùPIS V√ùSLEDK≈Æ ---
df_results = pd.DataFrame(detailed_results)
print("\n" + "="*80)
print(f"DETAILN√ç V√ùSLEDKY PO JEDNOTLIV√ùCH TRAJEKTORI√çCH")
print("="*80)
pd.options.display.float_format = '{:,.2f}'.format
print(df_results[["Run_ID", "KNet_MSE", "UKF_MSE", "PF_MSE", "KNet_PosErr", "UKF_PosErr", "PF_PosErr"]])

print("\n" + "="*80)
print(f"SOUHRNN√Å STATISTIKA ({n_samples} trajektori√≠)")
print("="*80)

def get_stats(key):
    return np.mean(agg_mse[key]), np.std(agg_mse[key]), np.mean(agg_pos[key]), np.std(agg_pos[key])

knet_stats = get_stats("KNet")
ukf_stats = get_stats("UKF")
pf_stats = get_stats("PF")

print(f"{'Model':<15} | {'MSE (Mean ¬± Std)':<25} | {'Pos Error (Mean ¬± Std)':<25}")
print("-" * 75)
print(f"{'KalmanNet':<15} | {knet_stats[0]:.1f} ¬± {knet_stats[1]:.1f} | {knet_stats[2]:.2f} ¬± {knet_stats[3]:.2f} m")
print(f"{'UKF':<15} | {ukf_stats[0]:.1f} ¬± {ukf_stats[1]:.1f} | {ukf_stats[2]:.2f} ¬± {ukf_stats[3]:.2f} m")
print(f"{'PF':<15} | {pf_stats[0]:.1f} ¬± {pf_stats[1]:.1f} | {pf_stats[2]:.2f} ¬± {pf_stats[3]:.2f} m")
print("="*80)

plt.figure(figsize=(10, 6))
plt.boxplot([agg_pos["KNet"], agg_pos["UKF"], agg_pos["PF"]], labels=['KalmanNet', 'UKF', 'PF'], patch_artist=True)
plt.title(f"Position Error Distribution ({n_samples} test trajectories)")
plt.ylabel("Avg Position Error [m]")
plt.grid(True, axis='y')
plt.show()