<a href="https://colab.research.google.com/github/ChoiDae1/2022_KCCV_Programming/blob/main/Few%20Shot%20Segmentation/VAT_Code_Explanation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Code Explanation of Cost Aggregation Is All You Need for Few-Shot Segmentation

[[Project Page](https://seokju-cho.github.io/VAT/)] [[arXiv](https://arxiv.org/abs/2112.11685)]

![Qualitative](https://seokju-cho.github.io/VAT/resources/figure1.png)



<!-- 
Before executing the codes, please make sure to follow the [Link](https://drive.google.com/drive/folders/1lB-VhUXMXoTnaCgRiqvRSBIWcXD32pYd?usp=sharing) and add shortcut to your drive
![](https://github.com/Seokju-Cho/CATs-Demo/blob/main/drive_vat.png?raw=true) -->


In [None]:
!git clone https://github.com/Seokju-Cho/Volumetric-Aggregation-Transformer.git
!mv Volumetric-Aggregation-Transformer/* .
!rm -rf Volumetric-Aggregation-Transformer 
!pip install einops timm

Cloning into 'Volumetric-Aggregation-Transformer'...
remote: Enumerating objects: 194, done.[K
remote: Counting objects: 100% (176/176), done.[K
remote: Compressing objects: 100% (133/133), done.[K
remote: Total 194 (delta 54), reused 129 (delta 26), pack-reused 18[K
Receiving objects: 100% (194/194), 14.05 MiB | 19.52 MiB/s, done.
Resolving deltas: 100% (57/57), done.
mv: cannot move 'Volumetric-Aggregation-Transformer/common' to './common': Directory not empty
mv: cannot move 'Volumetric-Aggregation-Transformer/config' to './config': Directory not empty
mv: cannot move 'Volumetric-Aggregation-Transformer/data' to './data': Directory not empty
mv: cannot move 'Volumetric-Aggregation-Transformer/images' to './images': Directory not empty
mv: cannot move 'Volumetric-Aggregation-Transformer/model' to './model': Directory not empty
Collecting timm
  Downloading timm-0.5.4-py3-none-any.whl (431 kB)
[K     |████████████████████████████████| 431 kB 8.0 MB/s 
Installing collected package

### Introducing Einops
Einops: Clear and Reliable Tensor Manipulations with Einstein-like Notation, ICLR'22 Oral

- rearrange
- repeat

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange


class MaxPool4d(nn.Module):
    def __init__(self, kernel_size, stride, padding, dim='support'):
        super().__init__()
        self.pool = nn.MaxPool2d(kernel_size, stride, padding, ceil_mode=True)
        self.dim = dim
        
    def forward(self, x):
        """
        x: Hyper correlation.
            shape: B L H_q W_q H_s W_s
        """
        B, L, H_q, W_q, H_s, W_s = x.size()
        
        if self.dim == 'support':
            x = rearrange(x, 'B L H_q W_q H_s W_s -> (B H_q W_q) L H_s W_s')
            x = self.pool(x)
            x = rearrange(x, '(B H_q W_q) L H_s W_s -> B L H_q W_q H_s W_s', H_q=H_q, W_q=W_q)
        elif self.dim == 'query':
            x = rearrange(x, 'B L H_q W_q H_s W_s -> (B H_s W_s) L H_q W_q')
            x = self.pool(x)
            x = rearrange(x, '(B H_s W_s) L H_q W_q -> B L H_q W_q H_s W_s', H_s=H_s, W_s=W_s)
        else:
            raise NotImplemented(f'Invalid dimension {self.dim}. dim should be "support" or "query"')
        return x
    

class Interpolate4d(nn.Module):
    def __init__(self, size, dim='support'):
        super().__init__()
        self.size = size
        self.dim = dim
        
    def forward(self, x):
        """
        x: Hyper correlation.
        """
        B, L, H_q, W_q, H_s, W_s = x.size()
        if self.dim == 'support':
            x = rearrange(x, 'B L H_q W_q H_s W_s -> (B H_q W_q) L H_s W_s')
            x = F.interpolate(x, size=self.size, mode='bilinear', align_corners=True)
            x = rearrange(x, '(B H_q W_q) L H_s W_s -> B L H_q W_q H_s W_s', H_q=H_q, W_q=W_q)
        elif self.dim == 'query':
            x = rearrange(x, 'B L H_q W_q H_s W_s -> (B H_s W_s) L H_q W_q')
            x = F.interpolate(x, size=self.size, mode='bilinear', align_corners=True)
            x = rearrange(x, '(B H_s W_s) L H_q W_q -> B L H_q W_q H_s W_s', H_s=H_s, W_s=W_s)
        else:
            raise NotImplemented(f'Invalid dimension {self.dim}. dim should be "support" or "query"')
        return x
        

class Conv4d(nn.Module):
    def __init__(self,
            in_channels,
            out_channels,
            kernel_size,
            stride,
            padding,
            bias=True,
            transposed_query=False,
            transposed_supp=False,
            target_size=None,
            output_padding=None
        ):
        super().__init__()
        
        if transposed_query:
            assert output_padding is not None, 'output_padding cannot be None'
            self.query_conv = nn.ConvTranspose2d(in_channels, out_channels, kernel_size[:2], stride=stride[:2],
                bias=bias, padding=padding[:2], output_padding=output_padding[:2])
        else:
            self.query_conv = nn.Conv2d(in_channels, out_channels, kernel_size[:2], stride=stride[:2],
                bias=bias, padding=padding[:2])
            
        if transposed_supp:
            assert output_padding is not None, 'output_padding cannot be None'
            self.supp_conv = nn.ConvTranspose2d(in_channels, out_channels, kernel_size[2:], stride=stride[2:],
                bias=bias, padding=padding[2:], output_padding=output_padding[2:])
        else:
            self.supp_conv = nn.Conv2d(in_channels, out_channels, kernel_size[2:], stride=stride[2:],
                bias=bias, padding=padding[2:])
            
        self.change_supp = stride[-1] > 1 or stride[0] == 1 and kernel_size[0] == 1
        if self.change_supp:
            if transposed_supp:
                assert target_size is not None, 'Invalid size'
                self.pool_supp = Interpolate4d(target_size[-2:], dim='support')
            else:
                self.pool_supp = MaxPool4d(kernel_size=stride[-2:], stride=stride[-2:], padding=(0, 0), dim='support')
        
        self.change_query = stride[0] > 1 or stride[0] == 1 and kernel_size[0] == 1
        if self.change_query:
            if transposed_query:
                assert target_size is not None, 'Invalid size'
                self.pool_query = Interpolate4d(target_size[:2], dim='query')
            else:
                self.pool_query = MaxPool4d(kernel_size=stride[:2], stride=stride[:2], padding=(0, 0), dim='query')
            
        self.kernel_size = kernel_size
        self.stride = stride
        self.padding = padding
        
    def forward(self, x):
        """
        x: Hyper correlation map.
            shape: B L H_q W_q H_s W_s
        """
        B, L, H_q, W_q, H_s, W_s = x.size()
        
        if self.change_supp:
            x_query = self.pool_supp(x)
            H_s, W_s = x_query.shape[-2:]
        else:
            x_query = x.clone()
        
        if self.change_query:
            x_supp = self.pool_query(x)
            H_q, W_q = x_supp.shape[2:4]
        else:
            x_supp = x.clone()
        
        x_query = rearrange(x_query, 'B L H_q W_q H_s W_s -> (B H_s W_s) L H_q W_q')
        x_query = self.query_conv(x_query)
        x_query = rearrange(x_query, '(B H_s W_s) L H_q W_q -> B L H_q W_q H_s W_s', H_s=H_s, W_s=W_s)
        
        x_supp = rearrange(x_supp, 'B L H_q W_q H_s W_s -> (B H_q W_q) L H_s W_s')
        x_supp = self.supp_conv(x_supp)
        x_supp = rearrange(x_supp, '(B H_q W_q) L H_s W_s -> B L H_q W_q H_s W_s', H_q=H_q, W_q=W_q)
        
        return x_query + x_supp


class Encoder4D(nn.Module):
    def __init__(self,
        corr_levels,
        kernel_size,
        stride,
        padding,
        group=(4,),
        residual=True
    ):
        super().__init__()
        self.conv4d = nn.ModuleList([])
        for i, (k, s, p) in enumerate(zip(kernel_size, stride, padding)):
            conv4d = nn.Sequential(
                Conv4d(corr_levels[i], corr_levels[i + 1], k, s, p),
                nn.GroupNorm(group[i], corr_levels[i + 1]),
                nn.ReLU() # No inplace for residual
            )
            self.conv4d.append(conv4d)
        
        self.residual = residual
        
    def forward(self, x):
        """
        x: Hyper correlation. B L H_q W_q H_s W_s
        """
        residuals = []
        for conv in self.conv4d:
            if self.residual:
                residuals.append(x)
            x = conv(x)
        # Patch embedding for transformer

        return x, residuals

In [None]:
# --------------------------------------------------------
# Swin Transformer
# Copyright (c) 2021 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# Written by Ze Liu
# --------------------------------------------------------
from functools import reduce

import torch
import torch.nn as nn
import torch.utils.checkpoint as checkpoint
from timm.models.layers import DropPath, to_2tuple, trunc_normal_

from einops import rearrange


class Mlp(nn.Module):
    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = act_layer()
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.drop = nn.Dropout(drop)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x


def window_partition_4d(x, window_size):
    """
    Args:
        x: (B, H_q, W_q, H_s, W_s, C)
        window_size (int): window size
    Returns:
        x: (num_windows*B, window_size, window_size, window_size, window_size, C)
    """
    kwargs = {f'W_{i}':window_size for i in range(1, 5)}
    x = rearrange(x, 'B (H_q W_1) (W_q W_2) (H_s W_3) (W_s W_4) C -> (B H_q W_q H_s W_s) W_1 W_2 W_3 W_4 C', **kwargs)
    return x


def window_reverse_4d(windows, window_size, H_q, W_q, H_s, W_s):
    """
    Args:
        windows: (num_windows*B, window_size, window_size, window_size, window_size, C)
        window_size (int): size of window
        H_q (int): Height of query image
        W_q (int): Width of query image
        H_s (int): Height of support image
        W_s (int): Width of support image
    Returns:
        x: (B, H_q, W_q, H_s, W_s, C)
    """
    kwargs = {
        'H_q': H_q // window_size,
        'W_q': W_q // window_size,
        'H_s': H_s // window_size,
        'W_s': W_s // window_size
    }
    x = rearrange(windows, '(B H_q W_q H_s W_s) W_1 W_2 W_3 W_4 C -> B (H_q W_1) (W_q W_2) (H_s W_3) (W_s W_4) C', **kwargs)
    return x


class WindowAttention(nn.Module):
    r""" Window based multi-head self attention (W-MSA) module with relative position bias.
    It supports both of shifted and non-shifted window.
    Args:
        dim (int): Number of input channels.
        window_size (tuple[int]): The height and width of the window.
        num_heads (int): Number of attention heads.
        qkv_bias (bool, optional):  If True, add a learnable bias to query, key, value. Default: True
        qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
        attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
        proj_drop (float, optional): Dropout ratio of output. Default: 0.0
    """

    def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):

        super().__init__()
        self.dim = dim
        self.window_size = window_size  # Wh, Ww
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = qk_scale or head_dim ** -0.5

        # define a parameter table of relative position bias
        self.relative_position_bias_table = nn.Parameter(
            torch.zeros(
                (2 * window_size[0] - 1) * (2 * window_size[0] - 1) * (2 * window_size[0] - 1) * (2 * window_size[0] - 1), num_heads
            )
        ) # 2*H_q-1 * 2*W_q-1 * 2*H_s * 2*W_q, nH

        relative_position_index = self.relative_bias_4d(window_size)
        self.register_buffer("relative_position_index", relative_position_index)

        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

        trunc_normal_(self.relative_position_bias_table, std=.02)
        self.softmax = nn.Softmax(dim=-1)

    def relative_bias_4d(self, window_size=(5, 5, 5, 5)):
        assert len(window_size) == 4, f'Invalid window size {window_size}'

        coords_H_q = torch.arange(window_size[0])
        coords_W_q = torch.arange(window_size[1])
        coords_H_s = torch.arange(window_size[2])
        coords_W_s = torch.arange(window_size[3])
        coords = torch.stack(torch.meshgrid((coords_H_q, coords_W_q, coords_H_s, coords_W_s))) # 4, H_q, W_q, H_s, W_s

        coords_flatten = rearrange(coords, 'd H_q W_q H_s W_s -> d (H_q W_q H_s W_s)')
        relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # 2, H_q*W_q*H_s*W_s, H_q*W_q*H_s*W_s
        relative_coords = rearrange(relative_coords, 'd q k -> q k d')

        relative_coords[..., 0] += window_size[0] - 1
        relative_coords[..., 1] += window_size[1] - 1
        relative_coords[..., 2] += window_size[2] - 1
        relative_coords[..., 3] += window_size[3] - 1

        relative_coords[..., 0] *= (2 * window_size[0] - 1) ** 3
        relative_coords[..., 1] *= (2 * window_size[1] - 1) ** 2
        relative_coords[..., 2] *= 2 * window_size[2] - 1

        relative_position_index = relative_coords.sum(-1)

        return relative_position_index

    def forward(self, x, mask=None):
        """
        Args:
            x: input features with shape of (num_windows*B, N, C)
            mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
        """
        B_, N, C = x.shape
        qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]  # make torchscript happy (cannot use tensor as tuple)

        q = q * self.scale
        attn = (q @ k.transpose(-2, -1))

        relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
            reduce(lambda x, y: x * y, self.window_size), reduce(lambda x, y: x * y, self.window_size), -1)  # H_q*W_q*H_s*W_s,H_q*W_q*H_s*W_s,nH
        relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()  # nH, H_q*W_q*H_s*W_s, H_q*W_q*H_s*W_s
        attn = attn + relative_position_bias.unsqueeze(0)

        if mask is not None:
            nW = mask.shape[0]
            attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
            attn = attn.view(-1, self.num_heads, N, N)
            attn = self.softmax(attn)
        else:
            attn = self.softmax(attn)

        attn = self.attn_drop(attn)

        x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x

    def extra_repr(self) -> str:
        return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}'


class SwinTransformerBlock(nn.Module):
    r""" Swin Transformer Block.
    Args:
        dim (int): Number of input channels.
        input_resolution (tuple[int]): Input resulotion.
        num_heads (int): Number of attention heads.
        window_size (int): Window size.
        shift_size (int): Shift size for SW-MSA.
        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
        qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
        qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
        drop (float, optional): Dropout rate. Default: 0.0
        attn_drop (float, optional): Attention dropout rate. Default: 0.0
        drop_path (float, optional): Stochastic depth rate. Default: 0.0
        act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
        norm_layer (nn.Module, optional): Normalization layer.  Default: nn.LayerNorm
    """

    def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0,
                 mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
                 act_layer=nn.GELU, norm_layer=nn.LayerNorm):
        super().__init__()
        self.dim = dim
        self.input_resolution = input_resolution
        self.num_heads = num_heads
        self.window_size = window_size
        self.shift_size = shift_size
        self.mlp_ratio = mlp_ratio
        if min(self.input_resolution) <= self.window_size:
            # if window size is larger than input resolution, we don't partition windows
            self.shift_size = 0
            self.window_size = min(self.input_resolution)
        assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"

        self.norm1 = norm_layer(dim)
        self.attn = WindowAttention(
            dim, window_size=(self.window_size,)*4, num_heads=num_heads,
            qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)

        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)

        if self.shift_size > 0:
            # calculate attention mask for SW-MSA
            H_q, W_q, H_s, W_s = self.input_resolution
            img_mask = torch.zeros((1, H_q, W_q, H_s, W_s, 1))  # 1 H W 1
            slices = (slice(0, -self.window_size),
                        slice(-self.window_size, -self.shift_size),
                        slice(-self.shift_size, None))
            cnt = 0
            for h_q in slices:
                for w_q in slices:
                    for h_s in slices:
                        for w_s in slices:
                            img_mask[:, h_q, w_q, h_s, w_s, :] = cnt
                            cnt += 1

            mask_windows = window_partition_4d(img_mask, self.window_size)
            # nW, window_size, window_size, window_size, window_size, 1
            mask_windows = mask_windows.view(-1, self.window_size ** 4)
            attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
            attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
        else:
            attn_mask = None

        self.register_buffer("attn_mask", attn_mask)

    def forward(self, x):
        H_q, W_q, H_s, W_s = self.input_resolution
        B, L, C = x.shape
        assert L == H_q * W_q * H_s * W_s, "input feature has wrong size"

        shortcut = x
        x = self.norm1(x)
        x = rearrange(x, 'B (H_q W_q H_s W_s) C -> B H_q W_q H_s W_s C', H_q=H_q, W_q=W_q, H_s=H_s, W_s=W_s)

        # cyclic shift
        if self.shift_size > 0:
            shifted_x = torch.roll(x, shifts=(-self.shift_size,)*4, dims=(1, 2, 3, 4))
        else:
            shifted_x = x

        # partition windows
        x_windows = window_partition_4d(shifted_x, self.window_size)  # nW*B, window_size, window_size, window_size, window_size, C
        x_windows = rearrange(x_windows, 'B W_1 W_2 W_3 W_4 C -> B (W_1 W_2 W_3 W_4) C')

        # W-MSA/SW-MSA
        attn_windows = self.attn(x_windows, mask=self.attn_mask)  # nW*B, window_size*window_size, C

        # merge windows
        kwargs = {f'W_{i}':self.window_size for i in range(1, 5)}
        attn_windows = rearrange(attn_windows, 'B (W_1 W_2 W_3 W_4) C -> B W_1 W_2 W_3 W_4 C', **kwargs)
        shifted_x = window_reverse_4d(attn_windows, self.window_size, H_q=H_q, W_q=W_q, H_s=H_s, W_s=W_s)  # B H' W' C

        # reverse cyclic shift
        if self.shift_size > 0:
            x = torch.roll(shifted_x, shifts=(self.shift_size,)*4, dims=(1, 2, 3, 4))
        else:
            x = shifted_x
        x = rearrange(x, 'B H_q W_q H_s W_s C -> B (H_q W_q H_s W_s) C')

        # FFN
        x = shortcut + self.drop_path(x)
        x = x + self.drop_path(self.mlp(self.norm2(x)))

        return x

    def extra_repr(self) -> str:
        return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \
               f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}"


class BasicLayer(nn.Module):
    """ A basic Swin Transformer layer for one stage.
    Args:
        dim (int): Number of input channels.
        input_resolution (tuple[int]): Input resolution.
        depth (int): Number of blocks.
        num_heads (int): Number of attention heads.
        window_size (int): Local window size.
        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
        qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
        qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
        drop (float, optional): Dropout rate. Default: 0.0
        attn_drop (float, optional): Attention dropout rate. Default: 0.0
        drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
        norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
        downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
        use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
    """

    def __init__(self, dim, input_resolution, depth, num_heads, window_size,
                 mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,
                 drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False):

        super().__init__()
        self.dim = dim
        self.input_resolution = input_resolution
        self.depth = depth
        self.use_checkpoint = use_checkpoint

        # build blocks
        self.blocks = nn.ModuleList([
            SwinTransformerBlock(dim=dim, input_resolution=input_resolution,
                                 num_heads=num_heads, window_size=window_size,
                                 shift_size=0 if (i % 2 == 0) else window_size // 2,
                                 mlp_ratio=mlp_ratio,
                                 qkv_bias=qkv_bias, qk_scale=qk_scale,
                                 drop=drop, attn_drop=attn_drop,
                                 drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
                                 norm_layer=norm_layer)
            for i in range(depth)])

        # patch merging layer
        if downsample is not None:
            self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer)
        else:
            self.downsample = None

    def forward(self, x):
        for blk in self.blocks:
            if self.use_checkpoint:
                x = checkpoint.checkpoint(blk, x)
            else:
                x = blk(x)
        if self.downsample is not None:
            x = self.downsample(x)
        return x

    def extra_repr(self) -> str:
        return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"


class PatchEmbed(nn.Module):
    r""" Image to Patch Embedding
    Args:
        img_size (int): Image size.  Default: 224.
        patch_size (int): Patch token size. Default: 4.
        in_chans (int): Number of input image channels. Default: 3.
        embed_dim (int): Number of linear projection output channels. Default: 96.
        norm_layer (nn.Module, optional): Normalization layer. Default: None
    """

    def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
        super().__init__()
        img_size = to_2tuple(img_size)
        patch_size = to_2tuple(patch_size)
        patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
        self.img_size = img_size
        self.patch_size = patch_size
        self.patches_resolution = patches_resolution
        self.num_patches = patches_resolution[0] * patches_resolution[1]

        self.in_chans = in_chans
        self.embed_dim = embed_dim

        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
        if norm_layer is not None:
            self.norm = norm_layer(embed_dim)
        else:
            self.norm = None

    def forward(self, x):
        B, C, H, W = x.shape
        # FIXME look at relaxing size constraints
        assert H == self.img_size[0] and W == self.img_size[1], \
            f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
        x = self.proj(x).flatten(2).transpose(1, 2)  # B Ph*Pw C
        if self.norm is not None:
            x = self.norm(x)
        return x


class SwinTransformer(nn.Module):
    r""" Swin Transformer
        A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows`  -
          https://arxiv.org/pdf/2103.14030
    Args:
        img_size (int | tuple(int)): Input image size. Default 224
        patch_size (int | tuple(int)): Patch size. Default: 4
        in_chans (int): Number of input image channels. Default: 3
        num_classes (int): Number of classes for classification head. Default: 1000
        embed_dim (int): Patch embedding dimension. Default: 96
        depths (tuple(int)): Depth of each Swin Transformer layer.
        num_heads (tuple(int)): Number of attention heads in different layers.
        window_size (int): Window size. Default: 7
        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
        qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
        qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None
        drop_rate (float): Dropout rate. Default: 0
        attn_drop_rate (float): Attention dropout rate. Default: 0
        drop_path_rate (float): Stochastic depth rate. Default: 0.1
        norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
        ape (bool): If True, add absolute position embedding to the patch embedding. Default: False
        patch_norm (bool): If True, add normalization after patch embedding. Default: True
        use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False
    """

    def __init__(self, corr_size=(16, 16, 16, 16),
                 embed_dim=64, depth=2, num_head=4,
                 window_size=4, mlp_ratio=4., qkv_bias=True, qk_scale=None,
                 drop_rate=0., attn_drop_rate=0., drop_path_rate=0.,
                 norm_layer=nn.LayerNorm,
                 use_checkpoint=False):
        super().__init__()

        self.embed_dim = embed_dim
        self.mlp_ratio = mlp_ratio
        self.pos_drop = nn.Dropout(p=drop_rate)

        # stochastic depth
        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]  # stochastic depth decay rule

        # build layers
        self.layer = BasicLayer(dim=embed_dim,
                            input_resolution=corr_size,
                            depth=depth,
                            num_heads=num_head,
                            window_size=window_size,
                            mlp_ratio=self.mlp_ratio,
                            qkv_bias=qkv_bias, qk_scale=qk_scale,
                            drop=drop_rate, attn_drop=attn_drop_rate,
                            drop_path=dpr,
                            norm_layer=norm_layer,
                            downsample=None,
                            use_checkpoint=use_checkpoint)

        self.norm = norm_layer(self.embed_dim)

        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)

    @torch.jit.ignore
    def no_weight_decay_keywords(self):
        return {'relative_position_bias_table'}

    def forward(self, x):
        x = self.pos_drop(x)
        return self.layer(x)


class TransformerWarper(nn.Module):
    def __init__(self, transformer):
        super().__init__()
        self.transformer = transformer

    def forward(self, x):
        _, _, H_q, W_q, H_s, W_s = x.size()
        x = rearrange(x, 'B L H_q W_q H_s W_s -> B (H_q W_q H_s W_s) L')
        x = self.transformer(x)
        x = rearrange(x, 'B (H_q W_q H_s W_s) L -> B L H_q W_q H_s W_s', H_q=H_q, W_q=W_q, H_s=H_s, W_s=W_s)
        return x

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

from model.base.swin import SwinTransformer2d, TransformerWarper2d
from model.base.our_conv4d import Interpolate4d, Encoder4D
from model.base.swin4d import SwinTransformer, TransformerWarper


class OurModel(nn.Module):
    def __init__(
        self,
        inch=(3, 23, 4),
        feature_affinity=(True, True, True),
    ):
        super().__init__()

        self.encoders = nn.ModuleList([
            Encoder4D( # Encoder for conv_5
                corr_levels=(inch[0], 64, 128),
                kernel_size=(
                    (3, 3, 3, 3),
                    (3, 3, 3, 3),
                ),
                stride=(
                    (2, 2, 1, 1),
                    (1, 1, 2, 2),
                ),
                padding=(
                    (1, 1, 1, 1),
                    (1, 1, 1, 1),
                ),
                group=(4, 8),
                residual=False
            ),
            Encoder4D( # Encoder for conv_4
                corr_levels=(inch[1], 64, 128),
                kernel_size=(
                    (3, 3, 3, 3),
                    (3, 3, 3, 3),
                ),
                stride=(
                    (2, 2, 2, 2),
                    (1, 1, 2, 2),
                ),
                padding=(
                    (1, 1, 1, 1),
                    (1, 1, 1, 1),
                ),
                group=(4, 8),
                residual=False
            ),
            Encoder4D( # Encoder for conv_3
                corr_levels=(inch[2], 32, 64, 128),
                kernel_size=(
                    (3, 3, 3, 3),
                    (3, 3, 3, 3),
                    (3, 3, 3, 3),
                ),
                stride=(
                    (2, 2, 2, 2),
                    (1, 1, 2, 2),
                    (1, 1, 2, 2),
                ),
                padding=(
                    (1, 1, 1, 1),
                    (1, 1, 1, 1),
                    (1, 1, 1, 1),
                ),
                group=(2, 4, 8,),
                residual=False
            ),
        ])

        self.transformer = nn.ModuleList([
            TransformerWarper(SwinTransformer(
                corr_size=(8, 8, 8, 8),
                embed_dim=128,
                depth=4,
                num_head=4,
                window_size=4,
            )),
            TransformerWarper(SwinTransformer(
                corr_size=(16, 16, 8, 8),
                embed_dim=128,
                depth=2,
                num_head=4,
                window_size=4,
            )),
            TransformerWarper(SwinTransformer(
                corr_size=(32, 32, 8, 8),
                embed_dim=128,
                depth=2,
                num_head=4,
                window_size=4,
            )),
        ])

        self.upscale = nn.ModuleList([
            Interpolate4d(size=(16, 16), dim='query'),
            Interpolate4d(size=(32, 32), dim='query'),
            Interpolate4d(size=(64, 64), dim='query'),
        ])

        self.feature_affinity = feature_affinity
        decoder_dim = [
            (128 + 64) if feature_affinity[0] else 128,
            (96 + 32) if feature_affinity[1] else 96,
            (48 + 16) if feature_affinity[2] else 48
        ]

        self.swin_decoder = nn.ModuleList([
            nn.Sequential(
                TransformerWarper2d(
                    SwinTransformer2d(img_size=(32, 32), embed_dim=decoder_dim[0], window_size=8)
                ),
                nn.Conv2d(decoder_dim[0], 96, 1)
            ),
            nn.Sequential(
                TransformerWarper2d(
                    SwinTransformer2d(img_size=(64, 64), embed_dim=decoder_dim[1], window_size=8)
                ),
                nn.Conv2d(decoder_dim[1], 48, 1)
            ),
            nn.Sequential(
                TransformerWarper2d(
                    SwinTransformer2d(img_size=(128, 128), embed_dim=decoder_dim[2], window_size=8)
                ),
            )
        ])
        
        self.decoder = nn.Sequential(
            nn.Conv2d(decoder_dim[2], 32, (3, 3), padding=(1, 1), bias=True),
            nn.ReLU(True),
            nn.Conv2d(32, 2, (3, 3), padding=(1, 1), bias=True)
        )

        self.dropout2d = nn.Dropout2d(p=0.5)

        self.proj_query_feat = nn.ModuleList([
            nn.Sequential(
                nn.Conv2d(1024, 64, 1),
                nn.ReLU(),
            ),
            nn.Sequential(
                nn.Conv2d(512, 32, 1),
                nn.ReLU(),
            ),
            nn.Sequential(
                nn.Conv2d(256, 16, 1),
                nn.ReLU(),
            )
        ])

    
    def extract_last(self, x):
        return [k[:, -1] for k in x]

    def apply_dropout(self, dropout, *feats):
        sizes = [x.shape[-2:] for x in feats]
        max_size = max(sizes)
        resized_feats = [F.interpolate(x, size=max_size, mode='nearest') for x in feats]

        channel_list = [x.size(1) for x in feats]
        feats = dropout(torch.cat(resized_feats, dim=1))
        feats = torch.split(feats, channel_list, dim=1)
        recoverd_feats = [F.interpolate(x, size=size, mode='nearest') for x, size in zip(feats, sizes)]
        return recoverd_feats

    def forward(self, hypercorr_pyramid, query_feats, support_mask):
        _, query_feat4, query_feat3, query_feat2 = self.extract_last(query_feats)
        query_feat4, query_feat3, query_feat2 = [
            self.proj_query_feat[i](x) for i, x in enumerate((query_feat4, query_feat3, query_feat2)) 
        ]

        query_feat4, query_feat3 = self.apply_dropout(self.dropout2d, query_feat4, query_feat3)

        corr5 = self.encoders[0](hypercorr_pyramid[0])[0]
        corr4 = self.encoders[1](hypercorr_pyramid[1])[0]
        corr3 = self.encoders[2](hypercorr_pyramid[2])[0]
        
        corr5 = corr5 + self.transformer[0](corr5)
        corr5_upsampled = self.upscale[0](corr5)

        corr4 += corr5_upsampled
        corr4 = corr4 + self.transformer[1](corr4)
        corr4_upsampled = self.upscale[1](corr4)

        corr3 += corr4_upsampled
        corr3 = corr3 + self.transformer[2](corr3)
        x = corr3.mean(dim=(-2, -1))

        x = self.swin_decoder[0](torch.cat((x, query_feat4), dim=1) if self.feature_affinity[0] else x)
        x = F.interpolate(x, size=(64, 64), mode='bilinear', align_corners=True)
        x = self.swin_decoder[1](torch.cat((x, query_feat3), dim=1) if self.feature_affinity[1] else x)
        x = F.interpolate(x, size=(128, 128), mode='bilinear', align_corners=True)
        x = self.swin_decoder[2](torch.cat((x, query_feat2), dim=1) if self.feature_affinity[2] else x)

        return self.decoder(x)