In [None]:
import os
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
import matplotlib.pyplot as plt

### Define Unet model

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class UNetIso(nn.Module):
    def __init__(self):
        super(UNetIso, self).__init__()
        
        self.enc1 = self.conv_block(2, 64)
        self.enc2 = self.conv_block(64, 128)
        self.enc3 = self.conv_block(128, 256)
        self.enc4 = self.conv_block(256, 512)
        self.enc5 = self.conv_block(512, 1024)

        self.pool = nn.MaxPool2d(2)

        self.up1 = self.upconv(1024, 512)
        self.dec1 = self.conv_block(1024, 512)
        
        self.up2 = self.upconv(512, 256)
        self.dec2 = self.conv_block(512, 256)
        
        self.up3 = self.upconv(256, 128)
        self.dec3 = self.conv_block(256, 128)
        
        self.up4 = self.upconv(128, 64)
        self.dec4 = self.conv_block(128, 64)
        
        self.out_conv = nn.Conv2d(64, 1, kernel_size=1)

    def conv_block(self, in_channels, out_channels):
        block = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True)
        )
        return block

    def upconv(self, in_channels, out_channels):
        return nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2)

    def forward(self, x):
        enc1 = self.enc1(x)
        enc2 = self.enc2(self.pool(enc1))
        enc3 = self.enc3(self.pool(enc2))
        enc4 = self.enc4(self.pool(enc3))
        enc5 = self.enc5(self.pool(enc4))

        dec1 = self.up1(enc5)
        dec1 = torch.cat((dec1, enc4), dim=1)
        dec1 = self.dec1(dec1)
        
        dec2 = self.up2(dec1)
        dec2 = torch.cat((dec2, enc3), dim=1)
        dec2 = self.dec2(dec2)
        
        dec3 = self.up3(dec2)
        dec3 = torch.cat((dec3, enc2), dim=1)
        dec3 = self.dec3(dec3)
        
        dec4 = self.up4(dec3)
        dec4 = torch.cat((dec4, enc1), dim=1)
        dec4 = self.dec4(dec4)
        
        out = self.out_conv(dec4)
        
        return out
    
class UNetDir(nn.Module):
    def __init__(self):
        super(UNetDir, self).__init__()
        
        self.enc1 = self.conv_block(3, 64)
        self.enc2 = self.conv_block(64, 128)
        self.enc3 = self.conv_block(128, 256)
        self.enc4 = self.conv_block(256, 512)
        self.enc5 = self.conv_block(512, 1024)

        self.pool = nn.MaxPool2d(2)

        self.up1 = self.upconv(1024, 512)
        self.dec1 = self.conv_block(1024, 512)
        
        self.up2 = self.upconv(512, 256)
        self.dec2 = self.conv_block(512, 256)
        
        self.up3 = self.upconv(256, 128)
        self.dec3 = self.conv_block(256, 128)
        
        self.up4 = self.upconv(128, 64)
        self.dec4 = self.conv_block(128, 64)
        
        self.out_conv = nn.Conv2d(64, 1, kernel_size=1)

    def conv_block(self, in_channels, out_channels):
        block = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True)
        )
        return block

    def upconv(self, in_channels, out_channels):
        return nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2)

    def forward(self, x):
        enc1 = self.enc1(x)
        enc2 = self.enc2(self.pool(enc1))
        enc3 = self.enc3(self.pool(enc2))
        enc4 = self.enc4(self.pool(enc3))
        enc5 = self.enc5(self.pool(enc4))

        dec1 = self.up1(enc5)
        dec1 = torch.cat((dec1, enc4), dim=1)
        dec1 = self.dec1(dec1)
        
        dec2 = self.up2(dec1)
        dec2 = torch.cat((dec2, enc3), dim=1)
        dec2 = self.dec2(dec2)
        
        dec3 = self.up3(dec2)
        dec3 = torch.cat((dec3, enc2), dim=1)
        dec3 = self.dec3(dec3)
        
        dec4 = self.up4(dec3)
        dec4 = torch.cat((dec4, enc1), dim=1)
        dec4 = self.dec4(dec4)
        
        out = self.out_conv(dec4)
        
        return out

### Run prediction

In [None]:
########################################################################################################
#                                   Change the path below                                              #
########################################################################################################

# Load an input sample
PG_Uma_path = 'Input_data/antenna.npy'     # Your path to antenna path gain map
Building_path = 'Input_data/building.npy'  # Your path to building map
Sparse_SS_path = 'Input_data/Sparse_SSmap/Sparse_SSmap.npy'  # Your path to sparse signal strength map

########################################################################################################
#                                   Do not modify the code below                                       #
########################################################################################################

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

Unet_iso = UNetIso().to(device)
Unet_iso.load_state_dict(torch.load('Weight/Unet_iso.pth'))
Unet_iso.eval()  # Set the model to evaluation mode

Unet_dir = UNetDir().to(device)
Unet_dir.load_state_dict(torch.load('Weight/UnetDir_geo2sigmap.pth'))
Unet_dir.eval()  # Set the model to evaluation mode

PG_Uma = np.load(PG_Uma_path).astype(np.float32)

Building = np.load(Building_path).astype(np.float32)

Sparse_SS = np.load(Sparse_SS_path).astype(np.float32)

# Move to the device
input_image = np.stack((PG_Uma, Building), axis = -1)
input_image = torch.from_numpy(input_image).permute(2, 0, 1).unsqueeze(0)
input_image = input_image.to(device)

# Run the model to get the prediction
with torch.no_grad():
    PGmap = Unet_iso(input_image)
    PGmap = PGmap.squeeze().cpu().numpy()
    print(PGmap.shape)
    
    dir_input = np.stack((PGmap, Building, Sparse_SS), axis=-1)   
    dir_input = torch.from_numpy(dir_input).permute(2, 0, 1).unsqueeze(0)
    dir_input = dir_input.to(device)
    SSmap = Unet_dir(dir_input)

# Process the output
SSmap = SSmap.squeeze().cpu().numpy()  # Remove batch dimension and move to CPU


### Visualize result

In [None]:
plt.imshow(SSmap, cmap='hot', interpolation='bilinear')
plt.colorbar()
plt.title('Predicted Signal Strength Map')
plt.show()