# Example: Physics-Compliant OTFS System

This notebook demonstrates how to use the modular OTFS system for training and evaluation.

In [None]:
# Imports
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt

# OTFS modules
from otfs.modem import OTFS_Modem
from otfs.channel import generate_channel_params, ltv_channel_sim, get_effective_channel_matrix
from otfs.models.estimators import AttentionChannelEstimator
from otfs.models.detectors import DetectorNet
from otfs.models.end_to_end import NeuralReceiver
from otfs.data.datasets import OTFSPhysicsDataset, DetectorDataset
from otfs.classical.detectors import mmse_detector, zf_detector
from otfs.utils.metrics import calculate_ber
from otfs.training.trainer import train_model, validate_model

# Setup
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

## 1. System Parameters

In [None]:
# OTFS Grid Parameters
M = 4  # Delay bins
N = 4  # Doppler bins
num_symbols = M * N

# Pilot/Data Configuration
pilot_indices = np.array([0, 2, 5, 7, 8, 10, 13, 15])  # Checkerboard pattern
data_mask_flat = np.ones(num_symbols, dtype=bool)
data_mask_flat[pilot_indices] = False
data_indices = np.where(data_mask_flat)[0]

num_pilots = len(pilot_indices)
num_data = len(data_indices)

print(f"Grid: {M}x{N} = {num_symbols} symbols")
print(f"Pilots: {num_pilots} ({100*num_pilots/num_symbols:.1f}%)")
print(f"Data: {num_data} ({100*num_data/num_symbols:.1f}%)")

## 2. Generate Sample Data

In [None]:
# Generate a sample transmission
bits = np.random.choice([-1, 1], size=num_data)
x_dd = np.zeros(num_symbols, dtype=np.complex128)
x_dd[data_indices] = bits
x_dd[pilot_indices] = 1.0
x_dd_grid = x_dd.reshape(M, N)

# Modulate
tx_sig = OTFS_Modem.modulate(x_dd_grid, M, N)

# Channel
paths = generate_channel_params()
rx_sig, noise_power = ltv_channel_sim(tx_sig, paths, snr_db=10, M=M, N=N)

# Demodulate
y_dd_grid = OTFS_Modem.demodulate(rx_sig, M, N)

print(f"Transmitted: {x_dd_grid.real}")
print(f"Received: {y_dd_grid.real}")

## 3. Train Channel Estimator

In [None]:
# Create dataset
from otfs.data.datasets import ChannelEstimatorDataset

train_dataset = ChannelEstimatorDataset(5000, M, N, pilot_indices, snr_range=(0, 25))
val_dataset = ChannelEstimatorDataset(1000, M, N, pilot_indices, snr_range=(0, 25))

train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=128)

# Model
estimator = AttentionChannelEstimator().to(device)
criterion = nn.MSELoss()
optimizer = optim.Adam(estimator.parameters(), lr=0.001)

# Train
print("Training channel estimator...")
history = train_model(
    estimator, train_loader, criterion, optimizer, device,
    epochs=50, val_loader=val_loader, save_path='estimator.pth'
)

# Plot training curve
plt.figure(figsize=(10, 4))
plt.subplot(1, 2, 1)
plt.plot(history['train_loss'], label='Train')
plt.plot(history['val_loss'], label='Val')
plt.xlabel('Epoch')
plt.ylabel('MSE Loss')
plt.legend()
plt.title('Channel Estimator Training')
plt.grid(True)
plt.show()

## 4. Evaluate BER Performance

In [None]:
# Load trained estimator
estimator.load_state_dict(torch.load('estimator.pth', map_location=device)['model_state_dict'])

# Create receiver
receiver = NeuralReceiver(estimator_path='estimator.pth').to(device)

# Evaluate BER
snr_range = range(0, 22, 2)
ber_dl = []
ber_mmse = []
ber_zf = []

receiver.eval()

for snr in snr_range:
    errors_dl = 0
    errors_mmse = 0
    errors_zf = 0
    total_bits = 0
    
    for _ in range(50):
        # Generate data
        bits = np.random.choice([-1, 1], size=num_data)
        x_dd = np.zeros(num_symbols, dtype=np.complex128)
        x_dd[data_indices] = bits
        x_dd[pilot_indices] = 1.0
        x_dd_grid = x_dd.reshape(M, N)
        
        # Channel
        paths = generate_channel_params()
        tx_sig = OTFS_Modem.modulate(x_dd_grid, M, N)
        rx_sig, noise_power = ltv_channel_sim(tx_sig, paths, snr, M, N)
        y_dd_grid = OTFS_Modem.demodulate(rx_sig, M, N)
        
        # Neural detection
        ls_input = torch.zeros(2, M, N)
        ls_input[0, :, :] = torch.from_numpy(y_dd_grid.real).float()
        ls_input[1, :, :] = torch.from_numpy(y_dd_grid.imag).float()
        ls_input.flat[pilot_indices*2] = y_dd_grid.real.flat[pilot_indices]
        ls_input.flat[pilot_indices*2+1] = y_dd_grid.imag.flat[pilot_indices]
        
        y_tensor = torch.stack([
            torch.from_numpy(y_dd_grid.real).float(),
            torch.from_numpy(y_dd_grid.imag).float()
        ], dim=0).unsqueeze(0).to(device)
        
        with torch.no_grad():
            pred = receiver(ls_input.unsqueeze(0).to(device), y_tensor)
            pred_bits = torch.sign(pred[0, 0].cpu()).numpy().flatten()[data_indices]
        
        errors_dl += np.sum(pred_bits != bits)
        
        # Genie MMSE/ZF
        H_eff = get_effective_channel_matrix(paths, M, N)
        y_vec = y_dd_grid.flatten()
        x_pilots = np.zeros(num_symbols, dtype=np.complex128)
        x_pilots[pilot_indices] = 1.0
        y_clean = y_vec - (H_eff @ x_pilots)
        H_data = H_eff[:, data_indices]
        
        x_est_mmse = mmse_detector(y_clean, H_data, noise_power)
        bits_mmse = np.sign(x_est_mmse.real)
        errors_mmse += np.sum(bits_mmse != bits)
        
        x_est_zf = zf_detector(y_clean, H_data)
        bits_zf = np.sign(x_est_zf.real)
        errors_zf += np.sum(bits_zf != bits)
        
        total_bits += num_data
    
    ber_dl.append(errors_dl / total_bits)
    ber_mmse.append(errors_mmse / total_bits)
    ber_zf.append(errors_zf / total_bits)
    print(f"SNR {snr}dB: DL={ber_dl[-1]:.5f}, MMSE={ber_mmse[-1]:.5f}, ZF={ber_zf[-1]:.5f}")

# Plot BER curves
plt.figure(figsize=(10, 6))
plt.semilogy(snr_range, ber_dl, 'b-o', label='Neural Receiver')
plt.semilogy(snr_range, ber_mmse, 'g--', label='Genie MMSE')
plt.semilogy(snr_range, ber_zf, 'r:', label='Genie ZF')
plt.xlabel('SNR (dB)')
plt.ylabel('BER')
plt.title('OTFS Detection Performance')
plt.legend()
plt.grid(True, which='both', alpha=0.5)
plt.show()