In [1]:
from pathlib import Path
from scipy.io import loadmat
import sys
import os

# Robust path finding for data.mat
current_path = Path.cwd()
possible_data_paths = [
    current_path / 'data' / 'data.mat',
    current_path.parent / 'data' / 'data.mat',
    current_path.parent.parent / 'data' / 'data.mat',
    # Fallback absolute path
    Path('/home/luky/skola/KalmanNet-for-state-estimation/data/data.mat')
]

dataset_path = None
for p in possible_data_paths:
    if p.exists():
        dataset_path = p
        break

if dataset_path is None or not dataset_path.exists():
    print("Warning: data.mat not found automatically.")
    dataset_path = Path('data/data.mat')

print(f"Dataset path: {dataset_path}")

# Add project root to sys.path (2 levels up from debug/test)
notebook_dir = os.getcwd()
project_root = os.path.abspath(os.path.join(notebook_dir, '..', '..'))
if project_root not in sys.path:
    sys.path.insert(0, project_root)
print(f"Project root added: {project_root}")

mat_data = loadmat(dataset_path)
print(mat_data.keys())


Dataset path: /home/luky/skola/KalmanNet-main/data/data.mat
Project root added: /home/luky/skola/KalmanNet-main
dict_keys(['__header__', '__version__', '__globals__', 'hB', 'souradniceGNSS', 'souradniceX', 'souradniceY', 'souradniceZ'])


In [2]:
import torch
import matplotlib.pyplot as plt
from utils import trainer
from utils import utils
from Systems import DynamicSystem
import Filters
import torch.nn.functional as F
from torch.utils.data import TensorDataset, DataLoader
import numpy as np
from scipy.io import loadmat
from scipy.interpolate import RegularGridInterpolator
import random

torch.manual_seed(42)
np.random.seed(42)
random.seed(42)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(42)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"device: {device}")

device: cpu


  return torch._C._cuda_getDeviceCount() > 0


In [3]:
mat_data = loadmat(dataset_path)

souradniceX_mapa = mat_data['souradniceX']
souradniceY_mapa = mat_data['souradniceY']
souradniceZ_mapa = mat_data['souradniceZ']
souradniceGNSS = mat_data['souradniceGNSS'] 
x_axis_unique = souradniceX_mapa[0, :]
y_axis_unique = souradniceY_mapa[:, 0]

print(f"Dimensions of 1D X axis: {x_axis_unique.shape}")
print(f"Dimensions of 1D Y axis: {y_axis_unique.shape}")
print(f"Dimensions of 2D elevation data Z: {souradniceZ_mapa.shape}")

terMap_interpolator = RegularGridInterpolator(
    (y_axis_unique, x_axis_unique),
    souradniceZ_mapa,
    bounds_error=False, 
    fill_value=np.nan
)

def terMap(px, py):
    # Query bilinear interpolation over the terrain map
    points_to_query = np.column_stack((py, px))
    return terMap_interpolator(points_to_query)

Dimensions of 1D X axis: (2500,)
Dimensions of 1D Y axis: (2500,)
Dimensions of 2D elevation data Z: (2500, 2500)


In [4]:
import torch
from Systems import DynamicSystemTAN

state_dim = 4
obs_dim = 3
dT = 1
q = 1

F = torch.tensor([[1.0, 0.0, dT, 0.0],
                   [0.0, 1.0, 0.0, dT],
                   [0.0, 0.0, 1.0, 0.0],
                   [0.0, 0.0, 0.0, 1.0]])

Q = q* torch.tensor([[dT**3/3, 0.0, dT**2/2, 0.0],
                   [0.0, dT**3/3, 0.0, dT**2/2],
                   [dT**2/2, 0.0, dT, 0.0],
                   [0.0, dT**2/2, 0.0, dT]])
R = torch.tensor([[3.0**2, 0.0, 0.0],
                   [0.0, 1.0**2, 0.0],
                   [0.0, 0.0, 1.0**2]])

initial_velocity_np = souradniceGNSS[:2, 1] - souradniceGNSS[:2, 0]
# initial_velocity_np = torch.from_numpy()
initial_velocity = torch.from_numpy(np.array([0,0]))

initial_position = torch.from_numpy(souradniceGNSS[:2, 0])
x_0 = torch.cat([
    initial_position,
    initial_velocity
]).float()
print(x_0)

P_0 = torch.tensor([[25.0, 0.0, 0.0, 0.0],
                    [0.0, 25.0, 0.0, 0.0],
                    [0.0, 0.0, 0.5, 0.0],
                    [0.0, 0.0, 0.0, 0.5]])
import torch.nn.functional as func

def h_nl_differentiable(x: torch.Tensor, map_tensor, x_min, x_max, y_min, y_max) -> torch.Tensor:
    batch_size = x.shape[0]

    px = x[:, 0]
    py = x[:, 1]

    px_norm = 2.0 * (px - x_min) / (x_max - x_min) - 1.0
    py_norm = 2.0 * (py - y_min) / (y_max - y_min) - 1.0

    sampling_grid = torch.stack((px_norm, py_norm), dim=1).view(batch_size, 1, 1, 2)

    vyska_terenu_batch = func.grid_sample(
        map_tensor.expand(batch_size, -1, -1, -1),
        sampling_grid, 
        mode='bilinear', 
        padding_mode='border',
        align_corners=True
    )

    vyska_terenu = vyska_terenu_batch.view(batch_size)

    eps = 1e-12
    vx_w, vy_w = x[:, 2], x[:, 3]
    norm_v_w = torch.sqrt(vx_w**2 + vy_w**2).clamp(min=eps)
    cos_psi = vx_w / norm_v_w
    sin_psi = vy_w / norm_v_w

    vx_b = cos_psi * vx_w - sin_psi * vy_w 
    vy_b = sin_psi * vx_w + cos_psi * vy_w

    result = torch.stack([vyska_terenu, vx_b, vy_b], dim=1)

    return result

x_axis_unique = souradniceX_mapa[0, :]
y_axis_unique = souradniceY_mapa[:, 0]
terMap_tensor = torch.from_numpy(souradniceZ_mapa).float().unsqueeze(0).unsqueeze(0).to(device)
x_min, x_max = x_axis_unique.min(), x_axis_unique.max()
y_min, y_max = y_axis_unique.min(), y_axis_unique.max()

h_wrapper = lambda x: h_nl_differentiable(
    x, 
    map_tensor=terMap_tensor, 
    x_min=x_min, 
    x_max=x_max, 
    y_min=y_min, 
    y_max=y_max
)

system_model = DynamicSystemTAN(
    state_dim=state_dim,
    obs_dim=obs_dim,
    Q=Q.float(),
    R=R.float(),
    Ex0=x_0.float(),
    P0=P_0.float(),
    F=F.float(),
    h=h_wrapper,
    x_axis_unique=x_axis_unique, 
    y_axis_unique=y_axis_unique,
    device=device
)

tensor([1487547.1250, 6395520.5000,       0.0000,       0.0000])
INFO: DynamicSystemTAN inicializov√°n s hranicemi mapy:
  X: [1476611.42, 1489541.47]
  Y: [6384032.63, 6400441.34]


In [5]:
import torch
from torch.utils.data import TensorDataset, DataLoader
from utils import utils
import torch
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import TensorDataset, DataLoader
import numpy as np
import os
import random
from copy import deepcopy
from state_NN_models import TAN
from utils import trainer 

torch.manual_seed(42)
np.random.seed(42)
random.seed(42)


In [6]:
import torch
from torch.utils.data import TensorDataset, DataLoader
import os
from utils import trainer # P≈ôedpokl√°d√°m, ≈æe toto m√°≈°

# === 1. ZJEDNODU≈†EN√ù DATA MANAGER (BEZ NORMALIZACE) ===
class NavigationDataManager:
    def __init__(self, data_dir):
        """
        Jen dr≈æ√°k na cestu k dat≈Øm. ≈Ω√°dn√° statistika, ≈æ√°dn√° normalizace.
        """
        self.data_dir = data_dir
        
    def get_dataloader(self, seq_len, split='train', shuffle=True, batch_size=32):
        # Sestaven√≠ cesty: ./generated_data/len_100/train.pt
        path = os.path.join(self.data_dir, f'len_{seq_len}', f'{split}.pt')
        
        if not os.path.exists(path):
            raise FileNotFoundError(f"‚ùå Dataset nenalezen: {path}")
            
        # Naƒçten√≠ tenzor≈Ø
        data = torch.load(path)
        x = data['x'] # Stav [Batch, Seq, DimX]
        y = data['y'] # Mƒõ≈ôen√≠ [Batch, Seq, DimY] - RAW DATA
        
        # Vytvo≈ôen√≠ datasetu
        dataset = TensorDataset(x, y)
        
        return DataLoader(dataset, batch_size=batch_size, shuffle=shuffle)

# === 2. KONFIGURACE CURRICULA ===
DATA_DIR = './generated_data_synthetic_controlled'

# Inicializace mana≈æera (teƒè je to jen wrapper pro naƒç√≠t√°n√≠ soubor≈Ø)
data_manager = NavigationDataManager(DATA_DIR)

# Definice f√°z√≠ (zde ≈ô√≠d√≠≈°, jak se tr√©nink vyv√≠j√≠)
curriculum_schedule = [
    # F√ÅZE 1: Warm-up (Kr√°tk√© sekvence)
    {
        'phase_id': 1,
        'seq_len': 10,          
        'epochs': 500,           
        'lr': 1e-3, 
        'batch_size': 256
    },
    
    # F√ÅZE 2: Stabilizace (St≈ôedn√≠ d√©lka)
    {
        'phase_id': 2,
        'seq_len': 100, 
        'epochs': 200, 
        'lr': 1e-4,             
        'batch_size': 256
    },
    
    # F√ÅZE 3: Long-term Reality (Pln√° d√©lka)
    {
        'phase_id': 3,
        'seq_len': 300,         
        'epochs': 200, 
        'lr': 1e-5,             
        'batch_size': 128       # Men≈°√≠ batch kv≈Øli pamƒõti GPU u dlouh√Ωch sekvenc√≠
    }
]

# === 3. NAƒå√çT√ÅN√ç DO PAMƒöTI (CACHING) ===
print("\n=== NAƒå√çT√ÅN√ç RAW DAT Z DISKU (BEZ EXT. NORMALIZACE) ===")
datasets_cache = {} 

for phase in curriculum_schedule:
    seq_len = phase['seq_len']
    bs = phase['batch_size']
    
    print(f"üì• Naƒç√≠t√°m F√°zi {phase['phase_id']}: Seq={seq_len} | Batch={bs} ...")
    
    try:
        # Pou≈æit√≠ DataManageru
        train_loader = data_manager.get_dataloader(seq_len=seq_len, split='train', shuffle=True, batch_size=bs)
        val_loader = data_manager.get_dataloader(seq_len=seq_len, split='val', shuffle=False, batch_size=bs)
        
        # Ulo≈æen√≠ do cache
        datasets_cache[phase['phase_id']] = (train_loader, val_loader)
        
        # Rychl√° kontrola pro jistotu
        x_ex, y_ex = next(iter(train_loader))
        if phase['phase_id'] == 1:
            print(f"   üîé Uk√°zka RAW dat (y): {y_ex[0, 0, :].tolist()}") 
            # Mƒõl bys vidƒõt velk√° ƒç√≠sla (nap≈ô. 250.0) a mal√° (0.2), ne ~0.0
        
    except FileNotFoundError as e:
        print(f"   ‚ö†Ô∏è CHYBA: {e}")
        # raise e # Odkomentuj, pokud chce≈°, aby to spadlo p≈ôi chybƒõ

print("\n‚úÖ Data p≈ôipravena. Normalizaci ≈ôe≈°√≠ model.")


=== NAƒå√çT√ÅN√ç RAW DAT Z DISKU (BEZ EXT. NORMALIZACE) ===
üì• Naƒç√≠t√°m F√°zi 1: Seq=10 | Batch=256 ...
   üîé Uk√°zka RAW dat (y): [323.7707824707031, -13.519903182983398, -29.721908569335938]
üì• Naƒç√≠t√°m F√°zi 2: Seq=100 | Batch=256 ...
üì• Naƒç√≠t√°m F√°zi 3: Seq=300 | Batch=128 ...

‚úÖ Data p≈ôipravena. Normalizaci ≈ôe≈°√≠ model.


In [7]:
import torch
import torch.nn.functional as F
import numpy as np
from copy import deepcopy

def gaussian_nll_safe(target, preds, var, min_var=1e-6, max_error_sq=100.0):
    """
    Bezpeƒçn√° NLL loss funkce.
    """
    safe_var = var + min_var
    error_sq = (preds - target) ** 2
    # Clampujeme velikost chyby v ƒçitateli, aby loss neexplodovala,
    # ale gradient do variance (ve jmenovateli) z≈Østal zachov√°n.
    error_sq_clamped = torch.clamp(error_sq, max=max_error_sq)
    nll = 0.5 * (torch.log(safe_var) + error_sq_clamped / safe_var)
    return nll.mean()

