In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

# Channel Attention Module
class ChannelAttention(nn.Module):
    def __init__(self, channels, reduction_ratio=16):
        super(ChannelAttention, self).__init__()
        reduced_channels = max(channels // reduction_ratio, 1)
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Sequential(
            nn.Linear(channels, reduced_channels, bias=False),
            nn.ReLU(inplace=True),
            nn.Linear(reduced_channels, channels, bias=False),
        )
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        try:
            b, c, h, w = x.size()
            y = self.avg_pool(x).view(b, c)
            y = self.fc(y).view(b, c, 1, 1)
            y = self.sigmoid(y)
            return x * y
        except Exception as e:
            raise ValueError(f"Error in ChannelAttention forward pass: {e}")

# Window Attention Module
class WindowAttention(nn.Module):
    def __init__(self, dim, num_heads, window_size):
        super(WindowAttention, self).__init__()
        self.num_heads = num_heads
        self.dim = dim
        self.window_size = window_size
        self.scale = (dim // num_heads) ** -0.5

        self.qkv = nn.Linear(dim, dim * 3, bias=False)
        self.proj = nn.Linear(dim, dim)

        self.relative_position_bias_table = nn.Parameter(
            torch.zeros((2 * window_size - 1) * (2 * window_size - 1), num_heads)
        )
        nn.init.trunc_normal_(self.relative_position_bias_table, std=0.02)

        # Relative position index for each token inside the window
        coords_h = torch.arange(self.window_size)
        coords_w = torch.arange(self.window_size)
        coords = torch.stack(torch.meshgrid(coords_h, coords_w))  # 2, Wh, Ww
        coords_flatten = coords.reshape(2, -1)  # 2, Wh*Ww
        relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # 2, N, N
        relative_coords = relative_coords.permute(1, 2, 0).contiguous()  # N, N, 2
        relative_coords[:, :, 0] += self.window_size - 1  # Shift to start from 0
        relative_coords[:, :, 1] += self.window_size - 1
        relative_coords[:, :, 0] *= 2 * self.window_size - 1
        relative_position_index = relative_coords.sum(-1)  # N, N
        self.register_buffer("relative_position_index", relative_position_index)

    def forward(self, x):
        try:
            B_, N, C = x.shape  # x is of shape (B*num_windows, N, C)
            qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads)
            qkv = (
                qkv.permute(2, 0, 3, 1, 4).contiguous()
            )  # (3, B_, num_heads, N, head_dim)
            q, k, v = qkv.unbind(0)  # Each has shape (B_, num_heads, N, head_dim)

            q = q * self.scale
            attn = (q @ k.transpose(-2, -1))  # (B_, num_heads, N, N)

            relative_position_bias = self.relative_position_bias_table[
                self.relative_position_index.view(-1)
            ].view(
                N, N, -1
            )  # (N, N, num_heads)
            relative_position_bias = (
                relative_position_bias.permute(2, 0, 1).contiguous()
            )  # (num_heads, N, N)
            attn = attn + relative_position_bias.unsqueeze(0)
            attn = attn.softmax(dim=-1)

            x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
            x = self.proj(x)
            return x
        except Exception as e:
            raise ValueError(f"Error in WindowAttention forward pass: {e}")

# Hybrid Attention Block (HAB)
class HAB(nn.Module):
    def __init__(self, channels, window_size, num_heads):
        super(HAB, self).__init__()
        self.window_size = window_size
        self.norm1 = nn.LayerNorm(channels)
        self.channel_attention = ChannelAttention(channels)
        self.norm2 = nn.LayerNorm(channels)
        self.window_attention = WindowAttention(channels, num_heads, window_size)

    def forward(self, x):
        try:
            residual = x  # x: (B, C, H, W)

            # LayerNorm and Channel Attention
            x = x.permute(0, 2, 3, 1).contiguous()  # (B, H, W, C)
            x = self.norm1(x)
            x = x.permute(0, 3, 1, 2).contiguous()  # (B, C, H, W)
            x = self.channel_attention(x)
            x = x + residual  # (B, C, H, W)

            residual = x

            # LayerNorm and Window Attention
            x = x.permute(0, 2, 3, 1).contiguous()  # (B, H, W, C)
            x = self.norm2(x)
            B, H, W, C = x.shape

            # Pad H and W if needed
            pad_h = (self.window_size - H % self.window_size) % self.window_size
            pad_w = (self.window_size - W % self.window_size) % self.window_size
            if pad_h > 0 or pad_w > 0:
                x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h))
            Hp = H + pad_h
            Wp = W + pad_w

            # Partition windows
            x = x.view(
                B,
                Hp // self.window_size,
                self.window_size,
                Wp // self.window_size,
                self.window_size,
                C,
            )
            x = x.permute(0, 1, 3, 2, 4, 5).contiguous()
            x = x.view(-1, self.window_size * self.window_size, C)  # (num_windows*B, N, C)

            # Window Attention
            x = self.window_attention(x)

            # Merge windows
            x = x.view(
                B,
                Hp // self.window_size,
                Wp // self.window_size,
                self.window_size,
                self.window_size,
                C,
            )
            x = x.permute(0, 1, 3, 2, 4, 5).contiguous()
            x = x.view(B, Hp, Wp, C)

            # Remove padding
            if pad_h > 0 or pad_w > 0:
                x = x[:, :H, :W, :].contiguous()

            x = x.permute(0, 3, 1, 2).contiguous()  # (B, C, H, W)
            x = x + residual  # (B, C, H, W)
            return x
        except Exception as e:
            raise ValueError(f"Error in HAB forward pass: {e}")

