In [27]:
import numpy as np
import yaml
import pprint
from pathlib import Path
import random
import torch
from torch.utils.data import DataLoader, Dataset, Subset
from torch.nn import MSELoss
from torch.optim import Adam
from torch import nn
import matplotlib.pyplot as plt

In [56]:
def compute_channel_power(data):
    return np.mean(np.abs(data)**2)

def compute_autocorrelation(hp_ls):
    if hp_ls.ndim == 2:
        hp_ls = hp_ls.unsqueeze(0)  # add batch dimension
    assert hp_ls.ndim == 3

    batch_size = hp_ls.shape[0]

    hp_ls = hp_ls.reshape(batch_size, -1)  # row-major flatten

    hp_ls_autocorr = np.conj(hp_ls.T) @ hp_ls
    return hp_ls_autocorr / batch_size

def lmmse_interpolation(hp_ls, SNR, pilot_autocorr, channel_pilot_corr):
    raise NotImplementedError("LMMSE interpolation is not implemented yet")


In [57]:
class TDLDataset(Dataset):
    def __init__(
        self, data_path, *, file_size, normalization_stats=None,return_pilots_only=True, num_subcarriers=120,
        num_symbols=14, SNRs=[0, 5, 10, 15, 20, 25, 30],
        pilot_symbols=[2, 11], pilot_every_n=2):
        """
        This class loads the data from the folder and returns a dataset of channels.

        data_path: path to the folder containing the data
        file_size: number of channels per file
        return_pilots_only: if True, only the LS channel estimate at pilots are returned
            if False, the LS channel estimate is returned as a sparse channel matrix with non-zero 
            values only at the pilot subcarriers and time instants.
        num_subcarriers: number of subcarriers
        num_symbols: number of OFDM symbols

        SNRs: list of SNR values to randomly sample from when return LS estimates.
            AWGN is added to simulate LS estimatation error
        pilot_symbols: list of OFDM symbol indices where pilots are placed
        pilot_every_n: number of subcarriers between pilot subcarriers
        """
        
        self.file_size = int(file_size)
        self.normalization_stats = normalization_stats
        self.return_pilots_only = return_pilots_only
        self.num_subcarriers = num_subcarriers
        self.num_symbols = num_symbols
        self.SNRs = SNRs
        self.pilot_symbols = pilot_symbols
        self.pilot_every_n = pilot_every_n

        self.file_list = list(Path(data_path).glob("*.npy"))
        self.stats = self._get_stats_per_file(self.file_list)
        self.data = self._load_data_from_folder(self.file_list, self.normalization_stats)
        self.pilot_mask = self._get_pilot_mask()

        self.num_pilot_symbols = len(self.pilot_symbols)
        self.num_pilot_subcarriers = int(self.pilot_mask.sum()) // self.num_pilot_symbols

    def __len__(self):
        return len(self.file_list) * self.file_size

    def __getitem__(self, idx):
        file_idx = idx // self.file_size
        sample_idx = idx % self.file_size
        file_path = self.file_list[file_idx]
        channels = self.data[file_path]
        channel = channels[sample_idx].squeeze().T

        SNR = random.choice(self.SNRs)
        LS_channel_at_pilots = self._get_LS_estimate_at_pilots(channel, SNR)
        stats = self.stats[file_path]
        stats["SNR"] = SNR

        LS_channel_at_pilots_torch = torch.from_numpy(LS_channel_at_pilots).to(torch.complex64)
        channel_torch = torch.from_numpy(channel).to(torch.complex64)
        return LS_channel_at_pilots_torch, channel_torch, stats
    
    @staticmethod
    def _load_data_from_folder(file_list, normalization_stats=None):
        data = {}
        for file_path in file_list:
            file_data = np.load(file_path)
            if normalization_stats is not None:
                normalized_real = (file_data.real - normalization_stats["real_mean"]) / normalization_stats["real_std"]
                normalized_imag = (file_data.imag - normalization_stats["imag_mean"]) / normalization_stats["imag_std"]
                file_data = normalized_real + 1j * normalized_imag
            data[file_path] = file_data
        return data

    @staticmethod
    def _get_stats_per_file(file_list):
        stats = {}

        for file_path in file_list:
            file_name = str(file_path.stem)
            file_parts = file_name.split("_")

            if file_parts[0] == "delay":
                delay_spread = int(file_parts[2])  # [delay, spread, y, doppler, x]
                doppler_shift = int(file_parts[-1])
            elif file_parts[0] == "doppler":
                doppler_shift = int(file_parts[1])  # [doppler, x, delay, spread, y]
                delay_spread = int(file_parts[-1])
            else:
                raise ValueError(f"File {file_name} has unexpected format")
            
            if file_path not in stats:
                stats[file_path] = {"doppler_shift": doppler_shift, "delay_spread": delay_spread}
            else:
                raise ValueError(f"File {file_path} already in stats, but should not be")
            
        return stats
    
    def _get_LS_estimate_at_pilots(self, channel_matrix, SNR):
        # unit symbol power and unit channel power --> rx noise var = LS error var
        noise_std = np.sqrt(1 / (10**(SNR / 10)))
        noise_real_imag = noise_std / np.sqrt(2)

        if self.return_pilots_only:
            pilot_mask_bool = self.pilot_mask.astype(bool)
            channel_at_pilots = channel_matrix[pilot_mask_bool]
            channel_at_pilots = channel_at_pilots.reshape(self.num_pilot_subcarriers, self.num_pilot_symbols)
            noise_real = noise_real_imag * np.random.randn(self.num_pilot_subcarriers, self.num_pilot_symbols)
            noise_imag = noise_real_imag * np.random.randn(self.num_pilot_subcarriers, self.num_pilot_symbols)
            noise = noise_real + 1j * noise_imag
        else:
            channel_at_pilots = self.pilot_mask * channel_matrix
            noise_real = noise_real_imag * np.random.randn(self.num_subcarriers, self.num_symbols)
            noise_imag = noise_real_imag * np.random.randn(self.num_subcarriers, self.num_symbols)
            noise = noise_real + 1j * noise_imag
            noise = noise * self.pilot_mask
        
        channel_at_pilots_LS = channel_at_pilots + noise
            
        return channel_at_pilots_LS

    def _get_pilot_mask(self):
        pilot_mask = np.zeros((self.num_subcarriers, self.num_symbols))
        pilot_mask_subcarrier_indices = np.arange(0, self.num_subcarriers, self.pilot_every_n)
        pilot_mask[np.ix_(pilot_mask_subcarrier_indices, self.pilot_symbols)] = 1
        return pilot_mask
        

