In [9]:
import os
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
import torch
import h5py
import torchaudio
from torch.utils.data import random_split, Dataset, DataLoader
import torch.optim as optim
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from scipy.special import beta
from torch.special import psi
from tqdm.auto import tqdm
import matplotlib.pyplot as plt
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.data import Subset
import copy
import math
import wandb
import yaml
import time
import gc
import datetime
import optuna
import zarr
import seaborn as sns
from scipy import stats

if wandb.run is not None:
    wandb.run.tags = list(wandb.run.tags) + ["junk"]
wandb.finish()

In [10]:
'''
Import and cache dataset for fast loading in future
'''

# file_path = "C:/Users/maild/mldrivenpeled/data/channel_measurements/zarr_files/channel_3e5-15MHz_3.5V_scale2.zarr"
file_path = "C:/Users/maild/mldrivenpeled/data/channel_measurements/zarr_files/channel_3e5-15MHz_2.8V_scale2_v2.zarr"

if torch.cuda.is_available():
    device = torch.device("cuda")
elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
    device = torch.device("mps") # for M chip Macs
else:
    device = torch.device("cpu")
print("Device", device)

cache_path = file_path.replace(".zarr", "_cached.pt").replace(".h5", "_cached.pt")
if os.path.exists(cache_path):
    data = torch.load(cache_path, map_location=device)
    sent_frames_time = data["sent_frames_time"].to(device)
    received_frames_time = data["received_frames_time"].to(device)
    FREQUENCIES = data["frequencies"].to(device)
    NUM_POINTS_SYMBOL = data["NUM_POINTS_SYMBOL"]
    CP_LENGTH = data["CP_LENGTH"]
    delta_f = FREQUENCIES[1] - FREQUENCIES[0]
    KS = (FREQUENCIES / delta_f).to(torch.int)
    K_MIN = int(KS[0].item())
    K_MAX = int(KS[-1].item())
    NUM_ZEROS = K_MIN - 1
    CP_RATIO = 0.25 # Known from experiment
    NUM_POINTS_FRAME = NUM_POINTS_SYMBOL - CP_LENGTH
    NUM_POS_FREQS_LOW_BAND = K_MAX + 1
    UPSAMPLING_ZEROS = (NUM_POINTS_FRAME  - 2 * NUM_POS_FREQS_LOW_BAND) // 2
    print("Loaded from cache!")

