In [6]:
import torch
import torch.nn as nn
import torchaudio

import math

import mne

import matplotlib.pyplot as plt

import numpy as np

from tqdm import tqdm

from ignite.metrics import TopKCategoricalAccuracy
from ignite.engine import Engine

from segment_prediction_dataset import CustomDataset

In [7]:
device = torch.device("cuda")

In [8]:
# pulled from Dr. Karpathy's minGPT implementation
class GELU(nn.Module):
    """
    Implementation of the GELU activation function currently in Google BERT repo (identical to OpenAI GPT).
    Reference: Gaussian Error Linear Units (GELU) paper: https://arxiv.org/abs/1606.08415
    """
    def forward(self, x):
        return 0.5 * x * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0))))

In [9]:
class BrainDecoderBlock(nn.Module):
    def __init__(self, k, input_dims=320, skip=True):
        super().__init__()

        self.skip = skip

        self.conv1 = nn.Conv1d(input_dims, 320, kernel_size=3, dilation=2**((2*k)%5), padding="same")
        self.conv2 = nn.Conv1d(320, 320, kernel_size=3, dilation=2**((2*k+1)%5), padding="same")
        self.conv3 = nn.Conv1d(320, 640, kernel_size=3, dilation=2, padding="same")

        self.bnorm1 = nn.BatchNorm1d(320)
        self.bnorm2 = nn.BatchNorm1d(320)

        self.gelu = GELU()

        # channel dim
        self.glu = nn.GLU(dim=1)

    def forward(self, x):
        output = self.conv1(x)
        output = self.bnorm1(output)
        output = self.gelu(output)

        if self.skip:
            # channel dim res connection
            output = output + x

            skip = output

        output = self.conv2(output)
        output = self.bnorm2(output)
        output = self.gelu(output)

        if self.skip:
            output = output + skip

        output = self.conv3(output)
        output = self.glu(output)

        return output

In [10]:
temperatures = []

class SpatialAttention(nn.Module):
    def __init__(self, in_channels, out_channels, num_harmonics, dropout=0.1):
        super().__init__()
        # position preprocessing
        easycap_montage = mne.channels.read_custom_montage("../data/umich/electrode_positions.sfp")

        info = mne.create_info([str(i+1) for i in range(in_channels)], sfreq=500, ch_types="eeg")
        info.set_montage(easycap_montage, on_missing="ignore")

        layout = mne.channels.find_layout(info)
        two_dim_pos = layout.pos[:, :2]

        # normalize 0-1
        two_dim_pos[:, 0] -= min(two_dim_pos[:, 0])
        two_dim_pos[:, 1] -= min(two_dim_pos[:, 1])

        two_dim_pos[:, 0] /= max(two_dim_pos[:, 0])
        two_dim_pos[:, 1] /= max(two_dim_pos[:, 1])

        self.input_channels = torch.tensor(two_dim_pos)
        
        # spatial attention calculation params

        self.z_trainable = torch.randn((out_channels, num_harmonics, num_harmonics), dtype=torch.cfloat)
        self.z_trainable = torch.nn.parameter.Parameter(torch.transpose(self.z_trainable.view(1, out_channels, num_harmonics, num_harmonics).repeat(in_channels, 1, 1, 1), 0, 1))

        self.k = self.l = torch.linspace(1, num_harmonics, num_harmonics).repeat(in_channels, num_harmonics, 1)

        # other stuff
        self.conv1 = nn.Conv1d(in_channels, out_channels, kernel_size=1)

        self.softmax = nn.Softmax(dim=0)

        self.dropout = nn.Dropout(dropout)
    
    def _apply(self, fn):
        super(SpatialAttention, self)._apply(fn)

        self.k = fn(self.k)
        self.l = fn(self.l)
        self.input_channels = fn(self.input_channels)

        return self

    def forward(self, x):
        # x dims - batch_size, C, T

        term_1 = torch.transpose(torch.cos(2 * np.pi * (self.k * self.input_channels[:, 0].view(-1, 1, 1) + torch.transpose(self.l, 1, -1) * self.input_channels[:, 1].view(1, -1, 1, 1))), 2, 3)
        term_2 = torch.transpose(torch.sin(2 * np.pi * (self.k * self.input_channels[:, 0].view(-1, 1, 1) + torch.transpose(self.l, 1, -1) * self.input_channels[:, 1].view(1, -1, 1, 1))), 2, 3)

        a_j = torch.sum(self.z_trainable.real * term_1 + self.z_trainable.imag * term_2, dim=(1, 2, 3)).repeat(x.shape[0], 1).view(x.shape[0], -1, 1)

        output = self.conv1(x)

        output = self.dropout(self.softmax(a_j)) * output

        return output.type(torch.float32)

