In [2]:
from torch._C import device
import torch
import torch.nn as nn
import torch.nn.functional as F
import sys 
from torchsummary import summary
import math

In [18]:
class Conv_com(nn.Module):
    def __init__(self, ):
        super(Conv_com, self).__init__()
        self.conv_1 = nn.Conv1d(in_channels=1,  out_channels=32, kernel_size=3, stride=2)
        self.conv_2 = nn.Conv1d(in_channels=32, out_channels=32, kernel_size=3, stride=2)
        self.conv_3 = nn.Conv1d(in_channels=32, out_channels=32, kernel_size=3)
        self.MaxPool = nn.MaxPool1d(2)
        self.Flatten = torch.nn.Flatten()

    def forward(self, x):
        out = F.relu(self.conv_1(x))
        out = F.relu(self.conv_2(out))
        out = F.relu(self.conv_3(out))
        out = self.MaxPool(out)
        out = self.Flatten(out)
        return out


class CRNN(nn.Module):
    def __init__(self, ):
        super(CRNN, self).__init__()
        self.conv_x = Conv_com()
        self.conv_y = Conv_com()
        self.conv_z = Conv_com()
        self.rnn = nn.RNN(input_size=2208, hidden_size=64)
        self.fla = nn.Flatten()
        self.fc1 = nn.Linear(64, 100)
        self.fc2 = nn.Linear(100, 2)

    def forward(self, data):
        x, y, z = data[:, 0, :], data[:, 1, :], data[:, 2, :]
        x = x.view(-1, 1, x.shape[-1])
        y = y.view(-1, 1, y.shape[-1])
        z = z.view(-1, 1, z.shape[-1])
        x_out = self.conv_x(x)
        y_out = self.conv_y(y)
        z_out = self.conv_z(z)
        new_feature = torch.cat([x_out, y_out, z_out], dim=1)
        new_feature = new_feature.view(-1, 1, new_feature.shape[-1])
        out, _ = self.rnn(new_feature)
        out = self.fla(out)
        out = F.relu(self.fc1(out))
        out = self.fc2(out)
        return F.log_softmax(out, dim=1)

In [19]:
model=CRNN()

data_in = torch.randn(128,3, 200)
model(data_in).shape

torch.Size([128, 2])

In [3]:
# Attention 

class Value(torch.nn.Module):
    def __init__(self, dim_input, dim_val):
        super(Value, self).__init__()
        self.dim_val = dim_val
        
        self.fc1 = nn.Linear(dim_input, dim_val, bias = False)
        #self.fc2 = nn.Linear(5, dim_val)
    
    def forward(self, x):
        x = self.fc1(x)
        #x = self.fc2(x)
        
        return x

class Key(torch.nn.Module):
    def __init__(self, dim_input, dim_attn):
        super(Key, self).__init__()
        self.dim_attn = dim_attn
        
        self.fc1 = nn.Linear(dim_input, dim_attn, bias = False)
        #self.fc2 = nn.Linear(5, dim_attn)
    
    def forward(self, x):
        x = self.fc1(x)
        #x = self.fc2(x)
        
        return x

class Query(torch.nn.Module):
    def __init__(self, dim_input, dim_attn):
        super(Query, self).__init__()
        self.dim_attn = dim_attn
        
        self.fc1 = nn.Linear(dim_input, dim_attn, bias = False)
        #self.fc2 = nn.Linear(5, dim_attn)
    
    def forward(self, x):
        
        x = self.fc1(x)
        #print(x.shape)
        #x = self.fc2(x)
        
        return x


def a_norm(Q, K):
    m = torch.matmul(Q, K.transpose(2,1).float())
    m /= torch.sqrt(torch.tensor(Q.shape[-1]).float())
    
    return torch.softmax(m , -1)


def attention(Q, K, V):
    #Attention(Q, K, V) = norm(QK)V
    a = a_norm(Q, K) #(batch_size, dim_attn, seq_length)
    
    return  torch.matmul(a,  V) #(batch_size, seq_length, seq_length)

class AttentionBlock(torch.nn.Module):
    def __init__(self, dim_val, dim_attn):
        super(AttentionBlock, self).__init__()
        self.value = Value(dim_val, dim_val)
        self.key = Key(dim_val, dim_attn)
        self.query = Query(dim_val, dim_attn)
    
    def forward(self, x, kv = None):
        if(kv is None):
            #Attention with x connected to Q,K and V (For encoder)
            return attention(self.query(x), self.key(x), self.value(x))
        
        #Attention with x as Q, external vector kv as K an V (For decoder)
        return attention(self.query(x), self.key(kv), self.value(kv))

