In [1]:
from einops import rearrange#, reduce, repeat
import torch
import numpy as np
import torch.nn as nn


In [2]:
class Base_Transformer_Block(nn.Module):
    def __init__(self, n_head, ftr_dim):
        super(Base_Transformer_Block, self).__init__()
        
        self.n_head=n_head
        self.ftr_dim=ftr_dim
        head_dim=ftr_dim//n_head
        self.head_dim=head_dim
        
        # input x, h are un-normalized, can take any + or negative value
        self.project_h=nn.Sequential(nn.GroupNorm(num_groups=1, num_channels=ftr_dim, affine=True),#layerNorm
                                     nn.Softplus(),
                                     nn.Linear(in_features=ftr_dim, out_features=ftr_dim, bias=False),
                                     nn.GroupNorm(num_groups=1, num_channels=ftr_dim, affine=True),#layerNorm
                                     nn.Softplus())    
        
        #########################################################################################################
        
        self.update_reset_lyr=nn.Sequential( # B,2.dim,n_head
                                     nn.Conv1d(in_channels=(head_dim), out_channels=(2*head_dim),
                                               kernel_size=1, stride=1, padding=0, groups=1, bias=False),
                                     
                                     nn.GroupNorm(num_groups=1, num_channels=(2*head_dim), affine=True),#layerNorm
                                     nn.Sigmoid())
                                    
        
        self.candidate_activation_vector=nn.Sequential( # B,2.dim,n_head
                                         nn.Conv1d(in_channels=(head_dim), out_channels=(head_dim),
                                               kernel_size=1, stride=1, padding=0, groups=1, bias=False),
                                     
                                         nn.GroupNorm(num_groups=1, num_channels=(head_dim), affine=True),#layerNorm
                                         nn.Softplus())                                    
        
        ##########################################################################################################
        
        self.project_v=nn.Linear(in_features=ftr_dim, out_features=ftr_dim, bias=True)
                                    # un-normalized, can take any + or - value
        
        ##########################################################################################################
        
        
    def forward(self, h):
        # h is the ftr at all depths.
        # x_inp is h from the previous state.
        
        # Project the x and h to n_head sub-spaces each with 768//n_head dimensions
        h=self.project_h(h) # [h^1, h^2, ... h^M]            
        h=rearrange(h, 'b (d m) -> b d m', d=self.head_dim, m=self.n_head) # B,dim,n_head
        
        update_reset=self.update_reset_lyr(h) # sigmoid, multi-head with 1x1 conv
        u=update_reset[:, 0:self.head_dim, :] # B,head_dim,n_head
        r=update_reset[:, self.head_dim:,:]
        
        g=self.candidate_activation_vector((h*r)) # B, head_dim, n_head
        
        v=u*(g-h) # B,head_dim,n_head
        
        #### concatenate all heads and apply fc layer
        v=rearrange(v, 'b d m -> b (d m)')
        v=self.project_v(v)
        ### residual connection
        #v=v+x_inp
        return v

