In [53]:
import torch
import torch.nn as nn
import numpy as np
import constants
import math
import pytorch_lightning as pl

In [54]:
b, inner, outer, feat = 32, 5, 29, 128
input = torch.randn(b, inner, outer, feat)

In [55]:
input_plus = input.view(b*inner, outer, feat)

In [56]:
torch.sum(input[1,1,:,:] - input_plus[6,:,:])

tensor(0.)

In [57]:
input_plus_min = input_plus.view(b, inner, outer, feat)
torch.sum(input - input_plus_min)

tensor(0.)

In [58]:
encoder = nn.TransformerEncoderLayer(d_model=feat,
                                     nhead=8,
                                     dim_feedforward=1024,
                                     batch_first=True)
transformer_output = encoder(input_plus)
print(transformer_output.shape)

torch.Size([160, 29, 128])


In [59]:
class Aggregator(nn.Module):
    """
        This class performs the final attention to reduce the dimensionality after the transformer encoder layer from [b+ x inner x feat] to [b+ x feat]
        It does this by using attention which results in a weighted sum
    """
    def __init__(self, feat_dim, hidden_dim=None):
        """
            hidden_dim is the hidden dimension used by attention
            feat_dim is the feature dimension
        """
        super().__init__()
        if hidden_dim is None:
            hidden_dim = feat_dim
        # self.Wa = torch.randn((hidden_dim, feat_dim), requires_grad=True)
        # self.ba = torch.zeros(hidden_dim, requires_grad=True)
        self.ae = nn.Parameter(torch.randn((hidden_dim,1), requires_grad=True))
        self.linear = nn.Linear(in_features=feat_dim,
                                out_features=hidden_dim)
        self.tanh = nn.Tanh()
        self.softmax = nn.Softmax(dim=-1)
    def forward(self, x):
        b, l, f = x.shape
        #ats = self.tanh(x @ self.Wa.T + self.ba.repeat(l,1).unsqueeze(0).repeat(b,1,1))  # maybe this can faster
        ats = self.tanh(self.linear(x))
        alphas = self.softmax(ats @ self.ae)
        result = torch.bmm(x.transpose(1,2), alphas).squeeze(2)
        return result



In [60]:
agg = Aggregator(feat_dim=feat)
agg(transformer_output).shape

torch.Size([160, 128])

In [61]:
A, F = 9, 10
Wa = torch.randn((A, F), requires_grad=True)
ba = torch.zeros(A, requires_grad=True)
ae = torch.randn((A,1), requires_grad=True)
tanh = nn.Tanh()
softmax = nn.Softmax(dim=-1)

b = 5  # batch size
l = 3  # inner dimension
batch = torch.randn(b, l, F)

step1 = batch @ Wa.T
print(step1.shape)

torch.Size([5, 3, 9])


In [62]:
# Now add the bias row-wise
row = torch.arange(A)
rows = row.repeat(l,1).unsqueeze(0).repeat(b,1,1)
print(rows.shape)

torch.Size([5, 3, 9])


In [63]:
step2 = step1 + ba.repeat(l,1).unsqueeze(0).repeat(b,1,1)
print(step2.shape)

torch.Size([5, 3, 9])


In [64]:
step3 = tanh(step2)
step4 = step3 @ ae
print(step4.shape)

torch.Size([5, 3, 1])


In [65]:
step5 = softmax(step4)
print(step5.shape)

torch.Size([5, 3, 1])


In [66]:
# Now it remains to make the appropriate sum with the original matrix
step6 = torch.bmm(batch.transpose(1,2), step5).squeeze(2)
print(step6.shape)

torch.Size([5, 10])


In [83]:
class InnerTransformer(nn.Module):
    """
        This module takes as input tensors of the form [batch x outer x inner x feat] and outputs tensors of the form [batch x outer x feat]
        For the SleepTransformer, inner x feat is a time frequency image of a 1D EEG time-series (inner is time, feat is freq dimension)
    """
    def __init__(self, inner_dim, feat_dim, dim_feedforward, num_heads, num_layers):
        super().__init__()
        self.dim_feedforward = dim_feedforward  # Size of hidden dimension used in MLP within transformerencoder layer
        self.transformer = nn.TransformerEncoder(
            encoder_layer=nn.TransformerEncoderLayer(
                d_model=feat_dim,
                nhead=num_heads,
                dim_feedforward=dim_feedforward,
                batch_first=True
            ),
            num_layers=num_layers)
        self.aggregator = Aggregator(feat_dim=feat_dim)
        self.inner_position_encoding = PositionalEncoding(
            sequence_length=inner_dim,
            hidden_size=feat_dim
        )

    def forward(self, x):
        batch_dim, outer_dim, inner_dim, feat_dim = x.shape
        batch_plus = x.view(batch_dim*outer_dim, inner_dim, feat_dim)  # reshape before putting through transformer
        batch_plus = self.inner_position_encoding(batch_plus)  # Add positional encoding
        transformed_plus = self.transformer(batch_plus)
        aggregrated_plus = self.aggregator(transformed_plus)
        return aggregrated_plus.view(batch_dim, outer_dim, feat_dim)




