In [1]:
import os
import shutil
import sys
sys.path.extend(['../SamplingAssistedPathlossRadioMapPrediction/'])

import numpy as np
import pandas as pd
from tqdm import tqdm
import torch
import imageio.v3 as iio
from matplotlib import pyplot as plt
from skimage.transform import resize
from skimage.io import imread
from PIL import Image
from torchvision.io import read_image

from src.algorithms import MLSP
from src.networks import UNetModel
from src.datamodules import MLSPDatamodule
from src.datamodules.datasets import PathlossDataset
from src.utils.mlsp.featurizer import featurizer
from src.utils.mlsp.augmentations import normalize_size, resize_db
from src.utils.mlsp.types import RadarSampleInputs, RadarSample

%load_ext autoreload
%autoreload 2

In [2]:
network = UNetModel(
    **{
        "n_channels": 8, 
    }
)
alg_conf = {
    "fixed_scale": False,
    "use_sip2net": False,
    "sip2net_params": {
      "mse_weight": 0.02,
      "alpha1": 500.0,
      "alpha2": 1.0,
      "alpha3": 0.0
    },
    "network": network
}
alg = MLSP.load_from_checkpoint(
    "./task1.ckpt",
    **alg_conf
)
alg.network.eval()

UNetModel(
  (unet): Unet(
    (encoder): ResNetEncoder(
      (conv1): Conv2d(8, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
      (layer1): Sequential(
        (0): BasicBlock(
          (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
          (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
        (1): BasicBlock(
          (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn1): BatchN

In [3]:
main_path = "./ICASSP_TEST_DATA/"
ss_path = "./ICASSP_TEST_DATA/rate0.5/"

input_dir = os.path.join(main_path, f"Inputs/Task_1/")
output_dir = os.path.join(main_path, f"Outputs/Task_1/")
positions_dir = os.path.join(main_path, "Test_Data_Positions/")
radiation_patterns_dir = os.path.join(main_path, "Test_Radiation_Patterns/")
sampling_dir = os.path.join(ss_path, f"sampledGT")
INITIAL_PIXEL_SIZE = 0.25
IMG_TARGET_SIZE = 640

In [15]:
def get_radar_sample_input(b, ant, f, sp):    
    input_file = f"B{b}_Ant{ant}_f{f}_S{sp}.png"
    output_file = f"B{b}_Ant{ant}_f{f}_S{sp}.png"
    radiation_file = f"Ant{ant}_Pattern.csv"
    position_file = f"Positions_B{b}_Ant{ant}_f{f}.csv"
    
    if os.path.exists(os.path.join(input_dir, input_file)):
        freq_mhz = freqs_mhz[f - 1]
        input_img_path = os.path.join(input_dir, input_file)
        output_img_path = os.path.join(output_dir, output_file)
        positions_path = os.path.join(positions_dir, position_file)
        radiation_pattern_file = os.path.join(radiation_patterns_dir, radiation_file)
        sampling_file = os.path.join(sampling_dir, output_file)
        if not os.path.exists(sampling_file):
            return None
        output_img_path = sampling_file

        radar_sample_inputs = RadarSampleInputs(
            file_name=input_file,
            freq_MHz=freq_mhz,
            input_file=input_img_path,
            output_file=output_img_path,
            position_file=positions_path,
            radiation_pattern_file=radiation_pattern_file,
            sampling_position=sp,
            ids=(b, ant, f, sp),
        )
        return radar_sample_inputs
    
def read_sample(inputs):
    if isinstance(inputs, RadarSampleInputs):
        inputs = inputs.asdict()
    file_name = inputs["file_name"]
    freq_MHz = inputs["freq_MHz"]
    input_file = inputs["input_file"]
    output_file = inputs["output_file"]
    position_file = inputs["position_file"]
    sampling_position = inputs["sampling_position"]
    radiation_pattern_file = inputs["radiation_pattern_file"]

    input_img = read_image(input_file).float()
    C, H, W = input_img.shape

    output_img = read_image(output_file).float()
    if output_img.size(0) == 1:  # If single channel, remove channel dimension
        output_img = output_img.squeeze(0)

    sampling_positions = pd.read_csv(position_file)
    x_ant, y_ant, azimuth = sampling_positions.loc[int(sampling_position), ["Y", "X", "Azimuth"]]
    radiation_pattern_np = np.genfromtxt(radiation_pattern_file, delimiter=',')
    radiation_pattern = torch.from_numpy(radiation_pattern_np).float()

#     pl_clip = float("inf")
    pl_clip = torch.tensor(160, dtype=torch.float32)

    sample = RadarSample(
        file_name=file_name,
        task_idx=1,
        pl_clip=pl_clip,
        H=H,
        W=W,
        x_ant=x_ant,
        y_ant=y_ant,
        azimuth=azimuth,
        freq_MHz=freq_MHz,
        input_img=input_img,
        output_img=output_img,
        radiation_pattern=radiation_pattern,
        pixel_size=INITIAL_PIXEL_SIZE,
        mask=torch.ones_like(input_img[0]),
    )

    # Ensure the antenna is within bounds
    sample = PathlossDataset.pad_sample(sample)

    return sample

def get_input(sample):
    orig_h, orig_w = sample.H, sample.W
    
    sparse_input = sample.output_img
    sample.input_img = torch.cat([sample.input_img, sparse_input.unsqueeze(0)], dim=0)

    sample = normalize_size(sample=sample, target_size=IMG_TARGET_SIZE)

    output_tensor = sample.output_img if sample.output_img is not None else None
    
    input_tensor = featurizer(sample=sample)
    mask = sample.mask
    sample.H, sample.W = orig_h, orig_w
    return input_tensor, output_tensor, mask, sample.asdict()

def pred(batch, pred):
    inputs, targets, masks, sample = batch

    old_h, old_w = sample["H"], sample["W"]
    scaling_factor = INITIAL_PIXEL_SIZE / sample["pixel_size"]
    norm_h, norm_w = int(old_h * scaling_factor), int(old_w * scaling_factor)
    
    pred = pred[torch.where(masks == 1)].reshape((norm_h, norm_w))
    pred = resize_db(pred.unsqueeze(0), new_size=(old_h, old_w)).squeeze(0)
    pred = pred.detach().cpu().numpy()

    return pred

In [18]:
Buildings = range(1, 7)
ant_ids = [1]
freq = [1]
freqs_mhz = [868]
solution = pd.DataFrame()

os.makedirs("./preds/task1")
for Antenna_ID in (ant_ids):  
    for f_i in  freq:
        for b in (Buildings):
            for sp in tqdm(range(0, 50), total=50):
                radar_sample_inputs = get_radar_sample_input(b, Antenna_ID, f_i, sp)
                if radar_sample_inputs is None:
                    continue
                sample = read_sample(radar_sample_inputs)
                your_input_tensor, output_tensor, mask, sample = get_input(sample)
                out = alg.network(your_input_tensor.unsqueeze(0).to("cuda:0")).squeeze(0)
                y_PL = pred(
                    (your_input_tensor, output_tensor, mask, sample), out
                ) # Note that y_PL should have the same dimensions, W x H,  as the input image
                y = Image.fromarray(y_PL).convert("RGB")
                image_name = f"B{b}_Ant{Antenna_ID}_f{f_i}_S{sp}.png"
                y.save(f"./preds/task1/{image_name}.png")





100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 50/50 [00:03<00:00, 14.59it/s]
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 50/50 [00:00<00:00, 6209.55it/s]
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 50/50 [00:00<00:00, 6230.95it/s]
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 50/50 [00:00<00:00, 6431.21it/s]






100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 50/50 [00:03<00:00, 14.03it/s]
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 50/50 [00:00<00:00, 7044.04it/s]


In [17]:
data = []
pred_path = "./preds/task1/"
for file_name in tqdm(os.listdir(pred_path)):
    if file_name.endswith(".png"):
        file_path = os.path.join(pred_path, file_name)
        image = Image.open(file_path).convert("L")
        pl_array = np.array(image)

        flat_pl = pl_array.flatten()
        for idx, value in enumerate(flat_pl):
            id_str = f"{file_name.split('.')[0]}_{idx}"
            data.append((id_str, value))

df = pd.DataFrame(data, columns=["ID", "PL"])
df = df.groupby("ID", as_index=False).mean()
df.to_csv(os.path.join(pred_path, "predictions.csv"), index=False)

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 52/52 [00:00<00:00, 196.19it/s]
