In [None]:
import matplotlib.pyplot as plt
import numpy as np
import mitsuba as mi
import pyvista as pv
import sionna
import tensorflow as tf
from sionna.rt import load_scene, Transmitter, Receiver, PlanarArray, PathSolver, Camera
import h5py
from pathlib import Path
import sys

no_preview = False # Toggle to False to use the preview widget

mi.set_variant("llvm_ad_mono_polarized")

scene_path = "../scene/scene_01.xml"

In [36]:
# Load scene with Sionna for ray tracing capabilities
print("Loading scene with Sionna...")
scene = load_scene(scene_path)
print("Sionna scene loaded.")

Loading scene with Sionna...
Sionna scene loaded.


In [37]:
# 2×2 MIMO patch-like array (spacings ≈ λ/2 @ 2.4 GHz)
lambda_half = 0.0625  # half wavelength spacing at 2.4 GHz

scene.tx_array = PlanarArray(
    num_rows=2,
    num_cols=2,
    vertical_spacing=lambda_half,
    horizontal_spacing=lambda_half,
    pattern="tr38901",
    polarization="V"
)

scene.rx_array = PlanarArray(
    num_rows=2,
    num_cols=2,
    vertical_spacing=lambda_half,
    horizontal_spacing=lambda_half,
    pattern="dipole",
    polarization="cross"
)

In [38]:
# --- Transmitter & Receiver Positions ---
tx_positions = [
    np.array([0.0, 1.0, 1.5]),
    np.array([2.0, 1.5, 1.5]),
]
rx_positions = [
    np.array([1.0, 2.0, 1.0]),
    np.array([2.5, 0.8, 1.0]),
]

# --- Add TX/RX to Scene ---
tx_list, rx_list = [], []

for i, pos in enumerate(tx_positions):
    tx = Transmitter(name=f"tx_{i}", position=pos, display_radius=0.08)
    scene.add(tx)
    tx_list.append(tx)

for i, pos in enumerate(rx_positions):
    rx = Receiver(name=f"rx_{i}", position=pos, display_radius=0.08)
    scene.add(rx)
    rx_list.append(rx)

# Aim each transmitter toward each receiver
for tx in tx_list:
    for rx in rx_list:
        tx.look_at(rx)

print(f"Placed {len(tx_list)} transmitters and {len(rx_list)} receivers.\n")


Placed 2 transmitters and 2 receivers.



In [39]:
camera_pos = [1.5, 1.0, 1.6]  # ~human height indoors, *inside* the room
camera_look = [1.5, 1.5, 1.3]  # Look slightly downward into the room

my_cam = Camera(
    position=camera_pos,
    look_at=camera_look,
)

In [40]:
solver = PathSolver()
paths = solver(scene, max_depth=5, los=True, specular_reflection=True, refraction=True)

num_paths = paths.tau.shape[-1]  # number of paths per TX-RX pair
print(f"Computed {num_paths} propagation paths per TX-RX link.\n")

Computed 286 propagation paths per TX-RX link.



In [41]:
if no_preview:
    scene.render(camera=my_cam, paths=paths, clip_at=20)
else:
    scene.preview(paths=paths, clip_at=20)