def train_BayesianKalmanNet_TwoPhase(
    model, train_loader, val_loader, device,
    total_train_iter, learning_rate, clip_grad,
    J_samples, validation_period, logging_period,
    mse_warmup_iters=0,  # <--- NOV√ù PARAMETR: Kolik iterac√≠ tr√©novat jen na MSE
    weight_decay_=1e-5
):
    # Optimizer
    optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay_)
    
    # Scheduler (voliteln√Ω, zde vypnut√Ω pro jednoduchost)
    # scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=50)

    best_val_metric = float('inf') # Buƒè MSE nebo ANEES podle f√°ze
    best_model_state = None
    best_iter_count = 0
    train_iter_count = 0
    done = False

    print(f"üöÄ START Two-Phase Training")
    print(f"    Phase 1: MSE Warmup (0 - {mse_warmup_iters} iters)")
    print(f"    Phase 2: NLL Optimization ({mse_warmup_iters} - {total_train_iter} iters)")

    while not done:
        model.train()
        for x_true_batch, y_meas_batch in train_loader:
            if train_iter_count >= total_train_iter: done = True; break
            
            # Detekce NaN v datech
            if torch.isnan(x_true_batch).any():
                print(f"!!! SKIP BATCH iter {train_iter_count}: NaN found in x_true !!!")
                continue
            
            x_true_batch = x_true_batch.to(device)
            y_meas_batch = y_meas_batch.to(device)
            batch_size, seq_len, _ = x_true_batch.shape
            
            # --- Training Step ---
            optimizer.zero_grad()
            
            all_trajectories_for_ensemble = []
            all_regs_for_ensemble = []

            # 1. Ensemble Forward Pass
            for j in range(J_samples):
                model.reset(batch_size=batch_size, initial_state=x_true_batch[:, 0, :])
                current_trajectory_x_hats = []
                current_trajectory_regs = []
                
                for t in range(1, seq_len):
                    y_t = y_meas_batch[:, t, :]
                    x_filtered_t, reg_t = model.step(y_t)
                    
                    if torch.isnan(x_filtered_t).any():
                        raise ValueError(f"NaN in x_filtered_t at sample {j}, step {t}")
                        
                    current_trajectory_x_hats.append(x_filtered_t)
                    current_trajectory_regs.append(reg_t)
                
                all_trajectories_for_ensemble.append(torch.stack(current_trajectory_x_hats, dim=1))
                all_regs_for_ensemble.append(torch.sum(torch.stack(current_trajectory_regs)))

            # 2. Statistiky Ensemble
            ensemble_trajectories = torch.stack(all_trajectories_for_ensemble, dim=0)
            x_hat_sequence = ensemble_trajectories.mean(dim=0)
            
            # Epistemick√° variance
            cov_diag_sequence = ensemble_trajectories.var(dim=0) + 1e-9 
            
            # Normalizovan√° regularizace (na d√©lku sekvence)
            regularization_loss = torch.stack(all_regs_for_ensemble).mean() / seq_len
            
            target_sequence = x_true_batch[:, 1:, :]
            
            # --- 3. V√ùPOƒåET LOSS (Dvƒõ f√°ze) ---
            
            # V≈ædy spoƒç√≠t√°me oboj√≠ pro logov√°n√≠
            mse_loss = F.mse_loss(x_hat_sequence, target_sequence)
            nll_loss = gaussian_nll_safe(
                target=target_sequence, 
                preds=x_hat_sequence, 
                var=cov_diag_sequence, 
                min_var=1e-5, 
                max_error_sq=100.0
            )
            
            # Rozhodov√°n√≠ o optimalizaƒçn√≠ Loss
            loss_mode = ""
            if train_iter_count < mse_warmup_iters:
                # F√ÅZE 1: Warmup na MSE
                # Ignorujeme NLL, soust≈ôed√≠me se na trefen√≠ trajektorie
                loss = mse_loss + regularization_loss
                loss_mode = "MSE_WARMUP"
            else:
                # F√ÅZE 2: NLL Optimalizace
                # Minimalizujeme NLL (kalibrace nejistoty + p≈ôesnost)
                # MSE zde nen√≠ p≈ô√≠mo v gradientu (je schovan√© v NLL), ale logujeme ho
                loss = nll_loss + regularization_loss
                loss_mode = "NLL_OPTIM"
            
            if torch.isnan(loss): 
                print("Collapse detected (NaN loss)"); done = True; break
            
            loss.backward()

            # --- DIAGNOSTIKA GRADIENT≈Æ ---
            if train_iter_count % logging_period == 0:
                total_norm = 0.0
                max_grad = 0.0
                nan_grad_detected = False
                for p in model.parameters():
                    if p.grad is not None:
                        param_norm = p.grad.data.norm(2).item()
                        total_norm += param_norm ** 2
                        p_max = p.grad.data.abs().max().item()
                        if p_max > max_grad: max_grad = p_max
                        if torch.isnan(p.grad).any():
                            nan_grad_detected = True
                total_norm = total_norm ** 0.5
                
                if nan_grad_detected:
                     print(f"!!! WARNING: NaN gradient detected at iter {train_iter_count} !!!")

            if clip_grad > 0: 
                torch.nn.utils.clip_grad_norm_(model.parameters(), clip_grad)
            
            optimizer.step()
            train_iter_count += 1
            
            # --- LOGGING ---
            if train_iter_count % logging_period == 0:
                with torch.no_grad():
                    # Statistiky variance
                    min_variance = cov_diag_sequence.min().item()
                    max_variance = cov_diag_sequence.max().item()
                    mean_variance = cov_diag_sequence.mean().item()
                    
                    # Dropout pravdƒõpodobnosti
                    p1 = torch.sigmoid(model.dnn.concrete_dropout1.p_logit).item()
                    p2 = torch.sigmoid(model.dnn.concrete_dropout2.p_logit).item()
                    
                    # Chyba v metrech (L1)
                    diff = x_hat_sequence - target_sequence
                    mean_error = diff.abs().mean().item()
                
                print(f"--- Iter [{train_iter_count}/{total_train_iter}] ({loss_mode}) ---")
                print(f"    Total Loss:     {loss.item():.4f}")
                print(f"    MSE (Metric):   {mse_loss.item():.4f}")
                print(f"    NLL (Metric):   {nll_loss.item():.4f}")
                print(f"    Reg Loss:       {regularization_loss.item():.6f}")
                print(f"    Var Stats:      Min={min_variance:.2e}, Max={max_variance:.2e}, Mean={mean_variance:.2e}")
                print(f"    Mean Error L1:  {mean_error:.4f} m")
                print(f"    Grad Norm:      {total_norm:.4f} (Max: {max_grad:.4f})")
                print(f"    Dropout:        p1={p1:.4f}, p2={p2:.4f}")

 # --- VALIDATION (Upraven√° ƒç√°st) ---
            if train_iter_count > 0 and train_iter_count % validation_period == 0:
                print(f"\n--- Validation at iteration {train_iter_count} ---")
                model.eval()
                val_mse_list = []
                
                # Pro ANEES
                all_val_x_true, all_val_x_hat, all_val_P_hat = [], [], []

                with torch.no_grad():
                    for x_true_val, y_meas_val in val_loader:
                        v_bs, v_seq, _ = x_true_val.shape
                        x_true_val = x_true_val.to(device)
                        y_meas_val = y_meas_val.to(device)
                        
                        val_ensemble_trajs = []
                        
                        for j in range(J_samples):
                            model.reset(batch_size=v_bs, initial_state=x_true_val[:, 0, :])
                            v_x_hats = []
                            for t in range(1, v_seq):
                                est, _ = model.step(y_meas_val[:, t, :])
                                v_x_hats.append(est)
                            val_ensemble_trajs.append(torch.stack(v_x_hats, dim=1))
                        
                        val_ens_stack = torch.stack(val_ensemble_trajs, dim=0) # [J, B, T, D]
                        val_mean = val_ens_stack.mean(dim=0)
                        val_var_diag = val_ens_stack.var(dim=0) + 1e-9
                        
                        val_mse_list.append(F.mse_loss(val_mean, x_true_val[:, 1:, :]).item())
                        
                        # --- P≈ô√≠prava dat pro ANEES ---
                        # 1. Stavy (Ground Truth a Odhad)
                        # P≈ôid√°me startovn√≠ bod (t=0), abychom mƒõli celou sekvenci
                        full_x_hat = torch.cat([x_true_val[:, 0, :].unsqueeze(1), val_mean], dim=1)
                        
                        # 2. Kovariance (P)
                        # Vytvo≈ô√≠me pln√© matice 4x4 z diagon√°ln√≠ho rozptylu
                        val_covs_full = torch.zeros(v_bs, v_seq-1, 4, 4, device=device)
                        # Rychl√° vektorizovan√° konstrukce diagon√°ly
                        # (M√≠sto cyklu p≈ôes batch a ƒças pou≈æijeme diag_embed, pokud to PyTorch verze um√≠)
                        try:
                            val_covs_full = torch.diag_embed(val_var_diag)
                        except:
                            # Fallback pro star≈°√≠ verze nebo pokud to sel≈æe
                            for b in range(v_bs):
                                for t in range(v_seq-1):
                                    val_covs_full[b, t] = torch.diag(val_var_diag[b, t])
                        
                        # P0 (poƒç√°teƒçn√≠ nejistota - mal√°)
                        P0 = torch.eye(4, device=device).unsqueeze(0).unsqueeze(0).repeat(v_bs, 1, 1, 1) * 1e-6
                        full_P_hat = torch.cat([P0, val_covs_full], dim=1)
                        
                        all_val_x_true.append(x_true_val.cpu())
                        all_val_x_hat.append(full_x_hat.cpu())
                        all_val_P_hat.append(full_P_hat.cpu())

                avg_val_mse = np.mean(val_mse_list)
                
                # --- V√ùPOƒåET ANEES ---
                try:
                    cat_true = torch.cat(all_val_x_true, dim=0)
                    cat_hat = torch.cat(all_val_x_hat, dim=0)
                    cat_P = torch.cat(all_val_P_hat, dim=0)
                    
                    # Vol√°me va≈°i funkci (p≈ôedpokl√°d√°m, ≈æe je v modulu 'trainer' nebo 'utils')
                    # Pokud ji nem√°te naimportovanou, mus√≠te ji definovat.
                    # Zde pou≈æ√≠v√°m 'trainer.calculate_anees_vectorized' z va≈°eho p≈Øvodn√≠ho k√≥du
                    if hasattr(trainer, 'calculate_anees_vectorized'):
                         avg_val_anees = trainer.calculate_anees_vectorized(cat_true, cat_hat, cat_P)
                    else:
                         avg_val_anees = float('nan') # Placeholder
                         
                except Exception as e:
                    print(f"Error calculating ANEES: {e}")
                    avg_val_anees = float('nan')

                print(f"  Avg MSE: {avg_val_mse:.4f} | Avg ANEES: {avg_val_anees:.4f}")
                
                # --- LOGIKA UKL√ÅD√ÅN√ç (Smart Saving) ---
                # Ukl√°d√°me, pokud se zlep≈°√≠ MSE (priorita 1).
                # Volitelnƒõ: Ve f√°zi NLL by se dalo ukl√°dat, pokud se zlep≈°√≠ ANEES (k ide√°ln√≠ 4.0),
                # ale to je riskantn√≠, pokud by MSE vyletƒõlo.
                # Z≈Østa≈àme u MSE jako "kotvy kvality".
                
                current_metric = avg_val_mse
                if current_metric < best_val_metric:
                    print(f"  >>> New Best Model! (MSE: {best_val_metric:.4f} -> {current_metric:.4f}) <<<")
                    best_val_metric = current_metric
                    best_iter_count = train_iter_count
                    best_model_state = deepcopy(model.state_dict())
                
                print("-" * 50)
                model.train()

    print("\nTraining completed.")
    if best_model_state:
        print(f"Loading best model from iteration {best_iter_count}")
        model.load_state_dict(best_model_state)
    
    return {"final_model": model}

In [8]:
import torch
import torch.nn.functional as F
import numpy as np
from copy import deepcopy

def calculate_anees_vectorized(x_true, x_hat, P_hat, eps=1e-6):
    """
    Robustn√≠ v√Ωpoƒçet ANEES pomoc√≠ Pseudo-Inverze.
    Zvl√°d√° i situaci, kdy J_samples < State_Dim (singul√°rn√≠ matice).
    """
    # 1. Zplo≈°tƒõn√≠ (Flattening)
    if x_true.dim() == 3: x_true = x_true.reshape(-1, x_true.shape[-1])
    if x_hat.dim() == 3: x_hat = x_hat.reshape(-1, x_hat.shape[-1])
    if P_hat.dim() == 4: P_hat = P_hat.reshape(-1, P_hat.shape[-2], P_hat.shape[-1])

    # 2. V√Ωpoƒçet chyby
    error = (x_true - x_hat).unsqueeze(-1)  # [N, Dim, 1]
    
    # 3. Pseudo-Inverze matice P
    # pinv je mnohem stabilnƒõj≈°√≠ pro Low-Rank matice (kdy≈æ m√°me m√°lo vzork≈Ø)
    # hermitian=True ≈ô√≠k√°, ≈æe matice je symetrick√° (co≈æ kovariance je)
    P_inv = torch.linalg.pinv(P_hat, hermitian=True)
    
    # 4. ANEES = error^T * P_inv * error
    # bmm: [N, 1, Dim] @ [N, Dim, Dim] -> [N, 1, Dim]
    #      [N, 1, Dim] @ [N, Dim, 1] -> [N, 1, 1]
    
    temp = torch.bmm(error.transpose(1, 2), P_inv)
    anees_per_sample = torch.bmm(temp, error).squeeze() # [N]
    
    # Ochrana proti numerick√Ωm artefakt≈Øm (malink√° z√°porn√° ƒç√≠sla jako -1e-10)
    anees_per_sample = torch.clamp(anees_per_sample, min=0.0)

    return anees_per_sample.mean().item()