In [58]:
TRAIN_SET_PATH = "/opt/shared/datasets/NeoRadiumTDLdataset/train/TDLA"
TEST_SNR_LEVELS = [0, 5, 10, 15, 20, 25, 30]

In [59]:
file_list = list(Path(TRAIN_SET_PATH).glob("*.npy"))

print("There are", len(file_list), "files in the dataset")

channel_power_list = []
for file in file_list:
    channel_power_list.append(compute_channel_power(np.load(file).squeeze()))

print("Average channel power:", np.array(channel_power_list).mean())


There are 240 files in the dataset
Average channel power: 1.0


In [60]:
with open(Path(TRAIN_SET_PATH, "metadata.yaml"), "r") as f:
    train_metadata = yaml.safe_load(f)

large_dataset = TDLDataset(
    TRAIN_SET_PATH, 
    file_size=train_metadata["config"]["num_channels_per_config"],
    SNRs=TEST_SNR_LEVELS,
    return_pilots_only=True)

data_loader = DataLoader(large_dataset, batch_size=512, shuffle=True)

In [None]:
for batch in data_loader:
    hp_ls, h_true, batch_stats = batch
    hp_ls = hp_ls.numpy()
    h_true = h_true.numpy()
    print(hp_ls.shape)
    print(h_true.shape)
    pilot_autocorr = compute_autocorrelation(hp_ls)
    print(pilot_autocorr.shape)
    break    

(512, 60, 2)
(512, 120, 14)
(120, 120)


In [66]:
np.diag(pilot_autocorr)

array([1.1274128+0.0000000e+00j, 1.1944957+0.0000000e+00j,
       1.0576942+0.0000000e+00j, 1.1438823+0.0000000e+00j,
       1.0601453+0.0000000e+00j, 1.1531299+0.0000000e+00j,
       1.0842636+0.0000000e+00j, 1.1956844+0.0000000e+00j,
       1.068183 +0.0000000e+00j, 1.1836448+0.0000000e+00j,
       1.0921831+0.0000000e+00j, 1.2332484+0.0000000e+00j,
       1.072501 +0.0000000e+00j, 1.1754954+0.0000000e+00j,
       1.1798823+0.0000000e+00j, 1.2376853+0.0000000e+00j,
       1.0406507+1.2106793e-10j, 1.1957436+2.1899077e-10j,
       1.0838684-2.6505378e-10j, 1.192205 +3.5072492e-10j,
       1.1083679+0.0000000e+00j, 1.1594903+0.0000000e+00j,
       1.1049502+0.0000000e+00j, 1.2553414+0.0000000e+00j,
       1.1489837+0.0000000e+00j, 1.2509367+0.0000000e+00j,
       1.0591795+0.0000000e+00j, 1.1884764+0.0000000e+00j,
       1.1305239+0.0000000e+00j, 1.3046973+0.0000000e+00j,
       1.088654 +0.0000000e+00j, 1.2595648+0.0000000e+00j,
       1.1152133+0.0000000e+00j, 1.203526 +0.0000000e+00