In [8]:
import torch
import torch.nn as nn
from functools import partial

from timm.models.vision_transformer import VisionTransformer, _cfg, Mlp, HybridEmbed, PatchEmbed
from timm.models.registry import register_model
from timm.models.layers import trunc_normal_, DropPath

from curves import compute_curve_order, coords_to_index, index_to_coords_indexes
import numpy as np

class Violin_Attention(nn.Module):
    def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., 
                 pos_emb = True, cls_tok = True, curve_list = ['s', 'sr', 'h', 'hr', 'm', 'mr','z', 'zr'],num_patches=196, qk_norm=False,mask='learned',scale=False,method='mul_v1',initialize=False,mask_sum='weighted'):
        super().__init__()
        self.num_heads = num_heads
        head_dim = dim // num_heads
        # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
        self.scale = qk_scale or head_dim ** -0.5

        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

        self.cls_tok = cls_tok
        self.curve_list = curve_list
        self.qk_norm = qk_norm
        self.mask = mask
        self.scale = scale
        self.method = method

        self.curve_indices_inv = []
        self.ai_list = []

        N = num_patches 
        order = torch.range(0,N-1)
        S = int(np.sqrt(N))
        grid = order.view(S,S).clone()

        # for curve in curve_list:
        #     if curve not in ['s', 'sr', 'h', 'hr', 'm', 'mr', 'z', 'zr']:
        #         raise ValueError("Invalid value for curve. Allowed values are: 's', 'sr', 'h', 'hr', 'm', 'mr', 'z', 'zr'.")
            
        #     curve_coords = compute_curve_order(grid, curve)
        #     self.curve_indices_inv.append(torch.tensor(index_to_coords_indexes(curve_coords, S,S)  , dtype=torch.long ))  
        #     if mask == 'fixed':
        #         self.ai_list.append(torch.ones(num_heads) * 0.996)
        #     else:
        self.ai_list = nn.ParameterList([nn.Parameter(torch.randn(num_heads)) for _ in range(len(curve_list))])
        # self.ai_list.append(nn.Parameter(torch.randn(num_heads)))

        if qk_norm:
            self.q_norm = nn.LayerNorm(head_dim)
            self.k_norm = nn.LayerNorm(head_dim)

        if mask == 'weighted':
            self.mask_weights = nn.Parameter(torch.randn(len(self.ai_list)))
        else:
            self.mask_weights =torch.ones(len(self.ai_list)) / len(self.ai_list)

        if scale:
            self.normalize = nn.Parameter(torch.randn(num_heads))
        else:
            self.normalize = torch.ones(num_heads)

    def forward(self, x):
        return x

In [9]:
a = Violin_Attention(64)

  order = torch.range(0,N-1)


In [10]:
for param in a.parameters():
    print(type(param.data), param.size())

<class 'torch.Tensor'> torch.Size([192, 64])
<class 'torch.Tensor'> torch.Size([64, 64])
<class 'torch.Tensor'> torch.Size([64])
<class 'torch.Tensor'> torch.Size([8])
<class 'torch.Tensor'> torch.Size([8])
<class 'torch.Tensor'> torch.Size([8])
<class 'torch.Tensor'> torch.Size([8])
<class 'torch.Tensor'> torch.Size([8])
<class 'torch.Tensor'> torch.Size([8])
<class 'torch.Tensor'> torch.Size([8])
<class 'torch.Tensor'> torch.Size([8])


In [18]:
torch.sigmoid(torch.tensor(5))

tensor(0.9933)