class MultiHeadAttentionBlock(torch.nn.Module):
    def __init__(self, dim_val, dim_attn, n_heads):
        super(MultiHeadAttentionBlock, self).__init__()
        self.heads = []
        for i in range(n_heads):
            self.heads.append(AttentionBlock(dim_val, dim_attn))
        self.heads = nn.ModuleList(self.heads)  
        self.fc = nn.Linear(n_heads * dim_val, dim_val, bias = False)
    def forward(self, x, kv = None):
        a = []
        for h in self.heads:
            a.append(h(x, kv = kv))  
        a = torch.stack(a, dim = -1) #combine heads
        a = a.flatten(start_dim = 2) #flatten all head outputs
        x = self.fc(a)
        return x

In [13]:
# 测试案例

data_in = torch.randn(1, 1, 200)  # [Batch,input channel,len]

conv_x1 = nn.Conv1d(
    in_channels=1,
    out_channels=32,
    kernel_size=3,
    stride=2,
)
msattn = MultiHeadAttentionBlock(dim_val=32, dim_attn=4, n_heads=8)
conv_x2 = nn.Conv1d(in_channels=32, out_channels=32, kernel_size=3, stride=2)
conv_x3 = nn.Conv1d(in_channels=32, out_channels=32, kernel_size=3)
dropout = nn.Dropout(p=0.5)
MaxPool = nn.MaxPool1d(2)
Flatten = torch.nn.Flatten()

out = F.relu(conv_x1(data_in))
print('conv_x1 shape: ', out.shape)

###  attention
out = out.permute(0, 2, 1)  # [Batch,hidden,len] -> [Batch,len,hidden]
out = msattn(out)
print('msattn shape: ', out.shape)
out = out.permute(0, 2, 1)
###  attention

out = F.relu(conv_x2(out))
print('conv2 shape: ', out.shape)
out = F.relu(conv_x3(out))
print('conv3 shape: ', out.shape)
out = dropout(out)
out = MaxPool(out)
print('MaxPool shape: ', out.shape)
out = Flatten(out)
print('Flatten shape: ', out.shape)

conv_x1 shape:  torch.Size([1, 32, 99])
msattn shape:  torch.Size([1, 99, 32])
conv2 shape:  torch.Size([1, 32, 49])
conv3 shape:  torch.Size([1, 32, 47])
MaxPool shape:  torch.Size([1, 32, 23])
Flatten shape:  torch.Size([1, 736])


In [50]:
data_in = torch.randn(1,3, 200)

class Conv3(nn.Module):
    def __init__(self,):
        super(Conv3, self).__init__()
        
        self.conv_1 =nn.Conv1d(in_channels=1,out_channels=32,kernel_size=3,stride=2)
        self.msattn = MultiHeadAttentionBlock(dim_val=32, dim_attn=4, n_heads=8)
        self.conv_2 =nn.Conv1d(in_channels=32,out_channels=32,kernel_size=3,stride=2)
        self.conv_3 =nn.Conv1d(in_channels=32,out_channels=32,kernel_size=3)
#         self.dropout = nn.Dropout(p=0.5) 
        self.MaxPool = nn.MaxPool1d(2)
        self.Flatten = torch.nn.Flatten()

    def forward(self, x):
        out = F.relu(self.conv_1(x))
        print('conv_x1 shape: ',out.shape)

        ###  attention
        out = out.permute(0,2,1)  # [Batch,hidden,len] -> [Batch,len,hidden]
        out = self.msattn(out)
        print('msattn shape: ',out.shape)
        out = out.permute(0,2,1)
        ### attention

        out = F.relu(self.conv_2(out))
        print('conv2 shape: ',out.shape)
        out = F.relu(self.conv_3(out))
        print('conv3 shape: ',out.shape)
        out = self.dropout(out)
        out = self.MaxPool(out)
        print('MaxPool shape: ',out.shape)
        out = self.Flatten(out)
        print('Flatten shape: ',out.shape)

        return x

class Net(nn.Module):
    def __init__(self, ):
        super(Net, self).__init__()    
        self.conv_x = Conv3()
        self.conv_y = Conv3()
        self.conv_z = Conv3()
        
    def forward(self, x):
        # input shape:  [batch,channel,len]
        x_out = self.conv_x(x)
        y_out = self.conv_y(y)
        z_out = self.conv_z(z)
        
    return x_out,y_out,z_out