In [1]:
import torch
import torch.nn as nn
import torchaudio

import math

import mne

import matplotlib.pyplot as plt

import numpy as np

from utils.dataset import CustomDataset

In [2]:
device = torch.device("cuda")

In [3]:
# pulled from Dr. Karpathy's minGPT implementation
class GELU(nn.Module):
    """
    Implementation of the GELU activation function currently in Google BERT repo (identical to OpenAI GPT).
    Reference: Gaussian Error Linear Units (GELU) paper: https://arxiv.org/abs/1606.08415
    """
    def forward(self, x):
        return 0.5 * x * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0))))

In [4]:
class BrainDecoderBlock(nn.Module):
    def __init__(self, k, input_dims=320, skip=True):
        super().__init__()

        self.skip = skip

        self.conv1 = nn.Conv1d(input_dims, 320, kernel_size=3, dilation=2**((2*k)%5), padding="same")
        self.conv2 = nn.Conv1d(320, 320, kernel_size=3, dilation=2**((2*k+1)%5), padding="same")
        self.conv3 = nn.Conv1d(320, 640, kernel_size=3, dilation=2, padding="same")

        self.bnorm1 = nn.BatchNorm1d(320)
        self.bnorm2 = nn.BatchNorm1d(320)

        self.gelu = GELU()

        # channel dim
        self.glu = nn.GLU(dim=1)

    def forward(self, x):
        output = self.conv1(x)
        output = self.bnorm1(output)
        output = self.gelu(output)

        if self.skip:
            # channel dim res connection
            output = output + x

            skip = output

        output = self.conv2(output)
        output = self.bnorm2(output)
        output = self.gelu(output)

        if self.skip:
            output = output + skip

        output = self.conv3(output)
        output = self.glu(output)

        return output

In [5]:
class SpatialAttention(nn.Module):
    def __init__(self, in_channels, out_channels, num_harmonics):
        super().__init__()
        # position preprocessing
        easycap_montage = mne.channels.read_custom_montage("./data/umich/electrode_positions.sfp")

        info = mne.create_info([str(i+1) for i in range(in_channels)], sfreq=500, ch_types="eeg")
        info.set_montage(easycap_montage, on_missing="ignore")

        layout = mne.channels.find_layout(info)
        two_dim_pos = layout.pos[:, :2]

        # normalize 0-1
        two_dim_pos[:, 0] -= min(two_dim_pos[:, 0])
        two_dim_pos[:, 1] -= min(two_dim_pos[:, 1])

        two_dim_pos[:, 0] /= max(two_dim_pos[:, 0])
        two_dim_pos[:, 1] /= max(two_dim_pos[:, 1])

        self.input_channels = torch.tensor(two_dim_pos)
        
        # spatial attention calculation params

        self.z_trainable = torch.randn(out_channels, num_harmonics, num_harmonics, dtype=torch.cfloat)
        self.z_trainable = torch.nn.parameter.Parameter(torch.transpose(self.z_trainable.view(1, out_channels, num_harmonics, num_harmonics).repeat(in_channels, 1, 1, 1), 0, 1))

        self.k = self.l = torch.linspace(1, num_harmonics, num_harmonics).repeat(in_channels, num_harmonics, 1)

        # other stuff
        self.conv1 = nn.Conv1d(in_channels, out_channels, kernel_size=1)

        self.softmax = nn.Softmax()
    
    def _apply(self, fn):
        super(SpatialAttention, self)._apply(fn)

        self.k = fn(self.k)
        self.l = fn(self.l)
        self.input_channels = fn(self.input_channels)

        return self

    def forward(self, x):
        # x dims - batch_size, C, T

        term_1 = torch.transpose(torch.cos(2 * np.pi * (self.k * self.input_channels[:, 0].view(-1, 1, 1) + torch.transpose(self.l, 1, -1) * self.input_channels[:, 1].view(1, -1, 1, 1))), 2, 3)
        term_2 = torch.transpose(torch.sin(2 * np.pi * (self.k * self.input_channels[:, 0].view(-1, 1, 1) + torch.transpose(self.l, 1, -1) * self.input_channels[:, 1].view(1, -1, 1, 1))), 2, 3)

        a_j = torch.sum(self.z_trainable.real * term_1 + self.z_trainable.imag * term_2, dim=(1, 2, 3)).repeat(x.shape[0], 1).view(x.shape[0], -1, 1)

        output = self.conv1(x)

        output = a_j * output

        return output.type(torch.float32)

In [6]:
class BrainDecoder(nn.Module):
    def __init__(self, input_channels, num_k, num_freq_bands):
        super().__init__()

        self.spatial_attention = SpatialAttention(input_channels, 270, 32)

        self.conv1 = nn.Conv1d(270, 270, kernel_size=1)
        self.subject_layer = nn.Conv1d(270, 270, kernel_size=1)

        self.decoder_blocks = []

        for i in range(num_k):
            if i == 0:
                self.decoder_blocks += [BrainDecoderBlock(i+1, 270, False)]
            else:
                self.decoder_blocks += [BrainDecoderBlock(i+1, 320, True)]

        self.decoder_blocks = nn.ModuleList(self.decoder_blocks)

        self.conv2 = nn.Conv1d(320, 640, kernel_size=1)
        self.final_conv = nn.Conv1d(640, num_freq_bands, kernel_size=1)

    def forward(self, x):
        output = self.spatial_attention(x)

        output = self.conv1(output)
        output = self.subject_layer(output)

        for block in self.decoder_blocks:
            output = block(output)
        
        output = self.conv2(output)
        output = self.final_conv(output)

        return output

