In [1]:
import wandb
import optuna
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.optim.lr_scheduler import ReduceLROnPlateau
import sys
import os
import yaml
import numpy as np
import datetime
import matplotlib.pyplot as plt
import torch.nn.functional as F
import math
from scipy.fft import irfft, rfft
import gc
import time
sys.path.append(os.path.abspath("../lab_scripts"))
from constellation_diagram import QPSK_Constellation
from constellation_diagram import RingShapedConstellation

if wandb.run is not None:
    wandb.run.tags = list(wandb.run.tags) + ["junk"]
wandb.finish()
NUM_POINTS_FRAME = 6000
# NUM_POINTS_FRAME = 3040
CP_LENGTH = 2000
# CP_LENGTH = 1013
NUM_POINTS_SYMBOL = NUM_POINTS_FRAME + CP_LENGTH
POWER_NORMALIZATION = False

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
api = wandb.Api()
run = api.run("dylanbackprops-university-of-washington/mldrivenpeled/1zfa43s4") # Variable
model_name = "channel_model_final"
artifact = api.artifact("dylanbackprops-university-of-washington/mldrivenpeled/channel_model:v1943") # Variable
artifact_dir = artifact.download()
remote_config = run.config
run_name = run.name
print("Channel Run name:", run_name)

[34m[1mwandb[0m:   1 of 1 files downloaded.  


Channel Run name: earnest-bee-7750


In [3]:
remote_config

{'Nf': 1499,
 'Nt': 1,
 'lr': 0.001,
 'wd': 0.001,
 'flow': 300000,
 'gain': 20,
 'fhigh': '15e6',
 'ar_taps': 0,
 'fnyquist': '30e6',
 'num_taps': 10,
 'dc_offset': 3.5,
 'batch_size': 32,
 'state_size': 8,
 'grid_run_id': 'grid_20251202_103823',
 'hidden_size': 8,
 'linear_fast': False,
 'weight_init': 'default',
 'dilation_base': 2,
 'scheduler_type': 'reduce_lr_on_plateu',
 'detach_residuals': False,
 'training_schedule': [{'mode': 'normal', 'epochs': 10}],
 'subcarrier_spacing': '1e4'}

In [4]:
print("Contents of artifact_dir:", os.listdir(artifact_dir))
# Set device
if torch.cuda.is_available():
    device = torch.device("cuda")
elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
    device = torch.device("mps") # for M chip Macs
else:
    device = torch.device("cpu")
device = torch.device("cpu")
# print(device)

Contents of artifact_dir: ['channel_model_final.pth']


In [None]:
def zero_last_layer(m):
    if isinstance(m, nn.Linear):
        nn.init.zeros_(m.weight)
        if m.bias is not None:
            nn.init.zeros_(m.bias)

class StateSpaceModel(nn.Module):
    def __init__(self,
                 deterministic_num_taps,
                 deterministic_state_size,
                 deterministic_hidden_size,
                 stochastic_state_size,
                 stochastic_hidden_size
                 ):
        super().__init__()
        self.deterministic_state_size = deterministic_state_size
        self.stochastic_state_size = stochastic_state_size
        self.deterministic_num_taps = deterministic_num_taps
        self.deterministic_state_map = nn.Sequential(
            nn.Linear(deterministic_num_taps + deterministic_state_size, deterministic_hidden_size),
            nn.SiLU(),
            nn.Linear(deterministic_hidden_size, deterministic_hidden_size),
            nn.SiLU(),
            nn.Linear(deterministic_hidden_size, deterministic_state_size)
        )
        self.deterministic_out_map = nn.Sequential(
            nn.Linear(deterministic_num_taps + deterministic_state_size, deterministic_hidden_size),
            nn.SiLU(),
            nn.Linear(deterministic_hidden_size, 1) # Predict next scalar ouptut
        )
        self.stochastic_state_map = nn.Sequential(
            nn.Linear(stochastic_state_size + deterministic_num_taps + 1, stochastic_hidden_size),
            nn.SiLU(),
            nn.Linear(stochastic_hidden_size, stochastic_hidden_size),
            nn.SiLU(),
            nn.Linear(stochastic_hidden_size, stochastic_state_size)
        )
        self.stochastic_out_map = nn.Sequential(
            nn.Linear(deterministic_num_taps + stochastic_state_size + deterministic_state_size, stochastic_hidden_size),
            nn.SiLU(),
            nn.Linear(stochastic_hidden_size, 1)
        )
        self.linear_det_out_map = nn.Linear(deterministic_num_taps + deterministic_state_size, 1)
        self.linear_det_state_map = nn.Linear(deterministic_num_taps + deterministic_state_size, deterministic_state_size)
        self.linear_stoch_state_map = nn.Linear(stochastic_state_size + deterministic_num_taps + 1, stochastic_state_size)
        self.linear_stoch_out_map = nn.Linear(deterministic_num_taps + stochastic_state_size + deterministic_state_size, 1)

        self.n0 = nn.Parameter(torch.zeros(deterministic_state_size))
        self.z0 = nn.Parameter(torch.zeros(stochastic_state_size))

        # Make it so that the stochastic output starts at zero
        self.deterministic_state_map[-1].apply(zero_last_layer)
        self.deterministic_out_map[-1].apply(zero_last_layer)
        self.stochastic_state_map[-1].apply(zero_last_layer)
        self.stochastic_out_map[-1].apply(zero_last_layer)
        self.mode = "nonlinear"

    def forward(self, x, y):
        mode = self.mode
        device = x.device
        y_pred = torch.zeros_like(x, device=device)
        e_pred = torch.zeros_like(x, device=device)
        T = x.size(-1)
        B = x.size(0)
        # n_t = torch.zeros(B, self.deterministic_state_size, device=device)
        # z_t = torch.zeros(B, self.stochastic_state_size, device=device)
        n_t = self.n0.unsqueeze(0).expand(B, -1)  # [B, nx]
        z_t = self.z0.unsqueeze(0).expand(B, -1)  # [B, nz]
        # add zeros in front equal to num_taps - 1
        x = torch.cat([torch.zeros(B, self.deterministic_num_taps - 1, device=device), x], dim=-1)
        for t in range(T):
            x_t = x[:, t: t + self.deterministic_num_taps]

            if mode == "linear":
                y_t_pred = self.linear_det_out_map(torch.cat([x_t, n_t], dim=-1))
            else:
                y_t_pred = self.deterministic_out_map(torch.cat([x_t, n_t], dim=-1)) + self.linear_det_out_map(torch.cat([x_t, n_t], dim=-1))

            if y is None:
                '''INFERENCE MODE'''
                y_pred[:, t] = y_t_pred.squeeze(-1)
                # Assume residuals are an innovation process with zero mean
                if mode == "linear":
                    n_t = n_t + self.linear_det_state_map(torch.cat([x_t, n_t], dim=-1)) # [B, nx + num_taps]
                else:
                    n_t = n_t + self.deterministic_state_map(torch.cat([x_t, n_t], dim=-1)) + self.linear_det_state_map(torch.cat([x_t, n_t], dim=-1)) # [B, nx + num_taps]

            else:
                '''TRAINING MODE'''
                # Training mode
                y_t = y[:, t]
                r_t = y_t.unsqueeze(-1) - y_t_pred

                if mode == "linear":
                    nonlinear_noise_t = self.linear_stoch_out_map(torch.cat([x_t, n_t, z_t], dim=-1))
                else:
                    nonlinear_noise_t = self.stochastic_out_map(torch.cat([x_t, n_t, z_t], dim=-1)) + self.linear_stoch_out_map(torch.cat([x_t, n_t, z_t], dim=-1))
                y_t_next_pred = y_t_pred + nonlinear_noise_t
                y_pred[:, t] = y_t_next_pred.squeeze(-1)
                e_t = r_t - nonlinear_noise_t
                e_pred[:, t] = e_t.squeeze(-1)
                # Make state updates

                if mode == "linear":
                    z_t = z_t + self.linear_stoch_state_map(torch.cat([x_t, z_t, r_t], dim=-1))
                    n_t = n_t + self.linear_det_state_map(torch.cat([x_t, n_t], dim=-1)) # [B, nx + num_taps]
                else:
                    z_t = z_t + self.stochastic_state_map(torch.cat([x_t, z_t, r_t], dim=-1)) + self.linear_stoch_state_map(torch.cat([x_t, z_t, r_t], dim=-1))
                    n_t = n_t + self.deterministic_state_map(torch.cat([x_t, n_t], dim=-1)) + self.linear_det_state_map(torch.cat([x_t, n_t], dim=-1)) # [B, nx + num_taps]
        return y_pred, e_pred


class TCNBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, dilation):
        super().__init__()
        self.conv = nn.Conv1d(
            in_channels,
            out_channels,
            kernel_size=kernel_size,
            dilation=dilation,
            padding=0
        )
        self.padding = (kernel_size - 1) * dilation
        self.SiLU = nn.SiLU()
        self.resample = None
        if in_channels != out_channels:
            self.resample = nn.Conv1d(in_channels, out_channels, kernel_size=1)

    def forward(self, x):
        out = F.pad(x, (self.padding, 0))
        out = self.conv(out)
        out = self.SiLU(out)
        if out.size(2) > x.size(2):
            out = out[:, :, :x.size(2)]
        if self.resample:
            x = self.resample(x)
        return out + x  # residual connection


def sample_student_t_mps(mean, std, nu):
    '''
    Wilson-Hilferty Approximation for chi^2 converted to scaled and shifted student t
    '''
    z = torch.randn_like(mean)
    z_chi = torch.randn_like(mean)
    chi2_approx = nu * (1 - 2/(9*nu) + z_chi * torch.sqrt(2/(9*nu))).pow(3)
    chi2_approx = chi2_approx.clamp(min=0.01)
    scale = torch.sqrt(nu / chi2_approx)
    return mean + std * z * scale


class TCN(nn.Module):
    def __init__(self, nlayers=3, dilation_base=2, num_taps=10, hidden_channels=32):
        super().__init__()
        layers = []
        in_channels = 1
        for i in range(nlayers):
            dilation = dilation_base ** i
            layers.append(
                TCNBlock(in_channels, hidden_channels, num_taps, dilation)
            )
            in_channels = hidden_channels
        self.tcn = nn.Sequential(*layers)
        self.readout = nn.Conv1d(hidden_channels, 1, kernel_size=1)

        # Calculate the total receptive field for the whole TCN stack
        self.receptive_field = 1
        for i in range(nlayers):
            dilation = dilation_base ** i
            self.receptive_field += (num_taps - 1) * dilation

    def forward(self, xin):
        x = xin.unsqueeze(1)    # [B,1,T]
        out = self.tcn(x)     # [B,H,T]
        out = self.readout(out).squeeze(1)
        out = out - out.mean(dim=1, keepdim=True)  # [B,T]
        return out


class TCN_channel(nn.Module):
    def __init__(self, nlayers=3, dilation_base=2, num_taps=10, hidden_channels=32, learn_noise=False, gaussian=True):
        super().__init__()
        layers = []
        in_channels = 1
        for i in range(nlayers):
            dilation = dilation_base ** i
            layers.append(
                TCNBlock(in_channels, hidden_channels, num_taps, dilation)
            )
            in_channels = hidden_channels
        self.learn_noise = learn_noise
        self.tcn = nn.Sequential(*layers)
        if gaussian:
            self.readout = nn.Conv1d(hidden_channels, 2, kernel_size=1) # 2 channels mean | std
        else:
            self.readout = nn.Conv1d(hidden_channels, 3, kernel_size=1) # 3 channels mean | std | nu
        self.num_taps = num_taps
        self.gaussian = gaussian

        if not gaussian:
            with torch.no_grad():
                # Initialize log_nu bias
                self.readout.bias[2].fill_(48)

    def forward(self, xin):
        x = xin.unsqueeze(1)    # [B,1,T]
        out = self.tcn(x)     # [B,H,T]
        out = self.readout(out) # [B, 3, T] mean | std | nu
        mean_out = out[:, 0, :]
        log_std_out = out[:, 1, :]
        std_out = torch.exp(log_std_out)
        if not self.gaussian:
            log_nu_out = out[:, 2, :]
            nu_out = torch.nn.functional.softplus(log_nu_out)
            nu_out = torch.clamp(nu_out, 2, 50) # nu between 2 and 50
        mean_out = mean_out - mean_out.mean(dim=1, keepdim=True)  # [B ,T]

        # # Produce noisy output
        if self.gaussian:
            z = torch.randn_like(mean_out)
            noisy_out = mean_out + std_out * z
            nu_out = torch.zeros_like(mean_out)
        else:
            noisy_out = sample_student_t_mps(mean_out, std_out, nu_out)
        if self.learn_noise:
            return noisy_out, mean_out, std_out, nu_out
        else:
            return mean_out, mean_out, torch.zeros_like(mean_out)

class ProbabilisticStateSpaceModel(nn.Module):
    def __init__(self,
                 num_taps,
                 state_size,
                 hidden_size,
                 ar_taps,
                 detach_residuals=True,
                 linear_fast=True
                 ):
        super().__init__()
        self.state_size = state_size
        self.num_taps = num_taps
        self.linear_fast = linear_fast
        self.detach_residuals = detach_residuals
        self.state_map = nn.Sequential(
            nn.Linear(num_taps + state_size, hidden_size),
            # nn.SiLU(),
            # nn.Linear(hidden_size, hidden_size),
            nn.SiLU(),
            nn.Linear(hidden_size, state_size)
        )
        self.state_out_map = nn.Sequential(
            nn.Linear(num_taps + state_size, hidden_size),
            nn.SiLU(),
            nn.Linear(hidden_size, 2 + ar_taps) # Predict mean and std
        )
        if linear_fast:
            self.fast_feedthrough = nn.Linear(num_taps, 1, bias=False)
        else:
            self.fast_feedthrough = nn.Sequential(
                nn.Linear(num_taps, hidden_size, bias=False),
                nn.SiLU(),
                nn.Linear(hidden_size, 1, bias=False)
            )
        # nn.init.zeros_(self.fast_feedthrough.weight)
        # nn.init.zeros_(self.fast_feedthrough.weight)
        self.n0 = nn.Parameter(torch.zeros(state_size))
        # self.alpha = nn.Parameter(torch.tensor(0.5))
        self.alpha = nn.Parameter(torch.tensor(3 * torch.randn(state_size)))
        # self.state_map[-1].apply(zero_last_layer)
        # self.state_out_map[-1].apply(zero_last_layer)
        self.ar_coeffs = nn.Parameter(torch.zeros(ar_taps))
        self.ar_taps = ar_taps

    def _whiten(self, x, ar_preds):
        out = torch.zeros_like(x)
        for t in range(x.size(-1)):
            x_whitened_t = x[:, t]
            for i in range(self.ar_taps):
                j = t - (i + 1)
                if j >= 0: # Enforce boundary
                #   x_whitened_t = x_whitened_t - nn.Tanh(ar_preds[i]) * x[:, j]
                    if self.detach_residuals:
                        x_whitened_t = x_whitened_t - ar_preds[:, t, i] * x[:, j].detach()
                    else:
                        x_whitened_t = x_whitened_t - ar_preds[:, t, i] * x[:, j] # Remove the possibility of cheating by spiking previous errors
            out[:, t] = x_whitened_t
        return out

    def _step(self, xt, nt):
        inp = torch.cat([xt, nt], dim=-1)
        out = self.state_out_map(inp)
        y_pred = out[:, 0]
        y_fast = self.fast_feedthrough(xt).squeeze(-1)
        # y_fast_sq = self.fast_square(xt ** 2).squeeze(-1)
        y_pred = y_pred + y_fast
        std_pred = F.softplus(out[:, 1]) + 1e-4
        ar_preds = F.tanh(out[:, 2:])
        delta_n = self.state_map(inp)
        alpha = torch.sigmoid(self.alpha)
        nt_next = (1.0 - alpha) * nt + delta_n
        return y_pred, std_pred, ar_preds, nt_next

    def forward_train(self, x, y):
        device = x.device
        T = x.size(-1)
        B = x.size(0)
        y_pred = torch.zeros(B, T, device=device)
        std_pred = torch.zeros(B, T, device=device)
        residuals = torch.zeros(B, T, device=device)
        ar_preds = torch.zeros(B, T, self.ar_taps, device=device)
        nt = self.n0.unsqueeze(0).repeat(B, 1).clone()  # [B, nx]
        # add zeros in front equal to num_taps - 1
        x = torch.cat([torch.zeros(B, self.num_taps - 1, device=device), x], dim=-1)

        for t in range(T):
            xt = x[:, t: t + self.num_taps]
            y_pred_t, std_pred_t, ar_preds_t, nt = self._step(xt, nt)
            y_pred[:, t] = y_pred_t

            # get residuals
            r_t = (y[:, t] - y_pred_t)
            residuals[:, t] = r_t #
            std_pred[:, t] = std_pred_t
            ar_preds[:, t, :] = ar_preds_t

        # whiten the innovations
        innovations = self._whiten(residuals, ar_preds)
        return y_pred, std_pred, innovations


    def forward_simulate(self, x, eps=None):
        device = x.device
        T = x.size(-1)
        B = x.size(0)
        y_pred = torch.zeros(B, T, device=device)
        if eps is None:
            eps = torch.randn(B, T, device=device)
        residuals = []
        x = torch.cat([torch.zeros(B, self.num_taps - 1, device=device), x], dim=-1)
        n_t = self.n0.unsqueeze(0).repeat(B, 1).clone()
        for t in range(T):
            max_lag = min(self.ar_taps, t)
            xt = x[:, t: t + self.num_taps]
            y_pred_t, std_pred_t, ar_preds, n_t = self._step(xt, n_t)
            if max_lag > 0:
                past_resids = torch.stack(residuals[-max_lag:], dim=0) # [max_lag, B]
                past_resids = past_resids.transpose(0, 1) # [B, max_lag]
                past_resids = past_resids.flip(dims=[1]) # Flip so that first residual is closest in time
                ar_component = torch.sum(ar_preds[:, :max_lag] * past_resids, dim=1) # [B, max_lag] * [B, max_lag]
            else:
                ar_component = torch.zeros(B, device=device)

            r_t = ar_component + std_pred_t * eps[:, t]
            residuals.append(r_t)
            y_pred[:, t] = y_pred_t + r_t
        return y_pred

In [6]:
# channel_model = StateSpaceModel(
#     deterministic_num_taps=remote_config['deterministic_num_taps'],
#     deterministic_state_size=remote_config['deterministic_state_size'],
#     deterministic_hidden_size=remote_config['deterministic_hidden_size'],
#     stochastic_state_size=remote_config['stochastic_state_size'],
#     stochastic_hidden_size=remote_config['stochastic_hidden_size']
# )

channel_model = ProbabilisticStateSpaceModel(
    num_taps = remote_config["num_taps"],
    state_size = remote_config["state_size"],
    hidden_size = remote_config["hidden_size"],
    ar_taps = remote_config["ar_taps"],
    detach_residuals = remote_config["detach_residuals"],
    linear_fast = remote_config["linear_fast"]
)

channel_model_path = os.path.join(artifact_dir, model_name + ".pth")
checkpoint = torch.load(channel_model_path)
channel_model.load_state_dict(checkpoint["channel_model"])

# Freeze before moving to device
for param in channel_model.parameters():
    param.requires_grad = False

channel_model = channel_model.to(device).float()
channel_model.eval()

print("Channel model parameters frozen:",
      all(not param.requires_grad for param in channel_model.parameters()))

Channel model parameters frozen: True


  self.alpha = nn.Parameter(torch.tensor(3 * torch.randn(state_size)))
  checkpoint = torch.load(channel_model_path)


In [7]:
constellation_mode = "m7_apsk_constellation"

def get_constellation(mode: str):
        if mode == "qpsk":
            constellation = QPSK_Constellation()
        elif mode == "m5_apsk_constellation":
            constellation = RingShapedConstellation(filename=r'/Users/dylanjones/Desktop/mldrivenpeled/lab_scripts/saved_constellations/m5_apsk_constellation.npy')
        elif mode == "m6_apsk_constellation":
             constellation = RingShapedConstellation(filename=r'/Users/dylanjones/Desktop/mldrivenpeled/lab_scripts/saved_constellations/m6_apsk_constellation.npy')
        elif mode == "m7_apsk_constellation":
             # /Users/dylanjones/Desktop/mldrivenpeled/lab_scripts/saved_constellations/m7_apsk_constellation.npy
             # C:\Users\maild\mldrivenpeled\lab_scripts\saved_constellations\m7_apsk_constellation.npy
             constellation = RingShapedConstellation(filename=r'C:\Users\maild\mldrivenpeled\lab_scripts\saved_constellations\m7_apsk_constellation.npy')
        return constellation

constellation = get_constellation(constellation_mode)

In [8]:
script_dir = os.getcwd()
config_path = os.path.join(script_dir, "..", "offline_time_ae_config.yml")
with open(config_path, "r") as f:
    hyperparams = yaml.safe_load(f)

    wandb.init(project="mldrivenpeled",
            config=hyperparams,
            tags=['autoencoder'])
    config = wandb.config
    if wandb.run.notes is None:
        wandb.run.notes = ""
    config.modulator = constellation_mode
    wandb.run.notes += wandb.run.notes + f"\n | trained on channel model {run_name} \n | {constellation_mode}"
    print(f"WandB run info:")
    print(f"  Name: {wandb.run.name}")
    print(f"  ID: {wandb.run.id}")
    print(f"  URL: {wandb.run.url}")
    print("Chosen hyperparameters for this session:")
    print(config)

[34m[1mwandb[0m: Currently logged in as: [33mdylanbackprops[0m ([33mdylanbackprops-university-of-washington[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


WandB run info:
  Name: sleek-bee-7752
  ID: grppk2fb
  URL: https://wandb.ai/dylanbackprops-university-of-washington/mldrivenpeled/runs/grppk2fb
Chosen hyperparameters for this session:
{'CP_ratio': 0.25, 'batch_size': 16, 'dc_offset': 0, 'num_taps': 10, 'epochs': 300, 'gain': 20, 'lr': 0.001, 'nlayers': 2, 'hidden_channels': 16, 'dilation_base': 2, 'preamble_amplitude': 3, 'num_symbols_per_frame': 1, 'scheduler_type': 'reduce_lr_on_plateu', 'weight_init': 'default', 'Nf': 370, 'Nt': 1, 'flow': 300000, 'fhigh': '4e6', 'subcarrier_spacing': '1e4', 'modulator': 'm7_apsk_constellation'}


In [9]:
encoder = TCN(
    nlayers=config.nlayers,
    dilation_base=config.dilation_base,
    num_taps=config.num_taps,
    hidden_channels=config.hidden_channels,
).to(device)

decoder = TCN(
    nlayers=config.nlayers,
    dilation_base=config.dilation_base,
    num_taps=config.num_taps,
    hidden_channels=config.hidden_channels
).to(device)

In [10]:
optimizer = optim.AdamW(list(encoder.parameters()) + list(decoder.parameters()), lr=config.lr, weight_decay=1e-4)
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=20, min_lr=1e-6)

In [11]:
NUM_BITS = config.Nt * config.Nf * constellation.modulation_order
FREQUENCIES = torch.arange(float(config.flow), float(config.fhigh), float(config.subcarrier_spacing))
delta_f = FREQUENCIES[1] - FREQUENCIES[0]
KS = (FREQUENCIES / delta_f).to(torch.int)
K_MIN = int(KS[0].item())
K_MAX = int(KS[-1].item())
NUM_ZEROS = K_MIN - 1
UPSAMPLING_ZEROS= (NUM_POINTS_FRAME  +  -2 * K_MIN + -2 * len(KS)) // 2
PREAMBLE_MAX = config.preamble_amplitude
UPSAMPLING_ZEROS

2600

In [12]:
NUM_BITS

2590

In [13]:
def make_time_validate_plots(sent_frames_time, received_frames_time, decoded_frames_time,
                             frame_BER, run_model, step=0, zoom_samples=200):

    # Convert to numpy
    enc_in = sent_frames_time.detach().cpu().numpy().flatten()
    enc_out = received_frames_time.detach().cpu().numpy().flatten()
    dec_in = received_frames_time.detach().cpu().numpy().flatten()
    dec_out = decoded_frames_time.detach().cpu().numpy().flatten()

    # Power and scaling
    enc_power_in = np.mean(enc_in**2)
    enc_power_out = np.mean(enc_out**2)
    enc_scale = enc_power_out / (enc_power_in + 1e-12)

    dec_power_in = np.mean(dec_in**2)
    dec_power_out = np.mean(dec_out**2)
    dec_scale = dec_power_out / (dec_power_in + 1e-12)

    # MSEs
    mse_encoder = np.mean((enc_in - enc_out) ** 2)
    mse_decoder = np.mean((dec_in - dec_out) ** 2)
    mse_total = np.mean((enc_in - dec_out) ** 2)

    # Log scalars
    prefix = "time_"
    wandb.log({f"{prefix}mse_loss": mse_total}, step=step)
    wandb.log({f"{prefix}frame_BER": frame_BER}, step=step)

    # Plot
    fig, axes = plt.subplots(3, 1, figsize=(12, 16))
    time_points = np.arange(zoom_samples)

    axes[0].plot(time_points, enc_in[:zoom_samples], 'r', alpha=0.5, label='Encoder Input')
    axes[0].plot(time_points, enc_out[:zoom_samples], 'b', alpha=0.8, label='Encoder Output')
    axes[0].set_title(
        f"Encoder Comparison (MSE: {mse_encoder:.2e}) | "
        f"In {enc_power_in:.3f} | Out {enc_power_out:.3f} | Scale {enc_scale:.3f}"
    )
    axes[0].legend(); axes[0].grid(True)

    axes[1].plot(time_points, dec_in[:zoom_samples], 'r', alpha=0.5, label='Decoder Input')
    axes[1].plot(time_points, dec_out[:zoom_samples], 'b', alpha=0.8, label='Decoder Output')
    axes[1].set_title(
        f"Decoder Comparison (MSE: {mse_decoder:.2e}) | "
        f"In {dec_power_in:.3f} | Out {dec_power_out:.3f} | Scale {dec_scale:.3f}"
    )
    axes[1].legend(); axes[1].grid(True)

    axes[2].plot(time_points, enc_in[:zoom_samples], 'r', alpha=0.5, label='Original Input')
    axes[2].plot(time_points, dec_out[:zoom_samples], 'b', alpha=0.8, label='Final Output')
    axes[2].set_title(
        f"End-to-End Comparison ({'Trained' if run_model else 'Untrained'})\n"
        f"MSE: {mse_total:.2e}, BER: {frame_BER:.2f}"
    )
    axes[2].legend(); axes[2].grid(True)

    fig.tight_layout()
    wandb.log({f"{prefix}time_signals": wandb.Image(fig)}, step=step)
    plt.close(fig)

In [14]:
CP_LENGTH

2000

In [15]:
KS

tensor([ 30,  31,  32,  33,  34,  35,  36,  37,  38,  39,  40,  41,  42,  43,
         44,  45,  46,  47,  48,  49,  50,  51,  52,  53,  54,  55,  56,  57,
         58,  59,  60,  61,  62,  63,  64,  65,  66,  67,  68,  69,  70,  71,
         72,  73,  74,  75,  76,  77,  78,  79,  80,  81,  82,  83,  84,  85,
         86,  87,  88,  89,  90,  91,  92,  93,  94,  95,  96,  97,  98,  99,
        100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113,
        114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127,
        128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141,
        142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155,
        156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169,
        170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183,
        184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197,
        198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 2

In [None]:
def evm_loss(true_symbols, predicted_symbols):
    return torch.mean(torch.abs(true_symbols - predicted_symbols) ** 2)

def in_band_filter(x, ks_indices, nfft):
    mask = torch.zeros(nfft, device=device)
    neg_ks_indices = nfft - ks_indices
    mask[ks_indices] = 1.0
    mask[neg_ks_indices] = 1.0

    impulse_response = torch.fft.ifftshift(torch.fft.ifft(mask).real)
    h = impulse_response.view(1, 1, -1)
    filtered_x = F.conv1d(x.unsqueeze(1), h, padding='same').squeeze(1)
    return filtered_x


def in_band_time_loss(sent_time, decoded_time, ks_indices, n_fft, num_taps):
    """Compute in-band loss directly in time domain using filtering"""
    # Create frequency mask
    mask = torch.zeros(n_fft, device=sent_time.device)
    neg_ks_indices = n_fft - ks_indices
    mask[ks_indices] = 1.0
    mask[neg_ks_indices] = 1.0

    # Convert to time-domain filter (this is differentiable)
    impulse_response = torch.fft.ifftshift(torch.fft.ifft(mask).real)
    h = impulse_response.view(1, 1, -1)

    # Filter both signals
    sent_filtered = F.conv1d(sent_time.unsqueeze(1), h, padding='same').squeeze(1)
    decoded_filtered = F.conv1d(decoded_time.unsqueeze(1), h, padding='same').squeeze(1)

    # Compute MSE on filtered signals (equivalent to in-band frequency loss)
    loss = torch.mean((sent_filtered[:, num_taps:] - decoded_filtered[:, num_taps:]).pow(2))
    return loss


def symbols_to_time(X, num_padding_zeros: int, num_leading_zeros=0):
    # Make hermetian symmetric
    Nt, Nf = X.shape
    padding_zeros = torch.zeros(Nt, num_padding_zeros, device=device)
    leading_zeros = torch.zeros(Nt, num_leading_zeros, device=device)
    X = torch.cat([leading_zeros, X.to(device), padding_zeros], dim=-1)
    DC_Nyquist = torch.zeros((X.shape[0], 1), device=X.device)
    X_hermitian = torch.flip(X, dims=[1]).conj()
    X_full = torch.hstack([DC_Nyquist, X, DC_Nyquist, X_hermitian])
    # Convert to time domain
    x_time = torch.fft.ifft(X_full, dim=-1, norm="ortho").real
    return x_time.to(device)

def add_noise(signal, SNR):
    signal_power = signal.abs().pow(2).mean()
    noise_power = signal_power / SNR
    noise_std = (noise_power / 2) ** 0.5 # real and complex
    noise = noise_std * torch.randn_like(signal) + noise_std * 1j * torch.randn_like(signal)
    signal += noise
    return signal

# def add_noise_time(signal, SNR):
#     signal_power = signal.pow(2).mean()
#     noise_power = signal_power / SNR
#     noise_std = noise_power.sqrt()
#     noise = noise_std * torch.randn_like(signal)
#     return signal + noise


def add_noise_time_cp(signal_with_cp, cp_length, snr_in, snr_low, snr_high, inband_idx, print_snr=False):
    """
    Adds spectrally-shaped noise with three regions:
      - In-band: indices in inband_idx
      - Low out-of-band: below min(inband_idx)
      - High out-of-band: above max(inband_idx)
    """
    B, N_with_cp = signal_with_cp.shape
    device = signal_with_cp.device

    signal_no_cp = signal_with_cp[:, cp_length:]
    P_sig = signal_no_cp.pow(2).mean(dim=-1, keepdim=True)

    Pn_in_target = P_sig / snr_in
    Pn_low_target = P_sig / snr_low
    Pn_high_target = P_sig / snr_high

    num_pos_freqs = (N_with_cp - 1) // 2
    pos_freq_slice = slice(1, num_pos_freqs + 1)
    neg_freq_slice = slice(N_with_cp - num_pos_freqs, N_with_cp)

    inband_mask = torch.zeros(num_pos_freqs, dtype=bool, device=device)
    valid_inband_indices = inband_idx[(inband_idx > 0) & (inband_idx <= num_pos_freqs)]
    if valid_inband_indices.numel() > 0:
        inband_mask[valid_inband_indices - 1] = True

    all_idx = torch.arange(num_pos_freqs, device=device)
    low_mask = (all_idx < inband_idx.min()) & ~inband_mask
    high_mask = (all_idx > inband_idx.max()) & ~inband_mask

    num_in_bins = inband_mask.sum()
    num_low_bins = low_mask.sum()
    num_high_bins = high_mask.sum()

    def make_noise(num_bins, target_power):
        if num_bins == 0:
            return torch.zeros((B, 0), dtype=torch.complex64, device=device)
        var_per_bin = (target_power * N_with_cp) / (2 * num_bins)
        std_per_bin = torch.sqrt(var_per_bin)
        noise = (torch.randn(B, num_bins, device=device) +
                 1j * torch.randn(B, num_bins, device=device)) / math.sqrt(2.0)
        return std_per_bin * noise

    noise_in_pos = make_noise(num_in_bins, Pn_in_target)
    noise_low_pos = make_noise(num_low_bins, Pn_low_target)
    noise_high_pos = make_noise(num_high_bins, Pn_high_target)

    noise_pos = torch.zeros(B, num_pos_freqs, dtype=torch.complex64, device=device)
    if num_in_bins > 0: noise_pos[:, inband_mask] = noise_in_pos
    if num_low_bins > 0: noise_pos[:, low_mask] = noise_low_pos
    if num_high_bins > 0: noise_pos[:, high_mask] = noise_high_pos

    noise_fft = torch.zeros(B, N_with_cp, dtype=torch.complex64, device=device)
    noise_fft[:, pos_freq_slice] = noise_pos
    noise_fft[:, neg_freq_slice] = torch.conj(torch.flip(noise_pos, dims=[1]))
    noise_fft[:, 0] = 0
    noise_time = torch.fft.ifft(noise_fft, norm="ortho").real

    if print_snr:
        P_sig_mean = P_sig.mean().item()
        def check(mask, noise_vals, target):
            if mask.sum() == 0: return
            tmp_fft = torch.zeros_like(noise_fft)
            tmp_pos = torch.zeros_like(noise_pos); tmp_pos[:, mask] = noise_vals
            tmp_fft[:, pos_freq_slice] = tmp_pos
            tmp_fft[:, neg_freq_slice] = torch.conj(torch.flip(tmp_pos, dims=[1]))
            Pn_actual = torch.fft.ifft(tmp_fft, norm="ortho").real.pow(2).mean().item()
            print(f"SNR Check: target={target:.2f}, actual={P_sig_mean/Pn_actual:.2f}")
        check(inband_mask, noise_in_pos, snr_in)
        check(low_mask, noise_low_pos, snr_low)
        check(high_mask, noise_high_pos, snr_high)

    return signal_with_cp + noise_time



def calculate_BER(received_symbols, true_bits, constellation):
    # Demap symbols to bits
    constellation_symbols = torch.tensor(
        list(constellation._symbols_to_bits_map.keys()),
        dtype=received_symbols.dtype,
        device=received_symbols.device
    )
    distances = abs(received_symbols.reshape(-1, 1) - constellation_symbols.reshape(1, -1))

    closest_idx = distances.argmin(axis=1)
    constellation_symbols_list = list(constellation._symbols_to_bits_map.keys())
    decided_bits = [constellation._symbols_to_bits_map[constellation_symbols_list[idx]] for idx in closest_idx.cpu().numpy()]

    # Flatten decided bits into a 1D array
    decided_bits_flat = [int(bit) for symbol_bits in decided_bits for bit in symbol_bits]


    # Convert to NumPy arrays for comparison
    true_bits_array = np.array(true_bits)
    decided_bits_flat_array = np.array(decided_bits_flat)

    # Take minimum length to avoid shape mismatch
    min_len = min(len(true_bits_array), len(decided_bits_flat_array))
    true_bits_array = true_bits_array[:min_len]
    decided_bits_flat_array = decided_bits_flat_array[:min_len]

    # Calculate BER
    BER = float(np.sum(true_bits_array != decided_bits_flat_array) / len(true_bits_array))
    return BER


def train(encoder, decoder, optimizer, scheduler, config, device, mask=None):

    encoder = encoder.to(device)
    decoder = decoder.to(device)

    for epoch in range(config["epochs"]):
        encoder.train()
        decoder.train()

        epoch_loss = 0
        epoch_freq_loss = 0
        optimizer.zero_grad()
        batch_entries = []
        true_bits_list = []
        for batch in range(config["batch_size"]):
            # Generate frame data
            true_bits = np.random.randint(0, 2, size=NUM_BITS)
            true_bits_list.append(torch.tensor(true_bits))
            # true_bits = np.zeros(NUM_BITS).astype(int)
            true_bits_str = ''.join(map(str, true_bits))
            true_symbols = torch.tensor(
                constellation.bits_to_symbols(true_bits_str),
                dtype=torch.complex64, device=device
            )
            true_frame = true_symbols.reshape(config["Nt"], config["Nf"])

            # Add known experimental noise

            # true_frame = torch.zeros(config["Nt"], config["Nf"])
            # true_frame[:, 100] = 10

            if POWER_NORMALIZATION:
                true_frame = true_frame / true_frame.abs().pow(2).mean(dim=1, keepdim=True).sqrt()
            batch_entries.append(true_frame)

        true_bits = torch.stack(true_bits_list)

        # Batch along time domain
        true_frame = torch.cat(batch_entries)

        # print("Inband symbol power", true_frame.abs().square().mean())
        # Convert to time domain
        sent_frames_time = symbols_to_time(true_frame, UPSAMPLING_ZEROS, NUM_ZEROS)
        sent_frames_time = torch.hstack((sent_frames_time[:, -CP_LENGTH:], sent_frames_time))

        encoded_frames_time = encoder(sent_frames_time)
        # encoded_frames_time = in_band_filter(encoded_frames_time, KS, NUM_POINTS_FRAME)

        # Clip to preamble make
        encoded_frames_time = torch.clip(encoded_frames_time, -PREAMBLE_MAX, PREAMBLE_MAX)
        # encoded_frames_time = add_noise_time_cp(encoded_frames_time,
        #                                              snr_in=float(100000),
        #                                                 snr_low=float(100000),
        #                                                snr_high=0,
        #                                               inband_idx=torch.arange(int(1), int(25e6)),
        #                                              cp_length=CP_LENGTH)

        received_frames_time = channel_model.forward_simulate(encoded_frames_time)


        # Filter out of band noise
        # received_frames_time_noisy = in_band_filter(received_frames_time_noisy, KS, NUM_POINTS_FRAME)

        # received_frames_time = add_noise(received_frames_time, SNR=10**(config["snr_db"]/10))
        decoded_frames_time = decoder(received_frames_time)

        # Convert to frequency domain for loss
        sent_frames_frequency = torch.tensor(rfft(sent_frames_time[:, CP_LENGTH:].detach().cpu().numpy(), norm='ortho', axis=1)[:, KS])
        decoded_frames_frequency = torch.tensor(rfft(decoded_frames_time[:, CP_LENGTH:].detach().cpu().numpy(), norm='ortho', axis=1)[:, KS])

        sent_norm = sent_frames_frequency.abs().square().mean(dim=1, keepdim=True).sqrt()
        decoded_norm = decoded_frames_frequency.abs().square().mean(dim=1, keepdim=True).sqrt()


        if POWER_NORMALIZATION:
            sent_frames_frequency = sent_frames_frequency / sent_norm
            decoded_frames_frequency = decoded_frames_frequency / decoded_norm


        # print("Time power through channel", sent_frames_time.abs().square().mean().item(), received_frames_time.abs().square().mean().item(), sent_frames_time.shape)
        loss = in_band_time_loss(sent_frames_time, decoded_frames_time, ks_indices=KS, n_fft=NUM_POINTS_FRAME, num_taps=config['num_taps'])
        # loss = evm_loss(sent_frames_time, decoded_frames_time)
        diff_complex = sent_frames_frequency.detach() - decoded_frames_frequency.detach()
        # print("evm shapes", sent_frames_frequency.shape, decoded_frames_frequency.shape)
        freq_loss = torch.mean(diff_complex.abs().pow(2))

        loss.backward()
        epoch_loss += loss.item()
        epoch_freq_loss += freq_loss.item()


        optimizer.step()
        scheduler.step(epoch_loss)

        wandb.log({"loss": epoch_loss}, step=epoch)
        wandb.log({"freq_loss": epoch_freq_loss}, step=epoch)
        lr = optimizer.param_groups[0]["lr"]
        wandb.log({"learning_rate": lr}, step=epoch)

        # Get BER
        ber = calculate_BER(decoded_frames_frequency.detach().flatten(), true_bits.flatten(), constellation=constellation)
        wandb.log({"BER": ber}, step=epoch)
        print(f"Epoch {epoch} Finished | Avg Loss: {epoch_loss}")
        if epoch % 5 == 0:
            make_time_validate_plots(
            sent_frames_time[0],
            received_frames_time[0],
            decoded_frames_time[0],
            frame_BER=ber,
            run_model=True,
            step=epoch
            )

            # Plot first example of sent and reconstructed time
            fig, ax = plt.subplots()
            ax.plot(sent_frames_time[0][:100].detach().cpu().numpy(), label="Sent (time)")
            ax.plot(decoded_frames_time[0][:100].detach().cpu().numpy(), label="Decoded (time)")
            ax.legend()
            ax.set_title(f"Sent vs Decoded (Time Domain) EVM: {loss: 0.3e}")
            wandb.log({"time_domain_plot": wandb.Image(fig)}, step=epoch)
            plt.close(fig)

            sent_symbols = sent_frames_frequency[0].detach().cpu().numpy()
            decoded_symbols = decoded_frames_frequency[0].detach().cpu().numpy()

            # Compute EVM for logging (per frame)
            evm_val = evm_loss(torch.tensor(sent_symbols), torch.tensor(decoded_symbols)).item()

            # Create constellation plot
            fig, ax = plt.subplots()
            ax.scatter(sent_symbols.real, sent_symbols.imag, color='blue', alpha=0.6, label='Sent')
            ax.scatter(decoded_symbols.real, decoded_symbols.imag, color='red', alpha=0.6, label='Decoded')
            ax.set_xlabel('In-phase')
            ax.set_ylabel('Quadrature')
            ax.set_title(f'Constellation Diagram EVM: {evm_val:0.3e}')
            ax.legend()
            ax.grid(True)

            # Log to wandb
            wandb.log({"constellation": wandb.Image(fig)}, step=epoch)
            plt.close(fig)

            evm_per_freq = ((sent_frames_frequency[0] - decoded_frames_frequency[0]).abs()**2).detach().cpu().numpy()
            fig, ax = plt.subplots()
            ax.plot(evm_per_freq)
            ax.set_title("EVM vs Frequency")
            wandb.log({"evm_vs_freq": wandb.Image(fig)}, step=epoch)
            plt.close(fig)


            # # Plot Model SNRs
            # sent_k = torch.fft.fft(encoded_frames_time[:, CP_LENGTH:], norm="ortho", dim=-1)
            # received_k = torch.fft.fft(received_frames_time[:, CP_LENGTH:], norm="ortho", dim=-1)

            # residual = received_frames_time - mean
            # received_noise_k = torch.fft.fft(residual[:, CP_LENGTH:] ** 2, norm="ortho", dim=-1)


            # signal_power = torch.mean(torch.abs(sent_k) ** 2, dim=0).detach().cpu().numpy()
            # received_power = torch.mean(torch.abs(received_k) ** 2, dim=0).detach().cpu().numpy()
            # received_noise_power = torch.mean(torch.abs(received_noise_k), dim=0).detach().cpu().numpy()
            # snr_vs_freq = (received_power / received_noise_power + 1e-12)

            # fig, ax = plt.subplots()
            # ax.plot(10 * np.log10(signal_power[:len(signal_power)//2]))
            # ax.axvline(30, c='r', linestyle='--')
            # ax.axvline(int(4e2), c='r', linestyle='--')
            # ax.set_title(f"Encoded Signal Frequency Power Spectrum for {run.name}")
            # ax.set_xlabel("Freq Index k (10 kHz)")
            # ax.set_ylabel("Power (dB)")
            # wandb.log({"encoded_signal_power_spectrum": wandb.Image(fig)}, step=epoch)
            # plt.close(fig)

            # fig, ax = plt.subplots()
            # ax.plot(10 * np.log10(received_power[:len(received_power)//2]))
            # ax.axvline(30, c='r', linestyle='--')
            # ax.axvline(int(4e2), c='r', linestyle='--')
            # ax.set_title(f"Received Signal Frequency Power Spectrum for {run.name}")
            # ax.set_xlabel("Freq Index k (10 kHz)")
            # ax.set_ylabel("Power (dB)")
            # wandb.log({"received_signal_power_spectrum": wandb.Image(fig)}, step=epoch)
            # plt.close(fig)

            # fig, ax = plt.subplots()
            # ax.plot(10 * np.log10(snr_vs_freq[:len(snr_vs_freq)//2]))
            # ax.set_title(f"SNR vs Freq Estimate for {run.name}")
            # ax.set_xlabel("Freq Index k (10 kHz)")
            # ax.set_ylabel("SNR (dB)")
            # ax.axvline(30, c='r', linestyle='--')
            # ax.axvline(int(4e2), c='r', linestyle='--')
            # wandb.log({"snr_vs_freq": wandb.Image(fig)}, step=epoch)
            # plt.close(fig)

            # fig, ax = plt.subplots()
            # ax.plot(10 * np.log10(received_noise_power[:len(snr_vs_freq)//2]))
            # ax.set_title(f"Noise power vs Freq Estimate for {run.name}")
            # ax.set_xlabel("Freq Index k (10 kHz)")
            # ax.set_ylabel("Power (dB)")
            # ax.axvline(30, c='r', linestyle='--')
            # ax.axvline(int(4e2), c='r', linestyle='--')
            # wandb.log({"noise_power_vs_freq": wandb.Image(fig)}, step=epoch)
            # plt.close(fig)

    # Save model
    torch.save({
        "time_encoder": encoder.state_dict(),
        "time_decoder": decoder.state_dict()
    }, "time_autoencoder.pth")

    artifact = wandb.Artifact("time_autoencoder", type="model")
    artifact.add_file("time_autoencoder.pth")
    wandb.log_artifact(artifact)

    return epoch_loss

# torch.autograd.set_detect_anomaly(True)
train(encoder, decoder, optimizer, scheduler, config, device)
wandb.finish()

  sent_filtered = F.conv1d(sent_time.unsqueeze(1), h, padding='same').squeeze(1)
  true_bits_array = np.array(true_bits)


Epoch 0 Finished | Avg Loss: 0.12044768035411835


  true_bits_array = np.array(true_bits)


Epoch 1 Finished | Avg Loss: 0.11955181509256363
Epoch 2 Finished | Avg Loss: 0.1183292493224144
Epoch 3 Finished | Avg Loss: 0.11794055998325348
Epoch 4 Finished | Avg Loss: 0.11445077508687973
Epoch 5 Finished | Avg Loss: 0.11368311941623688
Epoch 6 Finished | Avg Loss: 0.11236777901649475
Epoch 7 Finished | Avg Loss: 0.10975070297718048
Epoch 8 Finished | Avg Loss: 0.1090206578373909
Epoch 9 Finished | Avg Loss: 0.10685618966817856
Epoch 10 Finished | Avg Loss: 0.10436522960662842
Epoch 11 Finished | Avg Loss: 0.10026341676712036
Epoch 12 Finished | Avg Loss: 0.09785860031843185
Epoch 13 Finished | Avg Loss: 0.09673261642456055
Epoch 14 Finished | Avg Loss: 0.09390603005886078
Epoch 15 Finished | Avg Loss: 0.08937866985797882
Epoch 16 Finished | Avg Loss: 0.0863528922200203
Epoch 17 Finished | Avg Loss: 0.08346813172101974
Epoch 18 Finished | Avg Loss: 0.07959093898534775
Epoch 19 Finished | Avg Loss: 0.07460330426692963
Epoch 20 Finished | Avg Loss: 0.07132095098495483
Epoch 21 Fin

In [None]:
# Get SNR vs freq estimate
test_freqs = torch.arange(0, 1e4 * 2999, 1e4, device=device)
test_ks = (test_freqs / (1e4)).to(torch.int)
true_bits = np.random.randint(0, 2, size=7 * len(test_freqs) * 100)


true_bits_str = ''.join(map(str, true_bits))
true_symbols = torch.tensor(
    constellation.bits_to_symbols(true_bits_str),
    dtype=torch.complex64, device=device
)

test = (NUM_POINTS_FRAME  +  -2 * test_ks[0] + -2 * len(test_ks)) // 2

true_frame = true_symbols.reshape(100, 2999)
true_bits = torch.tensor(true_bits)
sent_frames_time = symbols_to_time(true_frame, test, 0)
sent_frames_time = torch.hstack((sent_frames_time[:, -CP_LENGTH:], sent_frames_time))
encoded_frames_time = encoder(sent_frames_time)
received_frames_time_noisy, mean, std, nu = channel_model(encoded_frames_time)
decoded_frames_time = decoder(received_frames_time_noisy)

sent_k = torch.fft.fft(encoded_frames_time[:, CP_LENGTH:], norm="ortho", dim=-1)
received_k = torch.fft.fft(mean[:, CP_LENGTH:], norm="ortho", dim=-1)

residual = received_frames_time_noisy - mean
received_noise_k = torch.fft.fft(residual[:, CP_LENGTH:] ** 2, norm="ortho", dim=-1)


signal_power = torch.mean(torch.abs(sent_k) ** 2, dim=0).detach().cpu().numpy()
received_power = torch.mean(torch.abs(received_k) ** 2, dim=0).detach().cpu().numpy()
received_noise_power = torch.mean(torch.abs(received_noise_k), dim=0).detach().cpu().numpy()
snr_vs_freq = (received_power / received_noise_power + 1e-12)


plt.plot(10 * np.log10(signal_power[:len(signal_power)//2]))
plt.axvline(30, c='r', linestyle='--')
plt.axvline(int(4e2), c='r', linestyle='--')
plt.title(f"Encoded Signal Frequency Power Spectrum for {run.name}")
plt.xlabel("Freq Index k (10 kHz)")
plt.ylabel("Power (dB)")
plt.show()

plt.plot(10 * np.log10(received_power[:len(received_power)//2]))
plt.axvline(30, c='r', linestyle='--')
plt.axvline(int(4e2), c='r', linestyle='--')
plt.title(f"Received Signal Frequency Power Spectrum for {run.name}")
plt.xlabel("Freq Index k (10 kHz)")
plt.ylabel("Power (dB)")
plt.show()


plt.plot(10 * np.log10(snr_vs_freq[:len(snr_vs_freq)//2]))
plt.title(f"SNR vs Freq Estimate for {run.name}")
plt.xlabel("Freq Index k (10 kHz)")
plt.ylabel("SNR (dB)")
plt.axvline(30, c='r', linestyle='--')
plt.axvline(int(4e2), c='r', linestyle='--')
plt.show()


plt.plot(10 * np.log10(received_noise_power[30:len(snr_vs_freq)//2]))
plt.title(f"Noise power vs Freq Estimate for {run.name}")
plt.xlabel("Freq Index k (10 kHz)")
plt.ylabel("Power (dB)")
plt.axvline(30, c='r', linestyle='--')
plt.axvline(int(4e2), c='r', linestyle='--')
plt.show()

In [None]:
def objective(trial, tag):
    # Sample hyperparameters
    dilation_base = trial.suggest_categorical("dilation_base", [2])
    num_taps = trial.suggest_int("num_taps", 10, 30, step=2)
    hidden_channels = trial.suggest_int("hidden_channels", 4, 64, step=8)
    lr = trial.suggest_loguniform("lr", 1e-5, 1e-2)
    nlayers = trial.suggest_categorical("nlayers", [2, 3, 4])

    local_config = {
        "dilation_base": dilation_base,
        "num_taps": num_taps,
        "hidden_channels": hidden_channels,
        "lr": lr,
        "epochs": 800,
        "batch_size": 16,
        "Nt": 1,
        "Nf": 370,
        "save_path": "./saved_models",
        "nlayers": nlayers,
        "weight_init": "default",
        "scheduler_type": "reduce_lr_on_plateu",
        "modulator": f"{constellation_mode}"
    }

    wandb.init(project="mldrivenpeled", config=local_config, reinit=True,
               tags=['autoencoder', f'{tag}', f'trial {trial.number}'], mode='online')

    wandb.run.notes = f"\n | trained on channel model {run_name} \n | {constellation_mode}"

    encoder = None
    decoder = None
    optimizer = None
    scheduler = None

    try:
        encoder = TCN(
            nlayers=nlayers,
            dilation_base=dilation_base,
            num_taps=num_taps,
            hidden_channels=hidden_channels
        ).to(device)

        decoder = TCN(
            nlayers=nlayers,
            dilation_base=dilation_base,
            num_taps=num_taps,
            hidden_channels=hidden_channels
        ).to(device)

        optimizer = torch.optim.Adam(
            list(encoder.parameters()) + list(decoder.parameters()), lr=lr
        )
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=5, factor=0.5)

        final_loss = train(encoder, decoder, optimizer, scheduler, local_config, device)

        return final_loss

    finally:
        wandb.finish()
        time.sleep(0.5)

        del encoder, decoder, optimizer, scheduler

        if torch.cuda.is_available():
            torch.cuda.empty_cache()
            torch.cuda.synchronize()
        gc.collect()

In [None]:
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
study_name = f"optuna_offline_ae_model_{timestamp}"
# study_name= "optuna_offline_ae_model_20250919_174002"
study = optuna.create_study(direction="minimize", study_name=study_name, storage="sqlite:///optuna_results.db", load_if_exists=True)
study.optimize(lambda trial: objective(trial, study_name), n_trials=100)
print("Best trial:")
print(study.best_trial.params)

'''
{'dilation_base': 2, 'num_taps': 12, 'hidden_channels': 20, 'lr': 0.0005269745303114973, 'nlayers': 4, 'taps': 20} Current best with loss 2e-3
Trial 44 finished with value: 0.001965869450941682 and parameters: {'dilation_base': 2, 'num_taps': 10, 'hidden_channels': 44, 'lr': 0.009314226151764216, 'nlayers': 4}. Best is trial 44 with value: 0.001965869450941682
'''

In [None]:
storage = "sqlite:///optuna_results.db"
summaries = optuna.get_all_study_summaries(storage=storage)
for summary in summaries:
    print(f"Study name: {summary.study_name}")
    print(f"  Trial count: {summary.n_trials}")
    if summary.best_trial is not None:
        print(f"  Best value: {summary.best_trial.value}")
        print(f"  Best params: {summary.best_trial.params}")
    else:
        print("  No trials completed yet.")
    print("-" * 50)

In [None]:
from optuna.visualization import (
    plot_optimization_history,
    plot_param_importances,
    plot_slice,
    plot_parallel_coordinate
)

# Step 1: Choose your study name (copy it from the summaries you printed earlier)
study_name = "optuna_offline_ae_model_20250928_115643"
storage = "sqlite:///optuna_results.db"

# Step 2: Load the study
study = optuna.load_study(study_name=study_name, storage=storage)

# Step 3: Plot using interactive Plotly charts
plot_optimization_history(study).show()
plot_param_importances(study).show()
plot_slice(study).show()
plot_parallel_coordinate(study).show()

In [None]:
study.best_value

In [None]:
study.best_trial

In [None]:
study.best_params