In [20]:
import torch
import torch.nn as nn
import numpy as np
import constants

In [2]:
b, inner, outer, feat = 32, 5, 29, 128
input = torch.randn(b, inner, outer, feat)

In [3]:
input_plus = input.view(b*inner, outer, feat)

In [4]:
torch.sum(input[1,1,:,:] - input_plus[6,:,:])

tensor(0.)

In [5]:
input_plus_min = input_plus.view(b, inner, outer, feat)
torch.sum(input - input_plus_min)

tensor(0.)

In [14]:
encoder = nn.TransformerEncoderLayer(d_model=feat,
                                     nhead=8,
                                     dim_feedforward=1024,
                                     batch_first=True)
transformer_output = encoder(input_plus)
print(transformer_output.shape)

torch.Size([160, 29, 128])


In [10]:
class Aggregator(nn.Module):
    """
        This class performs the final attention to reduce the dimensionality after the transformer encoder layer from [b+ x inner x feat] to [b+ x feat]
        It does this by using attention which results in a weighted sum
    """
    def __init__(self, feat_dim, hidden_dim=None):
        """
            hidden_dim is the hidden dimension used by attention
            feat_dim is the feature dimension
        """
        super().__init__()
        if hidden_dim is None:
            hidden_dim = feat_dim
        self.Wa = torch.randn((hidden_dim, feat_dim), requires_grad=True)
        self.ba = torch.zeros(hidden_dim, requires_grad=True)
        self.ae = torch.randn((hidden_dim,1), requires_grad=True)
        self.tanh = nn.Tanh()
        self.softmax = nn.Softmax(dim=-1)
    def forward(self, x):
        b, l, f = x.shape
        ats = self.tanh(x @ self.Wa.T + self.ba.repeat(l,1).unsqueeze(0).repeat(b,1,1))  # maybe this can faster
        alphas = self.softmax(ats @ self.ae)
        result = torch.bmm(x.transpose(1,2), alphas).squeeze(2)
        return result



In [15]:
agg = Aggregator(feat_dim=feat)
agg(transformer_output).shape

torch.Size([160, 128])

In [46]:
A, F = 9, 10
Wa = torch.randn((A, F), requires_grad=True)
ba = torch.zeros(A, requires_grad=True)
ae = torch.randn((A,1), requires_grad=True)
tanh = nn.Tanh()
softmax = nn.Softmax(dim=-1)

b = 5  # batch size
l = 3  # inner dimension
batch = torch.randn(b, l, F)

step1 = batch @ Wa.T
print(step1.shape)

torch.Size([5, 3, 9])


In [47]:
# Now add the bias row-wise
row = torch.arange(A)
rows = row.repeat(l,1).unsqueeze(0).repeat(b,1,1)
print(rows.shape)

torch.Size([5, 3, 9])


In [48]:
step2 = step1 + ba.repeat(l,1).unsqueeze(0).repeat(b,1,1)
print(step2.shape)

torch.Size([5, 3, 9])


In [49]:
step3 = tanh(step2)
step4 = step3 @ ae
print(step4.shape)

torch.Size([5, 3, 1])


In [50]:
step5 = softmax(step4)
print(step5.shape)

torch.Size([5, 3, 1])


In [53]:
# Now it remains to make the appropriate sum with the original matrix
step6 = torch.bmm(batch.transpose(1,2), step5).squeeze(2)
print(step6.shape)

torch.Size([5, 10])


torch.Size([5, 10])

In [16]:
class InnerTransformer(nn.Module):
    """
        This module takes as input tensors of the form [batch x outer x inner x feat] and outputs tensors of the form [batch x outer x feat]
        For the SleepTransformer, inner x feat is a time frequency image of a 1D EEG time-series (inner is time, feat is freq dimension)
    """
    def __init__(self, feat_dim, dim_feedforward, num_heads, num_layers):
        super().__init__()
        self.dim_feedforward = dim_feedforward  # Size of hidden dimension used in MLP within transformerencoder layer
        self.transformer = nn.TransformerEncoder(
            encoder_layer=nn.TransformerEncoderLayer(
                d_model=feat_dim,
                nhead=num_heads,
                dim_feedforward=dim_feedforward,
                batch_first=True
            ),
            num_layers=num_layers)
        self.aggregator = Aggregator(feat_dim=feat_dim)

    def forward(self, x):
        batch_dim, outer_dim, inner_dim, feat_dim = x.shape
        batch_plus = x.view(batch_dim*outer_dim, inner_dim, feat_dim)  # reshape before putting through transformer
        transformed_plus = self.transformer(batch_plus)
        aggregrated_plus = self.aggregator(transformed_plus)
        return aggregrated_plus.view(batch_dim, outer_dim, feat_dim)




In [17]:
b, inner, outer, feat = 32, 5, 29, 128
input = torch.randn(b, inner, outer, feat)
inner_transformer = InnerTransformer(feat_dim=feat,
                                     dim_feedforward=1024,
                                     num_heads=8,
                                     num_layers=4)
output = inner_transformer(input)
print(output.shape)

torch.Size([32, 5, 128])


In [18]:
class OuterTransformer(nn.Module):
    """
        This module takes as input tensors of the form [batch x outer x feat] and outputs a tensors of
        the form [batch x outer x feat] that can be used as input to a classifier
    """
    def __init__(self, feat_dim, dim_feedforward, num_heads, num_layers):
        super().__init__()
        self.transformer = nn.TransformerEncoder(
            encoder_layer=nn.TransformerEncoderLayer(
                d_model=feat_dim,
                nhead=num_heads,
                dim_feedforward=dim_feedforward,
                batch_first=True
            ),
            num_layers=num_layers
        )

    def forward(self, x):
        return self.transformer(x)

In [19]:
outer_transformer = OuterTransformer(feat_dim=feat,
                                     dim_feedforward=1024,
                                     num_heads=8,
                                     num_layers=4)
outer_output = outer_transformer(output)
print(outer_output.shape)

torch.Size([32, 5, 128])


In [None]:
class Classifier(nn.Module):
    """
        This module takes as input tensors of the form [batch x outer x feat] and outputs tensors of
        the form
    """
    def __init__(self, feat_dim, hidden_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(
                in_features=feat_dim,
                out_features=hidden_dim),
            nn.ReLU(),
            nn.Linear(
                in_features=hidden_dim,
                out_features=constants.N_CLASSES
            )
        )

    def forward(self, x):
        return self.net(x)


In [None]:
# Classifier
