# Conv Mode Test

In [1]:
import torch
import torch.nn as nn
import numpy as np
from models.eeg_AE import eeg_encoder as Encoder

from models.eeg_AE import eeg_decoder as Decoder

In [2]:
test =torch.randn(1, 440, 128)
en = Encoder(in_seq              = 440,
            in_channels          = 128,
            out_seq             = 768,
            dims               = [64, 128, 256, 512, 1024],
            shortcut           = True,
            dropout           = 0.5,
            groups              = 32,
            layer_mode          = 'conv',
            block_mode          = 'res',
            down_mode           = 'max',
            pos_mode           = 'trunc',
            n_layer           = 2,
            n_head              = 64,
            dff_factor        = 2,
            stride            = 4,
            skip_mode          = "conv")

out, skips = en(test)

ModuleList(
  (0): Conv1dLayer(
    (conv): Conv1d(64, 64, kernel_size=(1,), stride=(1,), bias=False)
  )
  (1): Conv1dLayer(
    (conv): Conv1d(128, 128, kernel_size=(1,), stride=(1,), bias=False)
  )
  (2): Conv1dLayer(
    (conv): Conv1d(256, 256, kernel_size=(1,), stride=(1,), bias=False)
  )
  (3): Conv1dLayer(
    (conv): Conv1d(512, 512, kernel_size=(1,), stride=(1,), bias=False)
  )
  (4): Conv1dLayer(
    (conv): Conv1d(1024, 1024, kernel_size=(1,), stride=(1,), bias=False)
  )
)
None
inputs :  torch.Size([1, 440, 128])
After In Layer : torch.Size([1, 64, 128])
BLOCK 0 after block0 : torch.Size([1, 64, 128])
BLOCK 0 after block1 : torch.Size([1, 64, 128])
BLOCK 0 after attn : torch.Size([1, 64, 128])
BLOCK 0 after down block : torch.Size([1, 64, 64])
BLOCK 1 after block0 : torch.Size([1, 128, 64])
BLOCK 1 after block1 : torch.Size([1, 128, 64])
BLOCK 1 after attn : torch.Size([1, 128, 64])
BLOCK 1 after down block : torch.Size([1, 128, 16])
BLOCK 2 after block0 : torch.Size([1

In [3]:
out.shape

torch.Size([1, 768, 1])

In [4]:
for i in skips:
    print(i.shape)

torch.Size([1, 1024, 1])
torch.Size([1, 512, 4])
torch.Size([1, 256, 16])
torch.Size([1, 128, 64])
torch.Size([1, 64, 128])


In [5]:
de = Decoder(in_seq              = 768,
            in_channels          = 1,
            out_seq             = 440,
            dims               = [64, 128, 256, 512, 1024],
            shortcut           = True,
            dropout           = 0.5,
            groups              = 32,
            layer_mode          = 'conv',
            block_mode          = 'res',
            up_mode           = 'trans',
            pos_mode           = 'trunc',
            n_layer           = 2,
            n_head              = 64,
            dff_factor        = 2,
            stride            = 4,
            skip_mode          = "conv")
outs = de(out, skips)

inputs :  torch.Size([1, 768, 1])
After In Layer : torch.Size([1, 1024, 1])
BLOCK 0 after concat : torch.Size([1, 2048, 1])
BLOCK 0 after block0 : torch.Size([1, 1024, 1])
BLOCK 0 after block1 : torch.Size([1, 1024, 1])
BLOCK 0 after attn : torch.Size([1, 1024, 1])
BLOCK 0 after up block : torch.Size([1, 512, 4])
BLOCK 1 after concat : torch.Size([1, 1024, 4])
BLOCK 1 after block0 : torch.Size([1, 512, 4])
BLOCK 1 after block1 : torch.Size([1, 512, 4])
BLOCK 1 after attn : torch.Size([1, 512, 4])
BLOCK 1 after up block : torch.Size([1, 256, 16])
BLOCK 2 after concat : torch.Size([1, 512, 16])
BLOCK 2 after block0 : torch.Size([1, 256, 16])
BLOCK 2 after block1 : torch.Size([1, 256, 16])
BLOCK 2 after attn : torch.Size([1, 256, 16])
BLOCK 2 after up block : torch.Size([1, 128, 64])
BLOCK 3 after concat : torch.Size([1, 256, 64])
BLOCK 3 after block0 : torch.Size([1, 128, 64])
BLOCK 3 after block1 : torch.Size([1, 128, 64])
BLOCK 3 after attn : torch.Size([1, 128, 64])
BLOCK 3 after up b

In [6]:
outs.shape

torch.Size([1, 440, 128])

# Down Mode Test

In [7]:
test =torch.randn(1, 440, 128)
en = Encoder(in_seq              = 440,
            in_channels             = 128,
            out_seq             = 768,
            dims               = [64, 128, 256, 512, 1024],
            shortcut           = True,
            dropout           = 0.5,
            groups              = 32,
            layer_mode          = 'conv',
            block_mode          = 'res',
            down_mode           = 'max',
            pos_mode           = 'sinusoidal',
            n_layer           = 2,
            n_head              = 64,
            dff_factor        = 2,
            stride            = 4,
            skip_mode          = "down")

out, skips = en(test)

ModuleList(
  (0): Conv1dLayer(
    (conv): Conv1d(64, 768, kernel_size=(1,), stride=(1,), bias=False)
  )
  (1): Conv1dLayer(
    (conv): Conv1d(128, 768, kernel_size=(1,), stride=(1,), bias=False)
  )
  (2): Conv1dLayer(
    (conv): Conv1d(256, 768, kernel_size=(1,), stride=(1,), bias=False)
  )
  (3): Conv1dLayer(
    (conv): Conv1d(512, 768, kernel_size=(1,), stride=(1,), bias=False)
  )
  (4): Conv1dLayer(
    (conv): Conv1d(1024, 768, kernel_size=(1,), stride=(1,), bias=False)
  )
)
LinearLayer(
  (linear): Linear(in_features=214, out_features=1, bias=False)
)
inputs :  torch.Size([1, 440, 128])
After In Layer : torch.Size([1, 64, 128])
BLOCK 0 after block0 : torch.Size([1, 64, 128])
BLOCK 0 after block1 : torch.Size([1, 64, 128])
BLOCK 0 after attn : torch.Size([1, 64, 128])
BLOCK 0 after down block : torch.Size([1, 64, 64])
BLOCK 1 after block0 : torch.Size([1, 128, 64])
BLOCK 1 after block1 : torch.Size([1, 128, 64])
BLOCK 1 after attn : torch.Size([1, 128, 64])
BLOCK 1 after 

In [8]:
out.shape

torch.Size([1, 768, 1])

In [9]:
de = Decoder(in_seq              = 768,
            in_channels          = 128,
            out_seq             = 440,
            dims               = [64, 128, 256, 512, 1024],
            shortcut           = True,
            dropout           = 0.5,
            groups              = 32,
            layer_mode          = 'conv',
            block_mode          = 'res',
            up_mode           = 'trans',
            pos_mode           = 'trunc',
            n_layer           = 2,
            n_head              = 64,
            dff_factor        = 2,
            stride            = 4,
            skip_mode          = "down")
outs = de(out)

inputs :  torch.Size([1, 768, 1])
After In Layer : torch.Size([1, 1024, 1])
BLOCK 0 after concat : torch.Size([1, 1024, 1])
BLOCK 0 after block0 : torch.Size([1, 1024, 1])
BLOCK 0 after block1 : torch.Size([1, 1024, 1])
BLOCK 0 after attn : torch.Size([1, 1024, 1])
BLOCK 0 after up block : torch.Size([1, 1024, 4])
BLOCK 1 after concat : torch.Size([1, 1024, 4])
BLOCK 1 after block0 : torch.Size([1, 512, 4])
BLOCK 1 after block1 : torch.Size([1, 512, 4])
BLOCK 1 after attn : torch.Size([1, 512, 4])
BLOCK 1 after up block : torch.Size([1, 512, 16])
BLOCK 2 after concat : torch.Size([1, 512, 16])
BLOCK 2 after block0 : torch.Size([1, 256, 16])
BLOCK 2 after block1 : torch.Size([1, 256, 16])
BLOCK 2 after attn : torch.Size([1, 256, 16])
BLOCK 2 after up block : torch.Size([1, 256, 64])
BLOCK 3 after concat : torch.Size([1, 256, 64])
BLOCK 3 after block0 : torch.Size([1, 128, 64])
BLOCK 3 after block1 : torch.Size([1, 128, 64])
BLOCK 3 after attn : torch.Size([1, 128, 64])
BLOCK 3 after up 

# None Mode Test

In [10]:
test =torch.randn(1, 440, 128)
en = Encoder(in_seq              = 440,
            in_channels             = 128,
            out_seq             = 768,
            dims               = [64, 128, 256, 512, 1024],
            shortcut           = True,
            dropout           = 0.5,
            groups              = 32,
            layer_mode          = 'conv',
            block_mode          = 'res',
            down_mode           = 'max',
            pos_mode           = 'trunc',
            n_layer           = 2,
            n_head              = 64,
            dff_factor        = 2,
            stride            = 4,
            skip_mode          = None)

out, skips = en(test)

None
None
inputs :  torch.Size([1, 440, 128])
After In Layer : torch.Size([1, 64, 128])
BLOCK 0 after block0 : torch.Size([1, 64, 128])
BLOCK 0 after block1 : torch.Size([1, 64, 128])
BLOCK 0 after attn : torch.Size([1, 64, 128])
BLOCK 0 after down block : torch.Size([1, 64, 64])
BLOCK 1 after block0 : torch.Size([1, 128, 64])
BLOCK 1 after block1 : torch.Size([1, 128, 64])
BLOCK 1 after attn : torch.Size([1, 128, 64])
BLOCK 1 after down block : torch.Size([1, 128, 16])
BLOCK 2 after block0 : torch.Size([1, 256, 16])
BLOCK 2 after block1 : torch.Size([1, 256, 16])
BLOCK 2 after attn : torch.Size([1, 256, 16])
BLOCK 2 after down block : torch.Size([1, 256, 4])
BLOCK 3 after block0 : torch.Size([1, 512, 4])
BLOCK 3 after block1 : torch.Size([1, 512, 4])
BLOCK 3 after attn : torch.Size([1, 512, 4])
BLOCK 3 after down block : torch.Size([1, 512, 1])
BLOCK 4 after block0 : torch.Size([1, 1024, 1])
BLOCK 4 after block1 : torch.Size([1, 1024, 1])
BLOCK 4 after attn : torch.Size([1, 1024, 1])


In [11]:
out.shape

torch.Size([1, 768, 1])

In [12]:
np.shape(skips)

(0,)

In [13]:
de = Decoder(in_seq              = 768,
            in_channels          = 1,
            out_seq             = 440,
            dims               = [64, 128, 256, 512, 1024],
            shortcut           = True,
            dropout           = 0.5,
            groups              = 32,
            layer_mode          = 'conv',
            block_mode          = 'res',
            up_mode           = 'trans',
            pos_mode           = 'sinusoidal',
            n_layer           = 2,
            n_head              = 64,
            dff_factor        = 2,
            stride            = 4,
            skip_mode          = None)
outs = de(out)

inputs :  torch.Size([1, 768, 1])
After In Layer : torch.Size([1, 1024, 1])
BLOCK 0 after concat : torch.Size([1, 1024, 1])
BLOCK 0 after block0 : torch.Size([1, 1024, 1])
BLOCK 0 after block1 : torch.Size([1, 1024, 1])
BLOCK 0 after attn : torch.Size([1, 1024, 1])
BLOCK 0 after up block : torch.Size([1, 1024, 4])
BLOCK 1 after concat : torch.Size([1, 1024, 4])
BLOCK 1 after block0 : torch.Size([1, 512, 4])
BLOCK 1 after block1 : torch.Size([1, 512, 4])
BLOCK 1 after attn : torch.Size([1, 512, 4])
BLOCK 1 after up block : torch.Size([1, 512, 16])
BLOCK 2 after concat : torch.Size([1, 512, 16])
BLOCK 2 after block0 : torch.Size([1, 256, 16])
BLOCK 2 after block1 : torch.Size([1, 256, 16])
BLOCK 2 after attn : torch.Size([1, 256, 16])
BLOCK 2 after up block : torch.Size([1, 256, 64])
BLOCK 3 after concat : torch.Size([1, 256, 64])
BLOCK 3 after block0 : torch.Size([1, 128, 64])
BLOCK 3 after block1 : torch.Size([1, 128, 64])
BLOCK 3 after attn : torch.Size([1, 128, 64])
BLOCK 3 after up 