# **Load Pretrained Weights**

In [1]:
using Pkg
using PyCall
using Knet
using Knet: atype
@pyimport torch
@pyimport torch.nn as nn

In [None]:
# model_path = "/home/mertcokelek/Downloads/80.7_T2T_ViT_t_14.pth.tar"
# weights = torch.load(model_path)
# weight_sample = Param(atype(weights["state_dict_ema"]["blocks.3.mlp.fc1.weight"][:cpu]()[:numpy]()))

# **Implementation**

## **gelu.jl**

In [2]:
using AutoGrad, Knet
using AutoGrad: @gcheck

const GConstant01 = sqrt(2/pi)
const GConstant02 = 0.044715 * sqrt(2/pi)
const GConstant03 = GConstant01 / 2

# Main definition, broadcasted version works on Arrays

gelu(x::T) where T = (x/2)*(1 + tanh(T(GConstant02)*x^3 + T(GConstant01)*x))
gelu(x::Int) = gelu(convert(Float32, x))
geluback(x::T,dy::T) where T = dy*(T(0.5)*tanh(T(GConstant02)*x^3 + T(GConstant01)*x) + (T(0.0535161)*x^3 + T(GConstant03)*x)*(1/cosh(T(GConstant02)*x^3 + T(GConstant01)*x))^2 + T(0.5))


# This defines gelu for AutoGrad

@primitive  gelu(x),dy  geluback.(x,dy)

import Base.Broadcast: broadcasted
import Knet: KnetArray

function KnetArray(x::CuArray{T,N}) where {T,N}
    p = Base.bitcast(Knet.Cptr, x.ptr)
    k = Knet.KnetPtr(p, sizeof(x), gpu(), x) 
    KnetArray{T,N}(k, size(x))
end

broadcasted(::typeof(gelu),x::KnetArray) = KnetArray(gelu.(CuArray(x)))
broadcasted(::typeof(geluback),x::KnetArray,dy::KnetArray) = KnetArray(geluback.(CuArray(x),CuArray(dy)));


## **transformer_block.jl**
**get_sinusoid_encoding, Mlp, Block**

### get_sinusoid_encoding

In [3]:
using Knet: atype
struct get_sinusoid_encoding 
    w 
end

function get_sinusoid_encoding(n_position, d_hid; λ=10000, atype=atype())
    x = exp.((0:2:d_hid-1) .* -(log(λ)/d_hid)) * (0:n_position-1)'
    pe = zeros(d_hid, n_position)
    pe[1:2:end,:] = sin.(x)
    pe[2:2:end,:] = cos.(x)
    get_sinusoid_encoding(atype(pe))
end

function (l::get_sinusoid_encoding)(x)
    x .+ l.w[:,1:size(x,2)]
end

### Mlp

In [4]:
using Knet.Layers21: dropout
mutable struct Mlp
    in_features; hidden_features; out_features; act_layer; drop;
    
    fc1; act; fc2;
    function Mlp(in_features; hidden_features=nothing, out_features=nothing, act_layer::Function=gelu, drop=0.)
        self = new(in_features, hidden_features, out_features, act_layer, drop)
        
        out_features = out_features == nothing ? in_features : out_features
        hidden_features = hidden_features == nothing ? in_features : hidden_features
        
        self.fc1 = Linear(in_features, hidden_features)
        self.act = act_layer#() ????
        self.fc2 = Linear(hidden_features, out_features)
        self.drop = drop # called in forward pass
        return self
    end
    
    function (self::Mlp)(x)
        x = dropout(self.act(self.fc1(x)), drop)
        x = dropout(self.fc2(x), drop)
        return x
    end
end

### Attention

In [5]:
mutable struct Attention
    dim; num_heads; qkv_bias; qk_scale; attn_drop; proj_drop;
    
    scale; qkv; proj;
    function Attention(;dim=784, num_heads=8, qkv_bias=false, qk_scale=false, attn_drop=0., proj_drop=0.)
        self = new(dim, num_heads, qkv_bias, qk_scale, attn_drop, proj_drop)
        self.num_heads = num_heads
        head_dim = dim ÷ num_heads
        
        self.scale = qk_scale || head_dim^-0.5
        
        self.qkv = Linear(dim, dim*3, bias=qkv_bias)
        # self.attn_drop = dropout(attn_drop) # dropouts applied at forward pass
        self.proj = Linear(dim, dim)
        # self.proj_drop = dropout(proj_drop)
        return self
    end
end


function transposedims(x, dim1, dim2)
    """
    Helper function equivalent to torch.Tensor.transpose(dim1, dim2)
    
    """
    size_ = [i for i in size(x)]
    dims = length(size_)
    dim1_ = dim1 > 0 ? dim1 : dims+dim1+1
    dim2_ = dim2 > 0 ? dim2 : dims+dim2+1
    
    size_[dim1_] = dim2_
    size_[dim2_] = dim1_
    
    return permutedims(x, size_)
