In [106]:
from typing import List
import torch
from torch import nn
import torch.nn.functional as F
from timm.models.pvt_v2 import Attention


class CrossAttention(nn.Module):
    def __init__(self, module, **kwargs):
        super().__init__(**kwargs)
        for attr_name, attr_value in module.__dict__.items():
            setattr(self, attr_name, attr_value)

    def forward(self, 
                x_query, feat_size_query: List[int],
                x_support, feat_size_support: List[int]):

        # B, N, C = x.shape
        # H, W = feat_size
        if x_query.shape[0] == 1:
            xs = [x_query, x_support]
            feat_sizes = [feat_size_query, feat_size_support]
        elif x_support.shape[0] == 1:
            xs = [x_support, x_query]
            feat_sizes = [feat_size_support, feat_size_query]
        else:
            raise ValueError('Either the query or support tensor should have a batch size of 1')

        qs = [self.q(x).reshape(*x.shape[:2], self.num_heads, -1).permute(0, 2, 1, 3) for x in xs]

        ks, vs = [], []
        for x, feat_size in zip(xs, feat_sizes):
            B, N, C = x.shape
            H, W = feat_size

            if self.pool is not None:
                x = x.permute(0, 2, 1).reshape(B, C, H, W)
                x = self.sr(self.pool(x)).reshape(B, C, -1).permute(0, 2, 1)
                x = self.norm(x)
                x = self.act(x)
            elif self.sr is not None:
                x = x.permute(0, 2, 1).reshape(B, C, H, W)
                x = self.sr(x).reshape(B, C, -1).permute(0, 2, 1)
                x = self.norm(x)
            
            kv = self.kv(x).reshape(B, -1, 2, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)  
            ks.append(kv[0]), vs.append(kv[1])

        k_expanded = ks[0].expand(ks[1].shape[0], -1, -1, -1)
        k_mean = ks[1].mean(dim=0, keepdim=True)

        v_expanded = vs[0].expand(vs[1].shape[0], -1, -1, -1)
        v_mean = vs[1].mean(dim=0, keepdim=True)

        k_cats = [torch.cat((ks[0], k_mean), dim=2), torch.cat((ks[1], k_expanded), dim=2)]
        v_cats = [torch.cat((vs[0], v_mean), dim=2), torch.cat((vs[1], v_expanded), dim=2)]

        outputs = []
        for x, feat_size, q, k_cat, v_cat in zip(xs, feat_sizes, qs, k_cats, v_cats):
            B, N, C = x.shape
            H, W = feat_size
            

            if self.fused_attn:
                x = F.scaled_dot_product_attention(q, k_cat, v_cat, dropout_p=self.attn_drop.p if self.training else 0.)
            else:
                q = q * self.scale
                attn = q @ k_cat.transpose(-2, -1)
                attn = attn.softmax(dim=-1)
                attn = self.attn_drop(attn)
                x = attn @ v_cat

            x = x.transpose(1, 2).reshape(B, N, C)
            x = self.proj(x)
            x = self.proj_drop(x)
            outputs.append(x)
        
        if x_support.shape[0] == 1:
            return reversed(outputs)
        return tuple(outputs)


In [108]:
class CrossPyramidVisionTransformerV2(nn.Module):
    def __init__(self, model):
        super().__init__()

        # Replace Attention layers with CrossAttention layers
        self.replace_attention_layers(model, Attention, CrossAttention)

        self.model = model
        for attr_name, attr_value in model.__dict__.items():
            setattr(self, attr_name, attr_value)
    
    def replace_attention_layers(self, module, target_class, replace_class):
        for name, child in module.named_children():
            if isinstance(child, target_class):
                new_layer = replace_class(child)
                setattr(module, name, new_layer)
            else:
                self.replace_attention_layers(child, target_class, replace_class)

    def forward_features(self, query, support):
        x_query = self.patch_embed(query)
        x_support = self.patch_embed(support)

        for stage in self.stages:
            if stage.downsample is not None:
                x_query = stage.downsample(x_query)
                x_support = stage.downsample(x_support)
            
            B_q, H_q, W_q, C_q = x_query.shape
            feat_size_query = (H_q, W_q)
            x_query = x_query.reshape(B_q, -1, C_q)

            B_s, H_s, W_s, C_s = x_support.shape
            feat_size_support = (H_s, W_s)
            x_support = x_support.reshape(B_s, -1, C_s)

            for block in stage.blocks:
                if isinstance(block.attn, CrossAttention):
                    x_query, x_support = block.attn(block.norm1(x_query), feat_size_query, block.norm1(x_support), feat_size_support)

                    x_query = x_query + block.drop_path1(x_query)
                    x_query = x_query + block.drop_path2(block.mlp(block.norm2(x_query), feat_size_query))

                    x_support = x_support + block.drop_path1(x_support)
                    x_support = x_support + block.drop_path2(block.mlp(block.norm2(x_support), feat_size_support))
                else:
                    x_query = block(x_query, feat_size_query)
                    x_support = block(x_support, feat_size_support)
            
            x_query = stage.norm(x_query)
            x_support = stage.norm(x_support)

            x_query = x_query.reshape(B_q, feat_size_query[0], feat_size_query[1], -1).permute(0, 3, 1, 2).contiguous()
            x_support = x_support.reshape(B_s, feat_size_support[0], feat_size_support[1], -1).permute(0, 3, 1, 2).contiguous()
        return x_query, x_support

    def forward(self, query, support):
        x_query, x_support = self.forward_features(query, support)
        x = self.head_drop(x_query.mean(dim=[2, 3])) if self.global_pool else x_query
        x = self.head(x)
        return x

# Load the pre-trained PVTv2 model from TIMM
from timm import create_model

model = create_model('pvt_v2_b0', pretrained=True)
cross_pvtv2 = CrossPyramidVisionTransformerV2(model)


# Example usage
query = torch.randn(10, 3, 128, 128)
support = torch.randn(1, 3, 80, 80)
output = cross_pvtv2(query, support)
print(output.shape)


torch.Size([10, 1000])