else:
    print("No cache found — loading original dataset...")


        # Open the Zarr root
    root = zarr.open(file_path, mode="r")

    # Load metadata (attributes live under .attrs)
    sent, received, received_time = [], [], []
    # Loop through frames
    num_skipped = 0
    for frame_key in root.group_keys():
        try:
            frame = root[frame_key]
            if FREQUENCIES is None:
                FREQUENCIES = torch.tensor(frame["freqs"][:], dtype=torch.int).real
                NUM_POINTS_SYMBOL = int(frame.attrs["num_points_symbol"])
                CP_LENGTH = int(frame.attrs["cp_length"])
            else:
                pass

            sent.append(torch.tensor(frame["sent"][:], dtype=torch.complex64))
            received.append(torch.tensor(frame["received"][:], dtype=torch.complex64))
            if "received_time" in frame:
                received_time.append(torch.tensor(frame["received_time"][:], dtype=torch.float32))
        except:
            num_skipped += 1
            pass # skip corrupted frames
    print(f"Skipped {num_skipped} corrupted frames")


    sent_frames = torch.stack(sent).squeeze(1)
    received_frames = torch.stack(received).squeeze(1)

    # Establish globals
    delta_f = FREQUENCIES[1] - FREQUENCIES[0]
    KS = (FREQUENCIES / delta_f).to(torch.int)
    K_MIN = int(KS[0].item())
    K_MAX = int(KS[-1].item())
    NUM_ZEROS = K_MIN - 1
    CP_RATIO = 0.25
    NUM_POINTS_FRAME = NUM_POINTS_SYMBOL - CP_LENGTH
    NUM_POS_FREQS_LOW_BAND = K_MAX + 1
    UPSAMPLING_ZEROS = (NUM_POINTS_FRAME  - 2 * NUM_POS_FREQS_LOW_BAND) // 2


    def symbols_to_time(X, num_padding_zeros: int, num_leading_zeros=0):
        'Convert OFDM symbols to real valued signal'
        # Make hermetian symmetric
        Nt, Nf = X.shape
        padding_zeros = torch.zeros(Nt, num_padding_zeros, device=device)
        leading_zeros = torch.zeros(Nt, num_leading_zeros, device=device)
        X = torch.cat([leading_zeros, X.to(device), padding_zeros], dim=-1)
        DC_Nyquist = torch.zeros((X.shape[0], 1), device=X.device)
        X_hermitian = torch.flip(X, dims=[1]).conj()
        X_full = torch.hstack([DC_Nyquist, X, DC_Nyquist, X_hermitian])

        # Convert to time domain
        x_time = torch.fft.ifft(X_full, dim=-1, norm="ortho").real
        return x_time.to(device)



    sent_frames_time = symbols_to_time(sent_frames, UPSAMPLING_ZEROS, NUM_ZEROS)
    # Add cyclic prefix
    sent_frames_time = torch.hstack((sent_frames_time[:, -CP_LENGTH:], sent_frames_time))

    # Handle received time symbols; perform some cleaning if necessary
    N_shortest = min(t.size(-1) for t in received_time)
    N_longest = max(t.size(-1) for t in received_time)
    good_indices = [i for i, x in enumerate(received_time) if x.size(-1) == N_shortest]
    received_frames_time = torch.stack([t for t in received_time if t.size(-1) == N_shortest], dim=0).real
    sent_frames = sent_frames[good_indices]
    received_frames_time = received_frames_time.squeeze(1)



    DELAY_TIME = 0 # If for some reason there is a global delay with measure data adjust here
    if DELAY_TIME > 0:
        sent_frames_time = sent_frames_time[:, :-DELAY_TIME]
    received_frames_time = received_frames_time[:, DELAY_TIME:]
    received_frames_time = received_frames_time - received_frames_time.mean(dim=1, keepdim=True) # Always zero mean
    sent_frames_time = sent_frames_time.to(device)
    received_frames_time = received_frames_time.to(device)


    # Create a cache path
    cache_path = file_path.replace(".zarr", "_cached.pt").replace(".h5", "_cached.pt")

    torch.save({
        "sent_frames_time": sent_frames_time.cpu(),
        "received_frames_time": received_frames_time.cpu(),
        "frequencies": FREQUENCIES.cpu(),
        "NUM_POINTS_SYMBOL": NUM_POINTS_SYMBOL,
        "CP_LENGTH": CP_LENGTH
    }, cache_path)


class ChannelData(Dataset):
    def __init__(self,
                sent_frames,
                received_frames,
                frequencies,
                transform=None,
                target_transform=None):

        self.sent_frames = sent_frames
        self.received_frames = received_frames
        assert len(sent_frames) == len(received_frames)

    def __len__(self):
        return len(self.sent_frames)

    def __getitem__(self, idx):
        return self.sent_frames[idx], self.received_frames[idx]


dataset = ChannelData(sent_frames_time, received_frames_time, FREQUENCIES)

# Split sizes
train_size = int(0.9 * len(dataset))

val_size = int(0.1 * len(dataset))
test_size = len(dataset) - train_size - val_size  # ensures total = 100%

# Perform split
train_dataset, val_dataset, test_dataset = random_split(
    dataset,
    [train_size, val_size, test_size],
    generator=torch.Generator()
)
print("Train Size", train_size)

Device cuda


  data = torch.load(cache_path, map_location=device)


Loaded from cache!
Train Size 6202


In [11]:
class TCNBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, dilation):
        super().__init__()
        self.conv = nn.Conv1d(
            in_channels,
            out_channels,
            kernel_size=kernel_size,
            dilation=dilation,
            padding=0
        )
        self.padding = (kernel_size - 1) * dilation
        self.relu = nn.ReLU()
        self.resample = None
        if in_channels != out_channels:
            self.resample = nn.Conv1d(in_channels, out_channels, kernel_size=1)

    def forward(self, x):
        out = F.pad(x, (self.padding, 0))
        out = self.conv(out)
        out = self.relu(out)
        if self.resample:
            x = self.resample(x)
        return out + x # residual connection