def gaussian_nll_safe(target, preds, var, min_var=1e-6, max_error_sq=100.0):
    # 1. Bezpeƒçn√° variance (epsilon) - spr√°vnƒõ
    safe_var = var + min_var
    
    # 2. Kvadratick√° chyba
    error_sq = (preds - target) ** 2
    
    # 3. === OPRAVA ===
    # Clampujeme "velikost chyby", nikoliv "velikost trestu".
    # T√≠m ≈ô√≠k√°me: "Pokud je chyba vƒõt≈°√≠ ne≈æ 10m (100m^2), chovej se, jako by byla 10m."
    # Vzorec z≈Øst√°v√°: Const / var.
    # Derivace je: -Const / var^2. (To je z√°porn√© ƒç√≠slo -> zvy≈°ov√°n√≠ var sni≈æuje Loss -> SPR√ÅVNƒö!)
    error_sq_clamped = torch.clamp(error_sq, max=max_error_sq)
    
    # 4. V√Ωpoƒçet s o≈ô√≠znutou chybou
    nll = 0.5 * (torch.log(safe_var) + error_sq_clamped / safe_var)
    
    return nll.mean()

def train_BayesianKalmanNet_Hybrid(
    model, train_loader, val_loader, device,
    total_train_iter, learning_rate, clip_grad,
    J_samples, validation_period, logging_period,
    warmup_iterations=0, weight_decay_=1e-5,
    lambda_mse=100.0  # <--- NOV√ù PARAMETR: Kotva pro MSE
):
    # torch.autograd.set_detect_anomaly(True)
    optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay_)
    
    # Scheduler: Pokud se loss zasekne, sn√≠≈æ√≠me LR (pom√°h√° stabilizovat konvergenci)
    # scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    #     optimizer, mode='min', factor=0.5, patience=50
    # )

    best_val_anees = float('inf')
    score_at_best = {"val_nll": 0.0, "val_mse": 0.0}
    best_iter_count = 0
    best_model_state = None
    train_iter_count = 0
    done = False

    print(f"üöÄ START Hybrid Training: Loss = NLL + {lambda_mse} * MSE")
    print(f"    Logging period: {logging_period} iterations")

    while not done:
        model.train()
        for x_true_batch, y_meas_batch in train_loader:
            if train_iter_count >= total_train_iter: done = True; break
            if torch.isnan(x_true_batch).any():
                print(f"!!! SKIP BATCH iter {train_iter_count}: NaN found in x_true (Ground Truth) !!!")
                continue
            
            x_true_batch = x_true_batch.to(device)
            y_meas_batch = y_meas_batch.to(device)
            
            # --- Training ---
            optimizer.zero_grad()
            batch_size, seq_len, _ = x_true_batch.shape
            
            all_trajectories_for_ensemble = []
            all_regs_for_ensemble = []

            # 1. Ensemble Forward Pass
            for j in range(J_samples):
                model.reset(batch_size=batch_size, initial_state=x_true_batch[:, 0, :])
                current_trajectory_x_hats = []
                current_trajectory_regs = []
                for t in range(1, seq_len):
                    y_t = y_meas_batch[:, t, :]
                    x_filtered_t, reg_t = model.step(y_t)
                    if torch.isnan(x_filtered_t).any():
                            raise ValueError(f"NaN in x_filtered_t at sample {j}, step {t}")
                    current_trajectory_x_hats.append(x_filtered_t)
                    current_trajectory_regs.append(reg_t)
                all_trajectories_for_ensemble.append(torch.stack(current_trajectory_x_hats, dim=1))
                all_regs_for_ensemble.append(torch.sum(torch.stack(current_trajectory_regs)))

            # 2. Statistiky Ensemble
            ensemble_trajectories = torch.stack(all_trajectories_for_ensemble, dim=0)
            x_hat_sequence = ensemble_trajectories.mean(dim=0)
            
            # Epistemick√° variance (ƒçist√Ω rozptyl s√≠tƒõ)
            # P≈ôiƒç√≠t√°me 1e-9 jen proti dƒõlen√≠ nulou, nen√≠ to "noise floor"
            cov_diag_sequence = ensemble_trajectories.var(dim=0) + 1e-9 
            
            regularization_loss = torch.stack(all_regs_for_ensemble).mean()/seq_len
            target_sequence = x_true_batch[:, 1:, :]
            
            # --- 3. V√ùPOƒåET HYBRIDN√ç LOSS ---
            
            # A) MSE ƒå√°st (P≈ôesnost)
            mse_loss = F.mse_loss(x_hat_sequence, target_sequence)
            
            # B) NLL ƒå√°st (Konzistence)
            # 0.5 * (log(var) + (target - pred)^2 / var)
            cov_diag_clamped = torch.clamp(cov_diag_sequence, min=1e-4, max=1e6)
            error_sq = (x_hat_sequence - target_sequence) ** 2
            nll_term = 0.5 * (torch.log(cov_diag_clamped) + error_sq / cov_diag_clamped)
            nll_loss = gaussian_nll_safe(
                target=target_sequence, 
                preds=x_hat_sequence, 
                var=cov_diag_sequence, 
                min_var=1e-5,       # Epsilon pro stabilitu
                max_error_sq=100.0 # O≈ôez√°n√≠ extr√©mn√≠ch chyb (pokud je implementov√°no)
            )
            mean_var = cov_diag_sequence.mean()
            var_penalty = torch.relu(mean_var - 100.0) * 0.01
            
            # C) Celkov√° Loss (Hybrid)
            # Zde je ta magie: I kdy≈æ NLL chce ut√©ct s varianc√≠, lambda_mse * mse ho dr≈æ√≠ zp√°tky
            weighted_mse = lambda_mse * mse_loss
            loss = nll_loss + weighted_mse + regularization_loss * 10.0 + var_penalty
            
            if torch.isnan(loss): 
                print("Collapse detected (NaN loss)"); done = True; break
            
            loss.backward()

            # --- DIAGNOSTIC LOGGING (Gradients) ---
            # Zaznamen√°me statistiky gradient≈Ø p≈ôed o≈ô√≠znut√≠m (clippingem)
            if train_iter_count % logging_period == 0:
                total_norm = 0.0
                max_grad = 0.0
                min_grad = float('inf')
                nan_grad_detected = False
                for p in model.parameters():
                    if p.grad is not None:
                        param_norm = p.grad.data.norm(2).item()
                        total_norm += param_norm ** 2
                        p_max = p.grad.data.abs().max().item()
                        p_min = p.grad.data.abs().min().item()
                        if p_max > max_grad: max_grad = p_max
                        if p_min < min_grad: min_grad = p_min
                        if torch.isnan(p.grad).any():
                            nan_grad_detected = True
                total_norm = total_norm ** 0.5
                
                if nan_grad_detected:
                     print(f"!!! WARNING: NaN gradient detected at iter {train_iter_count} !!!")

            if clip_grad > 0: torch.nn.utils.clip_grad_norm_(model.parameters(), clip_grad)
            optimizer.step()
            train_iter_count += 1
            
            # --- Logging ---
            diff = x_hat_sequence - target_sequence
            mean_error = diff.abs().mean().item()
            min_variance = cov_diag_sequence.min().item()
            max_variance = cov_diag_sequence.max().item()
            mean_variance = cov_diag_sequence.mean().item()

            if train_iter_count % logging_period == 0:
                with torch.no_grad():
                    # Zjist√≠me dropout pravdƒõpodobnosti (jen pro info)
                    p1 = torch.sigmoid(model.dnn.concrete_dropout1.p_logit).item()
                    p2 = torch.sigmoid(model.dnn.concrete_dropout2.p_logit).item()
                
                print(f"--- Iter [{train_iter_count}/{total_train_iter}] ---")
                print(f"    Total Loss:     {loss.item():.4f}")
                print(f"    MSE (Raw):      {mse_loss.item():.6f}")
                print(f"    MSE (Weighted): {weighted_mse.item():.4f} (lambda={lambda_mse})")
                print(f"    NLL Component:  {nll_loss.item():.4f}")
                print(f"    Reg Loss:       {regularization_loss.item():.6f}")
                print(f"    Variance stats: Min={min_variance:.2e}, Max={max_variance:.2e}, Mean={mean_variance:.2e}")
                print(f"    Mean Error L1:  {mean_error:.4f}")
                print(f"    Grad Norm:      {total_norm:.4f} (Max abs grad: {max_grad:.4f})")
                print(f"    Dropout probs:  p1={p1:.4f}, p2={p2:.4f}")
                
                # Check pro "Variance collapse"
                if mean_variance < 1e-8:
                    print("    !!! WARNING: Variance is extremely low (Collapse risk) !!!")

            # --- Validation step ---
            if train_iter_count > 0 and train_iter_count % validation_period == 0:
                # Step scheduleru podle tr√©novac√≠ loss (nebo validace, pokud bys to p≈ôedƒõlal)
                # scheduler.step(loss)
                
                print(f"\n--- Validation at iteration {train_iter_count} ---")
                model.eval()
                val_mse_list = []
                all_val_x_true_cpu, all_val_x_hat_cpu, all_val_P_hat_cpu = [], [], []

                with torch.no_grad():
                    for x_true_val_batch, y_meas_val_batch in val_loader:
                        val_batch_size, val_seq_len, _ = x_true_val_batch.shape
                        x_true_val_batch = x_true_val_batch.to(device)
                        y_meas_val_batch = y_meas_val_batch.to(device)
                        val_ensemble_trajectories = []
                        for j in range(J_samples):
                            model.reset(batch_size=val_batch_size, initial_state=x_true_val_batch[:, 0, :])
                            val_current_x_hats = []
                            for t in range(1, val_seq_len):
                                y_t_val = y_meas_val_batch[:, t, :]
                                x_filtered_t, _ = model.step(y_t_val)
                                val_current_x_hats.append(x_filtered_t)
                            val_ensemble_trajectories.append(torch.stack(val_current_x_hats, dim=1))
                        
                        # Agregace validace
                        val_ensemble = torch.stack(val_ensemble_trajectories, dim=0)
                        val_preds_seq = val_ensemble.mean(dim=0)
                        
                        val_target_seq = x_true_val_batch[:, 1:, :]
                        val_mse_list.append(F.mse_loss(val_preds_seq, val_target_seq).item())
                        
                        # P≈ô√≠prava pro ANEES
                        initial_state_val = x_true_val_batch[:, 0, :].unsqueeze(1)
                        full_x_hat = torch.cat([initial_state_val, val_preds_seq], dim=1)
                        
                        # Epistemick√° variance
                        val_covs_diag = val_ensemble.var(dim=0) + 1e-9
                        
                        # Vytvo≈ôen√≠ diagon√°ln√≠ch matic P
                        # (Zjednodu≈°en√° konstrukce pro ANEES calc)
                        # Pro p≈ôesn√© ANEES bychom mƒõli dƒõlat outer product, 
                        # ale diagon√°la z var() je dobr√° aproximace pro BKN
                        val_covs_full = torch.zeros(val_batch_size, val_seq_len-1, 4, 4, device=device)
                        for b in range(val_batch_size):
                            for t in range(val_seq_len-1):
                                val_covs_full[b, t] = torch.diag(val_covs_diag[b, t])

                        P0 = model.system_model.P0.unsqueeze(0).repeat(val_batch_size, 1, 1).unsqueeze(1)
                        full_P_hat = torch.cat([P0, val_covs_full], dim=1)
                        
                        all_val_x_true_cpu.append(x_true_val_batch.cpu())
                        all_val_x_hat_cpu.append(full_x_hat.cpu())
                        all_val_P_hat_cpu.append(full_P_hat.cpu())

                avg_val_mse = np.mean(val_mse_list)
                final_x_true_list = torch.cat(all_val_x_true_cpu, dim=0)
                final_x_hat_list = torch.cat(all_val_x_hat_cpu, dim=0)
                final_P_hat_list = torch.cat(all_val_P_hat_cpu, dim=0)
                
                # V√Ωpoƒçet ANEES
                try:
                    avg_val_anees = calculate_anees_vectorized(final_x_true_list, final_x_hat_list, final_P_hat_list)
                except Exception as e:
                    print(f"  !!! Error calculating ANEES: {e}")
                    avg_val_anees = float('nan')
                
                print(f"  Average MSE: {avg_val_mse:.4f}, Average ANEES: {avg_val_anees:.4f}")
                
                # Ukl√°d√°n√≠ modelu:
                if not np.isnan(avg_val_anees) and avg_val_anees < best_val_anees and avg_val_anees > 0:
                    print(f"  >>> New best VALIDATION ANEES! Saving model. (Old: {best_val_anees:.4f} -> New: {avg_val_anees:.4f}) <<<")
                    best_val_anees = avg_val_anees
                    best_iter_count = train_iter_count
                    score_at_best['val_mse'] = avg_val_mse
                    best_model_state = deepcopy(model.state_dict())
                print("-" * 50)
                model.train()

    print("\nTraining completed.")
    if best_model_state:
        print(f"Loading best model from iteration {best_iter_count} with ANEES {best_val_anees:.4f}")
        model.load_state_dict(best_model_state)
    else:
        print("No best model was saved; returning last state.")

    return {
        "best_val_anees": best_val_anees,
        "best_val_nll": score_at_best['val_nll'],
        "best_val_mse": score_at_best['val_mse'],
        "best_iter": best_iter_count,
        "final_model": model
    }

