In [1]:
import numpy as np
import yaml
from pathlib import Path
import random
import torch
from torch.utils.data import DataLoader, Dataset, Subset
from tqdm import tqdm

In [2]:
def compute_channel_power(data):
    return np.mean(np.abs(data)**2)

def compute_autocorrelation(hp_ls):
    if hp_ls.ndim == 2:
        hp_ls = hp_ls.unsqueeze(0)  # add batch dimension
    assert hp_ls.ndim == 3

    batch_size = hp_ls.shape[0]

    hp_ls = hp_ls.reshape(batch_size, -1)  # row-major flatten

    hp_ls_autocorr = np.conj(hp_ls.T) @ hp_ls
    return hp_ls_autocorr / batch_size

def compute_cross_correlation(hp_ls, h_true):
    if hp_ls.ndim == 2:
        hp_ls = hp_ls.unsqueeze(0)  # add batch dimension
    
    if h_true.ndim == 2:
        h_true = h_true.unsqueeze(0)  # add batch dimension
    
    batch_size = hp_ls.shape[0]

    assert hp_ls.ndim == 3 and h_true.ndim == 3
    assert hp_ls.shape[0] == h_true.shape[0]

    hp_ls = hp_ls.reshape(batch_size, -1)  # row-major flatten
    h_true = h_true.reshape(batch_size, -1)  # row-major flatten

    hp_ls_h_true_cross_corr = np.conj(h_true.T) @ hp_ls
    
    return hp_ls_h_true_cross_corr/ batch_size
    

def lmmse_interpolation(hp_ls, sigma2, pilot_autocorr, channel_pilot_corr):
    sigma2_eye = np.eye(pilot_autocorr.shape[0]) * sigma2
    pilot_autocorr_inv = np.linalg.inv(pilot_autocorr + sigma2_eye)
    h_true_hat = channel_pilot_corr @ pilot_autocorr_inv @ hp_ls
    return h_true_hat


In [3]:
class TDLDataset(Dataset):
    def __init__(
        self, data_path, *, file_size, normalization_stats=None,return_pilots_only=True, num_subcarriers=120,
        num_symbols=14, SNRs=[0, 5, 10, 15, 20, 25, 30],
        pilot_symbols=[2, 11], pilot_every_n=2):
        """
        This class loads the data from the folder and returns a dataset of channels.

        data_path: path to the folder containing the data
        file_size: number of channels per file
        return_pilots_only: if True, only the LS channel estimate at pilots are returned
            if False, the LS channel estimate is returned as a sparse channel matrix with non-zero 
            values only at the pilot subcarriers and time instants.
        num_subcarriers: number of subcarriers
        num_symbols: number of OFDM symbols

        SNRs: list of SNR values to randomly sample from when return LS estimates.
            AWGN is added to simulate LS estimatation error
        pilot_symbols: list of OFDM symbol indices where pilots are placed
        pilot_every_n: number of subcarriers between pilot subcarriers
        """
        
        self.file_size = int(file_size)
        self.normalization_stats = normalization_stats
        self.return_pilots_only = return_pilots_only
        self.num_subcarriers = num_subcarriers
        self.num_symbols = num_symbols
        self.SNRs = SNRs
        self.pilot_symbols = pilot_symbols
        self.pilot_every_n = pilot_every_n
        self.noise_variance = self._get_noise_variance(SNRs)

        self.file_list = list(Path(data_path).glob("*.npy"))
        self.stats = self._get_stats_per_file(self.file_list)
        self.data = self._load_data_from_folder(self.file_list, self.normalization_stats)
        self.pilot_mask = self._get_pilot_mask()

        self.num_pilot_symbols = len(self.pilot_symbols)
        self.num_pilot_subcarriers = int(self.pilot_mask.sum()) // self.num_pilot_symbols

    def __len__(self):
        return len(self.file_list) * self.file_size

    def __getitem__(self, idx):
        file_idx = idx // self.file_size
        sample_idx = idx % self.file_size
        file_path = self.file_list[file_idx]
        channels = self.data[file_path]
        channel = channels[sample_idx].squeeze().T

        SNR = random.choice(self.SNRs)
        LS_channel_at_pilots = self._get_LS_estimate_at_pilots(channel, SNR)
        stats = self.stats[file_path]
        stats["SNR"] = SNR

        LS_channel_at_pilots_torch = torch.from_numpy(LS_channel_at_pilots).to(torch.complex64)
        channel_torch = torch.from_numpy(channel).to(torch.complex64)
        return LS_channel_at_pilots_torch, channel_torch, stats
    
    @staticmethod
    def _load_data_from_folder(file_list, normalization_stats=None):
        data = {}
        for file_path in file_list:
            file_data = np.load(file_path)
            if normalization_stats is not None:
                normalized_real = (file_data.real - normalization_stats["real_mean"]) / normalization_stats["real_std"]
                normalized_imag = (file_data.imag - normalization_stats["imag_mean"]) / normalization_stats["imag_std"]
                file_data = normalized_real + 1j * normalized_imag
            data[file_path] = file_data
        return data

    @staticmethod
    def _get_stats_per_file(file_list):
        stats = {}

        for file_path in file_list:
            file_name = str(file_path.stem)
            file_parts = file_name.split("_")

            if file_parts[0] == "delay":
                delay_spread = int(file_parts[2])  # [delay, spread, y, doppler, x]
                doppler_shift = int(file_parts[-1])
            elif file_parts[0] == "doppler":
                doppler_shift = int(file_parts[1])  # [doppler, x, delay, spread, y]
                delay_spread = int(file_parts[-1])
            else:
                raise ValueError(f"File {file_name} has unexpected format")
            
            if file_path not in stats:
                stats[file_path] = {"doppler_shift": doppler_shift, "delay_spread": delay_spread}
            else:
                raise ValueError(f"File {file_path} already in stats, but should not be")
            
        return stats
    
    def _get_noise_variance(self, SNRs):
        noise_variances = []
        for SNR in SNRs:
            noise_variance = 1 / (10**(SNR / 10))
            noise_variances.append(noise_variance)
        return np.mean(np.array(noise_variances))
    
    def _get_LS_estimate_at_pilots(self, channel_matrix, SNR):
        # unit symbol power and unit channel power --> rx noise var = LS error var
        noise_std = np.sqrt(1 / (10**(SNR / 10)))
        noise_real_imag = noise_std / np.sqrt(2)

        if self.return_pilots_only:
            pilot_mask_bool = self.pilot_mask.astype(bool)
            channel_at_pilots = channel_matrix[pilot_mask_bool]
            channel_at_pilots = channel_at_pilots.reshape(self.num_pilot_subcarriers, self.num_pilot_symbols)
            noise_real = noise_real_imag * np.random.randn(self.num_pilot_subcarriers, self.num_pilot_symbols)
            noise_imag = noise_real_imag * np.random.randn(self.num_pilot_subcarriers, self.num_pilot_symbols)
            noise = noise_real + 1j * noise_imag
        else:
            channel_at_pilots = self.pilot_mask * channel_matrix
            noise_real = noise_real_imag * np.random.randn(self.num_subcarriers, self.num_symbols)
            noise_imag = noise_real_imag * np.random.randn(self.num_subcarriers, self.num_symbols)
            noise = noise_real + 1j * noise_imag
            noise = noise * self.pilot_mask
        
        channel_at_pilots_LS = channel_at_pilots + noise
            
        return channel_at_pilots_LS

    def _get_pilot_mask(self):
        pilot_mask = np.zeros((self.num_subcarriers, self.num_symbols))
        pilot_mask_subcarrier_indices = np.arange(0, self.num_subcarriers, self.pilot_every_n)
        pilot_mask[np.ix_(pilot_mask_subcarrier_indices, self.pilot_symbols)] = 1
        return pilot_mask
        

