In [1]:
import numpy as np
from collections import Counter
from tqdm import tqdm
from matplotlib import pyplot as plt
from sklearn.metrics import classification_report 

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

In [2]:


class MyDataset(Dataset):
    def __init__(self, data, label):
        self.data = data
        self.label = label

    def __getitem__(self, index):
        return (torch.tensor(self.data[index], dtype=torch.float), torch.tensor(self.label[index], dtype=torch.long))

    def __len__(self):
        return len(self.data)


In [8]:

class ACNN(nn.Module):
    """
    
    Input:
        X: (n_samples, n_channel, n_length)
        Y: (n_samples)
        
    Output:
        out: (n_samples)
        
    Pararmetes:
        n_classes: number of classes
        
    """

    def __init__(self, in_channels, out_channels, att_channels, n_len_seg, n_classes, device, verbose=False):
        super(ACNN, self).__init__()
        
        self.n_len_seg = n_len_seg
        self.n_classes = n_classes
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.att_channels = att_channels

        self.device = device
        self.verbose = verbose

        # (batch, channels, length)
        self.cnn = nn.Conv1d(in_channels=self.in_channels, 
                            out_channels=self.out_channels, 
                            kernel_size=16, 
                            stride=4)

        self.W_att_channel = nn.Parameter(torch.randn(self.out_channels, self.att_channels))
        self.v_att_channel = nn.Parameter(torch.randn(self.att_channels, 1))

        self.dense = nn.Linear(out_channels, n_classes)
        self.softmax = nn.Softmax(dim=1)
    def forward(self, x):

        self.n_channel, self.n_length = x.shape[-2], x.shape[-1]
        assert (self.n_length % self.n_len_seg == 0), "Input n_length should divided by n_len_seg"
        self.n_seg = self.n_length // self.n_len_seg

        out = x
        if self.verbose:
            print(out.shape)

        # (n_samples, n_channel, n_length) -> (n_samples, n_length, n_channel)
        out = out.permute(0,2,1)
        if self.verbose:
            print(out.shape)
        # (n_samples, n_length, n_channel) -> (n_samples*n_seg, n_len_seg, n_channel)
        out = out.view(-1, self.n_len_seg, self.n_channel)
        if self.verbose:
            print(out.shape)
        # (n_samples*n_seg, n_len_seg, n_channel) -> (n_samples*n_seg, n_channel, n_len_seg)
        out = out.permute(0,2,1)
        if self.verbose:
            print(out.shape)
        # cnn
        out = self.cnn(out)
        if self.verbose:
            print(out.shape)
        # global avg, (n_samples*n_seg, out_channels)
        out = out.mean(-1)
        if self.verbose:
            print(out.shape)
        # global avg, (n_samples, n_seg, out_channels)
        out = out.view(-1, self.n_seg, self.out_channels)
        if self.verbose:
            print(out.shape)
        # self attention
        e = torch.matmul(out, self.W_att_channel)
        e = torch.matmul(torch.tanh(e), self.v_att_channel)
        n1 = torch.exp(e)
        n2 = torch.sum(torch.exp(e), 1, keepdim=True)
        gama = torch.div(n1, n2)
        out = torch.sum(torch.mul(gama, out), 1)
        if self.verbose:
            print(out.shape)
        # dense
        out = self.dense(out)
        if self.verbose:
            print(out.shape)
        out = self.softmax(out)
        return out

In [9]:
Net = ACNN(1,10,10,1024,2,torch.cuda.device)

In [10]:
inp = torch.rand(20, 1, 1024)

In [12]:
Net.eval()

ACNN(
  (cnn): Conv1d(1, 10, kernel_size=(16,), stride=(4,))
  (dense): Linear(in_features=10, out_features=2, bias=True)
  (softmax): Softmax(dim=1)
)

In [11]:
Net(inp)

tensor([[0.5112, 0.4888],
        [0.5118, 0.4882],
        [0.5134, 0.4866],
        [0.5087, 0.4913],
        [0.5082, 0.4918],
        [0.5086, 0.4914],
        [0.5100, 0.4900],
        [0.5102, 0.4898],
        [0.5091, 0.4909],
        [0.5097, 0.4903],
        [0.5090, 0.4910],
        [0.5096, 0.4904],
        [0.5121, 0.4879],
        [0.5120, 0.4880],
        [0.5110, 0.4890],
        [0.5090, 0.4910],
        [0.5105, 0.4895],
        [0.5082, 0.4918],
        [0.5105, 0.4895],
        [0.5120, 0.4880]], grad_fn=<SoftmaxBackward0>)