In [1]:
from pathlib import Path

import torch
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt
from sklearn.metrics import roc_curve, auc, roc_auc_score
from scipy import interpolate

import project_path
import nn.utils as utils
import nn.model.net as net
import nn.model.data as data
import nn.model.loss as loss

In [2]:
import torch.nn as nn
import torch.nn.functional as F

In [33]:
class ConstituentNetGraph(nn.Module):
    """
        ConstituentNet with Fourier Mixer.
    """

    def __init__(self, in_dim:int=16, embbed_dim:int=32, num_classes:int=5, dropout:float=0.) -> None:
        super(ConstituentNetGraph, self).__init__()
        self.input_size = in_dim
        self.channel_in = in_dim
        self.embbed_dim = embbed_dim # C
        
        self.inp_layer = nn.Linear(in_dim, embbed_dim)
        self.graph_assign = nn.Sequential(
            nn.Linear(in_dim, 64), # number of graph nodes
            nn.Softmax(dim=-1)
        )
        
        self.convs = nn.Sequential(
            nn.BatchNorm1d(num_features=32),
            nn.Conv1d(32, 64, kernel_size=2, stride=2, bias=False),  # 32
            nn.ReLU(),
            nn.BatchNorm1d(num_features=64),
            nn.Conv1d(64, 128, kernel_size=2, stride=2, bias=False), # 16
            nn.ReLU(),
            nn.BatchNorm1d(num_features=128),
            nn.Conv1d(128, 128, kernel_size=2, stride=2, bias=False), # 8
            nn.ReLU(),
            nn.BatchNorm1d(num_features=128),
            nn.Conv1d(128, 128, kernel_size=2, stride=2, bias=False), # 4
            nn.ReLU(),
            nn.BatchNorm1d(num_features=128),
            nn.Conv1d(128, 256, kernel_size=2, stride=2, bias=False), # 2
            nn.ReLU(),
        )
        
        self.out_layer = nn.Sequential(
            nn.Linear(256 * 2, num_classes)
        )


    def forward(self, x):

        m_batch, seq_len, _ = x.size()      # (m_batch, seq_len=100, input_dim=16)
        
        # Graph node assignment
        graph_assign = self.graph_assign(x) # (m_batch, seq_len=100, num_nodes=64)

        # Input layer
        out = self.inp_layer(x)             # (m_batch, seq_len, C)
        out = out.transpose(1,2)            # (m_batch, C, seq_len)
        out = torch.bmm(out, graph_assign)  # (m_batch, C, num_nodes)
                        
        # Conv layers
        out = self.convs(out)                # (m_batch, C, 2)
        out = self.out_layer(out.view(m_batch, -1))

        return F.log_softmax(out, dim=-1)

In [None]:
>>> m = nn.Conv1d(16, 33, 3, stride=2)
>>> input = torch.randn(20, 16, 50)
>>> output = m(input)

In [4]:
2 * 2 * 2 * 2 * 2 * 2

64

In [34]:
model = ConstituentNetGraph()
# input_seq = torch.rand(2,100,16)
# model(input_seq)

In [35]:
utils.count_parameters(model)

351264