In [4]:
TRAIN_SET_PATH = "/opt/shared/datasets/NeoRadiumTDLdataset/train/TDLA"
TEST_SNR_LEVELS = [100]
LIMIT_CHANNELS = 1000

In [5]:
file_list = list(Path(TRAIN_SET_PATH).glob("*.npy"))

print("There are", len(file_list), "files in the dataset")

channel_power_list = []
for file in file_list:
    channel_power_list.append(compute_channel_power(np.load(file).squeeze()))

print("Average channel power:", np.array(channel_power_list).mean())


There are 240 files in the dataset
Average channel power: 1.0


In [6]:
with open(Path(TRAIN_SET_PATH, "metadata.yaml"), "r") as f:
    train_metadata = yaml.safe_load(f)

large_dataset = TDLDataset(
    TRAIN_SET_PATH, 
    file_size=train_metadata["config"]["num_channels_per_config"],
    SNRs=TEST_SNR_LEVELS,
    return_pilots_only=True)

small_dataset_indices = np.random.choice(len(large_dataset), size=LIMIT_CHANNELS, replace=False)

small_dataset = Subset(large_dataset, small_dataset_indices)

data_loader = DataLoader(small_dataset, batch_size=512, shuffle=True)

In [7]:
size_hp_ls = large_dataset.num_pilot_subcarriers * large_dataset.num_pilot_symbols
size_h_true = large_dataset.num_subcarriers * large_dataset.num_symbols
pilot_autocorr = np.zeros((size_hp_ls, size_hp_ls), dtype=np.complex128)
h_true_hp_ls_crosscorr = np.zeros((size_h_true, size_hp_ls), dtype=np.complex128)

for batch in data_loader:
    hp_ls, h_true, batch_stats = batch
    hp_ls = hp_ls.numpy()
    h_true = h_true.numpy()
    pilot_autocorr += compute_autocorrelation(hp_ls)
    h_true_hp_ls_crosscorr += compute_cross_correlation(hp_ls, h_true)

pilot_autocorr /= len(data_loader)
h_true_hp_ls_crosscorr /= len(data_loader)

In [8]:
print("R_hphp shape", pilot_autocorr.shape)
print("R_hhp shape", h_true_hp_ls_crosscorr.shape)

print("Channel power:", np.mean(np.diag(pilot_autocorr).real), "when SNRs are", TEST_SNR_LEVELS)
print("Noise Power:", large_dataset.noise_variance, "when SNRs are", TEST_SNR_LEVELS)

R_hphp shape (120, 120)
R_hhp shape (1680, 120)
Channel power: 0.9770109874506792 when SNRs are [100]
Noise Power: 1e-10 when SNRs are [100]


In [9]:
errors = []
for channel in tqdm(small_dataset):
    hp_ls, h_true, batch_stats = channel
    hp_ls = hp_ls.numpy()
    h_true = h_true.numpy()

    hp_ls = hp_ls.flatten()  # row-major flatten
    
    h_true_hat = lmmse_interpolation(hp_ls, large_dataset.noise_variance, pilot_autocorr, h_true_hp_ls_crosscorr)
    h_true_hat = h_true_hat.reshape(large_dataset.num_subcarriers, large_dataset.num_symbols)

    error = np.mean(np.abs(h_true - h_true_hat)**2)
    errors.append(error) 

100%|██████████| 1000/1000 [00:01<00:00, 677.38it/s]


In [10]:
np.array(errors).mean()

416.9051680059532

In [11]:
np.array(errors).min()

0.1229649040793855

In [12]:
np.array(errors).max()

11066.227087509948