In [1]:
from BcomMEG import *
import mne
import numpy as np
import matplotlib.pyplot as plt
import os

In [6]:
import torch.nn as nn
class BigTripletNet(nn.Module): #this one is designed to take in all the sensors as one big tensor
    def __init__(self):
        super(BigTripletNet, self).__init__()
        
        #Conv Blocks
        self.conv1 = nn.Conv2d(in_channels=247, out_channels=16, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(16)
        self.act1 = nn.GELU()
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2) #(16, 24, 40)

        self.conv2 = nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, padding=1) 
        self.bn2 = nn.BatchNorm2d(32)
        self.act2 = nn.GELU()
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2) #(32, 12, 20)

        self.conv3 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, padding=1)
        self.bn3 = nn.BatchNorm2d(64)
        self.act3 = nn.GELU()
        self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2) #64, 6, 10

        self.conv4 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, padding=1)
        self.bn4 = nn.BatchNorm2d(128)
        self.act4 = nn.GELU()
        self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2) #128, 3, 5

        #FC layers
        self.fc1 = nn.Linear(1920, 512)
        self.fc2 = nn.Linear(512, 128)


    def forward(self, x):
        x = self.pool1(self.act1(self.bn1(self.conv1(x))))
        x = self.pool2(self.act2(self.bn2(self.conv2(x))))
        x = self.pool3(self.act3(self.bn3(self.conv3(x))))
        x = self.pool4(self.act4(self.bn4(self.conv4(x))))

        x = x.view(x.size(0), -1)
        x = self.fc1(x)
        x = self.fc2(x)
        print(x.shape)

        return x


In [7]:
model = BigTripletNet()

In [None]:
dir = 'Data_Sample'
subject = ['BCOM_18_2']
avoid_reading=False
picks=None
frequencies = np.array(range(2, 100, 2))
divisor = 7
baseline = (None, 0)


data = BcomMEG(subjects=subject, 
               dir=dir, 
               picks=picks, 
               avoid_reading=False)



Reading /Users/ciprianbangu/Cogmaster/M2 Internship/BCI code/Data_Sample/BCOM_18_2_ma_32-epo.fif ...
    Found the data of interest:
        t =    -300.00 ...     500.00 ms
        0 CTF compensation matrices available
Not setting metadata
18 matching events found
No baseline correction applied
0 projection items activated
Reading /Users/ciprianbangu/Cogmaster/M2 Internship/BCI code/Data_Sample/BCOM_18_2_i_16-epo.fif ...
    Found the data of interest:
        t =    -300.00 ...     500.00 ms
        0 CTF compensation matrices available
Not setting metadata
12 matching events found
No baseline correction applied
0 projection items activated
Reading /Users/ciprianbangu/Cogmaster/M2 Internship/BCI code/Data_Sample/BCOM_18_2_me_34-epo.fif ...
    Found the data of interest:
        t =    -300.00 ...     500.00 ms
        0 CTF compensation matrices available
Not setting metadata
12 matching events found
No baseline correction applied
0 projection items activated
Reading /Users/ciprianb

In [None]:
data = data.get_spectrogram(frequencies=frequencies, 
                            baseline=baseline, 
                            cycle_divisor=divisor,
                            mode='zscore',
                            data_only=True)
data.get_syllable_counts()

{'BCOM_18_2': {'ma_32': 18,
  'i_16': 12,
  'me_34': 12,
  'si_56': 6,
  're_44': 8,
  'li_26': 14,
  'ti_66': 6,
  'ra_42': 14,
  'ta_62': 11,
  'le_24': 7,
  'ri_46': 13,
  'la_22': 15,
  'te_64': 10,
  'e_14': 12,
  'sa_52': 13,
  'se_54': 5,
  'mi_36': 10,
  'a_12': 17}}

In [37]:
# Get the total number of epochs
total_epochs = sum(len(data.data['BCOM_18_2'][syllable]) for syllable in data.data['BCOM_18_2'])

# Get the shape of a single epoch
sample_epoch = next(iter(data.data['BCOM_18_2'].values()))[0]
epoch_shape = sample_epoch.shape

# Preallocate the tensor array
tensor = np.empty((total_epochs, *epoch_shape))

# Fill the tensor array
index = 0
for syllable in data.data['BCOM_18_2']:
    for epoch in data.data['BCOM_18_2'][syllable]:
        tensor[index] = epoch
        index += 1

In [51]:
x = ()
x = sample_epoch.shape
x

(247, 49, 81)

In [47]:
type(epoch.shape)

tuple

In [38]:
tensor.shape

(203, 247, 49, 81)

In [4]:
tensor = data.data_to_tensor()

In [6]:
tensor.shape

(203, 247, 49, 81)

In [11]:
model.to(device='mps')

BigTripletNet(
  (conv1): Conv2d(247, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (bn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (act1): GELU(approximate='none')
  (pool1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv2): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (bn2): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (act2): GELU(approximate='none')
  (pool2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv3): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (bn3): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (act3): GELU(approximate='none')
  (pool3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv4): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (bn4): BatchNorm2d(128, eps=1e-05, momentum

In [1]:
# Add an extra dimension to the data tensor
data = data.unsqueeze(0)  # Add an extra dimension to the data tensor
data.shape

NameError: name 'data' is not defined

In [102]:
model(data)

torch.Size([1, 128])


tensor([[ 0.5340, -0.1571, -0.4861,  0.4363,  0.0455,  0.3874, -0.1402, -0.2919,
         -0.2654, -0.3802, -0.1191, -0.3442,  0.8937,  0.7089, -0.6383,  0.1493,
         -0.1862,  0.0138, -0.2385,  0.5055, -0.2366,  0.1213, -0.5600, -1.1846,
          0.2480, -0.5383,  0.1487,  0.1989, -0.1804, -0.0408,  0.0185,  0.4276,
          0.2146,  0.6647, -0.5078, -0.4014, -0.0204,  0.2452, -0.2439,  0.0123,
         -0.3380,  0.1554,  0.1999, -0.2314,  0.3361,  0.3895,  0.4610, -0.6637,
          0.7090, -0.2900,  0.5108,  0.1444, -0.2961, -0.1916, -1.0463, -0.0291,
         -0.1955, -0.7884, -0.0267,  0.2512,  0.0257,  0.4350,  0.0825,  0.8260,
          0.7408, -0.1746, -0.1979, -0.2789,  0.1306, -0.0055,  0.3905,  0.6522,
         -0.1001,  0.2987, -0.1067, -0.0682, -0.1999,  0.3542,  0.3543, -0.1683,
         -0.0475,  0.3658, -0.1471,  0.4864,  1.3819,  0.3530,  0.5918,  0.3983,
         -0.0049,  0.2658, -0.2764, -0.3923,  0.1244, -0.4531, -0.6760,  0.8823,
          0.0422, -0.1372, -