In [3]:
class Transformer_Block(nn.Module):
    def __init__(self, n_head, ftr_dim):
        super(Transformer_Block, self).__init__()
        
        self.n_head=n_head
        self.ftr_dim=ftr_dim
        head_dim=ftr_dim//n_head
        self.head_dim=head_dim
        
        # input x, h are un-normalized, can take any + or negative value
        self.project_x=nn.Sequential(nn.GroupNorm(num_groups=1, num_channels=ftr_dim, affine=True),#layerNorm
                                     nn.Softplus(),
                                     nn.Linear(in_features=ftr_dim, out_features=ftr_dim, bias=False),
                                     nn.GroupNorm(num_groups=1, num_channels=ftr_dim, affine=True),#layerNorm
                                     nn.Softplus())
                                    
        
        self.project_h=nn.Sequential(nn.GroupNorm(num_groups=1, num_channels=ftr_dim, affine=True),#layerNorm
                                     nn.Softplus(),
                                     nn.Linear(in_features=ftr_dim, out_features=ftr_dim, bias=False),
                                     nn.GroupNorm(num_groups=1, num_channels=ftr_dim, affine=True),#layerNorm
                                     nn.Softplus())    
        
        #########################################################################################################
        
        self.update_reset_lyr=nn.Sequential( # B,2.dim,n_head
                                     nn.Conv1d(in_channels=(2*head_dim), out_channels=(2*head_dim),
                                               kernel_size=1, stride=1, padding=0, groups=1, bias=False),
                                     
                                     nn.GroupNorm(num_groups=1, num_channels=(2*head_dim), affine=True),#layerNorm
                                     nn.Sigmoid())
                                    
        
        self.candidate_activation_vector=nn.Sequential( # B,2.dim,n_head
                                         nn.Conv1d(in_channels=(2*head_dim), out_channels=(head_dim),
                                               kernel_size=1, stride=1, padding=0, groups=1, bias=False),
                                     
                                         nn.GroupNorm(num_groups=1, num_channels=(head_dim), affine=True),#layerNorm
                                         nn.Softplus())                                    
        
        ##########################################################################################################
        
        self.project_v=nn.Linear(in_features=ftr_dim, out_features=ftr_dim, bias=True)
                                    # un-normalized, can take any + or - value
        
        ##########################################################################################################
        
        
    def forward(self, h, x_inp):
        # h is the ftr at all depths.
        # x_inp is h from the previous state.
        
        # Project the x and h to n_head sub-spaces each with 768//n_head dimensions
        x=self.project_x(x_inp) # [x^1, x^2, ... x^M]
        h=self.project_h(h) # [h^1, h^2, ... h^M]
        
        x=rearrange(x, 'b (d m) -> b d m', d=self.head_dim, m=self.n_head) # B,dim,n_head
        h=rearrange(h, 'b (d m) -> b d m', d=self.head_dim, m=self.n_head) # B,dim,n_head
        
        tmp_inp=torch.cat((x, h), dim=1) # B,(2*dim),n_head
        update_reset=self.update_reset_lyr(tmp_inp) # sigmoid, multi-head with 1x1 conv
        u=update_reset[:, 0:self.head_dim, :] # B,head_dim,n_head
        r=update_reset[:, self.head_dim:,:]
        
        tmp_inp=torch.cat((x, (h*r)), dim=1) # B,(2*dim),n_head
        g=self.candidate_activation_vector(tmp_inp) # B, head_dim, n_head
        
        v=u*(g-h) # B,head_dim,n_head
        
        #### concatenate all heads and apply fc layer
        v=rearrange(v, 'b d m -> b (d m)')
        v=self.project_v(v)
        ### residual connection
        v=v+x_inp
        return v

In [4]:
class ODE_network(nn.Module):
    def __init__(self, n_head=12, depth=3, ftr_dim=768):
        super(ODE_network, self).__init__()                                  
        self.base_ode_gru_lyr=Base_Transformer_Block(n_head, ftr_dim)
        
        lst=[]
        for d in range(0, depth-1):
            lst.append(Transformer_Block(n_head, ftr_dim))
        
        self.deep_ode_lyr=nn.ModuleList(lst)
        self.depth=depth
    
    
    def forward(self, t, h):
        #h hidden state : same as feature embedding F
        delta_h=self.base_ode_gru_lyr(h) # ftr 
        for d in range(0, self.depth-1):
            delta_h=self.deep_ode_lyr[d](h, delta_h)
        
        return delta_h

ode_model=ODE_network(n_head=12, depth=3, ftr_dim=768)
print(ode_model)

ftr=np.ones((12, 768))
ftr=torch.FloatTensor(ftr)
t=torch.FloatTensor(np.array([0.1]))

print(ftr.shape)

delta_h=ode_model(t, ftr)

print(delta_h)
print(delta_h.shape)