In [1]:
import torch
import torch.nn as nn
import numpy as np
from torchsummary import summary
from torch.nn import Conv2d
from einops.layers.torch import Rearrange, Reduce
from tensorboardX import SummaryWriter

In [2]:
def weight_init(m):
    if isinstance(m, nn.Linear):
        nn.init.kaiming_normal(m.weight)

In [3]:
class FeedForward(nn.Module):
    def __init__(self,dim,hidden_dim,dropout=0.):
        super().__init__()
        self.net=nn.Sequential(
            #由此可以看出 FeedForward 的输入和输出维度是一致的
            nn.Linear(dim,hidden_dim),
            #激活函数
            nn.GELU(),
            #防止过拟合
            nn.Dropout(dropout),
            #重复上述过程
            nn.Linear(hidden_dim,dim),
            nn.Dropout(dropout)
        )
    def forward(self,x):
        x=self.net(x)
        return x

In [4]:
class MixerBlock(nn.Module):
    def __init__(self,dim,token_dim,channel_dim,dropout=0.):
        super().__init__()
        self.token_mixer=nn.Sequential(
            nn.LayerNorm(dim),
            Rearrange('b h w -> b w h'),
            FeedForward(dim,token_dim,dropout),
            Rearrange('b w h -> b h w')
 
         )
        self.channel_mixer=nn.Sequential(
            nn.LayerNorm(dim),
            FeedForward(dim,channel_dim,dropout)
        )
    def forward(self,x):
        
        x = x+self.token_mixer(x)
        
        x = x+self.channel_mixer(x)
        
        return x

In [5]:
class MLPMixer(nn.Module):
    def __init__(self,in_channels,dim,num_classes,image_size,depth,token_dim,channel_dim,dropout=0.):
        super().__init__()
        
        self.to_input_arrange = nn.Sequential(Rearrange('b c h w -> b h (c w)'))
        # w as the channels -> input size (N,48,48)
 
        # 输入为48*48的table
        # 以下为token-mixing MLPs（MLP1）和channel-mixing MLPs（MLP2）各一层
        self.mixer_blocks=nn.ModuleList([])
        for _ in range(depth):
            self.mixer_blocks.append(MixerBlock(dim,token_dim,channel_dim,dropout))
 
        #
        self.layer_normal=nn.LayerNorm(dim)
 
        #
        self.mlp_head=nn.Sequential(
            #nn.Linear(dim,num_classes),
            nn.ReLU()
        )
        
    def forward(self,x):
        print('x.shape:', x.shape)
        x = self.to_input_arrange(x)
        print('input_arrange.shape:', x.shape)
        for mixer_block in self.mixer_blocks:
            x = mixer_block(x)
        x = self.layer_normal(x)    
        print('x_combine.shape:', x.shape)
        
        #x = x.mean(dim=1)
 
        x = self.mlp_head(x)
    
        '''
        for a in x:
            size = a.shape[0]
            size = int(np.sqrt(size))
            for i in range(size):
                x[:, i + i * size] = 0
        '''
                
        return x


In [12]:
a = torch.randn(2,3)
print(a)
w1 = torch.nn.Linear(3,3)
b= w1(a)
print(b)
w2 = nn.LayerNorm(3)
c = w2(b)
print(c)

tensor([[ 1.0646e-04,  9.5762e-01,  2.7629e-01],
        [-7.7432e-01, -3.2898e-01, -1.3674e-01]])
tensor([[ 0.0324, -0.2371, -0.3519],
        [ 0.0359, -0.7703,  0.3345]], grad_fn=<AddmmBackward>)
tensor([[ 1.3528, -0.3201, -1.0327],
        [ 0.3626, -1.3651,  1.0025]], grad_fn=<NativeLayerNormBackward>)


In [6]:
if __name__ == "__main__":
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    model = MLPMixer(in_channels=1, dim=48, num_classes=48*48, image_size=48, depth=1, token_dim=48,
                     channel_dim=48).to(device)
    summary(model,(1,48,48))
 
    # torch.Tensor([1, 2, 3, 4, 5, 6])
    inputs = torch.Tensor(1, 1, 48, 48)
    inputs = inputs.to(device)
    print(inputs.shape)
 
    # 将model保存为graph
    with SummaryWriter(log_dir='logs', comment='model') as w:
        w.add_graph(model, (inputs,))
        print("success")

x.shape: torch.Size([2, 1, 48, 48])
input_arrange.shape: torch.Size([2, 48, 48])
x_combine.shape: torch.Size([2, 48, 48])
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
         Rearrange-1               [-1, 48, 48]               0
         LayerNorm-2               [-1, 48, 48]              96
         Rearrange-3               [-1, 48, 48]               0
            Linear-4               [-1, 48, 48]           2,352
              GELU-5               [-1, 48, 48]               0
           Dropout-6               [-1, 48, 48]               0
            Linear-7               [-1, 48, 48]           2,352
           Dropout-8               [-1, 48, 48]               0
       FeedForward-9               [-1, 48, 48]               0
        Rearrange-10               [-1, 48, 48]               0
        LayerNorm-11               [-1, 48, 48]              96
           Linear-12               [-1, 48, 4

To keep the current behavior, use torch.div(a, b, rounding_mode='trunc'), or for actual floor division, use torch.div(a, b, rounding_mode='floor'). (Triggered internally at  ..\aten\src\ATen\native\BinaryOps.cpp:467.)
  return torch.floor_divide(self, other)
  for a in x:
  size = int(np.sqrt(size))


x.shape: torch.Size([1, 1, 48, 48])
input_arrange.shape: torch.Size([1, 48, 48])
x_combine.shape: torch.Size([1, 48, 48])
success