In [84]:
b, inner, outer, feat = 32, 5, 29, 128
input = torch.randn(b, inner, outer, feat)
inner_transformer = InnerTransformer(feat_dim=feat,
                                     dim_feedforward=1024,
                                     num_heads=8,
                                     num_layers=4)
output = inner_transformer(input)
print(output.shape)

TypeError: __init__() missing 1 required positional argument: 'inner_dim'

In [85]:
class OuterTransformer(nn.Module):
    """
        This module takes as input tensors of the form [batch x outer x feat] and outputs a tensors of
        the form [batch x outer x feat] that can be used as input to a classifier
    """
    def __init__(self, outer_dim, feat_dim, dim_feedforward, num_heads, num_layers):
        super().__init__()
        self.transformer = nn.TransformerEncoder(
            encoder_layer=nn.TransformerEncoderLayer(
                d_model=feat_dim,
                nhead=num_heads,
                dim_feedforward=dim_feedforward,
                batch_first=True
            ),
            num_layers=num_layers
        )
        self.outer_position_encoding = PositionalEncoding(
            sequence_length=outer_dim,
            hidden_size=feat_dim
        )

    def forward(self, x):
        return self.transformer(self.outer_position_encoding(x))

In [86]:
outer_transformer = OuterTransformer(feat_dim=feat,
                                     dim_feedforward=1024,
                                     num_heads=8,
                                     num_layers=4)
outer_output = outer_transformer(output)
print(outer_output.shape)

TypeError: __init__() missing 1 required positional argument: 'outer_dim'

In [87]:
class Classifier(nn.Module):
    """
        This module takes as input tensors of the form [batch x outer x feat] and outputs tensors of
        the form
    """
    def __init__(self, feat_dim, hidden_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(
                in_features=feat_dim,
                out_features=hidden_dim),
            nn.ReLU(),
            nn.Linear(
                in_features=hidden_dim,
                out_features=constants.N_CLASSES
            )
        )

    def forward(self, x):
        return self.net(x)


In [88]:
# Classifier
classifier = Classifier(feat_dim=feat,
                        hidden_dim=1024)
classifier_output = classifier(outer_output)
print(classifier_output.shape)

torch.Size([32, 5, 5])


In [89]:
class SleepTransformer(pl.LightningModule):
    """
        The SleepTransformer takes as input batches of time series of the form [batch x window x time] and outputs labels for each epoch: [batch x window x label]
    """
    def __init__(self, outer_dim, inner_dim, feat_dim, dim_feedforward, num_heads, num_layers):
        super().__init__()

        inner_transformer = InnerTransformer(feat_dim=feat_dim,
                                                  inner_dim=inner_dim,
                                             dim_feedforward=dim_feedforward,
                                             num_heads=num_heads,
                                             num_layers=num_layers)
        outer_transformer = OuterTransformer(feat_dim=feat_dim,
                                                  outer_dim=outer_dim,
                                             dim_feedforward=dim_feedforward,
                                             num_heads=num_heads,
                                             num_layers=num_layers)

        classifier = Classifier(feat_dim=feat_dim,
                                     hidden_dim=dim_feedforward)

        self.net = nn.Sequential(inner_transformer,
                                 outer_transformer,
                                 classifier)
        self.save_hyperparameters()

    def forward(self, x):
        return self.net(x)

    def training_step(self, batch, batch_idx):
        pass


    def validation_step(self, batch, batch_idx):
        pass

    def test_step(self, batch, batch_idx):
        pass

    def configure_optimizers(self):
        pass

In [92]:
b, outer, inner, feat = 6, 5, 29, 128
input_batch = torch.randn(b, outer, inner, feat)
sleep_transformer = SleepTransformer(
    outer_dim=outer,
    inner_dim=inner,
    feat_dim=feat,
    dim_feedforward=1024,
    num_heads=8,
    num_layers=4
)
output_batch = sleep_transformer(input_batch)
print(output_batch)