end

function(self::Attention)(x)
    B, N, C = size(x)
    qkv = permutedims(reshape(self.qkv(x), (B, N, 3, self.num_heads, C ÷ self.num_heads)), (3,1,4,2,5))
    q, k, v = qkv[1], qkv[1], qkv[2]
    
    attn = (q * transposedims(k, -2, -1)) .* self.scale
    softmax_dim = length(size(attn))
    attn = softmax(attn, dims=softmax_dim)
    attn = dropout(attn, self.attn_drop)
    
    x = reshape(transposedims(attn * v, 1, 2), (B, N, C))
    x = self.proj(x)
    x = dropout(x, self.proj_drop)
    return x
end 
    

### Block

In [6]:
mutable struct Block
    dim; num_heads; mlp_ratio; qkv_bias; qk_scale; drop; attn_drop;
    drop_path; act_layer; norm_layer;
    
    norm1;
    attn;
    norm2;
    mlp;
    
    function Block(;dim=784, num_heads=8, mlp_ratio=4.0, qkv_bias=false, qk_scale=nothing, drop=0., attn_drop=0.,
        drop_path=0., act_layer::Function=gelu, norm_layer="LayerNorm")
    
        self = new(dim, num_heads, mlp_ratio, qkv_bias, qk_scale, drop, attn_drop, drop_path, act_layer, norm_layer)
        
        @assert norm_layer == "LayerNorm"
        self.norm1 = LayerNorm(dim)
        
        # TODO: Attention
        self.attn =  Attention(dim=dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
        
        # TODO: DropPath
        # self.drop_path = drop_path > 0.? DropPath(drop_path) : nn.Identity()  # ????
                
        self.norm2 = LayerNorm(dim)
        mlp_hidden_dim = convert(Int, dim * mlp_ratio)
        self.mlp = Mlp(dim; hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
        return self
    end
end

function (self::Block)(x)
    # x = x + self.drop_path(self.attn(self.norm1(x)))
    x = x + self.drop_path(self.mlp(self.norm2(x)))
    return x
end 


## **utils.jl**
**Linear, 
LayerNorm**

In [7]:
struct LayerNorm; a; b; ϵ; end

function LayerNorm(dmodel; eps=1e-6)
    a = param(dmodel; init=ones)
    b = param(dmodel; init=zeros)
    LayerNorm(a, b, eps)
end

function (l::LayerNorm)(x, o...)
    μ = mean(x,dims=1)
    σ = std(x,mean=μ,dims=1)
    ϵ = eltype(x)(l.ϵ)
    l.a .* (x .- μ) ./ (σ .+ ϵ) .+ l.b # TODO: doing x .- μ twice?
end


struct Linear; w; b; end

function Linear(input::Int,outputs...; bias=true)
    Linear(param(outputs...,input), bias ? param0(outputs...) : nothing)
end

function (l::Linear)(x)
    W1,W2,X1,X2 = size(l.w)[1:end-1], size(l.w)[end], size(x,1), size(x)[2:end]; 
    @assert W2===X1
    y = reshape(l.w,:,W2) * reshape(x,X1,:)
    y = reshape(y, W1..., X2...)
    if l.b !== nothing; y = y .+ l.b; end
    return y
end


## **t2t_vit.jl**
**T2T_module, T2T_ViT**

### T2T_module

In [8]:
mutable struct T2T_module
    """
    Tokens-to-Token encoding module
    """
    img_size; tokens_type;
    in_chans
    embed_dim; token_dim;
    
    
    soft_split0; soft_split1; soft_split2
    attention1; attention2
    project
    num_patches
    
    function T2T_module(img_size=224, 
                        tokens_type="performer", 
                        in_chans=3, 
                        embed_dim=768, 
                        token_dim=64)
        self = new(img_size, tokens_type, in_chans, embed_dim, token_dim)
        
        if tokens_type == "performer"
            println("adopt performer encoder for tokens-to-token")
            self.soft_split0 = nn.Unfold(kernel_size=(7, 7), stride=(4, 4), padding=(2, 2))
            self.soft_split1 = nn.Unfold(kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
            self.soft_split2 = nn.Unfold(kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
            
            # Token_performer TODO
#             self.attention1 = Token_performer(dim=in_chans*7*7, in_dim=token_dim, kernel_ratio=0.5)
#             self.attention2 = Token_performer(dim=token_dim*3*3, in_dim=token_dim, kernel_ratio=0.5)
            self.project = Linear(token_dim * 3 * 3, embed_dim)
        end
        
        self.num_patches = (img_size ÷ (4 * 2 * 2)) * (img_size ÷ (4 * 2 * 2))  # there are 3 soft split, stride are 4,2,2 seperately

        return self
        """
        TODO / NOTE: initially, let's just consider 'performer' model. 'transformer' & 'convolution' can be done later.
        TODO: Unfold currently taken from PyTorch. Might need conversion between datatypes.
        """
    end
    
end

function (self::T2T_module)(x)
    # step0: soft split
    x = permutedims(self.soft_split0(x), [2, 3])
    
    # x [B, 56*56, 147=7*7*3]
    # iteration1: restructurization/reconstruction
    
    x_1_4 = self.attention1(x)
    B, new_HW, C = size(x_1_4)
    x = reshape(permutedims(x_1_4, [2, 3]), (B, C, convert(Int, sqrt(new_HW)), convert(Int, sqrt(new_HW))))
    # iteration1: soft_split
    x = permutedims(self.soft_split1(x), [2, 3])

    # iteration2: restructurization/reconstruction
    x_1_8 = self.attention2(x)
    B, new_HW, C = size(x_1_4)
    x = reshape(permutedims(x_1_8, [2, 3]), (B, C, convert(Int, sqrt(new_HW)), convert(Int, sqrt(new_HW))))
    # iteration1: soft_split
    x = permutedims(self.soft_split2(x), [2, 3])

    # final_tokens
    x = self.project(x)
    
    return x, x_1_8, x_1_4
end
    

### T2T_ViT

In [9]:
mutable struct T2T_ViT; 
    img_size; tokens_type
    in_chans; num_classes; num_features
    embed_dim; depth
    num_heads; mlp_ratio
    qkv_bias; qk_scale
    drop_rate; attn_drop_rate; drop_path_rate; 
    norm_layer;

    blocks
    norm
    head
    tokens_to_token
    cls_token
    pos_embed

    function T2T_ViT(img_size=224, 
            tokens_type="performer", 
            in_chans=3, num_classes=1000,
            embed_dim=784,depth=12,
            num_heads=12, mlp_ratio=4., 
            qkv_bias=false, qk_scale=false, 
            drop_rate=0., attn_drop_rate=0., drop_path_rate=0., 
            norm_layer="LayerNorm"
            )
            
        self = new(img_size, tokens_type, 
                    in_chans, num_classes, 
                    embed_dim,depth,num_heads,
                    mlp_ratio,qkv_bias,qk_scale,
                    drop_rate,attn_drop_rate,
                    drop_path_rate,norm_layer)
        self.num_features = embed_dim

        self.tokens_to_token = T2T_module()
        self.cls_token = convert(KnetArray, zeros(1,1,768))
        self.pos_embed = reshape(get_sinusoid_encoding(16, 784).w, (1, 16, 784))  #
        

        # NOTE: In PyTorch implementation, dropout is added as a portable layer.
        # We will add it as an operation, in the forward pass. 
        # UPDATE: probably we won't need it, since we'll use pretrained weights. 
        
        self.drop_rate = drop_rate
        dpr = [i for i in 0:0.25/11:0.25] # Stochastic depth decay rule

        self.blocks = [Block(dim=embed_dim, 
                            num_heads=num_heads, 
                            mlp_ratio=mlp_ratio, 
                            qkv_bias=qkv_bias, 
                            qk_scale=qk_scale,
                            drop=drop_rate, 
                            attn_drop=attn_drop_rate,
                            drop_path=dpr[i], 
                            act_layer=gelu,
                            norm_layer=norm_layer) for i in 1:depth]
        @assert norm_layer == "LayerNorm"                   
        self.norm = LayerNorm(embed_dim)

        # Classifier head
        self.head = Linear(embed_dim, num_classes)
        
#         For initializing weights, probably we won't need the rest.
#         We will use pretrained weights.
    
    end
end


function expand(x, B)
    """
    Helper function for torch.Tensor.expand()
    expand(x, B) is equivalent to x.expand(B, -1, -1)
    """
    res = x
    dimx = length(size(x))
    for i in 1:B-1
        res = cat(res, x, dims=dimx+1)
    end
    permute_order = [i for i in 1:dimx]
    # @show permute_order
    insert!(permute_order, 1, dimx+1)
    # @show permute_order, dimx, size(x)
    # @show size(res)
    res = permutedims(res, permute_order)
end

function (self::T2T_ViT)(x)
    B = size(x)[1]
    x, x_1_8, x_1_4 = self.tokens_to_token(x)
    
    cls_tokens = expand(self.cls_token, B) # self.cls_token.expand(B, -1, -1) in Torch
    x = hcat(cls_tokens, x)
    x += self.pos_embed
    # x = self.pos_drop # Dropout will not be used for now.
    
    # T2T-ViT backbone
    for blk in self.blocks
        x = blk(x)
    end
    
    x = self.norm(x)
    return x[:,2:end,:], x_1_8, x_1_4
end    


dummyT = T2T_ViT()

adopt performer encoder for tokens-to-token


Linear(P(Knet.KnetArrays.KnetMatrix{Float32}(1000,784)), P(Knet.KnetArrays.KnetVector{Float32}(1000)))