In [9]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from copy import deepcopy

# P≈ôedpokl√°d√°me existenci helperu (pokud ne, definujte ho jako v minul√© odpovƒõdi)
def gaussian_nll_safe(target, preds, var, min_var=1e-6, max_error_sq=100.0):
    safe_var = var + min_var
    error_sq = (preds - target) ** 2
    error_sq_clamped = torch.clamp(error_sq, max=max_error_sq)
    nll = 0.5 * (torch.log(safe_var) + error_sq_clamped / safe_var)
    return nll.mean()

def train_BayesianKalmanNet_TBPTT_TwoPhase(
    model, train_loader, val_loader, device,
    total_train_iter, learning_rate, clip_grad=1.0,
    J_samples=5, tbptt_steps=20,
    validation_period=50, logging_period=10,
    mse_warmup_iters=0, # Poƒçet updat≈Ø pro MSE f√°zi
    weight_decay_=1e-5
):
    optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay_)
    
    best_val_mse = float('inf')
    best_model_state = None
    train_iter_count = 0 # Poƒç√≠tadlo updat≈Ø (gradient steps)
    done = False

    print(f"üöÄ START TBPTT Two-Phase Training (Window={tbptt_steps})")
    print(f"    Phase 1: MSE Warmup (0 - {mse_warmup_iters} steps)")
    print(f"    Phase 2: NLL Optimization ({mse_warmup_iters} - {total_train_iter} steps)")

    while not done:
        model.train()
        for x_true_batch, y_meas_batch in train_loader:
            if train_iter_count >= total_train_iter: done = True; break
            
            x_true_batch = x_true_batch.to(device)
            y_meas_batch = y_meas_batch.to(device)
            
            batch_size, seq_len, dim_x = x_true_batch.shape
            
            # --- SUPER BATCH (Vektorizace) ---
            x_true_super = x_true_batch.repeat_interleave(J_samples, dim=0)
            y_meas_super = y_meas_batch.repeat_interleave(J_samples, dim=0)
            super_batch_size = x_true_super.shape[0]

            # 1. Reset na zaƒç√°tku sekvence
            model.reset(batch_size=super_batch_size, initial_state=x_true_super[:, 0, :])
            
            # Gradienty nulujeme p≈ôed sekvenc√≠
            optimizer.zero_grad()

            # 2. TBPTT Smyƒçka p≈ôes okna
            for t_start in range(1, seq_len, tbptt_steps):
                if train_iter_count >= total_train_iter: done = True; break

                t_end = min(t_start + tbptt_steps, seq_len)
                current_window_len = t_end - t_start
                if current_window_len <= 0: continue

                # A) Forward pass oknem
                window_x_preds = []
                window_regs = []
                
                for t in range(t_start, t_end):
                    y_t = y_meas_super[:, t, :]
                    x_est, reg = model.step(y_t)
                    window_x_preds.append(x_est)
                    window_regs.append(reg)
                
                # B) Zpracov√°n√≠ v√Ωsledk≈Ø okna
                preds_super = torch.stack(window_x_preds, dim=1) # [Batch*J, Window, 4]
                regs_super = torch.stack(window_regs)
                
                # Reshape pro statistiku [Batch, J, Window, 4]
                preds_reshaped = preds_super.view(batch_size, J_samples, current_window_len, dim_x)
                
                x_hat_seq = preds_reshaped.mean(dim=1)
                cov_diag_seq = preds_reshaped.var(dim=1) + 1e-9
                
                target_seq = x_true_batch[:, t_start:t_end, :]

                # C) V√Ωpoƒçet Loss (Two-Phase)
                mse_loss = F.mse_loss(x_hat_seq, target_seq)
                
                # NLL (pou≈æijeme bezpeƒçnou funkci)
                nll_loss = gaussian_nll_safe(target_seq, x_hat_seq, cov_diag_seq, max_error_sq=100.0)
                
                # Regularizace (pr≈Ømƒõr na krok)
                reg_loss = regs_super.mean()
                
                # Rozhodov√°n√≠ o Loss
                if train_iter_count < mse_warmup_iters:
                    loss = mse_loss + reg_loss
                    mode = "MSE"
                else:
                    # NLL F√°ze (m≈Ø≈æeme p≈ôidat malou kotvu MSE, nap≈ô. 0.1, pro stabilitu)
                    loss = nll_loss + reg_loss + (0.1 * mse_loss) 
                    mode = "NLL"

                # D) Backward & Update
                loss.backward()
                
                if clip_grad > 0:
                    torch.nn.utils.clip_grad_norm_(model.parameters(), clip_grad)
                
                optimizer.step()
                optimizer.zero_grad() # D≈Øle≈æit√©: nulujeme po ka≈æd√©m kroku TBPTT
                
                # E) Detach Hidden State (Kl√≠ƒçov√© pro TBPTT)
                model.detach_hidden()
                
                train_iter_count += 1

                # --- Logging ---
                if train_iter_count % logging_period == 0:
                    with torch.no_grad():
                        p1 = torch.sigmoid(model.dnn.concrete_dropout1.p_logit).item()
                        p2 = torch.sigmoid(model.dnn.concrete_dropout2.p_logit).item()
                        diff = x_hat_seq - target_seq
                        mae = diff.abs().mean().item()
                    
                    print(f"Iter {train_iter_count} ({mode}): Loss {loss.item():.4f} | MSE {mse_loss.item():.2f} | NLL {nll_loss.item():.2f} | MAE {mae:.2f}m")
                    print(f"    Dropout: p1={p1:.3f}, p2={p2:.3f} | VarMean: {cov_diag_seq.mean().item():.1f}")

                # --- Validation (Uvnit≈ô TBPTT smyƒçky) ---
                if train_iter_count % validation_period == 0:
                    model.eval()
                    val_mse_list = []
                    
                    # Pro ANEES sbƒõr
                    all_val_x_true, all_val_x_hat, all_val_P = [], [], []

                    with torch.no_grad():
                        for x_v, y_v in val_loader:
                            x_v, y_v = x_v.to(device), y_v.to(device)
                            b_v, s_v, _ = x_v.shape
                            
                            # Validace bƒõ≈æ√≠ Open-Loop na cel√© sekvenci (bez TBPTT)
                            x_v_sup = x_v.repeat_interleave(J_samples, dim=0)
                            y_v_sup = y_v.repeat_interleave(J_samples, dim=0)
                            
                            model.reset(batch_size=b_v*J_samples, initial_state=x_v_sup[:,0,:])
                            preds_list = []
                            
                            for ti in range(1, s_v):
                                est, _ = model.step(y_v_sup[:, ti, :])
                                preds_list.append(est)
                            
                            preds_stack = torch.stack(preds_list, dim=1).view(b_v, J_samples, s_v-1, 4)
                            val_mean = preds_stack.mean(dim=1)
                            val_var = preds_stack.var(dim=1) + 1e-9
                            
                            val_mse_list.append(F.mse_loss(val_mean, x_v[:, 1:, :]).item())
                            
                            # Data pro ANEES
                            full_hat = torch.cat([x_v[:,0,:].unsqueeze(1), val_mean], dim=1)
                            # Diagonal P construction
                            val_P_full = torch.zeros(b_v, s_v-1, 4, 4, device=device)
                            for b in range(b_v):
                                for t in range(s_v-1):
                                    val_P_full[b,t] = torch.diag(val_var[b,t])
                            P0 = torch.eye(4, device=device).unsqueeze(0).unsqueeze(0).repeat(b_v, 1, 1, 1)*1e-6
                            full_P = torch.cat([P0, val_P_full], dim=1)
                            
                            all_val_x_true.append(x_v.cpu())
                            all_val_x_hat.append(full_hat.cpu())
                            all_val_P.append(full_P.cpu())

                    avg_val_mse = np.mean(val_mse_list)
                    
                    # Calc ANEES
                    try:
                        cat_true = torch.cat(all_val_x_true, dim=0)
                        cat_hat = torch.cat(all_val_x_hat, dim=0)
                        cat_P = torch.cat(all_val_P, dim=0)
                        # Pokud m√°te funkci importovanou
                        avg_anees = calculate_anees_vectorized(cat_true, cat_hat, cat_P)
                    except:
                        avg_anees = float('nan')

                    print(f"\n--- VALIDATION: MSE {avg_val_mse:.2f} | ANEES {avg_anees:.2f} ---")
                    
                    # Ukl√°d√°n√≠ (podle MSE)
                    if avg_val_mse < best_val_mse:
                        print(f"  >>> New Best Model! (Old: {best_val_mse:.2f} -> New: {avg_val_mse:.2f}) <<<")
                        best_val_mse = avg_val_mse
                        best_model_state = deepcopy(model.state_dict())
                    print("-" * 40)
                    
                    model.train()

    print("TBPTT Training completed.")
    if best_model_state:
        model.load_state_dict(best_model_state)
    return {"final_model": model}

In [10]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from copy import deepcopy

def gaussian_nll_safe(target, preds, var, min_var=1e-6, max_error_sq=100.0):
    safe_var = var + min_var
    error_sq = (preds - target) ** 2
    error_sq_clamped = torch.clamp(error_sq, max=max_error_sq)
    nll = 0.5 * (torch.log(safe_var) + error_sq_clamped / safe_var)
    return nll.mean()
def calculate_anees_vectorized(x_true, x_hat, P_hat, eps=1e-6):
    """
    Robustn√≠ v√Ωpoƒçet ANEES pomoc√≠ Pseudo-Inverze.
    Zvl√°d√° i situaci, kdy J_samples < State_Dim (singul√°rn√≠ matice).
    """
    # 1. Zplo≈°tƒõn√≠ (Flattening)
    if x_true.dim() == 3: x_true = x_true.reshape(-1, x_true.shape[-1])
    if x_hat.dim() == 3: x_hat = x_hat.reshape(-1, x_hat.shape[-1])
    if P_hat.dim() == 4: P_hat = P_hat.reshape(-1, P_hat.shape[-2], P_hat.shape[-1])

    # 2. V√Ωpoƒçet chyby
    error = (x_true - x_hat).unsqueeze(-1)  # [N, Dim, 1]
    
    # 3. Pseudo-Inverze matice P
    # pinv je mnohem stabilnƒõj≈°√≠ pro Low-Rank matice (kdy≈æ m√°me m√°lo vzork≈Ø)
    # hermitian=True ≈ô√≠k√°, ≈æe matice je symetrick√° (co≈æ kovariance je)
    P_inv = torch.linalg.pinv(P_hat, hermitian=True)
    
    # 4. ANEES = error^T * P_inv * error
    # bmm: [N, 1, Dim] @ [N, Dim, Dim] -> [N, 1, Dim]
    #      [N, 1, Dim] @ [N, Dim, 1] -> [N, 1, 1]
    
    temp = torch.bmm(error.transpose(1, 2), P_inv)
    anees_per_sample = torch.bmm(temp, error).squeeze() # [N]
    
    # Ochrana proti numerick√Ωm artefakt≈Øm (malink√° z√°porn√° ƒç√≠sla jako -1e-10)
    anees_per_sample = torch.clamp(anees_per_sample, min=0.0)

    return anees_per_sample.mean().item()