def sample_student_t_mps(mean, std, nu):
    '''
    Wilson-Hilferty Approximation for chi^2 converted to scaled and shifted student t
    '''
    z = torch.randn_like(mean)
    z_chi = torch.randn_like(mean)
    chi2_approx = nu * (1 - 2/(9*nu) + z_chi * torch.sqrt(2/(9*nu))).pow(3)
    scale = torch.sqrt(nu / (chi2_approx + 1e-6))
    return mean + std * z * scale


class TCN(nn.Module):
    def __init__(self, nlayers=3, dilation_base=2, num_taps=10, hidden_channels=32):
        super().__init__()
        layers = []
        in_channels = 1
        for i in range(nlayers):
            dilation = dilation_base ** i
            layers.append(
                TCNBlock(in_channels, hidden_channels, num_taps, dilation)
            )
            in_channels = hidden_channels
        self.tcn = nn.Sequential(*layers)
        self.readout = nn.Conv1d(hidden_channels, 1, kernel_size=1)

        # Calculate the total receptive field for the whole TCN stack
        self.receptive_field = 1
        for i in range(nlayers):
            dilation = dilation_base ** i
            self.receptive_field += (num_taps - 1) * dilation

    def forward(self, xin):
        x = xin.unsqueeze(1)    # [B,1,T]
        out = self.tcn(x)     # [B,H,T]
        out = self.readout(out).squeeze(1)
        out = out - out.mean(dim=1, keepdim=True)  # [B,T]
        return out
    

class TCN_channel(nn.Module):
    def __init__(self, nlayers=3, dilation_base=2, num_taps=10,
                 hidden_channels=32, learn_noise=False, gaussian=True):
        super().__init__()
        layers = []
        in_channels = 1
        for i in range(nlayers):
            dilation = dilation_base ** i
            layers.append(
                TCNBlock(in_channels, hidden_channels, num_taps, dilation)
            )
            in_channels = hidden_channels
        self.learn_noise = learn_noise
        self.tcn = nn.Sequential(*layers)
        if gaussian:
            self.readout = nn.Conv1d(hidden_channels, 2, kernel_size=1) # 2 channels mean | std
        else:
            self.readout = nn.Conv1d(hidden_channels, 3, kernel_size=1) # 3 channels mean | std | nu
        self.num_taps = num_taps
        self.gaussian = gaussian

        self.ar_tap = nn.Parameter(torch.tensor(0.0))
        self.input_tap = nn.Parameter(torch.tensor(0.0))
        self.cross_tap = nn.Parameter(torch.tensor(0.0))
        self.cross_tap_2 = nn.Parameter(torch.tensor(0.0))
        self.cross_tap_3 = nn.Parameter(torch.tensor(0.0))

        if not gaussian:
            with torch.no_grad():
                # Initialize nu bias towards Gaussian for stability
                self.readout.bias[2].fill_(48)

    def forward(self, xin):
        x = xin.unsqueeze(1)    # [B,1,T]
        out = self.tcn(x)     # [B,H,T]
        out = self.readout(out) # [B, 3, T] mean | std | nu
        mean_out = out[:, 0, :]
        log_std_out = out[:, 1, :]
        std_out = torch.exp(log_std_out)
        if not self.gaussian:
            log_nu_out = out[:, 2, :]
            nu_out = torch.nn.functional.softplus(log_nu_out)
            nu_out = torch.clamp(nu_out, 2, 50) # nu between 2 and 50
        mean_out = mean_out - mean_out.mean(dim=1, keepdim=True)  # [B ,T]

        # # Produce noisy output
        if self.gaussian:
            z = torch.randn_like(mean_out)
            noisy_out = mean_out + std_out * z
            nu_out = torch.zeros_like(mean_out)
        else:
            noisy_out = sample_student_t_mps(mean_out, std_out, nu_out)
            
        if self.learn_noise:
            return noisy_out, mean_out, std_out, nu_out
        else:
            return mean_out