HBox(children=(Renderer(camera=PerspectiveCamera(aspect=1.31, children=(DirectionalLight(intensity=0.25, posit…

HBox(children=(Label(value='Clipping plane', layout=Layout(flex='2 2 auto', width='auto')), Checkbox(value=Tru…

In [None]:
# --- System parameters ---
fc = 2.4e9
bw = 20e6
nsub = 32

subcarriers = fc + (np.arange(nsub) - nsub//2) * (bw / nsub)
rng = np.random.default_rng(123) # rng is still used by helper functions if you kept them

# --- Extract CIR from Sionna paths ---
a, tau = paths.cir(normalize_delays=True, out_type="numpy")
a_snap = a[..., 0]  # first time sample
tau_snap = tau

# --- Determine array dimensions dynamically ---
if a_snap.ndim == 7:  # [max_depth, num_rx, num_rx_ant, num_tx, num_tx_ant, num_paths, num_time]
    _, num_rx, num_rx_ant, num_tx, num_tx_ant, num_paths, _ = a_snap.shape
    def get_paths(rx_idx, rx_ant, tx_idx, tx_ant):
        delays = tau_snap[:, rx_idx, rx_ant, tx_idx, tx_ant, :]
        powers = np.abs(a_snap[:, rx_idx, rx_ant, tx_idx, tx_ant, :]).mean(axis=0)
        return delays, powers
        
elif a_snap.ndim == 5:  # [num_rx, num_rx_ant, num_tx, num_tx_ant, num_paths]
    num_rx, num_rx_ant, num_tx, num_tx_ant, num_paths = a_snap.shape
    def get_paths(rx_idx, rx_ant, tx_idx, tx_ant):
        # --- THIS IS THE FIX ---
        # Delays (tau_snap) are 3D: [num_rx, num_tx, num_paths]
        delays = tau_snap[rx_idx, tx_idx, :]
        # Powers (a_snap) are 5D: [num_rx, num_rx_ant, num_tx, num_tx_ant, num_paths]
        powers = np.abs(a_snap[rx_idx, rx_ant, tx_idx, tx_ant, :])
        return delays, powers

elif a_snap.ndim == 3:  # [num_rx, num_tx, num_paths]
    num_rx, num_tx, num_paths = a_snap.shape
    num_rx_ant = 1
    num_tx_ant = 1
    def get_paths(rx_idx, rx_ant=None, tx_idx=None, tx_ant=None):
        delays = tau_snap[rx_idx, tx_idx, :]
        powers = np.abs(a_snap[rx_idx, tx_idx, :])
        return delays, powers
else:
    raise ValueError(f"Unexpected a_snap shape: {a_snap.shape}")

out_file = "../257A-sionna-csi/data/sionna_csi_BASE_CHANNEL.h5"
Path(out_file).parent.mkdir(parents=True, exist_ok=True)

with h5py.File(out_file, "w") as f:
    f.attrs['fc'] = fc
    f.attrs['bw'] = bw
    f.attrs['nsub'] = nsub
    f.attrs['ntx'] = num_tx * num_tx_ant
    f.attrs['nrx'] = num_rx * num_rx_ant
    f.attrs['num_paths'] = num_paths

    # This loop assumes num_rx is the number of receiver *positions*
    for rx_idx in range(num_rx): 
        g = f.create_group(f"pos_{rx_idx:04d}")
        
        # We need to get the correct 'rx' object from rx_list
        # Assuming rx_idx maps directly to rx_list
        if rx_idx >= len(rx_list):
            print(f"Warning: rx_idx {rx_idx} out of range for rx_list (len {len(rx_list)})")
            # Use a mock position or break
            gt_pos_data = np.array([-1, -1, -1]) 
        else:
            rx = rx_list[rx_idx]
            gt_pos_data = rx.position.numpy()


        H_freq_base = np.zeros((nsub, num_rx_ant, num_tx_ant), dtype=np.complex128)
        delays_all = np.zeros((num_rx_ant, num_tx_ant, num_paths))
        powers_all = np.zeros_like(delays_all)

        for rx_ant in range(num_rx_ant):
            # This loop assumes num_tx is the number of transmitter *positions*
            for tx_idx in range(num_tx):
                for tx_ant in range(num_tx_ant):
                    # Dynamically call get_paths with correct axes
                    if a_snap.ndim == 3:  # single-antenna case
                        delays, powers = get_paths(rx_idx, tx_idx=tx_idx)
                    else:
                        delays, powers = get_paths(rx_idx, rx_ant, tx_idx, tx_ant)

                    delays_all[rx_ant, tx_ant, :] = delays
                    powers_all[rx_ant, tx_ant, :] = powers

                    for l in range(num_paths):
                        # Ensure delays and subcarriers broadcast correctly
                        H_freq_base[:, rx_ant, tx_ant] += powers[l] * np.exp(-1j * 2*np.pi* delays[l] * subcarriers)

        # --- Save datasets ---
        # Save the 3D base channel, not the 4D snapshots
        g.create_dataset("H_freq_real", data=np.real(H_freq_base), compression="gzip")
        g.create_dataset("H_freq_imag", data=np.imag(H_freq_base), compression="gzip")
        
        g.create_dataset("path_delays", data=delays_all, compression="gzip")
        g.create_dataset("path_powers", data=powers_all, compression="gzip")
        g.create_dataset("gt_pos", data=gt_pos_data)

        # --- Attributes ---
        g.attrs['is_los'] = 1 # You might want to determine this dynamically
        g.attrs['subcarrier_spacing_hz'] = bw/nsub

print("Saved BASE CHANNEL 3D HDF5:", out_file)

Saved 2×2 MIMO CSI HDF5: ./sionna_csi_2x2MIMO.h5


In [None]:
# --- Visualization with PyVista ---
plotter = pv.Plotter(notebook=False)

# Load scene geometry and compute bounding box
mi_scene = mi.load_file(scene_path)
all_vertices = []
for shape in mi_scene.shapes():
    params = mi.traverse(shape)
    vertices = np.array(params["vertex_positions"], dtype=np.float32).reshape(-1, 3)
    faces = np.array(params["faces"], dtype=np.int32).reshape(-1, 3)
    faces_pv = np.hstack([np.full((faces.shape[0], 1), 3), faces]).flatten()
    mesh = pv.PolyData(vertices, faces_pv)
    plotter.add_mesh(mesh, color="lightgray", opacity=0.6, show_edges=False)
    all_vertices.append(vertices)

all_vertices = np.vstack(all_vertices)
bbox_min = np.min(all_vertices, axis=0)
bbox_max = np.max(all_vertices, axis=0)

# Add transmitters and receivers
for tx in tx_list:
    tx_center = np.squeeze(tx.position.numpy())
    plotter.add_mesh(pv.Sphere(radius=0.1, center=tx_center), color="red", label="Transmitter")

for rx in rx_list:
    rx_center = np.squeeze(rx.position.numpy())
    plotter.add_mesh(pv.Sphere(radius=0.1, center=rx_center), color="green", label="Receiver")

# --- Add ray paths ---
vertices = paths.vertices.numpy()  # convert to NumPy

def clip_path(path_pts, bbox_min, bbox_max):
    """Keep points inside the scene bounding box."""
    mask = np.all((path_pts >= bbox_min) & (path_pts <= bbox_max), axis=1)
    return path_pts[mask]

if vertices.ndim == 7:
    max_depth, num_rx, num_rx_ant, num_tx, num_tx_ant, num_paths, _ = vertices.shape
    for rx in range(num_rx):
        for rx_ant in range(num_rx_ant):
            for tx in range(num_tx):
                for tx_ant in range(num_tx_ant):
                    for p in range(num_paths):
                        path_pts = vertices[:, rx, rx_ant, tx, tx_ant, p, :]
                        if np.allclose(path_pts, 0):
                            continue
                        path_pts = clip_path(path_pts, bbox_min, bbox_max)
                        if len(path_pts) < 2:
                            continue
                        line = pv.lines_from_points(path_pts)
                        plotter.add_mesh(line, color="yellow", line_width=2)
elif vertices.ndim == 5:
    max_depth, num_rx, num_tx, num_paths, _ = vertices.shape
    for rx in range(num_rx):
        for tx in range(num_tx):
            for p in range(num_paths):
                path_pts = vertices[:, rx, tx, p, :]
                if np.allclose(path_pts, 0):
                    continue
                path_pts = clip_path(path_pts, bbox_min, bbox_max)
                if len(path_pts) < 2:
                    continue
                line = pv.lines_from_points(path_pts)
                plotter.add_mesh(line, color="yellow", line_width=2)
else:
    raise ValueError(f"Unexpected vertices shape: {vertices.shape}")

# Final touches
plotter.add_axes()
plotter.add_legend()
plotter.add_title("Ray Tracing Paths Visualization")
plotter.show()

# --- Scene Summary ---
print("\n=== Scene Summary ===")
print(f"Number of objects in scene: {len(scene.objects)}")
print(f"Number of TXs: {len(tx_list)}, RXs: {len(rx_list)}")
print(f"Path solver max_depth: {solver.max_depth if hasattr(solver, 'max_depth') else 'unknown'}")


=== Scene Summary ===
Number of objects in scene: 22
Number of TXs: 2, RXs: 2
Path solver max_depth: unknown
