In [1]:
import sys
sys.path.append("D:\\WETrak")

import torch.nn as nn

In [2]:
class CNN_classify(nn.Module):
    def __init__(self,emg_channel=5,p=0.1,channel=32,classify_num=0, finger_type=[]):
        super(CNN_classify,self).__init__()
        
        self.classify_num = classify_num
        self.emg_channel = emg_channel
        self.output_size = 8*classify_num*len(finger_type) # C*24
        
        if self.emg_channel == 2:
            max_pool_paras = ((5,2),(4,1),(2,1))
        elif self.emg_channel == 5:
            max_pool_paras = ((5,2),(4,2),(2,1))
        
        # downsample
        # kernel_size: 3*3
        self.down_block1 = nn.Sequential(nn.Conv2d(1, channel, 3, padding=1), nn.BatchNorm2d(channel), nn.ReLU(True), nn.Dropout(p=p), nn.MaxPool2d(max_pool_paras[0], stride=max_pool_paras[0]))
        self.down_block2 = nn.Sequential(nn.Conv2d(channel, channel*4, 3, padding=1), nn.BatchNorm2d(channel*4), nn.ReLU(True), nn.Dropout(p=p), nn.MaxPool2d(max_pool_paras[1], stride=max_pool_paras[1]))
        self.down_block3 = nn.Sequential(nn.Conv2d(channel*4, channel*8, 3, padding=1), nn.BatchNorm2d(channel*8), nn.ReLU(True), nn.Dropout(p=p), nn.MaxPool2d(max_pool_paras[2], stride=max_pool_paras[2]))
        
        # upsample
        # kernel_size: 3*3
        self.up_blcok1 = nn.Sequential(nn.Conv2d(channel*8, channel*4, 3, padding=1), nn.BatchNorm2d(channel*4), nn.ReLU(True), nn.Dropout(p=p), nn.Upsample(scale_factor=(5,2)))
        self.up_blcok2 = nn.Sequential(nn.Conv2d(channel*4, channel, 3, padding=1), nn.BatchNorm2d(channel), nn.ReLU(True), nn.Dropout(p=p), nn.Upsample(scale_factor=(4,2)))
        self.up_blcok3 = nn.Sequential(nn.Conv2d(channel, 1, 3, padding=1),nn.BatchNorm2d(1),nn.ReLU(True),nn.Dropout(p=p),nn.Upsample(scale_factor=(2,len(finger_type))))
        
        self.fc_block = nn.Sequential(nn.Flatten(),nn.Linear(in_features=1000*4*len(finger_type), out_features=self.output_size),nn.ReLU(True))
        
        # initialize weights
        self._initialize_weights()
        
    def forward(self,x):
        # downsample
        x = self.down_block1(x) # [batch, 32, 200, 1]
        x = self.down_block2(x) # [batch, 128, 50, 1]
        x = self.down_block3(x) # [batch, 256, 25, 1]
        
        # upsample
        x = self.up_blcok1(x) # [batch, 128, 125, 4]
        x = self.up_blcok2(x) # [batch, 32, 500, 4]
        x = self.up_blcok3(x)
        
        feature = x # [batch, 1, 1000, 12]
        feature = feature.view(feature.shape[0],24,500)
        
        x = self.fc_block(x)
        out = x.view(-1, self.classify_num, int(self.output_size/self.classify_num))
        
        return feature,out
    
    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, 0, 0.01)
                nn.init.constant_(m.bias, 0)