In [1]:
%load_ext autoreload
%autoreload 2

import torch
import numpy as np 
import matplotlib.pyplot as plt
from IPython import display
import glob
import scipy

from src.mat_dataset import MAT_Dataset,EVAL_Dateset
from torch.utils.data import DataLoader
from torchsummary import summary
from src.models import Beam_DnCNN_3D, train_model

from src.metrics import complex_MSE
import json

In [None]:
def gen_mimo_noise(H, SNR):
    N_tti, N_ue_ant, N_bs, N_subc = H.shape[0], H.shape[1],H.shape[2], H.shape[3]
    noise_SRS = torch.sqrt(0.5) * (torch.rand(N_tti, N_ue_ant, N_bs, N_subc) + 1j*torch.rand(N_tti, N_ue_ant, N_bs, N_subc))
    gain = torch.sqrt(torch.mean(H * H.conj(), dim = -1).unsqueeze(-1).repeat(1,1,1,288))
    noise_SRS_normed = 10**(-SNR/20) * gain * noise_SRS
    return noise_SRS_normed

In [None]:
file_list = glob.glob('./data/*.mat')
mode_path = glob.glob('./checkpoints/*.pt')

dataset = MAT_Dataset(path = file_list[0], UEs = [0,1,2,3])
dataloader = DataLoader(dataset, 5, shuffle = False)



cfg = json.load(open("model_config.txt"))
model = Beam_DnCNN_3D(cfg = cfg, n_layers = 15, n_features = 20)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
weights = torch.load(mode_path[0], map_location=torch.device('cpu'))
model.load_state_dict(weights)
model.to(device)
model.eval()

out = torch.tensor([], dtype = torch.complex64)

with torch.no_grad():
    for snr in range(-20, -4):
        out = torch.tensor([], dtype = torch.complex64)
        
        for signal in dataloader:
            noise = gen_mimo_noise(signal, snr)
            signal, noise = signal.to(device), noise.to(device)

            Power_noise = torch.sum(torch.abs(noise)**2 , dim = (1,2,3))
            sigma = Power_noise / (cfg['N_time']*cfg['N_Az']*cfg['N_El']*cfg['N_pol'])

            denoised_data, _, _ = model(signal + noise, norma = torch.sqrt(sigma))
            out = torch.concat((out, denoised_data ), dim = 0)
        out = out.to('cpu')
        name = 'denoised_5kmh_{0}_SNR.mat'.format(snr)
        scipy.io.savemat("./data/"+name, {"H_denoised": out.numpy()})