tensor([[[ 0.0572,  0.1763, -0.1210,  0.1096,  0.1139],
         [-0.0537, -0.0031, -0.1423,  0.1247,  0.1492],
         [ 0.1174,  0.0579, -0.3153,  0.1934,  0.1110],
         [ 0.0488,  0.0058, -0.2966,  0.2370,  0.0303],
         [ 0.0709,  0.0361, -0.0958,  0.2541,  0.0894]],

        [[ 0.0845,  0.0899, -0.2799,  0.1379, -0.0215],
         [-0.0017, -0.0315, -0.2627,  0.1713,  0.0042],
         [-0.0870,  0.1827, -0.3365,  0.0996,  0.0747],
         [ 0.0353,  0.0851, -0.2914,  0.1960,  0.1436],
         [ 0.0257,  0.0687, -0.2340,  0.1568,  0.0506]],

        [[-0.0827,  0.0610, -0.1491,  0.2429,  0.2998],
         [-0.0006,  0.2079, -0.0657,  0.1148,  0.1149],
         [-0.1578,  0.0996, -0.0410,  0.2496,  0.3133],
         [ 0.0598,  0.1035, -0.1630,  0.0992,  0.3429],
         [ 0.0202,  0.1283, -0.2296,  0.2161,  0.2460]],

        [[-0.0375,  0.0739, -0.1890,  0.2242,  0.1991],
         [-0.0170,  0.1516, -0.2030,  0.2181,  0.1715],
         [-0.0008,  0.0701, -0.2443,  0.18

In [74]:
# Test STFT
b, w, t = 12, 5, 3000
time_series = torch.randn(b, w, t)
time_series_plus = time_series.view(b*w, t)

# define STFT parameters
window_size = 4

batch_stft = torch.stft(time_series_plus,
                        win_length=256,
                        window=torch.hamming_window(256),
                        hop_length=23,
                        n_fft=56,
                        onesided=True
                        )
print(batch_stft.shape)

RuntimeError: stft(torch.FloatTensor[60, 3056], n_fft=56, hop_length=23, win_length=256, window=torch.FloatTensor{[256]}, normalized=0, onesided=1, return_complex=None) : expected 0 < win_length <= n_fft, but got win_length=256

In [75]:
encoding = torch.ones(10,5)
begin = torch.zeros(3,10,5)
begin+encoding

tensor([[[1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1.]],

        [[1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1.]],

        [[1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1.]]])

In [76]:
class PositionalEncoding(nn.Module):
    def __init__(self, sequence_length, hidden_size):
        super().__init__()
        self.sequence_length = sequence_length
        self.hidden_size = hidden_size
        self.encoding = self.get_positional_encoding()

    def forward(self, x):
        x = x + self.encoding
        return x

    def get_positional_encoding(self):
        # create a matrix of shape (sequence_length, hidden_size)
        position = torch.arange(self.sequence_length).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, self.hidden_size, 2) * -(math.log(10000.0) / self.hidden_size))
        sin = torch.sin(position * div_term)
        cos = torch.cos(position * div_term)
        encoding = torch.cat([sin, cos], dim=1)
        return encoding

In [50]:
pos_encoding = PositionalEncoding(sequence_length=5, hidden_size=10)
begin = torch.zeros(8,5,10)
pos_encoding(begin)

tensor([[[ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
           1.0000e+00,  1.0000e+00,  1.0000e+00,  1.0000e+00,  1.0000e+00],
         [ 8.4147e-01,  1.5783e-01,  2.5116e-02,  3.9811e-03,  6.3096e-04,
           5.4030e-01,  9.8747e-01,  9.9968e-01,  9.9999e-01,  1.0000e+00],
         [ 9.0930e-01,  3.1170e-01,  5.0217e-02,  7.9621e-03,  1.2619e-03,
          -4.1615e-01,  9.5018e-01,  9.9874e-01,  9.9997e-01,  1.0000e+00],
         [ 1.4112e-01,  4.5775e-01,  7.5285e-02,  1.1943e-02,  1.8929e-03,
          -9.8999e-01,  8.8908e-01,  9.9716e-01,  9.9993e-01,  1.0000e+00],
         [-7.5680e-01,  5.9234e-01,  1.0031e-01,  1.5924e-02,  2.5238e-03,
          -6.5364e-01,  8.0569e-01,  9.9496e-01,  9.9987e-01,  1.0000e+00]],

        [[ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
           1.0000e+00,  1.0000e+00,  1.0000e+00,  1.0000e+00,  1.0000e+00],
         [ 8.4147e-01,  1.5783e-01,  2.5116e-02,  3.9811e-03,  6.3096e-04,
           5.4030