def train_BayesianKalmanNet_TBPTT_Windowed(
    model, train_loader, val_loader, device,
    total_train_iter, learning_rate, clip_grad=1.0,
    J_samples=5, 
    tbptt_k=5,   
    tbptt_w=20,  
    validation_period=50, logging_period=10,
    mse_warmup_iters=0, 
    weight_decay_=1e-5,
    lambda_mse=100.0,
    calibration_parameter=10.0
):
    optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay_)
    
    best_val_score = float('inf')
    best_model_state = None
    train_iter_count = 0 
    done = False

    if not hasattr(model, 'detach_hidden'):
         raise AttributeError("Modelu chyb√≠ metoda 'detach_hidden()'.")

    print(f"üöÄ START TBPTT Windowed Training (k={tbptt_k}, w={tbptt_w})")
    print(f"    Phase 1: MSE Warmup (0 - {mse_warmup_iters} updates)")
    print(f"    Phase 2: NLL Optimization (> {mse_warmup_iters} updates)")

    while not done:
        model.train()
        for x_true_batch, y_meas_batch in train_loader:
            if train_iter_count >= total_train_iter: done = True; break
            
            x_true_batch = x_true_batch.to(device)
            y_meas_batch = y_meas_batch.to(device)
            
            batch_size, seq_len, dim_x = x_true_batch.shape
            
            x_true_super = x_true_batch.repeat_interleave(J_samples, dim=0)
            y_meas_super = y_meas_batch.repeat_interleave(J_samples, dim=0)
            super_batch_size = x_true_super.shape[0]

            # 1. Hard Reset na zaƒç√°tku sekvence
            # Vyƒçist√≠me v≈°e, aby reset vytvo≈ôil spr√°vn√© velikosti
            if hasattr(model, 'h_prev'): model.h_prev = None
            if hasattr(model, 'x_filtered_t_minus_1'): model.x_filtered_t_minus_1 = None
            if hasattr(model, 'x_filtered_t_minus_2'): model.x_filtered_t_minus_2 = None
            
            model.reset(batch_size=super_batch_size, initial_state=x_true_super[:, 0, :])
            
            # 2. TBPTT Loop
            for t_start in range(1, seq_len, tbptt_w):
                if train_iter_count >= total_train_iter: done = True; break

                t_end = min(t_start + tbptt_w, seq_len)
                current_window_len = t_end - t_start
                if current_window_len <= 0: continue

                optimizer.zero_grad()
                window_x_preds = []
                window_regs = []

                # A) Forward pass
                for t in range(t_start, t_end):
                    y_t = y_meas_super[:, t, :]
                    x_est, reg = model.step(y_t)
                    window_x_preds.append(x_est)
                    window_regs.append(reg)
                    
                    if (t - t_start + 1) % tbptt_k == 0:
                        model.detach_hidden()

                model.detach_hidden()

                # B) Loss Calculation
                preds_super = torch.stack(window_x_preds, dim=1)
                regs_super = torch.stack(window_regs)
                preds_reshaped = preds_super.view(batch_size, J_samples, current_window_len, dim_x)
                
                x_hat_seq = preds_reshaped.mean(dim=1)
                cov_diag_seq = preds_reshaped.var(dim=1) + 1e-9
                target_seq = x_true_batch[:, t_start:t_end, :]

                mse_loss = F.mse_loss(x_hat_seq, target_seq)
                nll_loss = gaussian_nll_safe(target_seq, x_hat_seq, cov_diag_seq, max_error_sq=100.0)
                reg_loss = regs_super.mean()

                if train_iter_count < mse_warmup_iters:
                    loss = mse_loss + reg_loss
                    mode = "MSE_WARMUP"
                else:
                    loss = nll_loss + reg_loss + (lambda_mse * mse_loss)
                    mode = "NLL_OPTIM"

                # C) Update
                loss.backward()
                if clip_grad > 0:
                    torch.nn.utils.clip_grad_norm_(model.parameters(), clip_grad)
                optimizer.step()
                optimizer.zero_grad() 
                model.detach_hidden()
                
                train_iter_count += 1
                
                if train_iter_count % logging_period == 0:
                    with torch.no_grad():
                        diff = x_hat_seq - target_seq
                        mae = diff.abs().mean().item()
                        
                        # Z√≠sk√°n√≠ hodnot dropoutu (sigmoida z logit≈Ø)
                        p1 = torch.sigmoid(model.dnn.concrete_dropout1.p_logit).item()
                        p2 = torch.sigmoid(model.dnn.concrete_dropout2.p_logit).item()
                        
                    print(f"Iter {train_iter_count} ({mode}): "
                          f"Loss {loss.item():.4f} | "
                          f"MSE {mse_loss.item():.2f} | "
                          f"NLL {nll_loss.item():.2f} | "
                          f"Reg {reg_loss.item():.5f} | " # Regularizaƒçn√≠ loss
                          f"MAE {mae:.2f}m | "
                          f"p1={p1:.3f}, p2={p2:.3f}")    # Dropout pravdƒõpodobnosti
                if train_iter_count % logging_period == 0:
                    with torch.no_grad():
                        diff = x_hat_seq - target_seq
                        mae = diff.abs().mean().item()
                    print(f"Iter {train_iter_count} ({mode}): Loss {loss.item():.4f} | MSE {mse_loss.item():.2f} | NLL {nll_loss.item():.2f} | MAE {mae:.2f}m")

                # --- VALIDATION WITH FULL STATE RESTORE (FIXED) ---
                if train_iter_count % validation_period == 0:
                    # 1. ULO≈ΩIT KOMPLETN√ç STAV TR√âNINKU
                    # Mus√≠me ulo≈æit √∫plnƒõ v≈°echno, co se mƒõn√≠ v ƒçase t
                    train_state = {}
                    
                    if model.h_prev is not None:
                        train_state['h_prev'] = model.h_prev.detach().clone()
                    
                    train_state['x_filt_1'] = model.x_filtered_t_minus_1.detach().clone()
                    train_state['x_pred_1'] = model.x_pred_t_minus_1.detach().clone()
                    
                    # !!! ZDE BYLA CHYBA: Mus√≠me ulo≈æit i t-2 a y_t-1 !!!
                    if hasattr(model, 'x_filtered_t_minus_2') and model.x_filtered_t_minus_2 is not None:
                        train_state['x_filt_2'] = model.x_filtered_t_minus_2.detach().clone()
                    
                    if hasattr(model, 'y_t_minus_1') and model.y_t_minus_1 is not None:
                        train_state['y_1'] = model.y_t_minus_1.detach().clone()
                    
                    if hasattr(model, 'P_t_minus_1') and model.P_t_minus_1 is not None:
                        train_state['P'] = model.P_t_minus_1.detach().clone()

                    # 2. Spustit Validaci
                    model.eval()
                    val_mse_list = []
                    all_val_x_true, all_val_x_hat, all_val_P = [], [], []

                    with torch.no_grad():
                        for x_v, y_v in val_loader:
                            x_v, y_v = x_v.to(device), y_v.to(device)
                            b_v, s_v, _ = x_v.shape
                            x_v_sup = x_v.repeat_interleave(J_samples, dim=0)
                            y_v_sup = y_v.repeat_interleave(J_samples, dim=0)
                            
                            # Hard reset pro validaci (sma≈æe intern√≠ promƒõnn√©)
                            if hasattr(model, 'h_prev'): model.h_prev = None
                            model.reset(batch_size=b_v*J_samples, initial_state=x_v_sup[:,0,:])
                            
                            preds_list = []
                            for ti in range(1, s_v):
                                est, _ = model.step(y_v_sup[:, ti, :])
                                preds_list.append(est)
                            
                            preds_stack = torch.stack(preds_list, dim=1).view(b_v, J_samples, s_v-1, 4)
                            val_mean = preds_stack.mean(dim=1)
                            val_var = preds_stack.var(dim=1) + 1e-9
                            target_v = x_v[:, 1:, :] 
                            
                            val_mse_list.append(F.mse_loss(val_mean, target_v).item())
                            
                            # Ukl√°d√°me jen srovnateln√© tensory (bez t=0)
                            all_val_x_true.append(target_v.cpu()) # Zde byla chyba (bylo x_v)
                            all_val_x_hat.append(val_mean.cpu())
                            
                            # P matice mus√≠ odpov√≠dat val_mean (tedy bez t=0)
                            val_P_full = torch.zeros(b_v, s_v-1, 4, 4, device=device)
                            for i in range(4): val_P_full[:, :, i, i] = val_var[:, :, i]
                            # P0 nep≈ôid√°v√°me, proto≈æe nem√°me odhad pro t=0
                            all_val_P.append(val_P_full.cpu())

                    avg_val_mse = np.mean(val_mse_list)
                    
                    try:
                        cat_true = torch.cat(all_val_x_true, dim=0)
                        cat_hat = torch.cat(all_val_x_hat, dim=0)
                        cat_P = torch.cat(all_val_P, dim=0)
                        # Nyn√≠ maj√≠ tensory shodnou d√©lku, tak≈æe reshape projde
                        avg_anees = calculate_anees_vectorized(cat_true, cat_hat, cat_P)
                    except Exception as e:
                        # Vyp√≠≈°eme chybu, a≈• v√≠me, co se dƒõje, m√≠sto tich√©ho nan
                        print(f"ANEES Error: {e}")
                        avg_anees = float('nan')

                    anees_penalty = abs(avg_anees - 4.0) if not np.isnan(avg_anees) else 100.0
                    
                    # Score = MSE + 10 * ANEES_deviation
                    # Pokud je ANEES 4.0, score = MSE. 
                    # Pokud je ANEES 20.0, score = MSE + 160 (v√Ωrazn√© zhor≈°en√≠).
                    hybrid_score = avg_val_mse + (calibration_parameter * anees_penalty)

                    print(f"\n--- VALIDATION: MSE {avg_val_mse:.2f} | ANEES {avg_anees:.2f} | Score {hybrid_score:.2f} ---")
                    
                    # Ukl√°d√°me, pokud je hybridn√≠ sk√≥re lep≈°√≠ (men≈°√≠)
                    if hybrid_score < best_val_score:
                        print(f"  >>> New Best Model! (Old Score: {best_val_score:.2f} -> New: {hybrid_score:.2f}) <<<")
                        best_val_score = hybrid_score
                        best_model_state = deepcopy(model.state_dict())
                    print("-" * 40)
                    
                    # 3. OBNOVIT KOMPLETN√ç STAV TR√âNINKU
                    model.train()
                    if 'h_prev' in train_state: model.h_prev = train_state['h_prev']
                    model.x_filtered_t_minus_1 = train_state['x_filt_1']
                    model.x_pred_t_minus_1 = train_state['x_pred_1']
                    
                    if 'x_filt_2' in train_state: model.x_filtered_t_minus_2 = train_state['x_filt_2']
                    if 'y_1' in train_state: model.y_t_minus_1 = train_state['y_1']
                    if 'P' in train_state: model.P_t_minus_1 = train_state['P']

    print("Training completed.")
    if best_model_state:
        model.load_state_dict(best_model_state)
    return {"final_model": model}

In [11]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from copy import deepcopy

def gaussian_nll_safe(target, preds, var, min_var=1e-6, max_error_sq=100.0):
    safe_var = var + min_var
    error_sq = (preds - target) ** 2
    error_sq_clamped = torch.clamp(error_sq, max=max_error_sq)
    nll = 0.5 * (torch.log(safe_var) + error_sq_clamped / safe_var)
    return nll.mean()

def calculate_anees_internal(x_true, x_hat, P_hat):
    if P_hat.dim() == x_hat.dim(): # Diagon√°ln√≠ variance
        error_sq = (x_true - x_hat) ** 2
        nees_per_sample = (error_sq / (P_hat + 1e-9)).sum(dim=-1)
        return nees_per_sample.mean().item()
    return float('nan')