In [12]:
def train(model, optimizer, loss_fn, loop):
    model.train()
    total_loss = 0
    batch_count = 0
    thetas = []
    for batch in loop:
        x, y = batch[0], batch[1]
        x, y = x.to(device), y.to(device)
        optimizer.zero_grad()
        noisy_y_pred, y_pred, y_pred_std, y_pred_nu = model(x)

        # calculate residual
        r = y - y_pred

        if model.learn_noise:
            if model.gaussian:
                loss = loss_fn(r, y_pred_std)
            else:
                loss = loss_fn(r, y_pred_std, y_pred_nu)
        else:
            loss = loss_fn(y, y_pred)

        mse_loss = F.mse_loss(y_pred, y)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
        wandb.log({"nnl_train_loss": loss.item()})
        wandb.log({"mse_train_loss": mse_loss.item()})
        lr = optimizer.param_groups[0]["lr"]
        wandb.log({"learning_rate": lr})
        batch_count += 1
        loop.set_postfix(loss=loss.item())
    loop.close()


def val(model, loss_fn, val_loader):
    model.eval()
    val_loss = 0
    batch_count = 0
    y_preds = []
    std_preds = []
    nu_preds = []
    true_ys = []
    noisy_ys = []
    val_mse_loss = 0
    nrmse_pct_loss = 0.0

    with torch.no_grad():
        for x, y in val_loader:
            x = x.to(device)
            y = y.to(device)
            noisy_y, mean_y, std_y, nu_y = model(x)
            y_preds.append(mean_y)
            std_preds.append(std_y)
            nu_preds.append(nu_y)
            true_ys.append(y)
            noisy_ys.append(noisy_y)
            if model.learn_noise:
                if model.gaussian:
                    loss = loss_fn(y - mean_y, std_y) # Use mean for validation
                else:
                    loss = loss_fn(y - mean_y, std_y, nu_y) # Use mean for validation
            else:
                loss = loss_fn(y, mean_y)
            r = y - mean_y
            mse_loss = F.mse_loss(y, mean_y)
            nrmse_pct_loss += (torch.sqrt(torch.mean(r ** 2) / torch.mean(y ** 2)) * 100).item()
            val_mse_loss += mse_loss.item()
            val_loss += loss.item()
            batch_count += 1
    avg_val_loss = (val_loss / batch_count)
    avg_val_mse_loss = (val_mse_loss / batch_count)
    avg_nrmse_pct_loss = (nrmse_pct_loss / batch_count)

    y_preds = torch.vstack(y_preds)
    std_preds = torch.vstack(std_preds)
    nu_preds = torch.vstack(nu_preds)
    true_ys = torch.vstack(true_ys)
    noisy_ys = torch.vstack(noisy_ys)


    noise_pred = noisy_ys - y_preds
    noise_power_pred_k = torch.fft.fft(noise_pred[:, CP_LENGTH:], norm='ortho', dim=-1).abs().square().mean(dim=0)
    signal_power_model = torch.fft.fft(y_preds[:, CP_LENGTH:], norm='ortho', dim=-1).abs().square().mean(dim=0)
    snr_k_model = (signal_power_model / (noise_power_pred_k + 1e-8))
    sample_rate = delta_f * NUM_POINTS_FRAME
    snr_mag_model = 10 * torch.log10(torch.abs(snr_k_model) + 1e-8)
    freqs = torch.fft.fftfreq(len(snr_mag_model), d=1/sample_rate)
    half = len(freqs)//2
    freqs = freqs[:half]
    snr_mag_model = snr_mag_model[:half]
    fig, ax = plt.subplots(figsize=(8, 5))
    ax.plot(freqs, snr_mag_model.cpu(), lw=1.5, color="orange")
    ax.set_title("SNR vs Frequency (Model)", fontsize=11)
    ax.set_xlabel("Frequency", fontsize=9)
    ax.set_ylabel("SNR Magnitude (dB)", fontsize=9)
    ax.grid(True, linestyle='--', alpha=0.6)
    # # ---- Log to WandB ----
    wandb.log({"SNR_Frequency": wandb.Image(fig)})
    plt.close(fig)


    # Log both scalar and histogram
    wandb.log({
        'val_nll_loss': avg_val_loss,
        "avg_val_mse_loss": avg_val_mse_loss,
        "avg_nrmse_pct_loss": avg_nrmse_pct_loss
    })

    # print(f"Average Val Loss: {avg_val_loss:.2e}")

    # visualize_std(model, x[:, :200])
    return avg_val_loss