In [7]:
class BrainDecoder(nn.Module):
    def __init__(self, input_channels, num_k, num_freq_bands, num_subjects):
        super().__init__()

        self.spatial_attention = SpatialAttention(input_channels, 270, 32, 0.1)

        self.conv1 = nn.Conv1d(270, 270, kernel_size=1)

        self.subject_layers = []
        for i in range(num_subjects):
            self.subject_layers += [nn.Conv1d(270, 270, kernel_size=1)]
        self.subject_layers = nn.ModuleList(self.subject_layers)

        self.decoder_blocks = []

        for i in range(num_k):
            if i == 0:
                self.decoder_blocks += [BrainDecoderBlock(i+1, 270, False)]
            else:
                self.decoder_blocks += [BrainDecoderBlock(i+1, 320, True)]

        self.decoder_blocks = nn.ModuleList(self.decoder_blocks)

        self.conv2 = nn.Conv1d(320, 640, kernel_size=1)
        self.final_conv = nn.Conv1d(640, num_freq_bands, kernel_size=1)

    def forward(self, x, subject_num):
        output = self.spatial_attention(x)

        output = self.conv1(output)

        final_output = torch.zeros_like(output)

        for _, i in enumerate(subject_num):
            final_output[_] = self.subject_layers[int(i)](output[_])

        for block in self.decoder_blocks:
            final_output = block(final_output)
        
        final_output = self.conv2(final_output)
        final_output = self.final_conv(final_output)

        return final_output

In [8]:
num_subjects = 49
T_out = 49

In [9]:
bundle = torchaudio.pipelines.WAV2VEC2_XLSR53
wave2vec = bundle.get_model().to(device)

Downloading: "https://download.pytorch.org/torchaudio/models/wav2vec2_fairseq_large_xlsr53.pth" to /home/kashyap/.cache/torch/hub/checkpoints/wav2vec2_fairseq_large_xlsr53.pth
100%|██████████| 1.18G/1.18G [00:23<00:00, 52.8MB/s]


In [10]:
batch_size = 1

exclude = [2, 7, 9, 23, 24, 27, 28, 29, 30, 31, 32, 33, 43, 46, 47, 49]

dataset = CustomDataset(data_dir="../data/umich", T_out=T_out, num_subjects=num_subjects, exclude=exclude)

train_set = torch.utils.data.Subset(dataset, range(int(len(dataset) * 0.8)))
val_set = torch.utils.data.Subset(dataset, range(int(len(dataset) * 0.8), int(len(dataset))))

train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=True)
val_loader = torch.utils.data.DataLoader(val_set, batch_size=batch_size, shuffle=True)

100%|██████████| 49/49 [02:13<00:00,  2.72s/it]


Brain data shape: torch.Size([7986, 61, 49])
Waveform shape: torch.Size([23859, 16000])
Subject num shape: torch.Size([7986])
Audio sampling rate (Hz): 16000
Brain data sampling rate (Hz): 500