def train_BayesianKalmanNet_TBPTT_Windowed(
    model, train_loader, val_loader, device,
    total_train_iter, learning_rate, clip_grad=1.0,
    J_samples=5, 
    tbptt_k=5,   
    tbptt_w=20,  
    validation_period=50, logging_period=10,
    mse_warmup_iters=0, 
    weight_decay_=1e-5,
    lambda_mse=1.0,
    calibration_parameter=10.0
):
    optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay_)
    
    best_val_score = float('inf')
    best_model_state = None
    train_iter_count = 0 
    done = False

    if not hasattr(model, 'detach_hidden'):
         raise AttributeError("Modelu chyb√≠ metoda 'detach_hidden()'.")

    print(f"üöÄ START TBPTT Windowed Training (k={tbptt_k}, w={tbptt_w})")
    print(f"    Phase 1: MSE Focused (Soft NLL) (0 - {mse_warmup_iters} updates)")
    print(f"    Phase 2: NLL Optimization (> {mse_warmup_iters} updates)")

    while not done:
        model.train()
        for x_true_batch, y_meas_batch in train_loader:
            if train_iter_count >= total_train_iter: done = True; break
            
            x_true_batch = x_true_batch.to(device)
            y_meas_batch = y_meas_batch.to(device)
            
            batch_size, seq_len, dim_x = x_true_batch.shape
            
            x_true_super = x_true_batch.repeat_interleave(J_samples, dim=0)
            y_meas_super = y_meas_batch.repeat_interleave(J_samples, dim=0)
            super_batch_size = x_true_super.shape[0]

            # 1. Hard Reset
            if hasattr(model, 'h_prev'): model.h_prev = None
            if hasattr(model, 'x_filtered_t_minus_1'): model.x_filtered_t_minus_1 = None
            if hasattr(model, 'x_filtered_t_minus_2'): model.x_filtered_t_minus_2 = None
            model.reset(batch_size=super_batch_size, initial_state=x_true_super[:, 0, :])
            
            # 2. TBPTT Loop
            for t_start in range(1, seq_len, tbptt_w):
                if train_iter_count >= total_train_iter: done = True; break

                t_end = min(t_start + tbptt_w, seq_len)
                if t_end - t_start <= 0: continue

                # Teacher Forcing (20%)
                use_teacher_forcing = (np.random.rand() < 0.2) 
                if use_teacher_forcing and t_start > 1:
                    gt_state_prev = x_true_super[:, t_start-1, :]
                    model.x_filtered_t_minus_1 = gt_state_prev.detach().clone()
                    model.x_pred_t_minus_1 = gt_state_prev.detach().clone()
                    model.detach_hidden() 
                else:
                    model.detach_hidden()

                optimizer.zero_grad()
                window_x_preds = []
                window_regs = []

                # A) Forward pass
                for t in range(t_start, t_end):
                    y_t = y_meas_super[:, t, :]
                    x_est, reg = model.step(y_t)
                    window_x_preds.append(x_est)
                    window_regs.append(reg)
                    
                    if (t - t_start + 1) % tbptt_k == 0:
                        model.detach_hidden()

                model.detach_hidden()

                # B) Loss Calculation
                preds_super = torch.stack(window_x_preds, dim=1)
                regs_super = torch.stack(window_regs)
                preds_reshaped = preds_super.view(batch_size, J_samples, -1, dim_x)
                
                x_hat_seq = preds_reshaped.mean(dim=1)
                cov_diag_seq = preds_reshaped.var(dim=1) + 1e-9
                target_seq = x_true_batch[:, t_start:t_end, :]

                mse_loss = F.mse_loss(x_hat_seq, target_seq)
                nll_loss = gaussian_nll_safe(target_seq, x_hat_seq, cov_diag_seq, max_error_sq=100.0)
                reg_loss = regs_super.mean()

                # Soft Warmup Strategy
                if train_iter_count < mse_warmup_iters:
                    loss = (1.0 * mse_loss) + nll_loss + reg_loss
                    mode = "MSE_WARMUP"
                else:
                    loss = nll_loss + reg_loss + (lambda_mse * mse_loss)
                    mode = "NLL_OPTIM"

                # C) Update
                loss.backward()
                if clip_grad > 0:
                    torch.nn.utils.clip_grad_norm_(model.parameters(), clip_grad)
                optimizer.step()
                optimizer.zero_grad() 
                model.detach_hidden()
                
                train_iter_count += 1

                # --- LOGGING (S P≈òIDANOU METRIKOU SIGMA) ---
                if train_iter_count % logging_period == 0:
                    with torch.no_grad():
                        mae = (x_hat_seq - target_seq).abs().mean().item()
                        # Pr≈Ømƒõrn√° smƒõrodatn√° odchylka v metrech (sqrt(variance))
                        avg_sigma = cov_diag_seq.mean().sqrt().item()
                        
                        p1 = torch.sigmoid(model.dnn.concrete_dropout1.p_logit).item()
                        p2 = torch.sigmoid(model.dnn.concrete_dropout2.p_logit).item()
                        
                    print(f"Iter {train_iter_count} ({mode}): "
                          f"Loss {loss.item():.2f} | "
                          f"MSE {mse_loss.item():.2f} | "
                          f"NLL {nll_loss.item():.2f} | "
                          f"Sigma {avg_sigma:.2f}m | " # <--- ZDE VID√çTE NEURƒåITOST
                          f"MAE {mae:.2f}m | "
                          f"p1={p1:.2f}, p2={p2:.2f}")

                # --- VALIDATION ---
                if train_iter_count % validation_period == 0:
                    train_state = {}
                    if model.h_prev is not None: train_state['h_prev'] = model.h_prev.detach().clone()
                    train_state['x_filt_1'] = model.x_filtered_t_minus_1.detach().clone()
                    train_state['x_pred_1'] = model.x_pred_t_minus_1.detach().clone()
                    if hasattr(model, 'x_filtered_t_minus_2') and model.x_filtered_t_minus_2 is not None:
                        train_state['x_filt_2'] = model.x_filtered_t_minus_2.detach().clone()
                    if hasattr(model, 'y_t_minus_1') and model.y_t_minus_1 is not None:
                        train_state['y_1'] = model.y_t_minus_1.detach().clone()
                    if hasattr(model, 'P_t_minus_1') and model.P_t_minus_1 is not None:
                        train_state['P'] = model.P_t_minus_1.detach().clone()

                    model.eval()
                    val_mse_list = []
                    all_val_x_true, all_val_x_hat, all_val_P = [], [], []
                    
                    # Pro v√Ωpoƒçet pr≈Ømƒõrn√© sigmy ve validaci
                    val_sigma_sum = 0.0
                    val_batch_count = 0

                    with torch.no_grad():
                        for x_v, y_v in val_loader:
                            x_v, y_v = x_v.to(device), y_v.to(device)
                            b_v, s_v, _ = x_v.shape
                            x_v_sup = x_v.repeat_interleave(J_samples, dim=0)
                            y_v_sup = y_v.repeat_interleave(J_samples, dim=0)
                            
                            if hasattr(model, 'h_prev'): model.h_prev = None
                            model.reset(batch_size=b_v*J_samples, initial_state=x_v_sup[:,0,:])
                            
                            preds_list = []
                            for ti in range(1, s_v):
                                est, _ = model.step(y_v_sup[:, ti, :])
                                preds_list.append(est)
                            
                            preds_stack = torch.stack(preds_list, dim=1).view(b_v, J_samples, s_v-1, 4)
                            val_mean = preds_stack.mean(dim=1)
                            val_var = preds_stack.var(dim=1) + 1e-9
                            target_v = x_v[:, 1:, :] 
                            
                            val_mse_list.append(F.mse_loss(val_mean, target_v).item())
                            
                            # Logging sigma stats
                            val_sigma_sum += val_var.mean().sqrt().item()
                            val_batch_count += 1
                            
                            all_val_x_true.append(target_v)
                            all_val_x_hat.append(val_mean)
                            all_val_P.append(val_var)

                    avg_val_mse = np.mean(val_mse_list)
                    avg_val_sigma = val_sigma_sum / max(1, val_batch_count)
                    
                    try:
                        cat_true = torch.cat(all_val_x_true, dim=0)
                        cat_hat = torch.cat(all_val_x_hat, dim=0)
                        cat_P = torch.cat(all_val_P, dim=0)
                        avg_anees = calculate_anees_internal(cat_true, cat_hat, cat_P)
                    except Exception as e:
                        print(f"ANEES Error: {e}")
                        avg_anees = 100.0

                    anees_penalty = abs(avg_anees - 4.0)
                    hybrid_score = avg_val_mse + (calibration_parameter * anees_penalty)

                    print(f"--- VALIDATION: MSE {avg_val_mse:.2f} | Sigma {avg_val_sigma:.2f}m | ANEES {avg_anees:.2f} | Score {hybrid_score:.2f} ---")
                    
                    if hybrid_score < best_val_score:
                        print(f"  >>> New Best Model! (Score: {best_val_score:.2f} -> {hybrid_score:.2f}) <<<")
                        best_val_score = hybrid_score
                        best_model_state = deepcopy(model.state_dict())
                    print("-" * 40)
                    
                    # 3. Restore State
                    model.train()
                    if 'h_prev' in train_state: model.h_prev = train_state['h_prev']
                    model.x_filtered_t_minus_1 = train_state['x_filt_1']
                    model.x_pred_t_minus_1 = train_state['x_pred_1']
                    if 'x_filt_2' in train_state: model.x_filtered_t_minus_2 = train_state['x_filt_2']
                    if 'y_1' in train_state: model.y_t_minus_1 = train_state['y_1']
                    if 'P' in train_state: model.P_t_minus_1 = train_state['P']

    print("Training completed.")
    if best_model_state:
        model.load_state_dict(best_model_state)
    return {"final_model": model}

In [None]:
import torch
import torch.nn.functional as F
import numpy as np
from copy import deepcopy
import sys

# Ensure utils is importable if needed
# import utils.utils as utils 

def gaussian_nll_safe(target, preds, var, min_var=1e-6, max_error_sq=100.0):
    """
    Bezpeƒçn√° NLL loss funkce.
    """
    safe_var = var + min_var
    error_sq = (preds - target) ** 2
    # Clampujeme velikost chyby v ƒçitateli, aby loss neexplodovala
    error_sq_clamped = torch.clamp(error_sq, max=max_error_sq)
    nll = 0.5 * (torch.log(safe_var) + error_sq_clamped / safe_var)
    return nll.mean()

def strain_BayesianKalmanNet_TwoPhase(
    model, train_loader, val_loader, device,
    total_train_iter, learning_rate, clip_grad,
    J_samples, validation_period, logging_period,
    mse_warmup_iters=0,
    calibration_parameter=0.0, # <--- NOV√ù PARAMETR: V√°ha pro ANEES v hybridn√≠m sk√≥re
    weight_decay_=1e-5
):
    # Optimizer
    optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay_)
    
    # Tracking
    best_hybrid_score = float('inf') # Hlavn√≠ metrika pro ukl√°d√°n√≠
    best_model_state = None
    best_iter_count = 0
    
    # V√Ωsledky pro return
    best_val_mse = float('inf')
    best_val_anees = float('inf')
    best_val_nll = float('inf')

    train_iter_count = 0
    done = False

    print(f"üöÄ START Two-Phase Training")
    print(f"    Phase 1: MSE Warmup (0 - {mse_warmup_iters} iters)")
    print(f"    Phase 2: NLL Optimization ({mse_warmup_iters} - {total_train_iter} iters)")
    print(f"    Saving Strategy: Hybrid Score = MSE + ({calibration_parameter} * |ANEES - 4.0|)")

    while not done:
        model.train()
        for x_true_batch, y_meas_batch in train_loader:
            if train_iter_count >= total_train_iter: done = True; break
            
            # Detekce NaN v datech
            if torch.isnan(x_true_batch).any():
                print(f"!!! SKIP BATCH iter {train_iter_count}: NaN found in x_true !!!")
                continue
            
            x_true_batch = x_true_batch.to(device)
            y_meas_batch = y_meas_batch.to(device)
            batch_size, seq_len, _ = x_true_batch.shape
            
            # --- Training Step ---
            optimizer.zero_grad()
            
            all_trajectories_for_ensemble = []
            all_regs_for_ensemble = []

            # 1. Ensemble Forward Pass
            for j in range(J_samples):
                model.reset(batch_size=batch_size, initial_state=x_true_batch[:, 0, :])
                current_trajectory_x_hats = []
                current_trajectory_regs = []
                
                for t in range(1, seq_len):
                    y_t = y_meas_batch[:, t, :]
                    x_filtered_t, reg_t = model.step(y_t)
                    
                    if torch.isnan(x_filtered_t).any():
                        # Fail-safe pro numerickou nestabilitu
                        print(f"NaN detected in forward pass at iter {train_iter_count}")
                        loss = torch.tensor(float('nan'), requires_grad=True) # Dummy NaN loss
                        break 
                        
                    current_trajectory_x_hats.append(x_filtered_t)
                    current_trajectory_regs.append(reg_t)
                
                if len(current_trajectory_x_hats) != (seq_len - 1): break # Pokud nastal break v inner loop

                all_trajectories_for_ensemble.append(torch.stack(current_trajectory_x_hats, dim=1))
                all_regs_for_ensemble.append(torch.sum(torch.stack(current_trajectory_regs)))

            if len(all_trajectories_for_ensemble) < J_samples:
                 # Pokud nƒõjak√Ω sample selhal, p≈ôeskoƒç√≠me update
                 optimizer.zero_grad()
                 continue

            # 2. Statistiky Ensemble
            ensemble_trajectories = torch.stack(all_trajectories_for_ensemble, dim=0)
            x_hat_sequence = ensemble_trajectories.mean(dim=0)
            
            # Epistemick√° variance
            cov_diag_sequence = ensemble_trajectories.var(dim=0) + 1e-9 
            
            # Regularizace
            regularization_loss = torch.stack(all_regs_for_ensemble).mean() / seq_len
            
            target_sequence = x_true_batch[:, 1:, :]
            
            # --- 3. V√ùPOƒåET LOSS ---
            mse_loss = F.mse_loss(x_hat_sequence, target_sequence)
            nll_loss = gaussian_nll_safe(
                target=target_sequence, 
                preds=x_hat_sequence, 
                var=cov_diag_sequence
            )
            
            # P≈ôep√≠n√°n√≠ f√°z√≠
            mode = ""
            if train_iter_count < mse_warmup_iters:
                loss = mse_loss + regularization_loss
                mode = "Warmup"
            else:
                loss = nll_loss + regularization_loss
                mode = "Optim"
            
            if torch.isnan(loss): 
                print("Collapse detected (NaN loss)"); done = True; break
            
            loss.backward()

            if clip_grad > 0: 
                torch.nn.utils.clip_grad_norm_(model.parameters(), clip_grad)
            
            optimizer.step()
            train_iter_count += 1
            
            # --- LOGGING ---
            if train_iter_count % logging_period == 0:
                with torch.no_grad():
                    # V√Ωpoƒçet pr≈Ømƒõrn√© smƒõrodatn√© odchylky v metrech (Sigma)
                    # cov_diag_sequence je variance [m^2], odmocn√≠me pro metry
                    avg_sigma = torch.sqrt(cov_diag_sequence).mean().item()
                    
                    # MAE (Mean Absolute Error) v metrech
                    mae = (x_hat_sequence - target_sequence).abs().mean().item()
                    
                    # Dropout pravdƒõpodobnosti
                    p1 = torch.sigmoid(model.dnn.concrete_dropout1.p_logit).item()
                    p2 = torch.sigmoid(model.dnn.concrete_dropout2.p_logit).item()
                
                # Form√°tovan√Ω v√Ωpis dle po≈æadavku
                print(f"Iter {train_iter_count} ({mode}): "
                      f"Loss {loss.item():.2f} | "
                      f"MSE {mse_loss.item():.2f} | "
                      f"NLL {nll_loss.item():.2f} | "
                      f"Sigma {avg_sigma:.2f}m | " 
                      f"MAE {mae:.2f}m | "
                      f"p1={p1:.2f}, p2={p2:.2f}")

            # --- VALIDATION ---
            if train_iter_count > 0 and train_iter_count % validation_period == 0:
                print(f"\n--- Validation at iteration {train_iter_count} ---")
                model.eval()
                val_mse_list = []
                val_nll_list = []
                
                all_val_x_true, all_val_x_hat, all_val_P_hat = [], [], []

                with torch.no_grad():
                    for x_true_val, y_meas_val in val_loader:
                        v_bs, v_seq, _ = x_true_val.shape
                        x_true_val = x_true_val.to(device)
                        y_meas_val = y_meas_val.to(device)
                        
                        val_ensemble_trajs = []
                        
                        for j in range(J_samples):
                            model.reset(batch_size=v_bs, initial_state=x_true_val[:, 0, :])
                            v_x_hats = []
                            for t in range(1, v_seq):
                                est, _ = model.step(y_meas_val[:, t, :])
                                v_x_hats.append(est)
                            val_ensemble_trajs.append(torch.stack(v_x_hats, dim=1))
                        
                        val_ens_stack = torch.stack(val_ensemble_trajs, dim=0)
                        val_mean = val_ens_stack.mean(dim=0)
                        val_var_diag = val_ens_stack.var(dim=0) + 1e-9
                        
                        # Metriky
                        val_mse_list.append(F.mse_loss(val_mean, x_true_val[:, 1:, :]).item())
                        val_nll_list.append(gaussian_nll_safe(x_true_val[:, 1:, :], val_mean, val_var_diag).item())
                        
                        # Data pro ANEES
                        full_x_hat = torch.cat([x_true_val[:, 0, :].unsqueeze(1), val_mean], dim=1)
                        
                        # Konstrukce pln√© kovarianƒçn√≠ matice P
                        # Pou≈æijeme diag_embed pro efektivitu, fallback na cyklus
                        try:
                            val_covs_full = torch.diag_embed(val_var_diag)
                        except:
                            val_covs_full = torch.zeros(v_bs, v_seq-1, 4, 4, device=device)
                            for b in range(v_bs):
                                for t in range(v_seq-1):
                                    val_covs_full[b, t] = torch.diag(val_var_diag[b, t])
                        
                        P0 = torch.eye(4, device=device).unsqueeze(0).unsqueeze(0).repeat(v_bs, 1, 1, 1) * 1e-6
                        full_P_hat = torch.cat([P0, val_covs_full], dim=1)
                        
                        all_val_x_true.append(x_true_val.cpu())
                        all_val_x_hat.append(full_x_hat.cpu())
                        all_val_P_hat.append(full_P_hat.cpu())

                avg_val_mse = np.mean(val_mse_list)
                avg_val_nll = np.mean(val_nll_list)
                
                # V√Ωpoƒçet ANEES
                try:
                    # Zde p≈ôedpokl√°d√°m, ≈æe funkce calculate_anees_vectorized je dostupn√°
                    # buƒè v 'trainer' nebo importovan√° z 'utils'
                    from utils import trainer as tr_utils # Lok√°ln√≠ import pro jistotu
                    avg_val_anees = tr_utils.calculate_anees_vectorized(
                        torch.cat(all_val_x_true, dim=0), 
                        torch.cat(all_val_x_hat, dim=0), 
                        torch.cat(all_val_P_hat, dim=0)
                    )
                except:
                    # Fallback pokud import sel≈æe, nastav√≠me na 4.0 (ide√°l) aby nezkazilo sk√≥re
                    # nebo NaN pro info
                    avg_val_anees = 100.0 # Velk√© ƒç√≠slo aby to bylo vidƒõt
                    # print("Warning: ANEES calculation failed.")

                # --- HYBRID SCORE CALCULATION ---
                anees_diff = abs(avg_val_anees - 4.0)
                hybrid_score = avg_val_mse + (calibration_parameter * anees_diff)

                print(f"  > Val MSE: {avg_val_mse:.4f} | Val ANEES: {avg_val_anees:.4f} | Hybrid Score: {hybrid_score:.4f}")
                
                # Ukl√°d√°n√≠ podle Hybrid Score
                if hybrid_score < best_hybrid_score:
                    print(f"  >>> ‚≠ê New Best Model! (Score: {best_hybrid_score:.4f} -> {hybrid_score:.4f})")
                    best_hybrid_score = hybrid_score
                    best_iter_count = train_iter_count
                    best_model_state = deepcopy(model.state_dict())
                    
                    # Ulo≈æ√≠me si metriky tohoto nejlep≈°√≠ho modelu
                    best_val_mse = avg_val_mse
                    best_val_anees = avg_val_anees
                    best_val_nll = avg_val_nll
                
                print("-" * 60)
                model.train()

    print("\nTraining completed.")
    if best_model_state:
        print(f"Loading best model from iteration {best_iter_count} (Score: {best_hybrid_score:.4f})")
        model.load_state_dict(best_model_state)
    
    # Vrac√≠me slovn√≠k v√Ωsledk≈Ø pro ulo≈æen√≠
    return {
        "final_model": model,
        "best_iter": best_iter_count,
        "best_val_mse": best_val_mse,
        "best_val_anees": best_val_anees,
        "best_val_nll": best_val_nll
    }