In [13]:
def students_t_loss(difference, y_pred_std, y_pred_nu):
    # nu = y_pred_nu.clamp_min(2.0)
    nu = y_pred_nu
    z_resid = (difference) / (y_pred_std)
    term1 = -1 * torch.lgamma((nu + 1) / 2) + 0.5 * torch.log(torch.pi * nu) + torch.lgamma(nu / 2) + torch.log(y_pred_std + 1e-8)
    term2 = ((nu + 1) / 2) * torch.log(1 + (1 / nu) * torch.square(z_resid) + 1e-8)
    loss = torch.mean(term1 + term2)
    if torch.isnan(loss):
        raise ValueError("NaN in loss")
    return loss

def gaussian_nll(difference, y_pred_std):
    term1 = 0.5 * torch.log(2 * torch.pi * (y_pred_std ** 2))
    term2 = 0.5 * torch.square((difference) / y_pred_std)
    loss = torch.mean(term1 + term2)
    if torch.isnan(loss):
        raise ValueError("NaN in loss")
    return loss

noise_model = None

def make_optimizer(mode):
    if mode == "channel_only":
        return optim.AdamW(
            list(channel_model.parameters()),
            lr=float(config.lr_channel),
            weight_decay=float(config.wd_channel)
        )

    elif mode == "noise_only":
        return optim.AdamW(
            list(noise_model.parameters()),
            lr=float(config.lr_noise),
            weight_decay=float(config.wd_noise)
        )

    elif mode == "joint":
        return optim.AdamW(
            list(channel_model.parameters()) +
            list(noise_model.parameters()),
            lr=float(config.lr_joint),
            weight_decay=float(config.wd_joint)
        )
    else:
        raise ValueError("Unknown mode")


script_dir = os.getcwd()
config_path = os.path.join(script_dir, "..", "offline_time_channel_config.yml")
with open(config_path, "r") as f:
    hyperparams = yaml.safe_load(f)

# Start Weights and Biases session
wandb.init(project="mldrivenpeled",
           config=hyperparams, tags=['channel_model'])
config = wandb.config

schedule = config.training_schedule


RECEPTIVE_FIELD = (1 + (config.num_taps - 1) * (config.dilation_base**config.nlayers - 1) // (config.dilation_base - 1))


print(f"WandB run info:")
print(f"  Name: {wandb.run.name}")
print(f"  ID: {wandb.run.id}")
print(f"  URL: {wandb.run.url}")
print("Chosen hyperparameters for this session:")
print(config)


# Create dataloader
train_loader = DataLoader(train_dataset, shuffle=True, batch_size=config.batch_size, drop_last=False)
test_loader = DataLoader(test_dataset)
val_loader = DataLoader(val_dataset, shuffle=True, batch_size=config.batch_size, drop_last=False)


channel_model = TCN_channel(
    nlayers=config.nlayers,
    dilation_base=config.dilation_base,
    num_taps=config.num_taps,
    hidden_channels=config.hidden_channels,
    learn_noise=config.learn_noise,
    gaussian=config.gaussian
).to(device)

initial_model_state = copy.deepcopy(channel_model.state_dict())

if channel_model.gaussian:
    loss_fn = gaussian_nll
else:
    loss_fn = students_t_loss

# loss_fn = F.mse_loss

num_epochs = config.epochs

epoch_counter = 0
for phase in schedule:
    mode = phase["mode"]
    num_batches = None # if None, all batches run
    if "batches" in phase:
        num_batches = phase["batches"]
    num_phase_epochs = phase["epochs"]


    optimizer = make_optimizer(mode)

    scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=1, min_lr=1e-6)

    for local_epoch in range(num_phase_epochs):
        epoch_counter += 1

        loop = tqdm(train_loader, desc=f'Epoch {epoch_counter} [{mode}]')
        train(channel_model,
              optimizer,
              loss_fn,
              loop)

        avg_val_loss = val(channel_model,
                           loss_fn,
                           val_loader)

        scheduler.step(avg_val_loss)


