## Models

### TransNext

In [1]:
"""
TransNeXt: Robust Foveal Visual Perception for Vision Transformers
Paper: https://arxiv.org/abs/2311.17132
Code: https://github.com/DaiShiResearch/TransNeXt

Author: Dai Shi
Github: https://github.com/DaiShiResearch
Email: daishiresearch@gmail.com

This source code is licensed under the license found in the
LICENSE file in the root directory of this source tree.
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
from functools import partial
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
import math
import swattention

CUDA_NUM_THREADS = 128


class sw_qkrpb_cuda(torch.autograd.Function):
    @staticmethod
    def forward(ctx, query, key, rpb, height, width, kernel_size):
        attn_weight = swattention.qk_rpb_forward(
            query, key, rpb, height, width, kernel_size, CUDA_NUM_THREADS
        )

        ctx.save_for_backward(query, key)
        ctx.height, ctx.width, ctx.kernel_size = height, width, kernel_size

        return attn_weight

    @staticmethod
    def backward(ctx, d_attn_weight):
        query, key = ctx.saved_tensors
        height, width, kernel_size = ctx.height, ctx.width, ctx.kernel_size

        d_query, d_key, d_rpb = swattention.qk_rpb_backward(
            d_attn_weight.contiguous(),
            query,
            key,
            height,
            width,
            kernel_size,
            CUDA_NUM_THREADS,
        )

        return d_query, d_key, d_rpb, None, None, None


class sw_av_cuda(torch.autograd.Function):
    @staticmethod
    def forward(ctx, attn_weight, value, height, width, kernel_size):
        output = swattention.av_forward(
            attn_weight, value, height, width, kernel_size, CUDA_NUM_THREADS
        )

        ctx.save_for_backward(attn_weight, value)
        ctx.height, ctx.width, ctx.kernel_size = height, width, kernel_size

        return output

    @staticmethod
    def backward(ctx, d_output):
        attn_weight, value = ctx.saved_tensors
        height, width, kernel_size = ctx.height, ctx.width, ctx.kernel_size

        d_attn_weight, d_value = swattention.av_backward(
            d_output.contiguous(),
            attn_weight,
            value,
            height,
            width,
            kernel_size,
            CUDA_NUM_THREADS,
        )

        return d_attn_weight, d_value, None, None, None


class DWConv(nn.Module):
    def __init__(self, dim=768):
        super(DWConv, self).__init__()
        self.dwconv = nn.Conv2d(
            dim, dim, kernel_size=3, stride=1, padding=1, bias=True, groups=dim
        )

    def forward(self, x, H, W):
        B, N, C = x.shape
        x = x.transpose(1, 2).view(B, C, H, W).contiguous()
        x = self.dwconv(x)
        x = x.flatten(2).transpose(1, 2)

        return x


class ConvolutionalGLU(nn.Module):
    def __init__(
        self,
        in_features,
        hidden_features=None,
        out_features=None,
        act_layer=nn.GELU,
        drop=0.0,
    ):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        hidden_features = int(2 * hidden_features / 3)
        self.fc1 = nn.Linear(in_features, hidden_features * 2)
        self.dwconv = DWConv(hidden_features)
        self.act = act_layer()
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.drop = nn.Dropout(drop)

    def forward(self, x, H, W):
        x, v = self.fc1(x).chunk(2, dim=-1)
        x = self.act(self.dwconv(x, H, W)) * v
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x


@torch.no_grad()
def get_relative_position_cpb(
    query_size,
    key_size,
    pretrain_size=None,
    device=torch.device("cuda" if torch.cuda.is_available() else "cpu"),
):
    pretrain_size = pretrain_size or query_size
    axis_qh = torch.arange(query_size[0], dtype=torch.float32, device=device)
    axis_kh = F.adaptive_avg_pool1d(axis_qh.unsqueeze(0), key_size[0]).squeeze(0)
    axis_qw = torch.arange(query_size[1], dtype=torch.float32, device=device)
    axis_kw = F.adaptive_avg_pool1d(axis_qw.unsqueeze(0), key_size[1]).squeeze(0)
    axis_kh, axis_kw = torch.meshgrid(axis_kh, axis_kw)
    axis_qh, axis_qw = torch.meshgrid(axis_qh, axis_qw)

    axis_kh = torch.reshape(axis_kh, [-1])
    axis_kw = torch.reshape(axis_kw, [-1])
    axis_qh = torch.reshape(axis_qh, [-1])
    axis_qw = torch.reshape(axis_qw, [-1])

    relative_h = (axis_qh[:, None] - axis_kh[None, :]) / (pretrain_size[0] - 1) * 8
    relative_w = (axis_qw[:, None] - axis_kw[None, :]) / (pretrain_size[1] - 1) * 8
    relative_hw = torch.stack([relative_h, relative_w], dim=-1).view(-1, 2)

    relative_coords_table, idx_map = torch.unique(
        relative_hw, return_inverse=True, dim=0
    )

    relative_coords_table = (
        torch.sign(relative_coords_table)
        * torch.log2(torch.abs(relative_coords_table) + 1.0)
        / torch.log2(torch.tensor(8, dtype=torch.float32))
    )

    return idx_map, relative_coords_table


@torch.no_grad()
def get_seqlen_scale(input_resolution, window_size, device):
    return torch.nn.functional.avg_pool2d(
        torch.ones(1, input_resolution[0], input_resolution[1], device=device)
        * (window_size**2),
        window_size,
        stride=1,
        padding=window_size // 2,
    ).reshape(-1, 1)


class AggregatedAttention(nn.Module):
    def __init__(
        self,
        dim,
        input_resolution,
        num_heads=8,
        window_size=3,
        qkv_bias=True,
        attn_drop=0.0,
        proj_drop=0.0,
        sr_ratio=1,
        is_extrapolation=False,
    ):
        super().__init__()
        assert (
            dim % num_heads == 0
        ), f"dim {dim} should be divided by num_heads {num_heads}."

        self.dim = dim
        self.num_heads = num_heads
        self.head_dim = dim // num_heads

        self.sr_ratio = sr_ratio

        self.is_extrapolation = is_extrapolation

        if not is_extrapolation:
            # The estimated training resolution is used for bilinear interpolation of the generated relative position bias.
            self.trained_H, self.trained_W = input_resolution
            self.trained_len = self.trained_H * self.trained_W
            self.trained_pool_H, self.trained_pool_W = (
                input_resolution[0] // self.sr_ratio,
                input_resolution[1] // self.sr_ratio,
            )
            self.trained_pool_len = self.trained_pool_H * self.trained_pool_W

        assert window_size % 2 == 1, "window size must be odd"
        self.window_size = window_size
        self.local_len = window_size**2

        self.temperature = nn.Parameter(
            torch.log((torch.ones(num_heads, 1, 1) / 0.24).exp() - 1)
        )  # Initialize softplus(temperature) to 1/0.24.

        self.q = nn.Linear(dim, dim, bias=qkv_bias)
        self.query_embedding = nn.Parameter(
            nn.init.trunc_normal_(
                torch.empty(self.num_heads, 1, self.head_dim), mean=0, std=0.02
            )
        )
        self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

        self.sr = nn.Conv2d(dim, dim, kernel_size=1, stride=1, padding=0)
        self.norm = nn.LayerNorm(dim)
        self.act = nn.GELU()

        # mlp to generate continuous relative position bias
        self.cpb_fc1 = nn.Linear(2, 512, bias=True)
        self.cpb_act = nn.ReLU(inplace=True)
        self.cpb_fc2 = nn.Linear(512, num_heads, bias=True)

        # relative_bias_local:
        self.relative_pos_bias_local = nn.Parameter(
            nn.init.trunc_normal_(
                torch.empty(num_heads, self.local_len), mean=0, std=0.0004
            )
        )

        # dynamic_local_bias:
        self.learnable_tokens = nn.Parameter(
            nn.init.trunc_normal_(
                torch.empty(num_heads, self.head_dim, self.local_len), mean=0, std=0.02
            )
        )
        self.learnable_bias = nn.Parameter(torch.zeros(num_heads, 1, self.local_len))

    def forward(
        self, x, H, W, relative_pos_index, relative_coords_table, seq_length_scale
    ):
        B, N, C = x.shape
        pool_H, pool_W = H // self.sr_ratio, W // self.sr_ratio
        pool_len = pool_H * pool_W

        # Generate queries, normalize them with L2, add query embedding, and then magnify with sequence length scale and temperature.
        # Use softplus function ensuring that the temperature is not lower than 0.
        q_norm = F.normalize(
            self.q(x).reshape(B, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3),
            dim=-1,
        )
        q_norm_scaled = (
            (q_norm + self.query_embedding)
            * F.softplus(self.temperature)
            * seq_length_scale
        )

        # Generate unfolded keys and values and l2-normalize them
        k_local, v_local = (
            self.kv(x)
            .reshape(B, N, 2 * self.num_heads, self.head_dim)
            .permute(0, 2, 1, 3)
            .chunk(2, dim=1)
        )

        # Compute local similarity
        attn_local = sw_qkrpb_cuda.apply(
            q_norm_scaled.contiguous(),
            F.normalize(k_local, dim=-1).contiguous(),
            self.relative_pos_bias_local,
            H,
            W,
            self.window_size,
        )

        # Generate pooled features
        x_ = x.permute(0, 2, 1).reshape(B, -1, H, W).contiguous()
        x_ = (
            F.adaptive_avg_pool2d(self.act(self.sr(x_)), (pool_H, pool_W))
            .reshape(B, -1, pool_len)
            .permute(0, 2, 1)
        )
        x_ = self.norm(x_)

        # Generate pooled keys and values
        kv_pool = (
            self.kv(x_)
            .reshape(B, pool_len, 2 * self.num_heads, self.head_dim)
            .permute(0, 2, 1, 3)
        )
        k_pool, v_pool = kv_pool.chunk(2, dim=1)

        if self.is_extrapolation:
            # Use MLP to generate continuous relative positional bias for pooled features.
            pool_bias = (
                self.cpb_fc2(self.cpb_act(self.cpb_fc1(relative_coords_table)))
                .transpose(0, 1)[:, relative_pos_index.view(-1)]
                .view(-1, N, pool_len)
            )
        else:
            # Use MLP to generate continuous relative positional bias for pooled features.
            pool_bias = (
                self.cpb_fc2(self.cpb_act(self.cpb_fc1(relative_coords_table)))
                .transpose(0, 1)[:, relative_pos_index.view(-1)]
                .view(-1, self.trained_len, self.trained_pool_len)
            )

            # bilinear interpolation:
            pool_bias = pool_bias.reshape(
                -1, self.trained_len, self.trained_pool_H, self.trained_pool_W
            )
            pool_bias = F.interpolate(pool_bias, (pool_H, pool_W), mode="bilinear")
            pool_bias = (
                pool_bias.reshape(-1, self.trained_len, pool_len)
                .transpose(-1, -2)
                .reshape(-1, pool_len, self.trained_H, self.trained_W)
            )
            pool_bias = (
                F.interpolate(pool_bias, (H, W), mode="bilinear")
                .reshape(-1, pool_len, N)
                .transpose(-1, -2)
            )

        # Compute pooled similarity
        attn_pool = (
            q_norm_scaled @ F.normalize(k_pool, dim=-1).transpose(-2, -1) + pool_bias
        )

        # Concatenate local & pooled similarity matrices and calculate attention weights through the same Softmax
        attn = torch.cat([attn_local, attn_pool], dim=-1).softmax(dim=-1)
        attn = self.attn_drop(attn)

        # Split the attention weights and separately aggregate the values of local & pooled features
        attn_local, attn_pool = torch.split(attn, [self.local_len, pool_len], dim=-1)
        attn_local = (q_norm @ self.learnable_tokens) + self.learnable_bias + attn_local
        x_local = sw_av_cuda.apply(
            attn_local.type_as(v_local), v_local.contiguous(), H, W, self.window_size
        )
        x_pool = attn_pool @ v_pool
        x = (x_local + x_pool).transpose(1, 2).reshape(B, N, C)

        x = self.proj(x)
        x = self.proj_drop(x)

        return x


class Attention(nn.Module):
    def __init__(
        self,
        dim,
        input_resolution,
        num_heads=8,
        qkv_bias=True,
        attn_drop=0.0,
        proj_drop=0.0,
        is_extrapolation=False,
    ):
        super().__init__()
        assert (
            dim % num_heads == 0
        ), f"dim {dim} should be divided by num_heads {num_heads}."

        self.is_extrapolation = is_extrapolation

        if not is_extrapolation:
            # The estimated training resolution is used for bilinear interpolation of the generated relative position bias.
            self.trained_H, self.trained_W = input_resolution
            self.trained_len = self.trained_H * self.trained_W

        self.dim = dim
        self.num_heads = num_heads
        self.head_dim = dim // num_heads
        self.temperature = nn.Parameter(
            torch.log((torch.ones(num_heads, 1, 1) / 0.24).exp() - 1)
        )  # Initialize softplus(temperature) to 1/0.24.

        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.query_embedding = nn.Parameter(
            nn.init.trunc_normal_(
                torch.empty(self.num_heads, 1, self.head_dim), mean=0, std=0.02
            )
        )
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

        # mlp to generate continuous relative position bias
        self.cpb_fc1 = nn.Linear(2, 512, bias=True)
        self.cpb_act = nn.ReLU(inplace=True)
        self.cpb_fc2 = nn.Linear(512, num_heads, bias=True)

    def forward(
        self, x, H, W, relative_pos_index, relative_coords_table, seq_length_scale
    ):
        B, N, C = x.shape
        qkv = (
            self.qkv(x)
            .reshape(B, -1, 3 * self.num_heads, self.head_dim)
            .permute(0, 2, 1, 3)
        )
        q, k, v = qkv.chunk(3, dim=1)

        if self.is_extrapolation:
            # Use MLP to generate continuous relative positional bias
            rel_bias = (
                self.cpb_fc2(self.cpb_act(self.cpb_fc1(relative_coords_table)))
                .transpose(0, 1)[:, relative_pos_index.view(-1)]
                .view(-1, N, N)
            )
        else:
            # Use MLP to generate continuous relative positional bias
            rel_bias = (
                self.cpb_fc2(self.cpb_act(self.cpb_fc1(relative_coords_table)))
                .transpose(0, 1)[:, relative_pos_index.view(-1)]
                .view(-1, self.trained_len, self.trained_len)
            )
            # bilinear interpolation:
            rel_bias = rel_bias.reshape(
                -1, self.trained_len, self.trained_H, self.trained_W
            )
            rel_bias = F.interpolate(rel_bias, (H, W), mode="bilinear")
            rel_bias = (
                rel_bias.reshape(-1, self.trained_len, N)
                .transpose(-1, -2)
                .reshape(-1, N, self.trained_H, self.trained_W)
            )
            rel_bias = (
                F.interpolate(rel_bias, (H, W), mode="bilinear")
                .reshape(-1, N, N)
                .transpose(-1, -2)
            )

        attn = (
            (F.normalize(q, dim=-1) + self.query_embedding)
            * F.softplus(self.temperature)
            * seq_length_scale
        ) @ F.normalize(k, dim=-1).transpose(-2, -1) + rel_bias
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)
        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x


class Block(nn.Module):

    def __init__(
        self,
        dim,
        num_heads,
        input_resolution,
        window_size=3,
        mlp_ratio=4.0,
        qkv_bias=False,
        drop=0.0,
        attn_drop=0.0,
        drop_path=0.0,
        act_layer=nn.GELU,
        norm_layer=nn.LayerNorm,
        sr_ratio=1,
        is_extrapolation=False,
    ):
        super().__init__()
        self.norm1 = norm_layer(dim)
        if sr_ratio == 1:
            self.attn = Attention(
                dim,
                input_resolution,
                num_heads=num_heads,
                qkv_bias=qkv_bias,
                attn_drop=attn_drop,
                proj_drop=drop,
                is_extrapolation=is_extrapolation,
            )
        else:
            self.attn = AggregatedAttention(
                dim,
                input_resolution,
                window_size=window_size,
                num_heads=num_heads,
                qkv_bias=qkv_bias,
                attn_drop=attn_drop,
                proj_drop=drop,
                sr_ratio=sr_ratio,
                is_extrapolation=is_extrapolation,
            )
        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = ConvolutionalGLU(
            in_features=dim,
            hidden_features=mlp_hidden_dim,
            act_layer=act_layer,
            drop=drop,
        )

        # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
        self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()

    def forward(
        self, x, H, W, relative_pos_index, relative_coords_table, seq_length_scale
    ):
        x = x + self.drop_path(
            self.attn(
                self.norm1(x),
                H,
                W,
                relative_pos_index,
                relative_coords_table,
                seq_length_scale,
            )
        )
        x = x + self.drop_path(self.mlp(self.norm2(x), H, W))

        return x


class OverlapPatchEmbed(nn.Module):
    """Image to Patch Embedding"""

    def __init__(self, patch_size=7, stride=4, in_chans=3, embed_dim=768):
        super().__init__()

        patch_size = to_2tuple(patch_size)

        assert max(patch_size) > stride, "Set larger patch_size than stride"
        self.proj = nn.Conv2d(
            in_chans,
            embed_dim,
            kernel_size=patch_size,
            stride=stride,
            padding=(patch_size[0] // 2, patch_size[1] // 2),
        )
        self.norm = nn.LayerNorm(embed_dim)

    def forward(self, x):
        x = self.proj(x)
        _, _, H, W = x.shape
        x = x.flatten(2).transpose(1, 2)
        x = self.norm(x)

        return x, H, W


class TransNeXt(nn.Module):
    """
    The parameter "img size" is primarily utilized for generating relative spatial coordinates,
    which are used to compute continuous relative positional biases. As this TransNeXt implementation can accept multi-scale inputs,
    it is recommended to set the "img size" parameter to a value close to the resolution of the inference images.
    It is not advisable to set the "img size" parameter to a value exceeding 800x800.
    The "pretrain size" refers to the "img size" used during the initial pre-training phase,
    which is used to scale the relative spatial coordinates for better extrapolation by the MLP.
    For models trained on ImageNet-1K at a resolution of 224x224,
    as well as downstream task models fine-tuned based on these pre-trained weights,
    the "pretrain size" parameter should be set to 224x224.
    """

    def __init__(
        self,
        img_size=224,
        pretrain_size=None,
        window_size=[3, 3, 3, None],
        patch_size=16,
        in_chans=3,
        num_classes=1000,
        embed_dims=[64, 128, 256, 512],
        num_heads=[1, 2, 4, 8],
        mlp_ratios=[4, 4, 4, 4],
        qkv_bias=False,
        drop_rate=0.0,
        attn_drop_rate=0.0,
        drop_path_rate=0.0,
        norm_layer=nn.LayerNorm,
        depths=[3, 4, 6, 3],
        sr_ratios=[8, 4, 2, 1],
        num_stages=4,
        pretrained=None,
        is_extrapolation=False,
    ):
        super().__init__()
        # self.num_classes = num_classes
        self.depths = depths
        self.num_stages = num_stages
        self.window_size = window_size
        self.sr_ratios = sr_ratios
        self.is_extrapolation = is_extrapolation
        self.pretrain_size = pretrain_size or img_size

        dpr = [
            x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))
        ]  # stochastic depth decay rule
        cur = 0

        for i in range(num_stages):
            if not self.is_extrapolation:
                relative_pos_index, relative_coords_table = get_relative_position_cpb(
                    query_size=to_2tuple(img_size // (2 ** (i + 2))),
                    key_size=to_2tuple(img_size // ((2 ** (i + 2)) * sr_ratios[i])),
                    pretrain_size=to_2tuple(pretrain_size // (2 ** (i + 2))),
                )

                self.register_buffer(
                    f"relative_pos_index{i + 1}", relative_pos_index, persistent=False
                )
                self.register_buffer(
                    f"relative_coords_table{i + 1}",
                    relative_coords_table,
                    persistent=False,
                )

            patch_embed = OverlapPatchEmbed(
                patch_size=patch_size * 2 - 1 if i == 0 else 3,
                stride=patch_size if i == 0 else 2,
                in_chans=in_chans if i == 0 else embed_dims[i - 1],
                embed_dim=embed_dims[i],
            )

            block = nn.ModuleList(
                [
                    Block(
                        dim=embed_dims[i],
                        input_resolution=to_2tuple(img_size // (2 ** (i + 2))),
                        window_size=window_size[i],
                        num_heads=num_heads[i],
                        mlp_ratio=mlp_ratios[i],
                        qkv_bias=qkv_bias,
                        drop=drop_rate,
                        attn_drop=attn_drop_rate,
                        drop_path=dpr[cur + j],
                        norm_layer=norm_layer,
                        sr_ratio=sr_ratios[i],
                        is_extrapolation=is_extrapolation,
                    )
                    for j in range(depths[i])
                ]
            )
            norm = norm_layer(embed_dims[i])
            cur += depths[i]

            setattr(self, f"patch_embed{i + 1}", patch_embed)
            setattr(self, f"block{i + 1}", block)
            setattr(self, f"norm{i + 1}", norm)

        # classification head
        # self.head = nn.Linear(embed_dims[3], num_classes) if num_classes > 0 else nn.Identity()

        for n, m in self.named_modules():
            self._init_weights(m, n)
        if pretrained:
            raise NotImplementedError
            # self.init_weights(pretrained)

    def _init_weights(self, m: nn.Module, name: str = ""):
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=0.02)
            if m.bias is not None:
                nn.init.zeros_(m.bias)
        elif isinstance(m, nn.Conv2d):
            fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
            fan_out //= m.groups
            m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
            if m.bias is not None:
                m.bias.data.zero_()
        elif isinstance(m, (nn.LayerNorm, nn.GroupNorm, nn.BatchNorm2d)):
            nn.init.zeros_(m.bias)
            nn.init.ones_(m.weight)

    # def init_weights(self, pretrained=None):
    #     if isinstance(pretrained, str):
    #         logger = get_root_logger()
    #         load_checkpoint(
    #             self, pretrained, map_location="cpu", strict=False, logger=logger
    #         )

    @torch.jit.ignore
    def no_weight_decay(self):
        return {}

    @torch.jit.ignore
    def no_weight_decay_keywords(self):
        return {"query_embedding", "relative_pos_bias_local", "cpb", "temperature"}

    def get_classifier(self):
        return self.head

    def reset_classifier(self, num_classes, global_pool=""):
        self.num_classes = num_classes
        self.head = (
            nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
        )

    def forward_features(self, x):
        B = x.shape[0]
        outs = []

        for i in range(self.num_stages):
            patch_embed = getattr(self, f"patch_embed{i + 1}")
            block = getattr(self, f"block{i + 1}")
            norm = getattr(self, f"norm{i + 1}")
            x, H, W = patch_embed(x)
            sr_ratio = self.sr_ratios[i]
            if self.is_extrapolation:
                relative_pos_index, relative_coords_table = get_relative_position_cpb(
                    query_size=(H, W),
                    key_size=(H // sr_ratio, W // sr_ratio),
                    pretrain_size=to_2tuple(self.pretrain_size // (2 ** (i + 2))),
                    device=x.device,
                )
            else:
                relative_pos_index = getattr(self, f"relative_pos_index{i + 1}")
                relative_coords_table = getattr(self, f"relative_coords_table{i + 1}")

            with torch.no_grad():
                if i != (self.num_stages - 1):
                    local_seq_length = get_seqlen_scale(
                        (H, W), self.window_size[i], device=x.device
                    )
                    seq_length_scale = torch.log(
                        local_seq_length + (H // sr_ratio) * (W // sr_ratio)
                    )
                else:
                    seq_length_scale = torch.log(
                        torch.as_tensor(
                            (H // sr_ratio) * (W // sr_ratio), device=x.device
                        )
                    )

            for blk in block:
                x = blk(
                    x, H, W, relative_pos_index, relative_coords_table, seq_length_scale
                )

            x = norm(x)
            x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
            outs.append(x)

        return outs

    def forward(self, x):
        x = self.forward_features(x)
        # x = self.head(x)

        return x


class transnext_tiny(TransNeXt):
    def __init__(self, **kwargs):
        super().__init__(
            window_size=[3, 3, 3, None],
            patch_size=4,
            embed_dims=[72, 144, 288, 576],
            num_heads=[3, 6, 12, 24],
            mlp_ratios=[8, 8, 4, 4],
            qkv_bias=True,
            norm_layer=partial(nn.LayerNorm, eps=1e-6),
            depths=[2, 2, 15, 2],
            sr_ratios=[8, 4, 2, 1],
            drop_rate=0.0,
            drop_path_rate=0.4,
            pretrained=kwargs["pretrained"],
            img_size=kwargs["img_size"],
            pretrain_size=kwargs["pretrain_size"],
            is_extrapolation=kwargs["is_extrapolation"],
            in_chans=kwargs["in_chans"],
        )


class transnext_small(TransNeXt):
    def __init__(self, **kwargs):
        super().__init__(
            window_size=[3, 3, 3, None],
            patch_size=4,
            embed_dims=[72, 144, 288, 576],
            num_heads=[3, 6, 12, 24],
            mlp_ratios=[8, 8, 4, 4],
            qkv_bias=True,
            norm_layer=partial(nn.LayerNorm, eps=1e-6),
            depths=[5, 5, 22, 5],
            sr_ratios=[8, 4, 2, 1],
            drop_rate=0.0,
            drop_path_rate=0.6,
            pretrained=kwargs["pretrained"],
            img_size=kwargs["img_size"],
            pretrain_size=kwargs["pretrain_size"],
            is_extrapolation=kwargs["is_extrapolation"],
            in_chans=kwargs["in_chans"],
        )


class transnext_base(TransNeXt):
    def __init__(self, **kwargs):
        super().__init__(
            window_size=[3, 3, 3, None],
            patch_size=4,
            embed_dims=[96, 192, 384, 768],
            num_heads=[4, 8, 16, 32],
            mlp_ratios=[8, 8, 4, 4],
            qkv_bias=True,
            norm_layer=partial(nn.LayerNorm, eps=1e-6),
            depths=[5, 5, 23, 5],
            sr_ratios=[8, 4, 2, 1],
            drop_rate=0.0,
            drop_path_rate=0.7,
            pretrained=kwargs["pretrained"],
            img_size=kwargs["img_size"],
            pretrain_size=kwargs["pretrain_size"],
            is_extrapolation=kwargs["is_extrapolation"],
            in_chans=kwargs["in_chans"],
        )

  from .autonotebook import tqdm as notebook_tqdm


### EMCAD

In [2]:
# https://github.com/SLDGroup/EMCAD

import torch
import torch.nn as nn
from functools import partial

import math
from timm.models.layers import trunc_normal_tf_
from timm.models.helpers import named_apply


def gcd(a, b):
    while b:
        a, b = b, a % b
    return a


# Other types of layers can go here (e.g., nn.Linear, etc.)
def _init_weights(module, name, scheme=""):
    if isinstance(module, nn.Conv2d) or isinstance(module, nn.Conv3d):
        if scheme == "normal":
            nn.init.normal_(module.weight, std=0.02)
            if module.bias is not None:
                nn.init.zeros_(module.bias)
        elif scheme == "trunc_normal":
            trunc_normal_tf_(module.weight, std=0.02)
            if module.bias is not None:
                nn.init.zeros_(module.bias)
        elif scheme == "xavier_normal":
            nn.init.xavier_normal_(module.weight)
            if module.bias is not None:
                nn.init.zeros_(module.bias)
        elif scheme == "kaiming_normal":
            nn.init.kaiming_normal_(module.weight, mode="fan_out", nonlinearity="relu")
            if module.bias is not None:
                nn.init.zeros_(module.bias)
        else:
            # efficientnet like
            fan_out = (
                module.kernel_size[0] * module.kernel_size[1] * module.out_channels
            )
            fan_out //= module.groups
            nn.init.normal_(module.weight, 0, math.sqrt(2.0 / fan_out))
            if module.bias is not None:
                nn.init.zeros_(module.bias)
    elif isinstance(module, nn.BatchNorm2d) or isinstance(module, nn.BatchNorm3d):
        nn.init.constant_(module.weight, 1)
        nn.init.constant_(module.bias, 0)
    elif isinstance(module, nn.LayerNorm):
        nn.init.constant_(module.weight, 1)
        nn.init.constant_(module.bias, 0)


def act_layer(act, inplace=False, neg_slope=0.2, n_prelu=1):
    # activation layer
    act = act.lower()
    if act == "relu":
        layer = nn.ReLU(inplace)
    elif act == "relu6":
        layer = nn.ReLU6(inplace)
    elif act == "leakyrelu":
        layer = nn.LeakyReLU(neg_slope, inplace)
    elif act == "prelu":
        layer = nn.PReLU(num_parameters=n_prelu, init=neg_slope)
    elif act == "gelu":
        layer = nn.GELU()
    elif act == "hswish":
        layer = nn.Hardswish(inplace)
    else:
        raise NotImplementedError("activation layer [%s] is not found" % act)
    return layer


def channel_shuffle(x, groups):
    batchsize, num_channels, height, width = x.data.size()
    channels_per_group = num_channels // groups
    # reshape
    x = x.view(batchsize, groups, channels_per_group, height, width)
    x = torch.transpose(x, 1, 2).contiguous()
    # flatten
    x = x.view(batchsize, -1, height, width)
    return x


#   Multi-scale depth-wise convolution (MSDC)
class MSDC(nn.Module):
    def __init__(
        self, in_channels, kernel_sizes, stride, activation="relu6", dw_parallel=True
    ):
        super(MSDC, self).__init__()

        self.in_channels = in_channels
        self.kernel_sizes = kernel_sizes
        self.activation = activation
        self.dw_parallel = dw_parallel

        self.dwconvs = nn.ModuleList(
            [
                nn.Sequential(
                    nn.Conv2d(
                        self.in_channels,
                        self.in_channels,
                        kernel_size,
                        stride,
                        kernel_size // 2,
                        groups=self.in_channels,
                        bias=False,
                    ),
                    nn.BatchNorm2d(self.in_channels),
                    act_layer(self.activation, inplace=True),
                )
                for kernel_size in self.kernel_sizes
            ]
        )

        self.init_weights("normal")

    def init_weights(self, scheme=""):
        named_apply(partial(_init_weights, scheme=scheme), self)

    def forward(self, x):
        # Apply the convolution layers in a loop
        outputs = []
        for dwconv in self.dwconvs:
            dw_out = dwconv(x)
            outputs.append(dw_out)
            if self.dw_parallel == False:
                x = x + dw_out
        # You can return outputs based on what you intend to do with them
        return outputs


class MSCB(nn.Module):
    """
    Multi-scale convolution block (MSCB)
    """

    def __init__(
        self,
        in_channels,
        out_channels,
        stride,
        kernel_sizes=[1, 3, 5],
        expansion_factor=2,
        dw_parallel=True,
        add=True,
        activation="relu6",
    ):
        super(MSCB, self).__init__()

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.stride = stride
        self.kernel_sizes = kernel_sizes
        self.expansion_factor = expansion_factor
        self.dw_parallel = dw_parallel
        self.add = add
        self.activation = activation
        self.n_scales = len(self.kernel_sizes)
        # check stride value
        assert self.stride in [1, 2]
        # Skip connection if stride is 1
        self.use_skip_connection = True if self.stride == 1 else False

        # expansion factor
        self.ex_channels = int(self.in_channels * self.expansion_factor)
        self.pconv1 = nn.Sequential(
            # pointwise convolution
            nn.Conv2d(self.in_channels, self.ex_channels, 1, 1, 0, bias=False),
            nn.BatchNorm2d(self.ex_channels),
            act_layer(self.activation, inplace=True),
        )
        self.msdc = MSDC(
            self.ex_channels,
            self.kernel_sizes,
            self.stride,
            self.activation,
            dw_parallel=self.dw_parallel,
        )
        if self.add == True:
            self.combined_channels = self.ex_channels * 1
        else:
            self.combined_channels = self.ex_channels * self.n_scales
        self.pconv2 = nn.Sequential(
            # pointwise convolution
            nn.Conv2d(self.combined_channels, self.out_channels, 1, 1, 0, bias=False),
            nn.BatchNorm2d(self.out_channels),
        )
        if self.use_skip_connection and (self.in_channels != self.out_channels):
            self.conv1x1 = nn.Conv2d(
                self.in_channels, self.out_channels, 1, 1, 0, bias=False
            )
        self.init_weights("normal")

    def init_weights(self, scheme=""):
        named_apply(partial(_init_weights, scheme=scheme), self)

    def forward(self, x):
        pout1 = self.pconv1(x)
        msdc_outs = self.msdc(pout1)
        if self.add == True:
            dout = 0
            for dwout in msdc_outs:
                dout = dout + dwout
        else:
            dout = torch.cat(msdc_outs, dim=1)
        dout = channel_shuffle(dout, gcd(self.combined_channels, self.out_channels))
        out = self.pconv2(dout)
        if self.use_skip_connection:
            if self.in_channels != self.out_channels:
                x = self.conv1x1(x)
            return x + out
        else:
            return out


#   Multi-scale convolution block (MSCB)
def MSCBLayer(
    in_channels,
    out_channels,
    n=1,
    stride=1,
    kernel_sizes=[1, 3, 5],
    expansion_factor=2,
    dw_parallel=True,
    add=True,
    activation="relu6",
):
    """
    create a series of multi-scale convolution blocks.
    """
    convs = []
    mscb = MSCB(
        in_channels,
        out_channels,
        stride,
        kernel_sizes=kernel_sizes,
        expansion_factor=expansion_factor,
        dw_parallel=dw_parallel,
        add=add,
        activation=activation,
    )
    convs.append(mscb)
    if n > 1:
        for i in range(1, n):
            mscb = MSCB(
                out_channels,
                out_channels,
                1,
                kernel_sizes=kernel_sizes,
                expansion_factor=expansion_factor,
                dw_parallel=dw_parallel,
                add=add,
                activation=activation,
            )
            convs.append(mscb)
    conv = nn.Sequential(*convs)
    return conv


#   Efficient up-convolution block (EUCB)
class EUCB(nn.Module):
    def __init__(
        self, in_channels, out_channels, kernel_size=3, stride=1, activation="relu"
    ):
        super(EUCB, self).__init__()

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.up_dwc = nn.Sequential(
            nn.Upsample(scale_factor=2),
            nn.Conv2d(
                self.in_channels,
                self.in_channels,
                kernel_size=kernel_size,
                stride=stride,
                padding=kernel_size // 2,
                groups=self.in_channels,
                bias=False,
            ),
            nn.BatchNorm2d(self.in_channels),
            act_layer(activation, inplace=True),
        )
        self.pwc = nn.Sequential(
            nn.Conv2d(
                self.in_channels,
                self.out_channels,
                kernel_size=1,
                stride=1,
                padding=0,
                bias=True,
            )
        )
        self.init_weights("normal")

    def init_weights(self, scheme=""):
        named_apply(partial(_init_weights, scheme=scheme), self)

    def forward(self, x):
        x = self.up_dwc(x)
        x = channel_shuffle(x, self.in_channels)
        x = self.pwc(x)
        return x


#   Large-kernel grouped attention gate (LGAG)
class LGAG(nn.Module):
    def __init__(self, F_g, F_l, F_int, kernel_size=3, groups=1, activation="relu"):
        super(LGAG, self).__init__()

        if kernel_size == 1:
            groups = 1
        self.W_g = nn.Sequential(
            nn.Conv2d(
                F_g,
                F_int,
                kernel_size=kernel_size,
                stride=1,
                padding=kernel_size // 2,
                groups=groups,
                bias=True,
            ),
            nn.BatchNorm2d(F_int),
        )
        self.W_x = nn.Sequential(
            nn.Conv2d(
                F_l,
                F_int,
                kernel_size=kernel_size,
                stride=1,
                padding=kernel_size // 2,
                groups=groups,
                bias=True,
            ),
            nn.BatchNorm2d(F_int),
        )
        self.psi = nn.Sequential(
            nn.Conv2d(F_int, 1, kernel_size=1, stride=1, padding=0, bias=True),
            nn.BatchNorm2d(1),
            nn.Sigmoid(),
        )
        self.activation = act_layer(activation, inplace=True)

        self.init_weights("normal")

    def init_weights(self, scheme=""):
        named_apply(partial(_init_weights, scheme=scheme), self)

    def forward(self, g, x):
        g1 = self.W_g(g)
        x1 = self.W_x(x)
        psi = self.activation(g1 + x1)
        psi = self.psi(psi)

        return x * psi


#   Channel attention block (CAB)
class CAB(nn.Module):
    def __init__(self, in_channels, out_channels=None, ratio=16, activation="relu"):
        super(CAB, self).__init__()

        self.in_channels = in_channels
        self.out_channels = out_channels
        if self.in_channels < ratio:
            ratio = self.in_channels
        self.reduced_channels = self.in_channels // ratio
        if self.out_channels == None:
            self.out_channels = in_channels

        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveMaxPool2d(1)
        self.activation = act_layer(activation, inplace=True)
        self.fc1 = nn.Conv2d(self.in_channels, self.reduced_channels, 1, bias=False)
        self.fc2 = nn.Conv2d(self.reduced_channels, self.out_channels, 1, bias=False)

        self.sigmoid = nn.Sigmoid()

        self.init_weights("normal")

    def init_weights(self, scheme=""):
        named_apply(partial(_init_weights, scheme=scheme), self)

    def forward(self, x):
        avg_pool_out = self.avg_pool(x)
        avg_out = self.fc2(self.activation(self.fc1(avg_pool_out)))

        max_pool_out = self.max_pool(x)
        max_out = self.fc2(self.activation(self.fc1(max_pool_out)))

        out = avg_out + max_out
        return self.sigmoid(out)


#   Spatial attention block (SAB)
class SAB(nn.Module):
    def __init__(self, kernel_size=7):
        super(SAB, self).__init__()

        assert kernel_size in (3, 7, 11), "kernel must be 3 or 7 or 11"
        padding = kernel_size // 2

        self.conv = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False)

        self.sigmoid = nn.Sigmoid()

        self.init_weights("normal")

    def init_weights(self, scheme=""):
        named_apply(partial(_init_weights, scheme=scheme), self)

    def forward(self, x):
        avg_out = torch.mean(x, dim=1, keepdim=True)
        max_out, _ = torch.max(x, dim=1, keepdim=True)
        x = torch.cat([avg_out, max_out], dim=1)
        x = self.conv(x)
        return self.sigmoid(x)


#   Efficient multi-scale convolutional attention decoding (EMCAD)
class EMCAD(nn.Module):
    def __init__(
        self,
        channels=[512, 320, 128, 64],
        kernel_sizes=[1, 3, 5],
        expansion_factor=6,
        dw_parallel=True,
        add=True,
        lgag_ks=3,
        activation="relu6",
    ):
        super(EMCAD, self).__init__()
        eucb_ks = 3  # kernel size for eucb
        self.mscb4 = MSCBLayer(
            channels[0],
            channels[0],
            n=1,
            stride=1,
            kernel_sizes=kernel_sizes,
            expansion_factor=expansion_factor,
            dw_parallel=dw_parallel,
            add=add,
            activation=activation,
        )

        self.eucb3 = EUCB(
            in_channels=channels[0],
            out_channels=channels[1],
            kernel_size=eucb_ks,
            stride=eucb_ks // 2,
        )
        self.lgag3 = LGAG(
            F_g=channels[1],
            F_l=channels[1],
            F_int=channels[1] // 2,
            kernel_size=lgag_ks,
            groups=channels[1] // 2,
        )
        self.mscb3 = MSCBLayer(
            channels[1],
            channels[1],
            n=1,
            stride=1,
            kernel_sizes=kernel_sizes,
            expansion_factor=expansion_factor,
            dw_parallel=dw_parallel,
            add=add,
            activation=activation,
        )

        self.eucb2 = EUCB(
            in_channels=channels[1],
            out_channels=channels[2],
            kernel_size=eucb_ks,
            stride=eucb_ks // 2,
        )
        self.lgag2 = LGAG(
            F_g=channels[2],
            F_l=channels[2],
            F_int=channels[2] // 2,
            kernel_size=lgag_ks,
            groups=channels[2] // 2,
        )
        self.mscb2 = MSCBLayer(
            channels[2],
            channels[2],
            n=1,
            stride=1,
            kernel_sizes=kernel_sizes,
            expansion_factor=expansion_factor,
            dw_parallel=dw_parallel,
            add=add,
            activation=activation,
        )

        self.eucb1 = EUCB(
            in_channels=channels[2],
            out_channels=channels[3],
            kernel_size=eucb_ks,
            stride=eucb_ks // 2,
        )
        self.lgag1 = LGAG(
            F_g=channels[3],
            F_l=channels[3],
            F_int=int(channels[3] / 2),
            kernel_size=lgag_ks,
            groups=int(channels[3] / 2),
        )
        self.mscb1 = MSCBLayer(
            channels[3],
            channels[3],
            n=1,
            stride=1,
            kernel_sizes=kernel_sizes,
            expansion_factor=expansion_factor,
            dw_parallel=dw_parallel,
            add=add,
            activation=activation,
        )

        self.cab4 = CAB(channels[0])
        self.cab3 = CAB(channels[1])
        self.cab2 = CAB(channels[2])
        self.cab1 = CAB(channels[3])

        self.sab = SAB()

    def forward(self, x, skips):

        # MSCAM4
        d4 = self.cab4(x) * x
        d4 = self.sab(d4) * d4
        d4 = self.mscb4(d4)

        # EUCB3
        d3 = self.eucb3(d4)

        # LGAG3
        x3 = self.lgag3(g=d3, x=skips[0])

        # Additive aggregation 3
        d3 = d3 + x3

        # MSCAM3
        d3 = self.cab3(d3) * d3
        d3 = self.sab(d3) * d3
        d3 = self.mscb3(d3)

        # EUCB2
        d2 = self.eucb2(d3)

        # LGAG2
        x2 = self.lgag2(g=d2, x=skips[1])

        # Additive aggregation 2
        d2 = d2 + x2

        # MSCAM2
        d2 = self.cab2(d2) * d2
        d2 = self.sab(d2) * d2
        d2 = self.mscb2(d2)

        # EUCB1
        d1 = self.eucb1(d2)

        # LGAG1
        x1 = self.lgag1(g=d1, x=skips[2])

        # Additive aggregation 1
        d1 = d1 + x1

        # MSCAM1
        d1 = self.cab1(d1) * d1
        d1 = self.sab(d1) * d1
        d1 = self.mscb1(d1)

        return [d4, d3, d2, d1]



### IPRM

In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F


class IPRM_Net(nn.Module):
    def __init__(
        self,
        in_chans=12,
        encoder="transnext_base",
        out_chans=1,
        drop_rate=0.3,
        img_size=256,
        **kwargs,
    ):
        super(IPRM_Net, self).__init__()

        # Initialize encoder
        self.backbone = transnext_base(
            pretrained=None,
            img_size=img_size,
            pretrain_size=img_size,
            is_extrapolation=False,
            in_chans=in_chans,
            drop_rate=drop_rate,
        )
        encoder_out_chans = [768, 384, 192, 96]

        #   decoder initialization
        self.decoder = EMCAD(
            channels=encoder_out_chans,
            kernel_sizes=[1, 3, 5],
            expansion_factor=2,
            dw_parallel=True,
            add=True,
            lgag_ks=3,
        )

        # Initialize output head and final convolutional layers
        self.out_head1 = nn.Conv2d(encoder_out_chans[3], encoder_out_chans[3], 1)
        self.conv_final = nn.Sequential(
            nn.Conv2d(encoder_out_chans[3] + in_chans, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.Conv2d(128, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.Conv2d(128, out_chans, kernel_size=3, padding=1),
        )

    def forward(self, input_image, frequency):

        freqs_GHz = torch.tensor([0.868, 1.8, 3.5, 2.4]).to(input_image.device)
        freq_embedding = freqs_GHz[frequency].unsqueeze(-1).unsqueeze(-1) * torch.ones(
            input_image.size(0), 1, input_image.size(2), input_image.size(3)
        ).to(input_image.device)
        input_feat = torch.cat([input_image, freq_embedding], dim=1)

        # encoder
        x1, x2, x3, x4 = self.backbone(input_feat)

        # decoder
        dec_outs = self.decoder(x4, [x3, x2, x1])

        # prediction heads
        p1 = self.out_head1(dec_outs[3])

        p1 = F.interpolate(p1, scale_factor=4, mode="bilinear")

        out = torch.cat([p1, input_feat], dim=1)
        out = self.conv_final(out)
        out = F.sigmoid(out)

        return out

### PathFomer

In [4]:
class PathFormer(nn.Module):
    def __init__(
        self,
        in_chans=3,
        encoder="transnext_tiny",
        img_size=256,
        features_dim=128,
    ):
        super(PathFormer, self).__init__()

        self.backbone, backbone_out_channels = self._initialize_encoder(
            encoder=encoder, in_chans=in_chans, img_size=img_size
        )
        backbone_out_channels = backbone_out_channels[::-1]

        #   decoder initialization
        self.decoder = EMCAD(
            channels=backbone_out_channels,
            kernel_sizes=[1, 3, 5],
            expansion_factor=2,
            dw_parallel=True,
            add=True,
            lgag_ks=3,
            activation="relu6",
        )

        # Initialize output head and final convolutional layers
        self.out_head1 = nn.Conv2d(backbone_out_channels[3], features_dim, 1)

        self.conv_final = nn.Sequential(
            nn.Conv2d(features_dim + in_chans, features_dim, kernel_size=3, padding=1),
            # nn.InstanceNorm2d(features_dim),
            nn.BatchNorm2d(features_dim),
            nn.LeakyReLU(negative_slope=1e-2, inplace=True),
            nn.Conv2d(features_dim, features_dim, kernel_size=3, padding=1),
            # nn.InstanceNorm2d(features_dim),
            nn.BatchNorm2d(features_dim),
            nn.LeakyReLU(negative_slope=1e-2, inplace=True),
            nn.Conv2d(features_dim, 1, kernel_size=3, padding=1),
        )

    def _initialize_encoder(
        self,
        encoder,
        in_chans,
        img_size,
    ):
        if encoder == "transnext_tiny":
            backbone = transnext_tiny(
                pretrained=None,
                img_size=img_size,
                pretrain_size=img_size,
                is_extrapolation=False,
                in_chans=in_chans,
            )
            encoder_out_chans = [72, 144, 288, 576]

        elif encoder == "transnext_small":
            backbone = transnext_small(
                pretrained=None,
                img_size=img_size,
                pretrain_size=img_size,
                is_extrapolation=False,
                in_chans=in_chans,
            )
            encoder_out_chans = [72, 144, 288, 576]
        elif encoder == "transnext_base":
            backbone = transnext_base(
                pretrained=None,
                img_size=img_size,
                pretrain_size=img_size,
                is_extrapolation=False,
                in_chans=in_chans,
            )
            encoder_out_chans = [96, 192, 384, 768]
        else:
            raise ValueError(f"Encoder {encoder} not supported")

        return backbone, encoder_out_chans

    def forward(self, input_feat, test=False):
        first_stage_output = input_feat[:, -1:, :, :]
        # encoder
        x1, x2, x3, x4 = self.backbone(input_feat)
        # decoder
        dec_outs = self.decoder(x4, [x3, x2, x1])

        # prediction heads
        decode_feat = self.out_head1(dec_outs[3])

        # bilinear interpolation
        decode_feat = F.interpolate(
            decode_feat, size=input_feat.shape[-2:], mode="bilinear", align_corners=True
        )

        final_out = torch.cat([decode_feat, input_feat], dim=1)
        final_out = self.conv_final(final_out)
        final_out = final_out + first_stage_output
        if test:
            return final_out, first_stage_output
        return (final_out)

## Datasets

In [27]:
import torch
import numpy as np
from torch.utils.data import Dataset, DataLoader
from PIL import Image
from torchvision import transforms
import os
import pandas as pd

frequences = [0.868, 1.8, 3.5, 2.4]


def compute_fspl_map(distance_map, frequency):
    # FSPL = 20*log10(d) + 20*log10(f) + 32.44
    f = frequency * 1e9
    c = 3e8
    fspl_map = 20 * np.log10(distance_map + 1e-6) + 20 * np.log10(f) - 147.55
    return fspl_map


def split_file_name(file_name):
    split = file_name.split("_")
    building_id = split[0]
    antenna_id = split[1]
    frequency = split[2]
    scenario = split[3].split(".")[0][1:]
    return building_id, antenna_id, frequency, scenario


def compute_ray_throughput(distance_map, reflectance, transmittance, center_tx):
    y, x = np.indices(distance_map.shape)
    angles = np.arctan2(y - center_tx[0], x - center_tx[1]) * 180 / np.pi
    angles = (angles + 360) % 360
    angles = angles.astype(int)

    # Initialize these arrays based on angles
    transmittance_map = np.ones_like(distance_map, dtype=float) * -1
    reflectance_map = np.ones_like(distance_map, dtype=float) * -1
    direct_path = np.ones_like(distance_map, dtype=float) * -1

    # Initialize counters for each angle
    angle_throughput = np.zeros(361, dtype=int)
    transmittance_angle_throughput = np.zeros(361, dtype=int)
    reflectance_angle_throughput = np.zeros(361, dtype=int)

    # Update counters based on conditions
    for i in range(361):
        angle_mask = angles == i
        # Unique distances per angle for efficiency
        unique_distances = np.unique(distance_map[angle_mask])
        for distance in unique_distances:
            mask = angle_mask & (distance_map == distance)
            reflectance_non_zero = reflectance[mask] != 0
            transmittance_non_zero = transmittance[mask] != 0

            transmittance_map[mask] = transmittance_angle_throughput[i]
            reflectance_map[mask] = reflectance_angle_throughput[i]
            direct_path[mask] = angle_throughput[i]

            # Update counters and maps
            if np.any(reflectance_non_zero) or np.any(transmittance_non_zero):
                angle_throughput[i] += 1
            if np.any(transmittance_non_zero):
                transmittance_angle_throughput[i] += 1
            if np.any(reflectance_non_zero):
                reflectance_angle_throughput[i] += 1
    return transmittance_map, reflectance_map, direct_path


class RadioMapDatasetStage1(Dataset):
    def __init__(
        self,
        file_names,
        dataset_dir,
        resize_size,
        task_name="Task_3_ICASSP",
        transform=None,
        test_set=False,
        with_augmentation=True,
    ):
        self.file_names = file_names
        self.dataset_dir = dataset_dir
        self.transform = transform or self.get_transform(resize_size, with_augmentation)
        self.position_infos = self.read_position_infos()
        self.single_antenna_info = self.read_single_antenna_info()
        self.test_set = test_set
        self.task_name = task_name

    def read_single_antenna_info(self):
        single_antenna_info = {}
        path = os.path.join(self.dataset_dir, "Radiation_Patterns")
        for antenna_info_file in os.listdir(path):
            antenna_id = antenna_info_file.split("_")[0]
            file_path = os.path.join(path, antenna_info_file)
            with open(file_path, "r") as file:
                single_antenna_info[antenna_id] = np.loadtxt(file)
        return single_antenna_info

    def read_position_infos(self):
        position_infos = {}
        path = os.path.join(self.dataset_dir, "Positions")
        for position_csv_file in os.listdir(path):
            file_path = os.path.join(path, position_csv_file)
            building_id, antenna_id, frequency = (
                self.extract_ids_from_position_filename(position_csv_file)
            )
            with open(file_path, "r") as file:
                for line in file:
                    if line.strip():
                        id, x, y, Azimuth = line.split(",")
                        if id == "":
                            continue
                        position_num = np.array([float(x), float(y), float(Azimuth)])
                        position_infos[(building_id, antenna_id, frequency, id)] = (
                            position_num
                        )
        return position_infos

    def extract_ids_from_position_filename(self, filename):
        parts = filename.split("_")
        return parts[1], parts[2], parts[3].split(".")[0]

    def __len__(self):
        return len(self.file_names)

    def __getitem__(self, idx):
        file_name = self.file_names[idx]
        meta_info = self.get_metadata(file_name)
        input_image, output_image = self.load_images(file_name)

        reflectance, transmittance, distance_map = (
            input_image[:, :, 0],
            input_image[:, :, 1],
            input_image[:, :, 2],
        )

        ray_encode_image_path = os.path.join(
            self.dataset_dir,
            "Ray_Encode_Images",
            self.task_name,
            file_name.split(".")[0] + ".npy",
        )

        if os.path.exists(ray_encode_image_path):
            ray_encode_image = np.load(ray_encode_image_path)
        else:
            transmittance_map, reflectance_map, direct_path = compute_ray_throughput(
                distance_map,
                reflectance,
                transmittance,
                meta_info["antenna_position"][:2].numpy(),
            )

            ray_encode_image = np.concatenate(
                (
                    reflectance_map[:, :, np.newaxis],
                    transmittance_map[:, :, np.newaxis],
                    direct_path[:, :, np.newaxis],
                ),
                axis=2,
            )
            os.makedirs(os.path.dirname(ray_encode_image_path), exist_ok=True)
            np.save(ray_encode_image_path, ray_encode_image)

        fspl_map_path = os.path.join(
            self.dataset_dir,
            "Inputs",
            "FSPL_Map",
            file_name.split(".")[0] + ".npy",
        )
        if os.path.exists(fspl_map_path):
            fspl_map = np.load(fspl_map_path)
        else:
            fspl_map = compute_fspl_map(
                distance_map * 0.25 + 0.125, frequences[meta_info["frequency"].item()]
            )
            os.makedirs(os.path.dirname(fspl_map_path), exist_ok=True)
            np.save(fspl_map_path, fspl_map)

        antenna_map = meta_info["antenna_map"]
        angle_map = meta_info["angle_map"]

        img_x = antenna_map.shape[0]
        img_y = antenna_map.shape[1]
        GirdEmbedding_x = (
            np.arange(0, img_x, 1).repeat(img_y).reshape(img_x, img_y) / img_x
        )
        GirdEmbedding_y = (
            np.arange(0, img_y, 1).repeat(img_x).reshape(img_y, img_x).transpose()
            / img_y
        )

        combined_image = np.concatenate(
            (
                input_image,
                fspl_map[:, :, np.newaxis],
                ray_encode_image,
                antenna_map[:, :, np.newaxis],
                angle_map[:, :, np.newaxis],
                GirdEmbedding_x[:, :, np.newaxis],
                GirdEmbedding_y[:, :, np.newaxis],
                output_image[:, :, np.newaxis],
            ),
            axis=2,
        )
        combined_image = torch.tensor(combined_image, dtype=torch.float32).permute(
            2, 0, 1
        )

        calibration = torch.tensor(
            [
                20.0,
                25.0,
                255.0,
                160.0,
                10.0,
                10.0,
                10.0,
                5.0,
                360.0,
                1.0,
                1.0,
                160.0,
            ],
            dtype=torch.float32,
        ).reshape(12, 1, 1)
        combined_image = combined_image / calibration

        combined_image = self.transform(combined_image)

        input_image = combined_image[:-1, :, :]
        output_image = combined_image[-1:, :, :]
        images = {
            "input_image": input_image,
            "output_image": output_image,
        }

        return {**images, **meta_info}

    def load_images(self, file_name):
        input_path = os.path.join(self.dataset_dir, "Inputs", self.task_name, file_name)
        output_path = os.path.join(
            self.dataset_dir, "Outputs", self.task_name, file_name
        )
        input_image = Image.open(input_path).convert("RGB")
        if self.test_set:
            output_image = Image.open(input_path).convert("L")
        else:
            output_image = Image.open(output_path).convert("L")

        return np.array(input_image), np.array(output_image)

    def get_metadata(self, file_name):
        building_id, antenna_id, frequency, scenario = split_file_name(file_name)
        antenna_position = self.position_infos[
            (building_id, antenna_id, frequency, scenario)
        ]
        signal_conditions = self.single_antenna_info[antenna_id]

        antenna_map_path = os.path.join(
            self.dataset_dir,
            'Inputs',
            'Antenna_Map',
            file_name.split(".")[0] + ".npy",
        )
        angle_map_path = os.path.join(
            self.dataset_dir,
            'Inputs',
            'Angle_Map',
            file_name.split(".")[0] + ".npy",
        )
        if os.path.exists(antenna_map_path):
            antenna_map = np.load(antenna_map_path)
            angle_map = np.load(angle_map_path)
        else:
            position_csv_path = os.path.join(
                self.dataset_dir,
                "Positions",
                "Positions_" + building_id + "_" + antenna_id + "_" + frequency + ".csv",
            )
            Sampling_positions = pd.read_csv(position_csv_path)
            if self.test_set:
                building_detail_path = os.path.join(
                    self.dataset_dir,
                    "Building_Details",
                    "T_" + building_id + "_Details.csv",
                )
            else:
                building_detail_path = os.path.join(
                    self.dataset_dir, "Building_Details", building_id + "_Details.csv"
                )
            Building_Details = pd.read_csv(building_detail_path)
            W, H = Building_Details["W"].iloc[0], Building_Details["H"].iloc[0]

            X_points = (
                np.repeat(np.linspace(0, W - 1, W), H, axis=0).reshape(W, H).transpose()
            )
            Y_points = np.repeat(np.linspace(0, H - 1, H), W, axis=0).reshape(H, W)

            scenario = int(scenario)
            Antenna_Azimuth_Pattern = np.array(signal_conditions)
            x_ant = Sampling_positions["Y"].loc[scenario]
            y_ant = Sampling_positions["X"].loc[scenario]

            Angles = (
                -(180 / np.pi) * np.arctan2((y_ant - Y_points), (x_ant - X_points))
                + 180
                + Sampling_positions["Azimuth"].iloc[scenario]
            )
            Angles = np.where(Angles > 359, Angles - 360, Angles).astype(int)
            antenna_map = Antenna_Azimuth_Pattern[Angles]
            angle_map = Angles
            
            os.makedirs(os.path.dirname(antenna_map_path), exist_ok=True)
            os.makedirs(os.path.dirname(angle_map_path), exist_ok=True)
            np.save(antenna_map_path, antenna_map)
            np.save(angle_map_path, angle_map)
        
        

        frequency_index = {"f1": 0, "f2": 3 if self.test_set else 1, "f3": 2}
        frequency = torch.tensor([frequency_index[frequency]], dtype=torch.long)

        return {
            "antenna_position": torch.tensor(antenna_position, dtype=torch.float32),
            "signal_conditions": torch.tensor(signal_conditions, dtype=torch.float32),
            "frequency": frequency,
            "original_output_path": os.path.join(
                self.dataset_dir, "Outputs", self.task_name, file_name
            ),
            "original_input_path": os.path.join(
                self.dataset_dir, "Inputs", self.task_name, file_name
            ),
            "antenna_map": torch.tensor(antenna_map, dtype=torch.float32),
            "angle_map": torch.tensor(angle_map, dtype=torch.float32),
        }

    def get_transform(self, resize_size, with_augmentation=True):
        if with_augmentation:
            transform_list = [
                # transforms.Resize(RESIZE_SIZE, interpolation=Image.NEAREST),
                transforms.Resize(
                    (resize_size, resize_size), interpolation=Image.BICUBIC
                ),
                transforms.RandomHorizontalFlip(0.5),
                transforms.RandomVerticalFlip(0.5),
                transforms.RandomRotation(360),
            ]
        else:
            transform_list = [
                transforms.Resize(
                    (resize_size, resize_size), interpolation=Image.BICUBIC
                ),
            ]
        return transforms.Compose(transform_list)

In [28]:
import cv2


def get_line_pixels(point1, point2):
    """
    Get all the pixel coordinates that lie along the line between two points using Bresenham's line algorithm.

    :param point1: Tuple[int, int], start point (x1, y1)
    :param point2: Tuple[int, int], end point (x2, y2)
    :return: List[Tuple[int, int]], list of pixel coordinates along the line
    """
    x1, y1 = point1
    x2, y2 = point2
    pixels = []

    # Calculate the differences
    dx = abs(x2 - x1)
    dy = abs(y2 - y1)
    sx = 1 if x1 < x2 else -1
    sy = 1 if y1 < y2 else -1
    err = dx - dy

    while True:
        pixels.append((x1, y1))

        if x1 == x2 and y1 == y2:
            break

        e2 = 2 * err
        if e2 > -dy:
            err -= dy
            x1 += sx
        if e2 < dx:
            err += dx
            y1 += sy

    return pixels


def compute_transmittance_encode(dataset_dir, file_name, transmittance):
    building_id, antenna_id, frequency, scenario = split_file_name(file_name)

    Sampling_positions = pd.read_csv(
        os.path.join(
            dataset_dir,
            "Positions",
            "Positions_" + building_id + "_" + antenna_id + "_" + frequency + ".csv",
        )
    )

    x_ant = Sampling_positions["Y"].loc[int(scenario)]
    y_ant = Sampling_positions["X"].loc[int(scenario)]

    point1 = (x_ant, y_ant)

    fake_transmittance_map = np.zeros_like(transmittance).astype(np.float16)

    for x in range(transmittance.shape[1]):
        for y in range(transmittance.shape[0]):
            point2 = (x, y)

            pixels_on_line = get_line_pixels(point1, point2)

            point2_transmittance = 0.0
            for pixel in pixels_on_line:
                pixel_x, pixel_y = pixel
                if (
                    pixel_x < 0
                    or pixel_y < 0
                    or pixel_x >= transmittance.shape[1]
                    or pixel_y >= transmittance.shape[0]
                ):
                    continue
                point2_transmittance = (
                    point2_transmittance + transmittance[pixel_y, pixel_x].astype(np.float16)
                )
            fake_transmittance_map[y, x] = point2_transmittance

    return fake_transmittance_map


class RadioMapDatasetStage2(Dataset):
    def __init__(
        self,
        file_names,
        dataset_dir,
        resize_size,
        task_name="Task_3_ICASSP",
        transform=None,
        test_set=False,
        with_augmentation=True,
        CALIBRIATE_FACTOR=180.0,
    ):
        self.file_names = file_names
        self.dataset_dir = dataset_dir
        self.resize_size = resize_size
        self.transform = transform or self.get_transform(resize_size, with_augmentation)
        self.position_infos = self.read_position_infos()
        self.single_antenna_info = self.read_single_antenna_info()
        self.test_set = test_set
        self.CALIBRIATE_FACTOR = CALIBRIATE_FACTOR
        self.task_name = task_name

    def read_single_antenna_info(self):
        single_antenna_info = {}
        path = os.path.join(self.dataset_dir, "Radiation_Patterns")
        for antenna_info_file in os.listdir(path):
            antenna_id = antenna_info_file.split("_")[0]
            file_path = os.path.join(path, antenna_info_file)
            with open(file_path, "r") as file:
                single_antenna_info[antenna_id] = np.loadtxt(file)
        return single_antenna_info

    def read_position_infos(self):
        position_infos = {}
        path = os.path.join(self.dataset_dir, "Positions")
        for position_csv_file in os.listdir(path):
            file_path = os.path.join(path, position_csv_file)
            building_id, antenna_id, frequency = (
                self.extract_ids_from_position_filename(position_csv_file)
            )
            with open(file_path, "r") as file:
                for line in file:
                    if line.strip():
                        id, x, y, Azimuth = line.split(",")
                        if id == "":
                            continue
                        position_num = np.array([float(x), float(y), float(Azimuth)])
                        position_infos[(building_id, antenna_id, frequency, id)] = (
                            position_num
                        )
        return position_infos

    def extract_ids_from_position_filename(self, filename):
        parts = filename.split("_")
        return parts[1], parts[2], parts[3].split(".")[0]

    def __len__(self):
        return len(self.file_names)

    def __getitem__(self, idx):
        file_name = self.file_names[idx]
        meta_info = self.get_metadata(file_name)
        input_image, output_image = self.load_images(file_name)

        reflectance, transmittance, distance_map = (
            input_image[:, :, 0],
            input_image[:, :, 1],
            input_image[:, :, 2],
        )

        fspl_map_path = os.path.join(
            self.dataset_dir,
            "Inputs",
            "FSPL_Map",
            file_name.split(".")[0] + ".npy",
        )
        if os.path.exists(fspl_map_path):
            fspl_map = np.load(fspl_map_path)
        else:
            fspl_map = compute_fspl_map(
                distance_map * 0.25 + 0.125, frequences[meta_info["frequency"].item()]
            )
            os.makedirs(os.path.dirname(fspl_map_path), exist_ok=True)
            np.save(fspl_map_path, fspl_map)

        antenna_map = meta_info["antenna_map"]
        angle_map = meta_info["angle_map"]
        
        transmittance_encode_path = os.path.join(
            self.dataset_dir,
            "Inputs",
            "Transmittance_Encode",
            file_name.split(".")[0] + ".npy",
        )
        if os.path.exists(transmittance_encode_path):
            transmittance_encode = np.load(transmittance_encode_path)
        else:
            transmittance_encode = compute_transmittance_encode(self.dataset_dir, file_name, transmittance).clip(0, self.CALIBRIATE_FACTOR)
            os.makedirs(os.path.dirname(transmittance_encode_path), exist_ok=True)
            np.save(transmittance_encode_path, transmittance_encode)

        first_stage_path = os.path.join(
            self.dataset_dir,
            "Inputs",
            "stage1_task1_output",
            file_name.split(".")[0] + ".npy",
        )
        if not os.path.exists(first_stage_path):
            first_stage_path = os.path.join(
            self.dataset_dir,
            "Inputs",
            "stage1_task3_output",
            file_name.split(".")[0] + ".npy",
            )
            if not os.path.exists(first_stage_path):
                raise ValueError(f"First stage output not found for {file_name}")
        first_stage_output = np.load(first_stage_path)

        img_x = antenna_map.shape[0]
        img_y = antenna_map.shape[1]
        GirdEmbedding_x = (
            np.arange(0, img_x, 1).repeat(img_y).reshape(img_x, img_y) / img_x
        )
        GirdEmbedding_y = (
            np.arange(0, img_y, 1).repeat(img_x).reshape(img_y, img_x).transpose()
            / img_y
        )

        frequency = meta_info["frequency"]
        freqs_GHz = torch.tensor([0.868, 1.8, 3.5, 2.4])
        freq_embedding = freqs_GHz[frequency].unsqueeze(-1) * torch.ones(
            img_x, img_y, 1
        )

        reflectance = cv2.GaussianBlur(reflectance, (5, 5), 0)
        transmittance = cv2.GaussianBlur(transmittance, (5, 5), 0)

        combined_image = np.concatenate(
            (
                reflectance[:, :, np.newaxis],
                transmittance[:, :, np.newaxis],
                distance_map[:, :, np.newaxis],
                fspl_map[:, :, np.newaxis],
                transmittance_encode[:, :, np.newaxis],
                antenna_map[:, :, np.newaxis],
                angle_map[:, :, np.newaxis],
                GirdEmbedding_x[:, :, np.newaxis],
                GirdEmbedding_y[:, :, np.newaxis],
                freq_embedding,
                first_stage_output[:, :, np.newaxis],
                output_image[:, :, np.newaxis],
            ),
            axis=2,
        )
        combined_image = torch.tensor(combined_image, dtype=torch.float32).permute(
            2, 0, 1
        )

        calibration = torch.tensor(
            [
                30.0,  # reflectance
                30.0,  # transmittance
                255.0,  # distance
                self.CALIBRIATE_FACTOR,  # fspl_map
                self.CALIBRIATE_FACTOR,  # transmittance_encode
                -40.0,  # antenna_map
                360.0,  # angle_map
                1.0,  # GirdEmbedding_x
                1.0,  # GirdEmbedding_y
                4.0,  # freq_embedding
                self.CALIBRIATE_FACTOR,  # first_stage_output
                self.CALIBRIATE_FACTOR,  # output_image
            ],
            dtype=torch.float32,
        ).reshape(12, 1, 1)
        combined_image = combined_image / calibration  # Normalize

        combined_image = self.transform(combined_image)

        input_image = combined_image[:-1, :, :]
        output_image = combined_image[-1:, :, :]
        images = {
            "input_feat": input_image,
            "target_feat": output_image,
        }

        return {**images, **meta_info}

    def load_images(self, file_name):
        input_path = os.path.join(self.dataset_dir, "Inputs", self.task_name, file_name)
        output_path = os.path.join(
            self.dataset_dir, "Outputs", self.task_name, file_name
        )
        input_image = Image.open(input_path).convert("RGB")
        if self.test_set:
            output_image = Image.open(input_path).convert("L")
        else:
            output_image = Image.open(output_path).convert("L")

        return np.array(input_image), np.array(output_image)

    def get_metadata(self, file_name):
        building_id, antenna_id, frequency, scenario = split_file_name(file_name)
        antenna_position = self.position_infos[
            (building_id, antenna_id, frequency, scenario)
        ]
        signal_conditions = self.single_antenna_info[antenna_id]
            
        antenna_map_path = os.path.join(
            self.dataset_dir,
            'Inputs',
            'Antenna_Map',
            file_name.split(".")[0] + ".npy",
        )
        angle_map_path = os.path.join(
            self.dataset_dir,
            'Inputs',
            'Angle_Map',
            file_name.split(".")[0] + ".npy",
        )
        if os.path.exists(antenna_map_path):
            antenna_map = np.load(antenna_map_path)
            angle_map = np.load(angle_map_path)
        else:
            Sampling_positions = pd.read_csv(
                os.path.join(
                    self.dataset_dir,
                    "Positions",
                    "Positions_"
                    + building_id
                    + "_"
                    + antenna_id
                    + "_"
                    + frequency
                    + ".csv",
                )
            )
            if self.test_set:
                building_detail_path = os.path.join(
                    self.dataset_dir,
                    "Building_Details",
                    "T_" + building_id + "_Details.csv",
                )
            else:
                building_detail_path = os.path.join(
                    self.dataset_dir, "Building_Details", building_id + "_Details.csv"
                )
            Building_Details = pd.read_csv(building_detail_path)
            W, H = Building_Details["W"].iloc[0], Building_Details["H"].iloc[0]

            X_points = (
                np.repeat(np.linspace(0, W - 1, W), H, axis=0).reshape(W, H).transpose()
            )
            Y_points = np.repeat(np.linspace(0, H - 1, H), W, axis=0).reshape(H, W)

            scenario = int(scenario)
            Antenna_Azimuth_Pattern = np.array(signal_conditions)
            x_ant = Sampling_positions["Y"].loc[scenario]
            y_ant = Sampling_positions["X"].loc[scenario]

            Angles = (
                -(180 / np.pi) * np.arctan2((y_ant - Y_points), (x_ant - X_points))
                + 180
                + Sampling_positions["Azimuth"].iloc[scenario]
            )
            Angles = np.where(Angles > 359, Angles - 360, Angles).astype(int)
            antenna_map = Antenna_Azimuth_Pattern[Angles]
            angle_map = Angles
            
            os.makedirs(os.path.dirname(antenna_map_path), exist_ok=True)
            os.makedirs(os.path.dirname(angle_map_path), exist_ok=True)
            np.save(antenna_map_path, antenna_map)
            np.save(angle_map_path, angle_map)

        frequency_index = {"f1": 0, "f2": 3 if self.test_set else 1, "f3": 2}
        frequency = torch.tensor([frequency_index[frequency]], dtype=torch.long)

        return {
            "antenna_position": torch.tensor(antenna_position, dtype=torch.float32),
            "signal_conditions": torch.tensor(signal_conditions, dtype=torch.float32),
            "frequency": frequency,
            "original_output_path": os.path.join(
                self.dataset_dir, "Outputs", self.task_name, file_name
            ),
            "original_input_path": os.path.join(
                self.dataset_dir, "Inputs", self.task_name, file_name
            ),
            "antenna_map": torch.tensor(antenna_map, dtype=torch.float32),
            "angle_map": torch.tensor(angle_map, dtype=torch.float32),
        }

    def get_transform(self, resize_size, with_augmentation=True):
        if with_augmentation:
            transform_list = [
                transforms.Resize(
                    (resize_size, resize_size), interpolation=Image.BICUBIC
                ),
                transforms.RandomHorizontalFlip(0.5),
                transforms.RandomVerticalFlip(0.5),
                transforms.RandomRotation(360),
            ]
        else:
            transform_list = [
                transforms.Resize(
                    (resize_size, resize_size), interpolation=Image.BICUBIC
                ),
            ]
        return transforms.Compose(transform_list)

## Training Models

In [9]:
from tqdm import tqdm


@torch.no_grad()
def evaluate(
    eval_model,
    data_loader,
    device,
    epoch=None,
    test=False,
    writer=None,
    CALIBRIATE_FACTOR=180.0,
):
    eval_model.eval()
    total_mse = 0
    total_resize_mse = 0
    total_rmse = 0
    total_num = 0
    with torch.no_grad():
        for batch_i, batch in enumerate(
            tqdm(data_loader, total=len(data_loader), desc="Evaluation")
        ):
            inputs = batch["input_feat"].to(device)
            targets = batch["target_feat"].to(device)
            if test:
                _, outputs = eval_model(inputs, test=True)
            else:
                outputs = eval_model(inputs)

            if epoch is not None and batch_i == 0:
                show_batch = torch.cat([outputs, targets], dim=2).cpu().numpy()

                if writer is not None:
                    writer.add_images(
                        "eval_output_target",
                        show_batch,
                        global_step=epoch,
                        dataformats="NCHW",
                    )

            for i in range(inputs.size(0)):
                output = outputs[i]
                target = targets[i]

                total_num += 1

                rmse = (output - target).pow(2).mean().sqrt().item()
                total_rmse += rmse

                mse = (
                    (output * CALIBRIATE_FACTOR - target * CALIBRIATE_FACTOR) ** 2
                ).mean()
                total_mse += mse.item()

                original_output = Image.open(batch["original_output_path"][i]).convert(
                    "L"
                )
                original_output = (
                    torch.tensor(np.array(original_output)).to(device).float()
                )

                resize_output = F.interpolate(
                    output.unsqueeze(0),
                    size=original_output.shape,
                    mode="bilinear",
                    align_corners=False,
                    antialias=True,
                ).squeeze()
                resize_mse = (
                    (resize_output * CALIBRIATE_FACTOR - original_output) ** 2
                ).mean()
                total_resize_mse += resize_mse.item()

    return total_resize_mse / total_num, total_mse / total_num, total_rmse / total_num


def train(
    model,
    train_loader,
    test_loader,
    device,
    optimizer,
    scheduler,
    criterion,
    num_epochs,
    stop_mse,
    save_dir,
    best_eval_mse=float("inf"),
    writer=None,
):
    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        progress_bar = tqdm(
            enumerate(train_loader),
            total=len(train_loader),
            desc=f"Epoch {epoch+1}/{num_epochs}",
        )

        for i, batch in progress_bar:
            inputs, targets = (
                batch["input_feat"].to(device),
                batch["target_feat"].to(device),
            )

            outputs = model(inputs)
            loss = criterion(outputs, targets)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            if scheduler is not None:
                scheduler.step()

            running_loss += loss.item()
            if writer is not None:
                writer.add_scalar(
                    "Loss/train", loss.item(), epoch * len(train_loader) + i
                )
            show_loss = running_loss / (i + 1)
            progress_bar.set_postfix({"loss": show_loss, "rmse": show_loss**0.5})

        print(
            f"Epoch {epoch+1}/{num_epochs}: loss: {show_loss:.4f}, rmse: {show_loss**0.5:.4f}"
        )
        if writer is not None:
            writer.add_scalar("Loss/train_rmse", show_loss**0.5, epoch)

        eval_resize_mse, eval_mse, eval_rmse = evaluate(
            model, test_loader, device, epoch
        )

        print(
            f"Epoch {epoch+1}/{num_epochs}: eval resize mse: {eval_resize_mse:.4f}, eval mse: {eval_mse:.4f}, eval rmse: {eval_rmse:.4f}"
        )
        if writer is not None:
            writer.add_scalar("Loss/eval_mse", eval_mse, epoch)
            writer.add_scalar("Loss/eval_rmse", eval_rmse, epoch)
            writer.add_scalar("Loss/eval_resize_mse", eval_resize_mse, epoch)

        if eval_mse < best_eval_mse:
            best_model_save_path = os.path.join(save_dir, "best_model.pth")
            print(f"New best model saved to {best_model_save_path}")
            torch.save(model.state_dict(), best_model_save_path)
            best_eval_mse = eval_mse

            if eval_mse < stop_mse:
                break

    print("Training complete.")
    final_model_save_path = os.path.join(save_dir, "final_model.pth")
    torch.save(model.state_dict(), final_model_save_path)


@torch.no_grad()
def run_post_process_forward_stage1(model, sample, device="cuda"):
    inputs = sample["input_image"].unsqueeze(0).to(device)
    targets = sample["output_image"].unsqueeze(0).to(device)
    frequency = sample["frequency"].unsqueeze(0).to(device)

    # Apply augmentations
    inputs_rot90 = torch.rot90(inputs, 1, [2, 3])
    inputs_rot180 = torch.rot90(inputs, 2, [2, 3])
    inputs_rot270 = torch.rot90(inputs, 3, [2, 3])
    inputs_vflip = torch.flip(inputs, [2])
    inputs_hflip = torch.flip(inputs, [3])
    inputs_rot90_vflip = torch.flip(torch.rot90(inputs, 1, [2, 3]), [2])
    inputs_rot90_hflip = torch.flip(torch.rot90(inputs, 1, [2, 3]), [3])

    # Model predictions on original and augmented inputs
    outputs = model(inputs, frequency)
    outputs_rot90 = torch.rot90(
        model(inputs_rot90, frequency), -1, [2, 3]
    )  # Rotate back
    outputs_rot180 = torch.rot90(
        model(inputs_rot180, frequency), -2, [2, 3]
    )  # Rotate back
    outputs_rot270 = torch.rot90(
        model(inputs_rot270, frequency), -3, [2, 3]
    )  # Rotate back
    outputs_vflip = torch.flip(
        model(inputs_vflip, frequency), [2]
    )  # Flip back vertically
    outputs_hflip = torch.flip(
        model(inputs_hflip, frequency), [3]
    )  # Flip back horizontally

    outputs_rot90_vflip = torch.rot90(
        torch.flip(model(inputs_rot90_vflip, frequency), [2]), -1, [2, 3]
    )  # Rotate back and flip
    outputs_rot90_hflip = torch.rot90(
        torch.flip(model(inputs_rot90_hflip, frequency), [3]), -1, [2, 3]
    )  # Rotate back and flip

    # Combine outputs (e.g., using mean or median) for a final prediction
    final_output = (
        outputs
        + outputs_rot90
        + outputs_rot180
        + outputs_rot270
        + outputs_vflip
        + outputs_rot90_vflip
        + outputs_hflip
        + outputs_rot90_hflip
    ) / 8
    return final_output, targets


@torch.no_grad()
def run_post_process_forward_stage2(model, sample, device="cuda"):
    inputs = sample["input_feat"].unsqueeze(0).to(device)
    targets = sample["target_feat"].unsqueeze(0).to(device)

    # Apply augmentations
    inputs_rot90 = torch.rot90(inputs, 1, [2, 3])
    inputs_rot270 = torch.rot90(inputs, 3, [2, 3])

    inputs_vflip = torch.flip(inputs, [2])
    inputs_hflip = torch.flip(inputs, [3])

    inputs_rot90_vflip = torch.flip(torch.rot90(inputs, 1, [2, 3]), [2])
    inputs_rot90_hflip = torch.flip(torch.rot90(inputs, 1, [2, 3]), [3])
    inputs_rot270_vflip = torch.flip(torch.rot90(inputs, 3, [2, 3]), [2])

    # Model predictions on original and augmented inputs
    outputs = model(inputs)
    outputs_rot90 = torch.rot90(model(inputs_rot90), -1, [2, 3])  # Rotate back
    outputs_rot270 = torch.rot90(model(inputs_rot270), -3, [2, 3])  # Rotate back

    outputs_vflip = torch.flip(model(inputs_vflip), [2])  # Flip back vertically
    outputs_hflip = torch.flip(model(inputs_hflip), [3])  # Flip back horizontally

    outputs_rot90_vflip = torch.rot90(
        torch.flip(model(inputs_rot90_vflip), [2]), -1, [2, 3]
    )  # Rotate back and flip
    outputs_rot90_hflip = torch.rot90(
        torch.flip(model(inputs_rot90_hflip), [3]), -1, [2, 3]
    )  # Rotate back and flip

    outputs_rot270_vflip = torch.rot90(
        torch.flip(model(inputs_rot270_vflip), [2]), -3, [2, 3]
    )  # Rotate back and flip
    # Combine outputs (e.g., using mean or median) for a final prediction
    final_output = (
        outputs
        + outputs_rot90
        + outputs_rot270
        + outputs_vflip
        + outputs_hflip
        + outputs_rot90_vflip
        + outputs_rot90_hflip
        + outputs_rot270_vflip
    ) / 8
    return final_output, targets


@torch.no_grad()
def save_stage1_outputs(model, dataset, save_dir, stage_type="stage1", device="cuda"):
    model.eval()
    for sample in tqdm(dataset, desc="Saving stage1 outputs"):
        if stage_type == "stage1":
            output, target = run_post_process_forward_stage1(model, sample, device)
        else:
            output, target = run_post_process_forward_stage2(model, sample, device)

        original_output = Image.open(sample["original_input_path"]).convert("L")
        original_output = torch.tensor(np.array(original_output)).to(device).float()

        resize_output = (
            F.interpolate(
                output,
                size=original_output.shape,
                mode="bilinear",
                align_corners=False,
                antialias=True,
            )
            .squeeze()
            .cpu()
            .numpy()
        ) * 160.0

        save_path = os.path.join(
            save_dir,
            os.path.basename(sample["original_input_path"]).split(".")[0] + ".npy",
        )
        np.save(save_path, resize_output)


@torch.no_grad()
def evaluate_post_process(eval_model, dataset, stage_type="stage1", device="cuda"):
    eval_model.eval()
    total_mse = 0
    total_resize_mse = 0
    total_rmse = 0
    calibration_scale = 160.0 if stage_type == "stage1" else 180.0
    with torch.no_grad():
        for sample in tqdm(dataset, desc="Evaluation"):

            if stage_type == "stage1":
                output, target = run_post_process_forward_stage1(
                    eval_model, sample, device
                )
            else:
                output, target = run_post_process_forward_stage2(
                    eval_model, sample, device
                )
            rmse = (output - target).pow(2).mean().sqrt().item()
            total_rmse += rmse

            mse = (
                (output * calibration_scale - target * calibration_scale) ** 2
            ).mean()
            total_mse += mse.item()

            original_output = Image.open(sample["original_output_path"]).convert("L")
            original_output = torch.tensor(np.array(original_output)).to(device).float()

            resize_output = F.interpolate(
                output,
                size=original_output.shape,
                mode="bilinear",
                align_corners=False,
                antialias=True,
            ).squeeze()

            resize_mse = (
                (resize_output * calibration_scale - original_output) ** 2
            ).mean()
            total_resize_mse += resize_mse.item()
    return (
        total_resize_mse / len(dataset),
        total_mse / len(dataset),
        total_rmse / len(dataset),
    )

In [None]:
import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_stage1 = IPRM_Net(in_chans=12, encoder='transnext_base', out_chans=1, drop_rate=0.3)

In [None]:
RESIZE_SIZE = 384

model_stage2 = PathFormer(
    in_chans=11, encoder="transnext_base", img_size=RESIZE_SIZE, features_dim=256
).to(device)

In [None]:
task_id = 1
test_building_ids=["B5", "B23"]
resize_size=RESIZE_SIZE
batch_size=4


from loguru import logger
import os
from torch.utils.data import DataLoader


full_file_names = os.listdir(
    f"ICASSP2025_Dataset/Inputs/Task_{task_id}_ICASSP"
)

full_file_names = sorted(full_file_names)
logger.info(f"Number of files: {len(full_file_names)}")

dataset_dir = "ICASSP2025_Dataset"

def splite_train_test(full_file_names, test_building_ids):
    test_file_names = [
        file_name
        for file_name in full_file_names
        if file_name.split("_")[0] in test_building_ids
    ]
    train_file_names = [
        file_name
        for file_name in full_file_names
        if file_name.split("_")[0] not in test_building_ids
    ]
    return train_file_names, test_file_names

train_file_names, test_file_names = splite_train_test(
    full_file_names, test_building_ids
)

train_dataset = RadioMapDatasetStage2(
    file_names=full_file_names,
    dataset_dir=dataset_dir,
    resize_size=resize_size,
)

test_dataset = RadioMapDatasetStage2(
    file_names=test_file_names,
    dataset_dir=dataset_dir,
    resize_size=resize_size,
    with_augmentation=False,
)

train_loader = DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=4,
)
test_loader = DataLoader(test_dataset, batch_size=4, shuffle=False, num_workers=4)

In [None]:
from torch.optim.lr_scheduler import LambdaLR, CosineAnnealingWarmRestarts
from torch import optim

paramwise_cfg = {
    "query_embedding": {"decay_mult": 0.0},
    "relative_pos_bias_local": {"decay_mult": 0.0},
    "cpb": {"decay_mult": 0.0},
    "temperature": {"decay_mult": 0.0},
    "norm": {"decay_mult": 0.0},
}

no_decay_keys = paramwise_cfg.keys()
optimizer_grouped_parameters = [
    {
        "params": [
            param
            for name, param in model_stage2.named_parameters()
            if any(key in name for key in no_decay_keys)
        ],
        "weight_decay": 0.0,
    },
    {
        "params": [
            param
            for name, param in model_stage2.named_parameters()
            if not any(key in name for key in no_decay_keys)
        ],
        "weight_decay": 0.05,
    },
]

class WarmUpCosineAnnealingLR:
    def __init__(self, optimizer, warmup_iters, warmup_ratio, T_0, T_mult, eta_min):
        self.optimizer = optimizer
        self.warmup_iters = warmup_iters
        self.warmup_ratio = warmup_ratio
        self.cosine_scheduler = CosineAnnealingWarmRestarts(
            optimizer, T_0=T_0, T_mult=T_mult, eta_min=eta_min
        )
        self.step_count = 0

    def step(self):
        if self.step_count < self.warmup_iters:
            # Warm-up phase: Linearly increase learning rate
            lr = initial_lr * (
                self.warmup_ratio
                + (1 - self.warmup_ratio) * (self.step_count / self.warmup_iters)
            )
            for param_group in self.optimizer.param_groups:
                param_group["lr"] = lr
        else:
            # Cosine annealing phase
            self.cosine_scheduler.step(self.step_count - self.warmup_iters)
        self.step_count += 1


class WarmUpMultiStepLR:
    def __init__(self, optimizer, milestones, gamma, warmup_iters, warmup_ratio):
        self.optimizer = optimizer
        self.lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
            optimizer, milestones=milestones, gamma=gamma
        )
        self.warmup_iters = warmup_iters
        self.warmup_ratio = warmup_ratio
        self.step_count = 0

    def step(self):
        if self.step_count < self.warmup_iters:
            # Warm-up phase: Linearly increase learning rate
            lr = initial_lr * (
                self.warmup_ratio
                + (1 - self.warmup_ratio) * (self.step_count / self.warmup_iters)
            )
            for param_group in self.optimizer.param_groups:
                param_group["lr"] = lr
        else:
            # MultiStepLR scheduler
            self.lr_scheduler.step()
        self.step_count += 1

# Warm-up configuration
warmup_iters = 0 # 1500
warmup_ratio = 1e-6
initial_lr = 1e-4
num_epochs = 30  # Replace with the actual number of epochs
steps_per_epoch = len(train_loader)  # Replace with the actual number of steps per epoch
total_steps = num_epochs * steps_per_epoch

optimizer = optim.Adam(
    model_stage2.parameters(),
    lr=initial_lr,
)


scheduler = WarmUpMultiStepLR(
    optimizer=optimizer,
    milestones=[int(0.5 * total_steps), int(0.75 * total_steps)],
    gamma=0.5,
    warmup_iters=warmup_iters,
    warmup_ratio=warmup_ratio,
)


criterion = nn.MSELoss()

In [None]:
save_dir = 'training_output'
os.makedirs(save_dir, exist_ok=True)

train(
    model=model_stage2,
    train_loader=train_loader,
    test_loader=test_loader,
    device=device,
    optimizer=optimizer,
    scheduler=scheduler,
    criterion=criterion,
    num_epochs=num_epochs,
    save_dir=save_dir,
)

In [None]:
eval_resize_mse, eval_mse, eval_rmse = evaluate(model_stage2, test_loader, device)
print(
    f"eval resize mse: {eval_resize_mse:.4f}, eval mse: {eval_mse:.4f}, eval rmse: {eval_rmse:.4f}"
)

## Evaluation

In [10]:
from tqdm import tqdm
from PIL import Image
import numpy as np
import pandas as pd


def get_eval_file_names(ant_ids, freq, Buildings):
    eval_file_names = []
    for Antenna_ID in ant_ids:
        for f_i in freq:
            for b in Buildings:
                for sp in range(0, 25):
                    input_img_name = (
                        "B"
                        + str(b)
                        + "_Ant"
                        + str(Antenna_ID)
                        + "_f"
                        + str(f_i)
                        + "_S"
                        + str(sp)
                        + ".png"
                    )
                    eval_file_names.append(input_img_name)
    return eval_file_names


def get_eval_dataset(eval_dir, task_name, task_id, resize_size):
    Buildings = ["1", "5"]
    if task_id == 1:
        ant_ids = [1]
        freq = [1]
    elif task_id == 2:
        ant_ids = [1]
        freq = [1, 2]
    elif task_id == 3:
        ant_ids = [1, 3]
        freq = [1, 2]

    eval_file_names = get_eval_file_names(ant_ids, freq, Buildings)

    eval_dataset_stage1 = RadioMapDatasetStage1(
        file_names=eval_file_names,
        dataset_dir=eval_dir,
        task_name=task_name,
        with_augmentation=False,
        test_set=True,
        resize_size=resize_size,
    )
    eval_dataset_stage2 = RadioMapDatasetStage2(
        file_names=eval_file_names,
        dataset_dir=eval_dir,
        task_name=task_name,
        with_augmentation=False,
        test_set=True,
        resize_size=resize_size,
    )
    return eval_dataset_stage1, eval_dataset_stage2


def get_final_eval_dataset(eval_dir, task_name, task_id, resize_size):
    input_file_dir = os.path.join(eval_dir, "Inputs", task_name)

    eval_file_names = os.listdir(input_file_dir)

    eval_dataset_stage1 = RadioMapDatasetStage1(
        file_names=eval_file_names,
        dataset_dir=eval_dir,
        task_name=task_name,
        with_augmentation=False,
        test_set=True,
        resize_size=resize_size,
    )
    eval_dataset_stage2 = RadioMapDatasetStage2(
        file_names=eval_file_names,
        dataset_dir=eval_dir,
        task_name=task_name,
        with_augmentation=False,
        test_set=True,
        resize_size=resize_size,
    )
    return eval_dataset_stage1, eval_dataset_stage2

In [11]:
from tqdm import tqdm


def eval_solution(
    eval_model,
    eval_dataset,
    device="cuda",
    CALIBRIATE_FACTOR=180.0,
    with_post_process=True,
):
    eval_model.eval()
    solution_dict = {}
    with torch.no_grad():
        for sample in tqdm(eval_dataset, total=len(eval_dataset), desc="Evaluation"):
            if with_post_process:
                outputs, targets = run_post_process_forward_stage2(eval_model, sample)
            else:
                inputs = sample["input_image"].unsqueeze(0).to(device)
                targets = sample["output_image"].unsqueeze(0).to(device)
                frequency = sample["frequency"].unsqueeze(0).to(device)

                outputs = eval_model(inputs, frequency)

            original_input_path = sample["original_input_path"]
            file_name = original_input_path.split("/")[-1]

            original_output = Image.open(original_input_path).convert("L")
            original_output = torch.tensor(np.array(original_output)).to(device).float()

            resize_output = F.interpolate(
                outputs,
                size=original_output.shape,
                mode="bilinear",
                align_corners=False,
                antialias=True,
            ).squeeze()

            resize_output = (resize_output * CALIBRIATE_FACTOR).cpu().numpy()
            solution_dict[file_name] = resize_output
    return solution_dict


def save_solution(solution_dict, eval_dataset, solution_save_path):

    solution_pd = pd.DataFrame()

    for sample_index in tqdm(range(len(eval_dataset))):
        file_name = eval_dataset.file_names[sample_index]
        output = solution_dict[file_name]
        output = output.flatten()

        output = np.expand_dims(output.flatten(), 1)

        base_name = file_name.split(".")[0] + "_"
        indices = np.linspace(0, output.size - 1, output.size).astype(int).astype(str)
        y_names = np.expand_dims([base_name + idx for idx in indices], 1)

        y_data = np.concatenate((y_names, output), axis=1)

        y_pd = pd.DataFrame(
            data=y_data,
            index=np.linspace(0, output.size - 1, output.size).astype(int),
            columns=["ID", "PL (dB)"],
        )
        solution_pd = pd.concat([solution_pd, y_pd], ignore_index=True)

    solution_pd.to_csv(solution_save_path, index=False)
    print(f"Solution saved to {solution_save_path}")

### Kaggle submission

In [48]:
eval_dataset_stage1, eval_dataset_stage2 = get_eval_dataset(
    eval_dir="iprm-task/Evaluation_Data_T3",
    task_name='Task_1',
    task_id=1,
    resize_size=RESIZE_SIZE,
)
save_stage1_outputs(model_stage1, eval_dataset_stage1, 'iprm-task/Evaluation_Data_T3/Inputs/stage1_task1_output')

Saving stage1 outputs: 100%|██████████| 50/50 [00:25<00:00,  1.95it/s]


In [44]:
eval_dataset_stage1, eval_dataset_stage2 = get_eval_dataset(
    eval_dir="iprm-task/Evaluation_Data_T3",
    task_name='Task_3',
    task_id=3,
    resize_size=RESIZE_SIZE,
)
save_stage1_outputs(model_stage1, eval_dataset_stage1, 'iprm-task/Evaluation_Data_T3/Inputs/stage1_task3_output')

Saving stage1 outputs: 100%|██████████| 200/200 [02:27<00:00,  1.36it/s]


In [None]:
eval_dataset_stage1, eval_dataset_stage2 = get_eval_dataset(
    eval_dir="iprm-task/Evaluation_Data_T3",
    task_name='Task_3',
    task_id=1,
    resize_size=RESIZE_SIZE,
)

solution_dict = eval_solution(model_stage2, eval_dataset_stage2, with_post_process=True)

save_solution(solution_dict, eval_dataset_stage2, "kaggle_results/solution_task1.csv")