In [15]:
import torch
import copy

# --- 1. DEFINICE CURRICULA (UPRAVENO PRO STABILITU) ---
curriculum_schedule = [
    # F√ÅZE 1: Stabilizace (Kr√°tk√© sekvence)
    # Zde se model nauƒç√≠ z√°klady dynamiky.
    {
        'phase_id': 1,
        'seq_len': 10,
        'iters': 2000,
        'lr': 1e-3,
        'lambda_mse': 1.0,    # Na kr√°tk√© sekvenci staƒç√≠ m√°lo
        'clip_grad': 1.0,
        'use_tbptt': False,
        'mse_warmup_iters': 1000 # Celou dobu jedeme v re≈æimu "Siln√© MSE + Soft NLL"
    },
    
    # F√ÅZE 2: Prodlou≈æen√≠ na 100 (Kritick√° f√°ze)
    # Tady doch√°zelo k explozi variance. Mus√≠me b√Ωt p≈ô√≠snƒõj≈°√≠.
    {
        'phase_id': 2,
        'seq_len': 100,
        'iters': 3000,        # P≈ôidali jsme iterace
        
        'lambda_mse': 5.0,   # <--- ZV√ù≈†ENO! (b√Ωvalo 1.0). Dr≈æ√≠ model u zemƒõ.
        
        'lr': 1e-5,           # Trochu pomalej≈°√≠ uƒçen√≠
        'clip_grad': 0.05,
        'use_tbptt': False,
        'tbptt_w': 50,
        'tbptt_k': 5,
        
        # D√≠ky nov√© funkci se i ve warmupu uƒç√≠ variance.
        # Proto chceme warmup del≈°√≠, aby model z√≠skal jistotu v trajektorii.
        'mse_warmup_iters': 1500, 
        
        'calibration_parameter': 10.0
    },
    
    # # F√ÅZE 3: Long-term (300)
    # {
    #     'phase_id': 3,
    #     'seq_len': 300,
    #     'iters': 2000,
    #     'lr': 1e-6,           # Jemn√© doladƒõn√≠
    #     'clip_grad': 0.01,
        
    #     'lambda_mse': 10.0,   # <--- VYSOK√Å KOTVA. Na 300 kroc√≠ch je MSE obrovsk√©.
        
    #     'use_tbptt': True,
    #     'tbptt_w': 50,
    #     'tbptt_k': 5,
    #     'mse_warmup_iters': 500, # Kr√°tk√Ω re-warmup na nov√© d√©lce
    #     'calibration_parameter': 10.0
    # }
]

# --- 2. INICIALIZACE MODELU ---
print("=== INICIALIZACE BKN MODELU ===")
state_knet2 = TAN.StateBayesianKalmanNetTAN(
        system_model=system_model, 
        device=device,
        hidden_size_multiplier=12,       
        output_layer_multiplier=4,
        num_gru_layers=1,
        init_max_dropout=0.6, 
        init_min_dropout=0.4    
).to(device)

# --- 3. CURRICULUM LOOP ---
for phase in curriculum_schedule:
    phase_id = phase['phase_id']
    seq_len = phase['seq_len']

    if phase_id not in datasets_cache:
        print(f"‚ö†Ô∏è Skipping Phase {phase_id}: Data not in cache.")
        continue

    print(f"\n" + "="*60)
    print(f"üöÄ START PHASE {phase_id}: SeqLen {seq_len} | LR {phase['lr']} | Lambda MSE: {phase['lambda_mse']}")
    print("="*60)
    
    train_loader_phase = datasets_cache[phase_id][0]
    val_loader_phase = datasets_cache[phase_id][1]
    
    if phase['use_tbptt']:
        if not hasattr(state_knet2, 'detach_hidden'):
            raise AttributeError("Modelu chyb√≠ metoda 'detach_hidden()'")

        result = train_BayesianKalmanNet_TBPTT_Windowed(
            model=state_knet2,
            train_loader=train_loader_phase,
            val_loader=val_loader_phase,
            device=device,
            total_train_iter=phase['iters'],
            learning_rate=phase['lr'],
            clip_grad=phase['clip_grad'],
            J_samples=7,
            tbptt_w=phase.get('tbptt_w', 10),
            tbptt_k=phase.get('tbptt_k', 2),
            validation_period=20,
            logging_period=10,
            
            # P≈ôed√°v√°me parametry z curricula
            mse_warmup_iters=phase['mse_warmup_iters'],
            lambda_mse=phase['lambda_mse'], 
            weight_decay_=1e-3,
            calibration_parameter=phase.get('calibration_parameter', 10.0)
        )
        
    else:
        print(f"   -> Using Standard Hybrid Training")
        result = train_BayesianKalmanNet_TwoPhase(
            model=state_knet2,
            train_loader=train_loader_phase,
            val_loader=val_loader_phase,
            device=device,
            total_train_iter=phase['iters'],
            learning_rate=phase['lr'],
            clip_grad=phase['clip_grad'],
            J_samples=10,
            validation_period=10,
            logging_period=10,
            mse_warmup_iters=phase['mse_warmup_iters'],
            weight_decay_=1e-3,
            calibration_parameter=phase.get('calibration_parameter', 10.0)
        )
    
    save_path = f"bkn_curriculum_phase{phase_id}_len{seq_len}.pth"
    torch.save(state_knet2.state_dict(), save_path)
    print(f"‚úÖ F√°ze {phase_id} dokonƒçena. Model ulo≈æen do: {save_path}")

print("\nüéâ Cel√Ω tr√©nink dokonƒçen.")

=== INICIALIZACE BKN MODELU ===
INFO: Aplikuji upravenou inicializaci pro BKN.
DEBUG: V√Ωstupn√≠ vrstva inicializov√°na konzervativnƒõ (interval -0.1 a≈æ 0.1).

üöÄ START PHASE 1: SeqLen 10 | LR 0.001 | Lambda MSE: 1.0
   -> Using Standard Hybrid Training
üöÄ START Two-Phase Training
    Phase 1: MSE Warmup (0 - 1000 iters)
    Phase 2: NLL Optimization (1000 - 2000 iters)
    Saving Strategy: Hybrid Score = MSE + (10.0 * |ANEES - 4.0|)
Iter 10 (Warmup): Loss 326.40 | MSE 326.40 | NLL 16.82 | Sigma 7.13m | MAE 10.30m | p1=0.43, p2=0.44

--- Validation at iteration 10 ---
  > Val MSE: 284.0581 | Val ANEES: 15.8909 | Hybrid Score: 402.9667
  >>> ‚≠ê New Best Model! (Score: inf -> 402.9667)
------------------------------------------------------------
Iter 20 (Warmup): Loss 80.82 | MSE 80.81 | NLL 2.60 | Sigma 5.90m | MAE 5.54m | p1=0.43, p2=0.44