In [10]:
class CLIP(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.temperature = torch.nn.parameter.Parameter(torch.randn(1, dtype=torch.float32))

        self.cross_entropy = nn.CrossEntropyLoss()

    def forward(self, brain_latents, audio_latents, save_temp=False):
        # dims = batch_size, frequency_dim, temporal_dim
        brain_latents = brain_latents.reshape((brain_latents.shape[0], -1)) # [batch_size, frequency_dim * temporal_dim]
        # dims = batch_size, frequency_dim, temporal_dim
        audio_latents = audio_latents.reshape((audio_latents.shape[0], -1)) # [batch_size, frequency_dim * temporal_dim]

        logits = brain_latents @ audio_latents.T
        logits_T = logits.T

        labels = torch.arange(brain_latents.shape[0]).to(device)

        # exponentiate and clip to max 100
        exponentiated_temperature = torch.exp(self.temperature)
        if exponentiated_temperature > 100: exponentiated_temperature = 100

        loss = ((self.cross_entropy(logits, labels) + self.cross_entropy(logits_T, labels)) / 2) / exponentiated_temperature
        
        return loss

In [53]:
C = 61
F = 768

# brain_decoder = torch.load("../saved_models/june_6_run_1_brain_decoder")
brain_decoder = BrainDecoder(input_channels=C, num_k=5, num_freq_bands=F, num_subjects=num_subjects)
brain_decoder.to(device)

# clip = torch.load("../saved_models/june_6_run_2_clip")
# clip.to(device)

# batch_size, C, T
test_data = torch.randn((32, C, T_out))

# expected output dims: batch_size, F, T_out
output = brain_decoder(test_data.to(device), torch.zeros((32)))
output.shape

  info.set_montage(easycap_montage, on_missing="ignore")


torch.Size([32, 768, 149])

In [54]:
# regression loss
regression_loss = nn.MSELoss()

In [35]:
targets = []

for (brain_data, audio_data, subject_num) in tqdm(val_loader):
    audio_data = audio_data.to(device).type(torch.float32)

    # wave2vec processing
    with torch.inference_mode():
        features, _ = wave2vec.extract_features(audio_data)

    # pull from last 4 layers
    semantic_features = (features[11] + features[10] + features[9] + features[8]) / 4    # dims -> batch_size, T, F
    semantic_features = torch.transpose(semantic_features, 1, 2)

    targets += [semantic_features]

  0%|          | 0/1591 [00:00<?, ?it/s]

100%|██████████| 1591/1591 [00:08<00:00, 177.39it/s]


In [55]:
def process_function(engine, batch):
    y_pred, y = batch
    
    return y_pred, y

def one_hot_to_output_transform(output):
    y_pred, y = output
    y = torch.argmax(y, dim=1)  # one-hot vector to label index vector

    return y_pred, y

engine = Engine(process_function)
metric = TopKCategoricalAccuracy(k=10, output_transform=one_hot_to_output_transform)
metric.attach(engine, 'top_k_accuracy')

In [56]:
brain_decoder.eval()
# clip.eval()

accuracies = []

for _, (brain_data, audio_data, subject_num) in tqdm(enumerate(val_loader)):
    brain_data = brain_data.to(device).type(torch.float32)

    # brain decoder processing
    brain_output = brain_decoder(brain_data, subject_num)

    losses = None
    for target in targets:
        if losses == None:
            losses = regression_loss(brain_output, target)
        else:
            losses = torch.vstack((losses, regression_loss(brain_output, target)))

    output_index = torch.argmin(losses)

    predicted = torch.zeros((1, len(targets)))
    predicted[0, output_index] = 1

    target_outputs = torch.zeros((1, len(val_loader)))
    target_outputs[0, _] = 1

    state = engine.run([[predicted, target_outputs]])
    accuracies += [state.metrics['top_k_accuracy']]

print("Accuracy: " + str(np.mean(accuracies) * 100))

1591it [02:35, 10.22it/s]

Accuracy: 0.6285355122564426





In [57]:
print("Accuracy: " + str(np.mean(accuracies) * 100))

Accuracy: 0.6285355122564426