def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"Trainable parameters: {count_parameters(channel_model):,}")
# Freeze model
for param in channel_model.parameters():
    param.requires_grad = False

# Save model
torch.save({
    "channel_model": channel_model.state_dict(),
}, "channel_model_final.pth")

artifact = wandb.Artifact("channel_model", type="model")
artifact.add_file("channel_model_final.pth")
wandb.log_artifact(artifact)
print("Finished!")
run_name = wandb.run.name
wandb.finish()

WandB run info:
  Name: sleek-violet-7785
  ID: tx4ofcbb
  URL: https://wandb.ai/dylanbackprops-university-of-washington/mldrivenpeled/runs/tx4ofcbb
Chosen hyperparameters for this session:
{'CP_ratio': 0.25, 'batch_size': 16, 'num_taps': 10, 'epochs': 10, 'gain': 20, 'lr': 0.001, 'nlayers': 2, 'hidden_channels': 8, 'dilation_base': 2, 'num_points_symbol': 4000, 'learn_noise': True, 'num_symbols_per_frame': 1, 'scheduler_type': 'reduce_lr_on_plateu', 'weight_init': 'default', 'gaussian': False, 'Nf': 1499, 'Nt': 1, 'flow': 300000, 'fhigh': '15e6', 'fnyquist': '30e6', 'subcarrier_spacing': '1e4', 'dc_offset': 3.5, 'lr_channel': '1e-3', 'lr_noise': '1e-3', 'lr_joint': '1e-3', 'wd_channel': '1e-3', 'wd_noise': '1e-3', 'wd_joint': '1e-4', 'training_schedule': [{'epochs': 10, 'mode': 'channel_only'}]}


Epoch 1 [channel_only]: 100%|██████████| 388/388 [00:04<00:00, 96.19it/s, loss=-3.4]  
Epoch 2 [channel_only]: 100%|██████████| 388/388 [00:03<00:00, 112.05it/s, loss=-3.46]
Epoch 3 [channel_only]: 100%|██████████| 388/388 [00:03<00:00, 117.46it/s, loss=-3.49]
Epoch 4 [channel_only]: 100%|██████████| 388/388 [00:03<00:00, 110.58it/s, loss=-3.5] 
Epoch 5 [channel_only]: 100%|██████████| 388/388 [00:03<00:00, 101.75it/s, loss=-3.51]
Epoch 6 [channel_only]: 100%|██████████| 388/388 [00:03<00:00, 105.42it/s, loss=-3.5] 
Epoch 7 [channel_only]: 100%|██████████| 388/388 [00:03<00:00, 108.88it/s, loss=-3.53]
Epoch 8 [channel_only]: 100%|██████████| 388/388 [00:03<00:00, 106.67it/s, loss=-3.52]
Epoch 9 [channel_only]: 100%|██████████| 388/388 [00:03<00:00, 106.33it/s, loss=-3.52]
Epoch 10 [channel_only]: 100%|██████████| 388/388 [00:03<00:00, 107.21it/s, loss=-3.51]


Trainable parameters: 784