# Residual Hybrid Attention Group (RHAG)
class RHAG(nn.Module):
    def __init__(self, channels, num_habs, window_size, num_heads):
        super(RHAG, self).__init__()
        self.habs = nn.ModuleList(
            [HAB(channels, window_size, num_heads) for _ in range(num_habs)]
        )
        self.conv = nn.Conv2d(
            channels, channels, kernel_size=3, stride=1, padding=1
        )

    def forward(self, x):
        try:
            residual = x  # x: (B, C, H, W)

            for hab in self.habs:
                x = hab(x)  # x remains in shape (B, C, H, W)

            x = self.conv(x)  # x: (B, C, H, W)
            x = x + residual  # (B, C, H, W)
            return x
        except Exception as e:
            raise ValueError(f"Error in RHAG forward pass: {e}")

# Full HAT Network Without ConvLSTM
class HAT(nn.Module):
    def __init__(
        self,
        in_channels,
        out_channels=1,
        channels=64,
        num_groups=4,
        num_habs=6,
        window_size=8,
        num_heads=8,
        upscale_factor=4,
        device=None,
    ):
        super(HAT, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.channels = channels
        self.upscale_factor = upscale_factor
        self.device = device or torch.device('cpu')

        self.entry = nn.Conv2d(
            in_channels, channels, kernel_size=3, stride=1, padding=1
        ).to(self.device)
        self.groups = nn.ModuleList(
            [
                RHAG(channels, num_habs, window_size, num_heads)
                for _ in range(num_groups)
            ]
        ).to(self.device)
        self.conv_after_body = nn.Conv2d(
            channels, channels, kernel_size=3, stride=1, padding=1
        ).to(self.device)
        self.upsample = self._make_upsample_layer()
        self.exit = nn.Conv2d(
            channels, out_channels, kernel_size=3, stride=1, padding=1
        ).to(self.device)

        # Adjust residual channels if in_channels != out_channels
        if in_channels != out_channels:
            self.residual_conv = nn.Conv2d(
                in_channels, out_channels, kernel_size=1, stride=1, padding=0
            ).to(self.device)
        else:
            self.residual_conv = nn.Identity()

    def _make_upsample_layer(self):
        layers = []
        num_upsamples = int(self.upscale_factor / 2)
        for _ in range(num_upsamples):
            layers += [
                nn.Conv2d(
                    self.channels, self.channels * 4, kernel_size=3, stride=1, padding=1
                ),
                nn.PixelShuffle(2),  # Upsample x2
            ]
        return nn.Sequential(*layers).to(self.device)

    def forward(self, x):
        try:
            # x shape: (B, C, H, W)
            x = x.to(self.device)
            B, C, H, W = x.shape

            # Initial residual connection, upscaled input
            residual = F.interpolate(
                x, scale_factor=self.upscale_factor, mode='bilinear', align_corners=False
            )
            residual = self.residual_conv(residual)  # Adjust residual channels

            x = self.entry(x)  # (B, channels, H, W)
            res = x.clone()

            for group in self.groups:
                x = group(x)  # x: (B, channels, H, W)

            x = self.conv_after_body(x)
            x = x + res  # Residual connection after body

            x = self.upsample(x)  # Upsample x

            x = self.exit(x)  # (B, out_channels, H*upscale_factor, W*upscale_factor)

            # Ensure shapes match before adding residual
            assert x.shape == residual.shape, (
                f"Shape mismatch: x.shape={x.shape}, residual.shape={residual.shape}"
            )

            x = x + residual  # Add residual

            return x
        except Exception as e:
            raise ValueError(f"Error in HAT forward pass: {e}")
# Example usage
if __name__ == "__main__":
    # Set device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # Data preparation
    batch_size = 2  # Number of samples
    channels = 46    # Number of input channels
    height = 168    # Height
    width = 64      # Width

    # Simulate data
    data = torch.randn(batch_size, channels, height, width).to(device)  # Shape: (B, C, H, W)

    # Initialize the model
    model = HAT(
        in_channels=channels,
        out_channels=1,   # Configured number of output channels
        channels=64,
        num_groups=4,
        num_habs=6,
        window_size=8,
        num_heads=8,
        upscale_factor=4,
        device=device,
    ).to(device)

    # Run the model
    output = model(data)

    print("Input shape:", data.shape)
    print("Output shape:", output.shape)  # Should be (B, out_channels, H*4, W*4)

Input shape: torch.Size([2, 46, 168, 64])
Output shape: torch.Size([2, 1, 672, 256])


  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
