In [None]:
class ChannelAttentionHead(nn.Module):
    def __init__(self, n_embed, dropout_rate):
        super().__init__()
        self.key = nn.Linear(1,  n_embed, bias=False)
        self.query = nn.Linear(1,  n_embed, bias=False)
        self.value = nn.Linear(n_embed,  n_embed, bias=False)
        self.dropout = nn.Dropout(dropout_rate)


    def forward(self, x):
        Xmean = x.mean(-1, keepdims=True)
        k = self.key(Xmean)
        q = self.query(Xmean)
        v = self.value(x.transpose(-2, -1))

        weights = q @ k.transpose(-2, -1) * (k.shape[-1] ** -0.5)
        weights = self.dropout(weights)
        out = v @ weights
        return out.transpose(-2, -1)

In [None]:
class Conv3dBatched(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, groups, dropout_rate):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv3d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, groups=groups, bias=False),
            nn.BatchNorm3d(num_features=out_channels, momentum=0.01, eps=0.001, track_running_stats=False),
            nn.ELU(),
            nn.Dropout(dropout_rate),

        )

    def forward(self, x):
        return self.net(x)

class FeatureExtractor3d(nn.Module):
    def __init__(
            self,
            unfold=256,
            unfold_step=128,
            dropout_rate=0.0,
            n_hidden=64,
    ):
        super().__init__()
        self.unfold = unfold
        self.unfold_step = unfold_step
        self.net = nn.Sequential(
            Conv3dBatched(in_channels=1, out_channels=8, kernel_size=(8, 4, 2), groups=1, dropout_rate=dropout_rate),
            Conv3dBatched(in_channels=8, out_channels=16, kernel_size=(8, 4, 2), groups=8, dropout_rate=dropout_rate),
            Conv3dBatched(in_channels=16, out_channels=32, kernel_size=(8, 2, 2), groups=16, dropout_rate=dropout_rate),
            nn.Flatten(start_dim=2),
            nn.Linear(in_features=253, out_features=n_hidden//32),
            nn.Flatten(start_dim=1),
            nn.ELU()
        )

    def forward(self, x): #output 32
        out = x.unfold(-1, self.unfold, self.unfold_step)
        out = out.unsqueeze(1)
        return self.net(out)

In [None]:
class PositionalEncoding(nn.Module):

    def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)

        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        pe = torch.zeros(max_len, 1, d_model)
        pe[:, 0, 0::2] = torch.sin(position * div_term)
        pe[:, 0, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x.transpose(-2, -1)
        x = x.reshape(x.shape[1], x.shape[0], x.shape[2])
        """
        Arguments:
            x: Tensor, shape ``[seq_len, batch_size, embedding_dim]``
        """
        x = x + self.pe[:x.size(0)]
        x = x.reshape(x.shape[1], x.shape[2], x.shape[0])
        return self.dropout(x)

class FeatureExtractorPosition(nn.Module):
    def __init__(
            self,
            n_channels,
            kernel_length,
            F1,
            D,
            F2,
            pool1_stride,
            pool2_stride,
            dropout_rate,
    ):
        super().__init__()
        self.net = nn.Sequential(
            ViewConv(),
            nn.Conv2d(in_channels=n_channels, out_channels=F1, kernel_size=(1, kernel_length), bias=False,
                      padding='same'),
            nn.BatchNorm2d(num_features=F1, momentum=0.01, eps=0.001, track_running_stats=False),
            DepthWiseConv2d(in_channels=F1, kernel_size=(n_channels, 1), kernels_per_layer=D, bias=False),
            nn.BatchNorm2d(num_features=F1 * D, momentum=0.01, eps=0.001, track_running_stats=False),
            nn.ELU(),
            nn.AvgPool2d(kernel_size=(1, pool1_stride), stride=pool1_stride),
            nn.Dropout(dropout_rate),
            SeparableConv2d(in_channels=F1 * D, kernel_size=(1, 16), out_channels=F2, bias=False),
            nn.BatchNorm2d(num_features=F2, momentum=0.01, eps=0.001, track_running_stats=False),
            nn.ELU(),
            nn.AvgPool2d(kernel_size=(1, pool2_stride), stride=pool2_stride),
            nn.Dropout(dropout_rate),
            nn.Flatten(start_dim=2),
            PositionalEncoding(d_model=F2, dropout=dropout_rate),
            nn.Flatten(),
        )

    def forward(self, x): #output 32
        return self.net(x)

In [None]:
class FeatureExtractionUnfold(nn.Module):
    def __init__(
            self,
            n_channels,
            kernel_length,
            F1,
            D,
            F2,
            pool1_stride,
            pool2_stride,
            dropout_rate,
            unfold,
    ):
        super().__init__()
        self.unfold = unfold
        self.net = nn.Sequential(
            # ViewConv(),
            nn.Conv2d(in_channels=n_channels, out_channels=F1, kernel_size=(1, kernel_length), bias=False,
                      padding='same'),
            nn.BatchNorm2d(num_features=F1, momentum=0.01, eps=0.001, track_running_stats=False),
            DepthWiseConv2d(in_channels=F1, kernel_size=(n_channels, 1), kernels_per_layer=D, bias=False),
            nn.BatchNorm2d(num_features=F1 * D, momentum=0.01, eps=0.001, track_running_stats=False),
            nn.ELU(),
            nn.AvgPool2d(kernel_size=(1, pool1_stride), stride=pool1_stride),
            nn.Dropout(dropout_rate),
            SeparableConv2d(in_channels=F1 * D, kernel_size=(1, 16), out_channels=F2, bias=False),
            nn.BatchNorm2d(num_features=F2, momentum=0.01, eps=0.001, track_running_stats=False),
            nn.ELU(),
            nn.AvgPool2d(kernel_size=(1, pool2_stride), stride=pool2_stride),
            nn.Dropout(dropout_rate),
            nn.Flatten(),
        )

    def forward(self, x): #output unfold/2
        out = x.unfold(-1, self.unfold, self.unfold//2)
        return self.net(out)



In [None]:
import math

import torch
from torch import nn
import torch.nn.functional as F
import torch
from x_transformers import TransformerWrapper, Encoder, ViTransformerWrapper


class DepthWiseConv2d(nn.Module):
    def __init__(self, in_channels, kernel_size, kernels_per_layer, bias=False):
        super().__init__()
        self.depthwise = nn.Conv2d(in_channels=in_channels, out_channels=in_channels * kernels_per_layer,
                                   kernel_size=kernel_size, groups=in_channels, bias=bias, padding='same')

    def forward(self, x):
        return self.depthwise(x)


class PointWiseConv2d(nn.Module):
    def __init__(self, in_channels, out_channels, kernels_per_layer=1, bias=False):
        super().__init__()
        self.pointwise = nn.Conv2d(in_channels=in_channels * kernels_per_layer, out_channels=out_channels,
                                   kernel_size=(1, 1), bias=bias, padding="valid")

    def forward(self, x):
        return self.pointwise(x)


class MaxNormLayer(nn.Linear):
    def __init__(self, in_features, out_features, max_norm=1.0):
        super(MaxNormLayer, self).__init__(in_features=in_features, out_features=out_features)
        self.max_norm = max_norm

    def forward(self, x):
        if self.max_norm is not None:
            with torch.no_grad():
                self.weight.data = torch.renorm(
                    self.weight.data, p=2, dim=0, maxnorm=self.max_norm
                )
        return super(MaxNormLayer, self).forward(x)


class SeparableConv2d(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, kernels_per_layer=1, bias=False):
        super().__init__()
        self.depthwise = DepthWiseConv2d(in_channels=in_channels, kernels_per_layer=kernels_per_layer,
                                         kernel_size=kernel_size, bias=bias)
        self.pointwise = PointWiseConv2d(in_channels=in_channels, out_channels=out_channels,
                                         kernels_per_layer=kernels_per_layer, bias=bias)

    def forward(self, x):
        x = self.depthwise(x)
        x = self.pointwise(x)
        return x


class ViewConv(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        return x.view((x.shape[0], x.shape[1], 1, x.shape[2]))



class FeatureExtraction(nn.Module):
    def __init__(
            self,
            n_channels,
            kernel_length,
            F1,
            D,
            F2,
            pool1_stride,
            pool2_stride,
    ):
        super().__init__()
        self.net = nn.Sequential(
            ViewConv(),
            nn.Conv2d(in_channels=n_channels, out_channels=F1, kernel_size=(1, kernel_length), bias=False,
                      padding='same'),
            nn.BatchNorm2d(num_features=F1, momentum=0.01, eps=0.001, track_running_stats=False),
            DepthWiseConv2d(in_channels=F1, kernel_size=(n_channels, 1), kernels_per_layer=D, bias=False),
            nn.BatchNorm2d(num_features=F1 * D, momentum=0.01, eps=0.001, track_running_stats=False),
            nn.ELU(),
            nn.AvgPool2d(kernel_size=(1, pool1_stride), stride=pool1_stride),
            SeparableConv2d(in_channels=F1 * D, kernel_size=(1, 16), out_channels=F2, bias=False),
            nn.BatchNorm2d(num_features=F2, momentum=0.01, eps=0.001, track_running_stats=False),
            nn.ELU(),
            nn.AvgPool2d(kernel_size=(1, pool2_stride), stride=pool2_stride),
            nn.Flatten(),
        )

    def forward(self, x):
        return self.net(x)
