In [1]:
!pip install mindspore

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting mindspore
  Downloading mindspore-1.9.0-cp38-cp38-manylinux1_x86_64.whl (158.7 MB)
[K     |████████████████████████████████| 158.7 MB 4.2 kB/s 
Collecting asttokens>=2.0.4
  Downloading asttokens-2.2.1-py2.py3-none-any.whl (26 kB)
Collecting psutil>=5.6.1
  Downloading psutil-5.9.4-cp36-abi3-manylinux_2_12_x86_64.manylinux2010_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl (280 kB)
[K     |████████████████████████████████| 280 kB 65.9 MB/s 
Installing collected packages: psutil, asttokens, mindspore
  Attempting uninstall: psutil
    Found existing installation: psutil 5.4.8
    Uninstalling psutil-5.4.8:
      Successfully uninstalled psutil-5.4.8
Successfully installed asttokens-2.2.1 mindspore-1.9.0 psutil-5.9.4


In [2]:
import mindspore as ms
import mindspore.nn as nn
import numpy as np
from mindspore.ops import operations as P
import mindspore.numpy as mnp

In [3]:
class ImgPatches(nn.Cell):
    def __init__(self, in_ch=3, embed_dim=768, patch_size=16):
        super().__init__()
        self.patch_embed = nn.Conv2d(in_ch,embed_dim, kernel_size = patch_size, stride = patch_size)#...
        
    def construct(self, img):
        patches = self.patch_embed(img)
        x = patches.shape
        x = patches.view(x[0],-1,x[1])
        return x

In [4]:
class Head(nn.Cell):
    def __init__(self, d_h):
        super().__init__()
        self.d_h = d_h
        self.Qu = ms.Parameter(ms.Tensor(np.random.randn(d_h,d_h), ms.float32),name='Q')
        self.Ke = ms.Parameter(ms.Tensor(np.random.randn(d_h,d_h), ms.float32),name='K')
        self.Va = ms.Parameter(ms.Tensor(np.random.randn(d_h,d_h), ms.float32),name='V')
        
    def construct(self,x):
        batch_s,n,emb_d = x.shape[0],x.shape[1],x.shape[2]
        Q = mnp.matmul(x, self.Qu.expand_as(ms.Tensor(np.random.randn(batch_s,self.d_h,self.d_h), ms.float32)))  #[batch_s,n,dh]
        K = mnp.matmul(x, self.Ke.expand_as(ms.Tensor(np.random.randn(batch_s,self.d_h,self.d_h), ms.float32)))  #[batch_s,n,dh]
        V = mnp.matmul(x, self.Va.expand_as(ms.Tensor(np.random.randn(batch_s,self.d_h,self.d_h), ms.float32)))   #[batch_s,n,dh]

        A = mnp.matmul(Q, mnp.transpose(K,[0,2,1]) * 1./(self.d_h**0.5))
        SA= mnp.matmul(A,V)
        return SA

In [5]:
class Attention(nn.Cell):
    def __init__(self, dim, num_heads=8, attn_dropout=0.01, proj_dropout=0.):
        super().__init__()
        self.num_heads = num_heads
        #self.scale = 1./dim**0.5
        #self.qkv = #...
        self.heads = nn.CellList([Head(dim//num_heads) for i in range(num_heads)])
        self.attn_dropout = nn.Dropout(attn_dropout)
        #self.out = #...

    def construct(self, x):
        x = x.view(x.shape[0], x.shape[1], self.num_heads, -1)
        SA = []
        for i in range(self.num_heads):
             Sa = nn.Softmax(axis = 2)(self.heads[i](x[:,:,i,:]))
             SA.append(Sa)                 
        result = mnp.concatenate(SA, axis = 2) #[b, n, dim_embed]
        return result

In [6]:
class MLP(nn.Cell):
    def __init__(self, in_features, hidden_features=None, out_features=None,
                 dropout=0.01):
        super().__init__()

        self.hidden_features = hidden_features
        if hidden_features != None:
            self.linear1 = nn.Dense(in_features, hidden_features)
            self.linear2 = nn.Dense(hidden_features, out_features)
        else:
            self.linear1 = nn.Dense(in_features, out_features)
        self.GeLU    = nn.GELU()    
        self.drop    = nn.Dropout(dropout)

        #...

    def construct(self, x):
        if self.hidden_features != None:
            x = self.linear1(x)
            x = self.GeLU(x)
            x = self.drop(x)
            x = self.linear2(x)
        else:
            x = self.linear1(x)
        #...
        return x

In [7]:
class Block(nn.Cell):
    def __init__(self, dim, n_patches, num_heads=8 ,mlp_ratio=4, drop_rate=0.):
        super().__init__()
        self.l_norm1    = nn.LayerNorm([n_patches,dim],begin_norm_axis=1,begin_params_axis=1)
        self.Attention  = Attention(dim = dim, num_heads = num_heads)
        self.l_norm2    = nn.LayerNorm([n_patches,dim],begin_params_axis=1,begin_norm_axis=1)
        self.MLP        = ms.Parameter(ms.Tensor(np.random.randn(dim,dim), ms.float32))    
        #...

    def construct(self, x):

        x1 = self.l_norm1(x)
        x2 = self.Attention(x1)

        x = x + x2
        x = mnp.matmul(self.l_norm2(x), self.MLP)
        return x

In [8]:
class Transformer(nn.Cell):
    def __init__(self, depth, dim, n_patches, num_heads=8, mlp_ratio=4, drop_rate=0.01):
        super().__init__()
        self.blocks = nn.CellList([
            Block(dim,n_patches, num_heads, mlp_ratio, drop_rate)
            for i in range(depth)])

    def construct(self, x):
        for block in self.blocks:
            x = block(x)
        return x

In [96]:
class ViT(nn.Cell):
    def __init__(self, img_size=224, patch_size=16, in_ch=3, num_classes=10,
                 embed_dim=768, depth=12, num_heads=12, mlp_ratio=4,
                 drop_rate=0.01):
        super().__init__()
        self.n_patches = (img_size//patch_size)**2 

        self.learn_tok   = ms.Parameter(ms.Tensor(np.random.randn(1,embed_dim), ms.float32))    
        self.pos_enc     = ms.Tensor([[pos/(10000**(2*i/embed_dim)) for i in range(embed_dim)] for pos in range(self.n_patches+1)], ms.float32)

        self.get_patches = ImgPatches(in_ch,embed_dim,patch_size)

        self.drop        = nn.Dropout(drop_rate)
        self.transformer = Transformer(depth = depth,
                                       dim = embed_dim,
                                       n_patches= self.n_patches + 1,
                                       num_heads=num_heads,
                                       mlp_ratio=mlp_ratio, 
                                       drop_rate=drop_rate)
        
        self.late_fuse   = LateFusionClass(in_chan = self.n_patches+1, out_chan = self.n_patches+1)

        self.MLP         = MLP(in_features = 8*embed_dim,
                               out_features = num_classes)

        #...

    def construct(self, x):
        #print(type(x),'here23')
        x   = self.get_patches(x)
        b,n,emb = x.shape
        batch_token = self.learn_tok.expand_as(ms.Tensor(np.random.randn(b,1,emb), ms.float32))
        x_  = mnp.concatenate([batch_token, x],axis = 1)
        batch_pos_enc = self.pos_enc.expand_as(ms.Tensor(np.random.randn(b,self.n_patches+1,emb), ms.float32))


        x_ = x_ + batch_pos_enc
        x = self.drop(x_)

        x = self.transformer(x)
        x = self.late_fuse(x)

        x = x[:,0,:]
        x = self.MLP(x)
        x = nn.Softmax(axis = -1)(x)

        return x

In [95]:
class LateFusionClass(nn.Cell):
     def __init__(self, in_chan, out_chan):
         super().__init__()
         self.res = nn.Conv1dTranspose(in_chan,out_chan,kernel_size = 8,stride = 8) 
     def construct(self, x):
         x = self.res(x)
         return x
         

In [82]:
class ChanWiseClass(nn.Cell):
    def __init__(self,in_chan = 32,num_classes=10):
        super().__init__()
        self.pool = nn.AdaptiveAvgPool1d(1)
        self.cll =  nn.Dense(in_chan,num_classes)
    def construct(self, x):
        x = mnp.transpose(x,[0,2,1])
        x = self.pool(x)
        bs  = x.shape[0]
        x = x.view(bs,-1)
        x = self.cll(x)
        return x