--- Validation at iteration 20 ---
  > Val MSE: 74.0805 | Val ANEES: 9.6974 | Hybrid Score: 131.0545
  >>> ‚≠ê New Best Model! (Score: 402.966

KeyboardInterrupt: 

In [None]:
if False:
    # save model.
    save_path = f'best_mse_and_anees_bknet.pth'
    torch.save(state_knet2.state_dict(), save_path)
    print(f"Model saved to '{save_path}'.")

# Test na synteticke trajektorii

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import Filters
import os
from tqdm import tqdm
from Filters import TAN

# === KONFIGURACE ===
TEST_DATA_PATH = './generated_data_synthetic_controlled/test_set/test.pt'
PLOT_PER_ITERATION = True  # Vykreslovat graf pro ka≈ædou trajektorii?
MAX_TEST_SAMPLES = 20        # Kolik trajektori√≠ z test setu vyhodnotit
J_EVALUATION = 100           # Poƒçet Monte Carlo vzork≈Ø pro BKN (Ensemble size)

print(f"=== VYHODNOCEN√ç BKN NA TESTOVAC√ç SADƒö (s ANEES) ===")
print(f"Naƒç√≠t√°m data z: {TEST_DATA_PATH}")

# 1. Naƒçten√≠ Testovac√≠ sady
if not os.path.exists(TEST_DATA_PATH):
    raise FileNotFoundError(f"Soubor {TEST_DATA_PATH} neexistuje!")

# P≈ôedpokl√°d√°me, ≈æe 'device' je definov√°no
if 'device' not in globals():
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

test_data = torch.load(TEST_DATA_PATH, map_location=device)
X_test_all = test_data['x']  # Ground Truth [N, Seq, 4]
Y_test_all = test_data['y']  # Measurements [N, Seq, 3]

n_samples = min(X_test_all.shape[0], MAX_TEST_SAMPLES)
print(f"Poƒçet testovac√≠ch trajektori√≠: {n_samples}")
print(f"Ensemble size (BKN): {J_EVALUATION}")
print(f"D√©lka sekvence: {X_test_all.shape[1]}")
print("Modely: BKN vs. UKF vs. PF vs. APF")

# 2. Inicializace pro sbƒõr dat
detailed_results = []
agg_mse = {"BKN": [], "UKF": [], "PF": [], "APF": []}
agg_pos = {"BKN": [], "UKF": [], "PF": [], "APF": []}
agg_anees = {"BKN": [], "UKF": [], "PF": [], "APF": []} # Nov√Ω list pro ANEES

# Ujist√≠me se, ≈æe BKN je v eval m√≥du 
state_knet2.eval() 

# --- POMOCN√Å FUNKCE PRO ANEES ---
def calculate_anees(gt, est, P):
    """
    Vypoƒç√≠t√° Average Normalized Estimation Error Squared.
    gt: Ground Truth [T, Dim] (NumPy)
    est: Odhad [T, Dim] (NumPy)
    P: Kovarianƒçn√≠ matice [T, Dim, Dim] (NumPy)
    """
    T = min(len(gt), len(est), len(P))
    anees_vals = []
    
    # O≈ô√≠znut√≠ na stejnou d√©lku
    gt = gt[:T]
    est = est[:T]
    P = P[:T]
    
    for t in range(T):
        e_t = gt[t] - est[t] # Chyba v ƒçase t
        P_t = P[t]
        
        try:
            # Inverze kovariance
            # P≈ôiƒçteme mal√© epsilon na diagon√°lu pro numerickou stabilitu, pokud je singul√°rn√≠
            if np.linalg.cond(P_t) > 1e10:
                P_t = P_t + np.eye(P_t.shape[0]) * 1e-6
                
            P_inv = np.linalg.inv(P_t)
            
            # Mahalanobisova vzd√°lenost: e^T * P^-1 * e
            anees_t = e_t.T @ P_inv @ e_t
            anees_vals.append(anees_t)
        except np.linalg.LinAlgError:
            anees_vals.append(np.nan)
            
    return np.nanmean(anees_vals)

# --- HLAVN√ç SMYƒåKA ---
for i in tqdm(range(n_samples), desc="Evaluace"):
    
    # A) P≈ô√≠prava dat
    x_gt_tensor = X_test_all[i].to(device)
    y_obs_tensor = Y_test_all[i].to(device)
    
    x_gt = x_gt_tensor.cpu().numpy()
    seq_len = x_gt.shape[0]
    true_init_state = x_gt_tensor[0] 
    
    # --- B) BKN (Ensemble) ---
    with torch.no_grad():
        init_batch = true_init_state.unsqueeze(0).repeat(J_EVALUATION, 1)
        state_knet2.reset(batch_size=J_EVALUATION, initial_state=init_batch)
        
        bkn_preds = []
        y_input_batch = y_obs_tensor.unsqueeze(0).repeat(J_EVALUATION, 1, 1)
        
        for t in range(1, seq_len):
            y_t = y_input_batch[:, t, :]
            x_est, _ = state_knet2.step(y_t) 
            bkn_preds.append(x_est)
            
        if len(bkn_preds) > 0:
            bkn_preds_tensor = torch.stack(bkn_preds, dim=1) # [J, Seq-1, 4]
            full_bkn_ensemble = torch.cat([init_batch.unsqueeze(1), bkn_preds_tensor], dim=1) # [J, Seq, 4]
            
            # Mean Estimate
            x_est_mean = full_bkn_ensemble.mean(dim=0)
            x_est_bkn = x_est_mean.cpu().numpy()
            
            # --- V√ùPOƒåET KOVARIANCE PRO BKN ---
            # P = 1/(J-1) * sum((x_j - x_mean) * (x_j - x_mean)^T)
            # Vycentrov√°n√≠
            residuals = full_bkn_ensemble - x_est_mean.unsqueeze(0) # [J, Seq, 4]
            # Permute pro batch matmul: [Seq, J, 4] a [Seq, 4, J]
            residuals = residuals.permute(1, 2, 0) # [Seq, 4, J]
            
            # Batch matrix multiplication: (Seq, 4, J) @ (Seq, J, 4) -> (Seq, 4, 4)
            P_bkn_tensor = torch.bmm(residuals, residuals.transpose(1, 2)) / (J_EVALUATION - 1)
            # P≈ôiƒçten√≠ process noise/stabilitu (voliteln√©, BKN variance je epistemick√°)
            P_bkn = P_bkn_tensor.cpu().numpy()
            
        else:
            x_est_bkn = x_gt
            P_bkn = np.eye(4)[np.newaxis, :, :].repeat(len(x_gt), axis=0)

    # --- C) Klasick√© Filtry ---
    
    # UKF
    ukf_ideal = Filters.UnscentedKalmanFilter(system_model)
    ukf_res = ukf_ideal.process_sequence(y_seq=y_obs_tensor, Ex0=true_init_state, P0=system_model.P0)
    x_est_ukf = ukf_res['x_filtered'].cpu().numpy()
    # Z√≠sk√°n√≠ P pro UKF (zkus√≠me r≈Øzn√© kl√≠ƒçe)
    P_ukf = ukf_res.get('P_filtered', ukf_res.get('P', None))
    if P_ukf is not None: P_ukf = P_ukf.cpu().numpy()

    # PF
    pf = TAN.ParticleFilterTAN(system_model, num_particles=1000) 
    pf_res = pf.process_sequence(y_seq=y_obs_tensor, Ex0=true_init_state, P0=system_model.P0)
    x_est_pf = pf_res['x_filtered'].cpu().numpy()
    P_pf = pf_res.get('P_filtered', pf_res.get('P', None))
    if P_pf is not None: P_pf = P_pf.cpu().numpy()

    # APF
    apf = TAN.AuxiliaryParticleFilterTAN(system_model, num_particles=2000) 
    apf_res = apf.process_sequence(y_seq=y_obs_tensor, Ex0=true_init_state, P0=system_model.P0)
    x_est_apf = apf_res['x_filtered'].cpu().numpy()
    P_apf = apf_res.get('P_filtered', apf_res.get('P', None))
    if P_apf is not None: P_apf = P_apf.cpu().numpy()
    
    # --- D) V√Ωpoƒçet chyb a ANEES ---
    min_len = min(len(x_gt), len(x_est_bkn), len(x_est_ukf))
    
    def calc_metrics(est, gt, P_mat):
        diff = est[:min_len] - gt[:min_len]
        mse = np.mean(np.sum(diff[:, :2]**2, axis=1)) 
        pos_err = np.mean(np.sqrt(diff[:, 0]**2 + diff[:, 1]**2))
        
        anees = np.nan
        if P_mat is not None:
            anees = calculate_anees(gt[:min_len], est[:min_len], P_mat[:min_len])
            
        return mse, pos_err, anees

    # Calculate for all
    mse_bkn, pos_bkn, anees_bkn = calc_metrics(x_est_bkn, x_gt, P_bkn)
    mse_ukf, pos_ukf, anees_ukf = calc_metrics(x_est_ukf, x_gt, P_ukf)
    mse_pf, pos_pf, anees_pf = calc_metrics(x_est_pf, x_gt, P_pf)
    mse_apf, pos_apf, anees_apf = calc_metrics(x_est_apf, x_gt, P_apf)
    
    # Ulo≈æen√≠
    agg_mse["BKN"].append(mse_bkn); agg_pos["BKN"].append(pos_bkn); agg_anees["BKN"].append(anees_bkn)
    agg_mse["UKF"].append(mse_ukf); agg_pos["UKF"].append(pos_ukf); agg_anees["UKF"].append(anees_ukf)
    agg_mse["PF"].append(mse_pf);   agg_pos["PF"].append(pos_pf);   agg_anees["PF"].append(anees_pf)
    agg_mse["APF"].append(mse_apf); agg_pos["APF"].append(pos_apf); agg_anees["APF"].append(anees_apf)

    detailed_results.append({
        "Run_ID": i + 1,
        "BKN_PosErr": pos_bkn, "BKN_ANEES": anees_bkn,
        "UKF_PosErr": pos_ukf, "UKF_ANEES": anees_ukf,
        "PF_PosErr": pos_pf,   "PF_ANEES": anees_pf,
        "APF_PosErr": pos_apf, "APF_ANEES": anees_apf
    })
    
    # E) Vykreslen√≠
    if PLOT_PER_ITERATION:
        fig = plt.figure(figsize=(12, 6))
        plt.plot(x_gt[:, 0], x_gt[:, 1], 'k-', linewidth=3, alpha=0.3, label='Ground Truth')
        plt.plot(x_est_bkn[:, 0], x_est_bkn[:, 1], 'g-', linewidth=2, label=f'BKN (Err: {pos_bkn:.1f}m, ANEES: {anees_bkn:.1f})')
        plt.plot(x_est_ukf[:, 0], x_est_ukf[:, 1], 'b--', linewidth=1, label=f'UKF (Err: {pos_ukf:.1f}m, ANEES: {anees_ukf:.1f})')
        # Pro p≈ôehlednost vykresl√≠me jen BKN a UKF, p≈ô√≠padnƒõ odkomentujte PF/APF
        # plt.plot(x_est_pf[:, 0], x_est_pf[:, 1], 'r:', linewidth=1, alpha=0.6, label='PF')
        plt.title(f"Test Trajectory {i+1}")
        plt.xlabel("X [m]")
        plt.ylabel("Y [m]")
        plt.legend()
        plt.axis('equal')
        plt.grid(True)
        plt.show()

# --- V√ùPIS V√ùSLEDK≈Æ ---
df_results = pd.DataFrame(detailed_results)
print("\n" + "="*120)
print(f"DETAILN√ç V√ùSLEDKY (Pozice v metrech | ANEES - ide√°l ~4.0)")
print("="*120)
pd.options.display.float_format = '{:,.2f}'.format
print(df_results[["Run_ID", "BKN_PosErr", "BKN_ANEES", "UKF_PosErr", "UKF_ANEES", "PF_PosErr", "APF_PosErr"]])

print("\n" + "="*120)
print(f"SOUHRNN√Å STATISTIKA ({n_samples} trajektori√≠)")
print("="*120)

def get_stats(key):
    return (np.nanmean(agg_mse[key]), np.nanstd(agg_mse[key]), 
            np.nanmean(agg_pos[key]), np.nanstd(agg_pos[key]),
            np.nanmean(agg_anees[key]), np.nanstd(agg_anees[key]))

bkn_s = get_stats("BKN")
ukf_s = get_stats("UKF")
pf_s = get_stats("PF")
apf_s = get_stats("APF")

# Form√°tov√°n√≠ tabulky
header = f"{'Model':<10} | {'Pos Error [m] (Mean ¬± Std)':<30} | {'ANEES (Mean ¬± Std)':<30}"
print(header)
print("-" * len(header))
print(f"{'BKN':<10} | {bkn_s[2]:.2f} ¬± {bkn_s[3]:.2f} m {'':<14} | {bkn_s[4]:.2f} ¬± {bkn_s[5]:.2f}")
print(f"{'UKF':<10} | {ukf_s[2]:.2f} ¬± {ukf_s[3]:.2f} m {'':<14} | {ukf_s[4]:.2f} ¬± {ukf_s[5]:.2f}")
print(f"{'PF':<10} | {pf_s[2]:.2f} ¬± {pf_s[3]:.2f} m {'':<14} | {pf_s[4]:.2f} ¬± {pf_s[5]:.2f}")
print(f"{'APF':<10} | {apf_s[2]:.2f} ¬± {apf_s[3]:.2f} m {'':<14} | {apf_s[4]:.2f} ¬± {apf_s[5]:.2f}")
print("="*120)

# Grafick√© porovn√°n√≠ (Boxplot Position Error)
plt.figure(figsize=(12, 5))
plt.subplot(1, 2, 1)
plt.boxplot([agg_pos["BKN"], agg_pos["UKF"], agg_pos["PF"]], labels=['BKN', 'UKF', 'PF'], patch_artist=True, boxprops=dict(facecolor='lightblue'))
plt.title("Position Error [m]")
plt.grid(True, axis='y', linestyle='--', alpha=0.7)

# Grafick√© porovn√°n√≠ (Boxplot ANEES)
plt.subplot(1, 2, 2)
# Filtrujeme NaN pro boxplot
anees_data = [
    [x for x in agg_anees["BKN"] if not np.isnan(x)],
    [x for x in agg_anees["UKF"] if not np.isnan(x)],
    [x for x in agg_anees["PF"] if not np.isnan(x)]
]
plt.boxplot(anees_data, labels=['BKN', 'UKF', 'PF'], patch_artist=True, boxprops=dict(facecolor='lightgreen'))
plt.axhline(y=4.0, color='r', linestyle='--', label='Ideal (4.0)')
plt.title("ANEES (Consistency)")
plt.legend()
plt.grid(True, axis='y', linestyle='--', alpha=0.7)
plt.tight_layout()
plt.show()