In [7]:
C = 62
F = 768
T_out = 149

brain_decoder = BrainDecoder(input_channels=C, num_k=5, num_freq_bands=F)
brain_decoder = brain_decoder.to(device)

# batch_size, C, T
test_data = torch.randn((32, C, T_out))

# expected output dims: batch_size, F, T_out
output = brain_decoder(test_data.to(device))
output.shape

  info.set_montage(easycap_montage, on_missing="ignore")


torch.Size([32, 768, 149])

In [8]:
bundle = torchaudio.pipelines.WAV2VEC2_ASR_BASE_960H
wave2vec = bundle.get_model().to(device)

In [9]:
batch_size = 16

dataset = CustomDataset(subject_path="./data/umich/S01.mat", audio_dir="./data/umich/audio/", T_out=T_out)

train_set, val_set = torch.utils.data.random_split(dataset, [0.8, 0.2])

train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=True)
val_loader = torch.utils.data.DataLoader(val_set, batch_size=batch_size, shuffle=True)

Brain data shape: torch.Size([62, 244, 149])
Waveform shape: torch.Size([241, 48000])
Audio sampling rate (Hz): 16000
Brain data sampling rate (Hz): 500


In [10]:
sample = next(iter(train_loader))

with torch.inference_mode():
    emission, _ = wave2vec(sample[1][0].view(1, -1).to(device).type(torch.float32))

outputs = torch.argmax(emission[0], 1)
outputs = torch.unique_consecutive(outputs, dim=0)

labels = bundle.get_labels()

output = ""

for sample in outputs:
    output += labels[sample]

output = output.replace("-", "")

print(output)

ELF|FOR|THIS|CURIOUS|CHILD|WAS|FOND|


In [11]:
cross_entropy = nn.CrossEntropyLoss()

def CLIP(brain_latents, audio_latents, temperature):
    # dims = batch_size, frequency_dim, temporal_dim
    brain_latents = brain_latents.reshape((batch_size, -1)) # [batch_size, frequency_dim * temporal_dim]
    # dims = batch_size, frequency_dim, temporal_dim
    audio_latents = audio_latents.reshape((batch_size, -1)) # [batch_size, frequency_dim * temporal_dim]

    logits = brain_latents @ audio_latents.T
    logits_T = logits.T

    labels = torch.arange(batch_size).to(device)

    loss = (cross_entropy(logits, labels) + cross_entropy(logits_T, labels)) / 2
    
    return loss

In [12]:
optimizer = torch.optim.Adam(brain_decoder.parameters(), lr=3e-4)

In [17]:
EPOCHS = 50

N_loss = 10

wave2vec.eval()

for epoch in range(EPOCHS):
    # training loop
    brain_decoder.train()
    for (brain_data, audio_data) in train_loader:
        optimizer.zero_grad()

        brain_data = brain_data.to(device).type(torch.float32)
        audio_data = audio_data.to(device).type(torch.float32)

        # wave2vec processing
        with torch.inference_mode():
            features, _ = wave2vec.extract_features(audio_data)

        # pull from 10th layer
        semantic_features = features[9]     # dims -> batch_size, T, F
        semantic_features = torch.transpose(semantic_features, 1, 2)

        # brain decoder processing
        brain_output = brain_decoder(brain_data)

        # propogate gradients
        loss = CLIP(brain_output, semantic_features, 1.0)
        loss.backward()

        optimizer.step()
    
    # set to eval mode
    brain_decoder.eval()

    # train loss calculation loop
    train_losses = []
    with torch.no_grad():
        for i, (brain_data, audio_data) in enumerate(train_loader):
            brain_data = brain_data.to(device).type(torch.float32)
            audio_data = audio_data.to(device).type(torch.float32)

            # wave2vec processing
            with torch.inference_mode():
                features, _ = wave2vec.extract_features(audio_data)

            # pull from 10th layer
            semantic_features = features[9]     # dims -> batch_size, T, F
            semantic_features = torch.transpose(semantic_features, 1, 2)

            # brain decoder processing
            brain_output = brain_decoder(brain_data)

            # propogate gradients
            loss = CLIP(brain_output, semantic_features, 1.0)

            train_losses += [loss.item()]

            if i == N_loss-1:
                print("Train loss: " + str(np.mean(np.array(train_losses))))
                break
    
    # val loss calculation loop
    val_losses = []
    with torch.no_grad():
        for (brain_data, audio_data) in val_loader:
            brain_data = brain_data.to(device).type(torch.float32)
            audio_data = audio_data.to(device).type(torch.float32)

            # wave2vec processing
            with torch.inference_mode():
                features, _ = wave2vec.extract_features(audio_data)

            # pull from 10th layer
            semantic_features = features[9]     # dims -> batch_size, T, F
            semantic_features = torch.transpose(semantic_features, 1, 2)

            # brain decoder processing
            brain_output = brain_decoder(brain_data)

            # propogate gradients
            loss = CLIP(brain_output, semantic_features, 1.0)

            val_losses += [loss.item()]
    
    print("Val loss: " + str(np.mean(np.array(val_losses))))

Train loss: 0.07031268363633654
Val loss: 27.080707550048828
Train loss: 0.09968605803840092
Val loss: 29.909603118896484
Train loss: 0.09830459285692239
Val loss: 29.824817021687824
Train loss: 0.15619338023335558


KeyboardInterrupt: 