In [2]:
import torch
import torch.nn as nn

import math

In [3]:
# pulled from Dr. Karpathy's minGPT implementation
class GELU(nn.Module):
    """
    Implementation of the GELU activation function currently in Google BERT repo (identical to OpenAI GPT).
    Reference: Gaussian Error Linear Units (GELU) paper: https://arxiv.org/abs/1606.08415
    """
    def forward(self, x):
        return 0.5 * x * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0))))

In [22]:
class BrainDecoderBlock(nn.Module):
    def __init__(self, k, input_dims=320, skip=True):
        super().__init__()

        self.skip = skip

        self.conv1 = nn.Conv1d(input_dims, 320, kernel_size=3, dilation=2**((2*k)%5), padding="same")
        self.conv2 = nn.Conv1d(320, 320, kernel_size=3, dilation=2**((2*k+1)%5), padding="same")
        self.conv3 = nn.Conv1d(320, 640, kernel_size=3, dilation=2, padding="same")

        self.bnorm1 = nn.BatchNorm1d(320)
        self.bnorm2 = nn.BatchNorm1d(320)

        self.gelu = GELU()

        # channel dim
        self.glu = nn.GLU(dim=1)

    def forward(self, x):
        output = self.conv1(x)
        output = self.bnorm1(output)
        output = self.gelu(output)

        if self.skip:
            # channel dim res connection
            output = output + x

            skip = output

        output = self.conv2(output)
        output = self.bnorm2(output)
        output = self.gelu(output)

        if self.skip:
            output = output + skip

        output = self.conv3(output)
        output = self.glu(output)

        return output

In [25]:
block = BrainDecoderBlock(1, input_dims=270, skip=False)

# batch_size, C, T
tesst_data = torch.randn((32, 270, 3600))

output = block(test_data)

print(output.shape)

torch.Size([32, 320, 3600])


In [79]:
class SpatialAttention(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()

        self.conv1 = nn.Conv1d(2, 1, kernel_size=3, padding="same")
        self.conv2 = nn.Conv1d(in_channels, out_channels, kernel_size=1)

        self.bnorm1 = nn.BatchNorm1d(1)

        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        # x dims - batch_size, C, T

        # pool across channel dim
        avg_pool = torch.unsqueeze(torch.mean(x, dim=1), dim=1)
        max_pool = torch.unsqueeze(torch.max(x, dim=1).values, dim=1)

        mask = torch.cat((avg_pool, max_pool), dim=1)

        mask = self.conv1(mask)
        mask = self.bnorm1(mask)
        mask = self.sigmoid(mask)

        # broadcasting multiplication operation
        output = mask * x

        output = self.conv2(output)

        return output

In [80]:
spatial_attention = SpatialAttention(61, 270)

# batch_size, C, T
test_data = torch.randn((32, 61, 3600))

output = spatial_attention(test_data)
output.shape

torch.Size([32, 270, 3600])

In [81]:
class BrainDecoder(nn.Module):
    def __init__(self, input_channels, num_k, num_freq_bands):
        super().__init__()

        self.spatial_attention = SpatialAttention(input_channels, 270)

        self.conv1 = nn.Conv1d(270, 270, kernel_size=1)
        self.subject_layer = nn.Conv1d(270, 270, kernel_size=1)

        self.decoder_blocks = []

        for i in range(num_k):
            if i == 0:
                self.decoder_blocks += [BrainDecoderBlock(i+1, 270, False)]
            else:
                self.decoder_blocks += [BrainDecoderBlock(i+1, 320, True)]

        self.decoder_blocks = nn.ModuleList(self.decoder_blocks)

        self.conv2 = nn.Conv1d(320, 640, kernel_size=1)
        self.final_conv = nn.Conv1d(640, num_freq_bands, kernel_size=1)

    def forward(self, x):
        output = self.spatial_attention(x)

        output = self.conv1(output)
        output = self.subject_layer(output)

        for block in self.decoder_blocks:
            output = block(output)
        
        output = self.conv2(output)
        output = self.final_conv(output)

        return output

In [82]:
C = 61
F = 100
T = 3600

brain_decoder = BrainDecoder(input_channels=C, num_k=5, num_freq_bands=F)

# batch_size, C, T
test_data = torch.randn((32, C, T))

# expected output dims: batch_size, F, T
output = brain_decoder(test_data)
output.shape

torch.Size([32, 100, 3600])