[34m[1mwandb[0m: [32m[41mERROR[0m The nbformat package was not found. It is required to save notebook history.


Finished!


0,1
avg_nrmse_pct_loss,█▄▃▃▁▁▁▂▂▁
avg_val_mse_loss,█▄▃▃▁▁▁▂▂▁
learning_rate,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
mse_train_loss,▄▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
nnl_train_loss,█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
val_nll_loss,█▅▄▄▃▃▂▂▂▁

0,1
avg_nrmse_pct_loss,12.1868
avg_val_mse_loss,5e-05
learning_rate,0.001
mse_train_loss,5e-05
nnl_train_loss,-3.50941
val_nll_loss,-3.51894


### Memory Polynomial for Comparison

In [6]:
'''Rename for convenience'''
X = sent_frames_time.cpu().numpy()
Y = received_frames_time.cpu().numpy()

In [7]:


def create_regressors(X, memory_linear, memory_nonlinear, nonlinearity_order):
    B, T = X.shape
    # pad data with longest memory
    max_taps = max(memory_linear, memory_nonlinear) # memory here is strictly in the past (current time step not considered)
    # Each example and target will get a matrix and column vector. All will be stacked
    # to form a A with shape [NxT, memory_linear + memory_nonlinearxnonlinear_order] regressor matrix

    batched_regressor_cols = []
    num_regressors = memory_linear + (memory_nonlinear * (nonlinearity_order - 1)) + 2
    regressor_length = T * B


    for i in range(memory_linear + 1):
        X_shifted = np.roll(X, i, axis=1)
        X_shifted[:, :i] = 0.0
        batched_regressor_cols.append(X_shifted)

    for k in range(2, nonlinearity_order + 1):
        for j in range(memory_nonlinear + 1):
            X_shifted = np.roll(X, j, axis=1)
            X_shifted[:, :j] = 0.0
            batched_regressor_cols.append(np.power(X_shifted, k))

    stack = np.array(batched_regressor_cols) # [features, B, T]
    stack = stack.transpose(1, 2, 0) # [B, T, freatures]
    A = stack.reshape(regressor_length, num_regressors)
    return A

def memory_polynomial(X, Y, memory_linear, memory_nonlinear, nonlinearity_order):
    A = create_regressors(X, memory_linear, memory_nonlinear, nonlinearity_order)
    Y_flat = Y.flatten()

    weights, residuals, rank, s = np.linalg.lstsq(A, Y_flat, rcond=None)
    # print("Solved Weights:", weights)
    y_pred = A @ weights
    # Reshape back to (B, T) for analysis
    B, T = X.shape
    y_pred = y_pred.reshape(B, T)
    residuals = Y - y_pred
    return weights, y_pred, residuals

    

weights, y_pred, residuals = memory_polynomial(X, Y, memory_linear=10, memory_nonlinear=10, nonlinearity_order=2)


# Calculate NRMSE
signal_power = np.mean(np.square(Y))
error_power = np.mean(np.square(residuals))

nrmse_pct = np.sqrt(error_power / signal_power) * 100
print("NRMSE  %", nrmse_pct)

NRMSE  % 13.434357


In [None]:
class memory_polynomial_channel(nn.Module):
    def __init__(self, weights):
        super().__init__()
        self.weights = torch.tensor(weights, device=device)

        def _create_regressors(x):
                B, T = X.shape
        # pad data with longest memory
        max_taps = max(memory_linear, memory_nonlinear) # memory here is strictly in the past (current time step not considered)
        # Each example and target will get a matrix and column vector. All will be stacked
        # to form a A with shape [NxT, memory_linear + memory_nonlinearxnonlinear_order] regressor matrix

        batched_regressor_cols = []
        num_regressors = memory_linear + (memory_nonlinear * (nonlinearity_order - 1)) + 2
        regressor_length = T * B


        for i in range(memory_linear + 1):
            X_shifted = np.roll(X, i, axis=1)
            X_shifted[:, :i] = 0.0
            batched_regressor_cols.append(X_shifted)

        for k in range(2, nonlinearity_order + 1):
            for j in range(memory_nonlinear + 1):
                X_shifted = np.roll(X, j, axis=1)
                X_shifted[:, :j] = 0.0
                batched_regressor_cols.append(np.power(X_shifted, k))

        stack = np.array(batched_regressor_cols) # [features, B, T]
        stack = stack.transpose(1, 2, 0) # [B, T, freatures]
        A = stack.reshape(regressor_length, num_regressors)