In [None]:
from eeg2vec.train.train import train
from eeg2vec.data_loader import get_dataloader
from eeg2vec.models.eeg2vec import EEG2Vec
from eeg2vec.contrastive_loss import ContrastiveLoss

import numpy as np
import torch
from sklearn.metrics import f1_score, accuracy_score
from sklearn.model_selection import train_test_split
import xgboost as xgb
from sklearn.multioutput import MultiOutputClassifier

In [None]:
## First let's load the training data
from pathlib import Path
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
from scipy.signal import butter, lfilter
import pandas as pd
import pickle

ROOT_PATH = Path("train/")
training_data = [(np.load(ROOT_PATH / f"data_{i}.npy"),np.load(ROOT_PATH / f"target_{i}.npy")) for i in range(4)]


# Testing stuff

In [26]:
print(training_data[0][0].shape)

(5, 7712740)


In [27]:
def butter_bandpass(lowcut, highcut, fs, order=5):
    return butter(order, [lowcut, highcut], fs=fs, btype='band')

def butter_bandpass_filter(data, lowcut, highcut, fs, order=5):
    b, a = butter_bandpass(lowcut, highcut, fs, order=order)
    y = lfilter(b, a, data)
    return y

In [28]:
# First we need to get the point that maps to a label

def reshape_array_into_windows(x, sample_rate, window_duration_in_seconds):
    """
    Reshape the data into an array of shape (C, T, window) where 'window' contains
    the points corresponding to 'window_duration' seconds of data.

    Parameters:
    x (numpy array): The input data array.
    sample_rate (int): The number of samples per second.
    window_duration_in_seconds (float): The duration of each window in seconds.

    Returns:
    reshaped_x (numpy array): The reshaped array with shape (C, T, window).
    """
    # Calculate the number of samples in one window
    window_size = int(window_duration_in_seconds * sample_rate)
    
    # Ensure the total length of x is a multiple of window_size
    total_samples = x.shape[-1]
    if total_samples % window_size != 0:
        # Truncate or pad x to make it divisible by window_size
        x = x[..., :total_samples - (total_samples % window_size)]
    # Reshape x into (C, T, window)
    reshaped_x = x.reshape(x.shape[0], -1, window_size)

    return reshaped_x

In [48]:
# We first load and reshape all the data
all_data = []
all_targets = []
# We need to have
# data of Shape: [num_samples, num_channels (5), sequence_length]
# labels of Shape: [num_samples, 5]

for data, target in training_data:
    reshaped_data = reshape_array_into_windows(data, 250, 2)
    reshaped_data = reshaped_data.transpose(1, 0, 2)
    target = target.reshape(-1, 5)
    all_data.append(reshaped_data)
    all_targets.append(target)

all_data = np.concatenate(all_data, axis=0)
all_targets = np.concatenate(all_targets, axis=0)


In [49]:
print(all_data.shape)
print(all_targets.shape)

(52351, 5, 500)
(52351, 5)


In [83]:
data = all_data[1200:2000]
labels = all_targets[1200:2000]

In [None]:
# Split the data into training and test sets
X_train_full, X_test, y_train_full, y_test = train_test_split(data, labels, test_size=0.2, random_state=42)

# Further split training data for embeddings and XGBoost
X_train_embeddings, X_train_xgboost, y_train_embeddings, y_train_xgboost = train_test_split(X_train_full, y_train_full, test_size=0.5, random_state=42)


In [85]:
print(X_train_embeddings.shape, y_train_embeddings.shape)


(320, 5, 500) (320, 5)


In [86]:
data_loader = get_dataloader(X_train_embeddings, y_train_embeddings, batch_size=100, shuffle=True) 

In [None]:
model = EEG2Vec(16, 2, 3, 2)

EEG2Vec(
  (cnn_encoder): CNNEncoder(
    (conv_layers): Sequential(
      (0): Conv1d(5, 16, kernel_size=(2,), stride=(1,))
      (1): ReLU()
      (2): MaxPool1d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (3): Conv1d(16, 16, kernel_size=(2,), stride=(1,))
      (4): ReLU()
    )
  )
  (transformer_encoder): TransformerEncoder(
    (transformer_encoder): TransformerEncoder(
      (layers): ModuleList(
        (0): TransformerEncoderLayer(
          (self_attn): MultiheadAttention(
            (out_proj): NonDynamicallyQuantizableLinear(in_features=16, out_features=16, bias=True)
          )
          (linear1): Linear(in_features=16, out_features=2, bias=True)
          (dropout): Dropout(p=0.1, inplace=False)
          (linear2): Linear(in_features=2, out_features=16, bias=True)
          (norm1): LayerNorm((16,), eps=1e-05, elementwise_affine=True)
          (norm2): LayerNorm((16,), eps=1e-05, elementwise_affine=True)
          (dropout1): Dropout(p=0.1,

In [88]:
# Use GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [89]:
print(torch.__version__)

1.12.1+cu113


In [90]:
# use cuda if available
model = model.to(device)
train(model, data_loader, 100, device)

Epoch 1/100 completed.
Epoch 2/100 completed.
Epoch 3/100 completed.
Epoch 4/100 completed.
Epoch 5/100 completed.
Epoch 6/100 completed.
Epoch 7/100 completed.
Epoch 8/100 completed.
Epoch 9/100 completed.
Epoch 10/100 completed.
Epoch 11/100 completed.
Epoch 12/100 completed.
Epoch 13/100 completed.
Epoch 14/100 completed.
Epoch 15/100 completed.
Epoch 16/100 completed.
Epoch 17/100 completed.
Epoch 18/100 completed.
Epoch 19/100 completed.
Epoch 20/100 completed.
Epoch 21/100 completed.
Epoch 22/100 completed.
Epoch 23/100 completed.
Epoch 24/100 completed.
Epoch 25/100 completed.
Epoch 26/100 completed.
Epoch 27/100 completed.
Epoch 28/100 completed.
Epoch 29/100 completed.
Epoch 30/100 completed.
Epoch 31/100 completed.
Epoch 32/100 completed.
Epoch 33/100 completed.
Epoch 34/100 completed.
Epoch 35/100 completed.
Epoch 36/100 completed.
Epoch 37/100 completed.
Epoch 38/100 completed.
Epoch 39/100 completed.
Epoch 40/100 completed.
Epoch 41/100 completed.
Epoch 42/100 completed.
E

In [91]:
# Save the model
torch.save(model.state_dict(), "eeg2vec/data/saved_models/eeg2vec_2_smaller_400windows.pth")

In [None]:
model.load_state_dict(torch.load("eeg2vec/data/saved_models/eeg2vec_2_smaller_400windows.pth"))
model.eval()

RuntimeError: Error(s) in loading state_dict for EEG2Vec:
	Missing key(s) in state_dict: "transformer_encoder.transformer_encoder.layers.3.self_attn.in_proj_weight", "transformer_encoder.transformer_encoder.layers.3.self_attn.in_proj_bias", "transformer_encoder.transformer_encoder.layers.3.self_attn.out_proj.weight", "transformer_encoder.transformer_encoder.layers.3.self_attn.out_proj.bias", "transformer_encoder.transformer_encoder.layers.3.linear1.weight", "transformer_encoder.transformer_encoder.layers.3.linear1.bias", "transformer_encoder.transformer_encoder.layers.3.linear2.weight", "transformer_encoder.transformer_encoder.layers.3.linear2.bias", "transformer_encoder.transformer_encoder.layers.3.norm1.weight", "transformer_encoder.transformer_encoder.layers.3.norm1.bias", "transformer_encoder.transformer_encoder.layers.3.norm2.weight", "transformer_encoder.transformer_encoder.layers.3.norm2.bias", "transformer_encoder.transformer_encoder.layers.4.self_attn.in_proj_weight", "transformer_encoder.transformer_encoder.layers.4.self_attn.in_proj_bias", "transformer_encoder.transformer_encoder.layers.4.self_attn.out_proj.weight", "transformer_encoder.transformer_encoder.layers.4.self_attn.out_proj.bias", "transformer_encoder.transformer_encoder.layers.4.linear1.weight", "transformer_encoder.transformer_encoder.layers.4.linear1.bias", "transformer_encoder.transformer_encoder.layers.4.linear2.weight", "transformer_encoder.transformer_encoder.layers.4.linear2.bias", "transformer_encoder.transformer_encoder.layers.4.norm1.weight", "transformer_encoder.transformer_encoder.layers.4.norm1.bias", "transformer_encoder.transformer_encoder.layers.4.norm2.weight", "transformer_encoder.transformer_encoder.layers.4.norm2.bias", "transformer_encoder.transformer_encoder.layers.5.self_attn.in_proj_weight", "transformer_encoder.transformer_encoder.layers.5.self_attn.in_proj_bias", "transformer_encoder.transformer_encoder.layers.5.self_attn.out_proj.weight", "transformer_encoder.transformer_encoder.layers.5.self_attn.out_proj.bias", "transformer_encoder.transformer_encoder.layers.5.linear1.weight", "transformer_encoder.transformer_encoder.layers.5.linear1.bias", "transformer_encoder.transformer_encoder.layers.5.linear2.weight", "transformer_encoder.transformer_encoder.layers.5.linear2.bias", "transformer_encoder.transformer_encoder.layers.5.norm1.weight", "transformer_encoder.transformer_encoder.layers.5.norm1.bias", "transformer_encoder.transformer_encoder.layers.5.norm2.weight", "transformer_encoder.transformer_encoder.layers.5.norm2.bias", "transformer_encoder.transformer_encoder.layers.6.self_attn.in_proj_weight", "transformer_encoder.transformer_encoder.layers.6.self_attn.in_proj_bias", "transformer_encoder.transformer_encoder.layers.6.self_attn.out_proj.weight", "transformer_encoder.transformer_encoder.layers.6.self_attn.out_proj.bias", "transformer_encoder.transformer_encoder.layers.6.linear1.weight", "transformer_encoder.transformer_encoder.layers.6.linear1.bias", "transformer_encoder.transformer_encoder.layers.6.linear2.weight", "transformer_encoder.transformer_encoder.layers.6.linear2.bias", "transformer_encoder.transformer_encoder.layers.6.norm1.weight", "transformer_encoder.transformer_encoder.layers.6.norm1.bias", "transformer_encoder.transformer_encoder.layers.6.norm2.weight", "transformer_encoder.transformer_encoder.layers.6.norm2.bias", "transformer_encoder.transformer_encoder.layers.7.self_attn.in_proj_weight", "transformer_encoder.transformer_encoder.layers.7.self_attn.in_proj_bias", "transformer_encoder.transformer_encoder.layers.7.self_attn.out_proj.weight", "transformer_encoder.transformer_encoder.layers.7.self_attn.out_proj.bias", "transformer_encoder.transformer_encoder.layers.7.linear1.weight", "transformer_encoder.transformer_encoder.layers.7.linear1.bias", "transformer_encoder.transformer_encoder.layers.7.linear2.weight", "transformer_encoder.transformer_encoder.layers.7.linear2.bias", "transformer_encoder.transformer_encoder.layers.7.norm1.weight", "transformer_encoder.transformer_encoder.layers.7.norm1.bias", "transformer_encoder.transformer_encoder.layers.7.norm2.weight", "transformer_encoder.transformer_encoder.layers.7.norm2.bias", "transformer_encoder.transformer_encoder.layers.8.self_attn.in_proj_weight", "transformer_encoder.transformer_encoder.layers.8.self_attn.in_proj_bias", "transformer_encoder.transformer_encoder.layers.8.self_attn.out_proj.weight", "transformer_encoder.transformer_encoder.layers.8.self_attn.out_proj.bias", "transformer_encoder.transformer_encoder.layers.8.linear1.weight", "transformer_encoder.transformer_encoder.layers.8.linear1.bias", "transformer_encoder.transformer_encoder.layers.8.linear2.weight", "transformer_encoder.transformer_encoder.layers.8.linear2.bias", "transformer_encoder.transformer_encoder.layers.8.norm1.weight", "transformer_encoder.transformer_encoder.layers.8.norm1.bias", "transformer_encoder.transformer_encoder.layers.8.norm2.weight", "transformer_encoder.transformer_encoder.layers.8.norm2.bias", "transformer_encoder.transformer_encoder.layers.9.self_attn.in_proj_weight", "transformer_encoder.transformer_encoder.layers.9.self_attn.in_proj_bias", "transformer_encoder.transformer_encoder.layers.9.self_attn.out_proj.weight", "transformer_encoder.transformer_encoder.layers.9.self_attn.out_proj.bias", "transformer_encoder.transformer_encoder.layers.9.linear1.weight", "transformer_encoder.transformer_encoder.layers.9.linear1.bias", "transformer_encoder.transformer_encoder.layers.9.linear2.weight", "transformer_encoder.transformer_encoder.layers.9.linear2.bias", "transformer_encoder.transformer_encoder.layers.9.norm1.weight", "transformer_encoder.transformer_encoder.layers.9.norm1.bias", "transformer_encoder.transformer_encoder.layers.9.norm2.weight", "transformer_encoder.transformer_encoder.layers.9.norm2.bias", "transformer_encoder.transformer_encoder.layers.10.self_attn.in_proj_weight", "transformer_encoder.transformer_encoder.layers.10.self_attn.in_proj_bias", "transformer_encoder.transformer_encoder.layers.10.self_attn.out_proj.weight", "transformer_encoder.transformer_encoder.layers.10.self_attn.out_proj.bias", "transformer_encoder.transformer_encoder.layers.10.linear1.weight", "transformer_encoder.transformer_encoder.layers.10.linear1.bias", "transformer_encoder.transformer_encoder.layers.10.linear2.weight", "transformer_encoder.transformer_encoder.layers.10.linear2.bias", "transformer_encoder.transformer_encoder.layers.10.norm1.weight", "transformer_encoder.transformer_encoder.layers.10.norm1.bias", "transformer_encoder.transformer_encoder.layers.10.norm2.weight", "transformer_encoder.transformer_encoder.layers.10.norm2.bias", "transformer_encoder.transformer_encoder.layers.11.self_attn.in_proj_weight", "transformer_encoder.transformer_encoder.layers.11.self_attn.in_proj_bias", "transformer_encoder.transformer_encoder.layers.11.self_attn.out_proj.weight", "transformer_encoder.transformer_encoder.layers.11.self_attn.out_proj.bias", "transformer_encoder.transformer_encoder.layers.11.linear1.weight", "transformer_encoder.transformer_encoder.layers.11.linear1.bias", "transformer_encoder.transformer_encoder.layers.11.linear2.weight", "transformer_encoder.transformer_encoder.layers.11.linear2.bias", "transformer_encoder.transformer_encoder.layers.11.norm1.weight", "transformer_encoder.transformer_encoder.layers.11.norm1.bias", "transformer_encoder.transformer_encoder.layers.11.norm2.weight", "transformer_encoder.transformer_encoder.layers.11.norm2.bias", "transformer_encoder.transformer_encoder.layers.12.self_attn.in_proj_weight", "transformer_encoder.transformer_encoder.layers.12.self_attn.in_proj_bias", "transformer_encoder.transformer_encoder.layers.12.self_attn.out_proj.weight", "transformer_encoder.transformer_encoder.layers.12.self_attn.out_proj.bias", "transformer_encoder.transformer_encoder.layers.12.linear1.weight", "transformer_encoder.transformer_encoder.layers.12.linear1.bias", "transformer_encoder.transformer_encoder.layers.12.linear2.weight", "transformer_encoder.transformer_encoder.layers.12.linear2.bias", "transformer_encoder.transformer_encoder.layers.12.norm1.weight", "transformer_encoder.transformer_encoder.layers.12.norm1.bias", "transformer_encoder.transformer_encoder.layers.12.norm2.weight", "transformer_encoder.transformer_encoder.layers.12.norm2.bias", "transformer_encoder.transformer_encoder.layers.13.self_attn.in_proj_weight", "transformer_encoder.transformer_encoder.layers.13.self_attn.in_proj_bias", "transformer_encoder.transformer_encoder.layers.13.self_attn.out_proj.weight", "transformer_encoder.transformer_encoder.layers.13.self_attn.out_proj.bias", "transformer_encoder.transformer_encoder.layers.13.linear1.weight", "transformer_encoder.transformer_encoder.layers.13.linear1.bias", "transformer_encoder.transformer_encoder.layers.13.linear2.weight", "transformer_encoder.transformer_encoder.layers.13.linear2.bias", "transformer_encoder.transformer_encoder.layers.13.norm1.weight", "transformer_encoder.transformer_encoder.layers.13.norm1.bias", "transformer_encoder.transformer_encoder.layers.13.norm2.weight", "transformer_encoder.transformer_encoder.layers.13.norm2.bias", "transformer_encoder.transformer_encoder.layers.14.self_attn.in_proj_weight", "transformer_encoder.transformer_encoder.layers.14.self_attn.in_proj_bias", "transformer_encoder.transformer_encoder.layers.14.self_attn.out_proj.weight", "transformer_encoder.transformer_encoder.layers.14.self_attn.out_proj.bias", "transformer_encoder.transformer_encoder.layers.14.linear1.weight", "transformer_encoder.transformer_encoder.layers.14.linear1.bias", "transformer_encoder.transformer_encoder.layers.14.linear2.weight", "transformer_encoder.transformer_encoder.layers.14.linear2.bias", "transformer_encoder.transformer_encoder.layers.14.norm1.weight", "transformer_encoder.transformer_encoder.layers.14.norm1.bias", "transformer_encoder.transformer_encoder.layers.14.norm2.weight", "transformer_encoder.transformer_encoder.layers.14.norm2.bias", "transformer_encoder.transformer_encoder.layers.15.self_attn.in_proj_weight", "transformer_encoder.transformer_encoder.layers.15.self_attn.in_proj_bias", "transformer_encoder.transformer_encoder.layers.15.self_attn.out_proj.weight", "transformer_encoder.transformer_encoder.layers.15.self_attn.out_proj.bias", "transformer_encoder.transformer_encoder.layers.15.linear1.weight", "transformer_encoder.transformer_encoder.layers.15.linear1.bias", "transformer_encoder.transformer_encoder.layers.15.linear2.weight", "transformer_encoder.transformer_encoder.layers.15.linear2.bias", "transformer_encoder.transformer_encoder.layers.15.norm1.weight", "transformer_encoder.transformer_encoder.layers.15.norm1.bias", "transformer_encoder.transformer_encoder.layers.15.norm2.weight", "transformer_encoder.transformer_encoder.layers.15.norm2.bias", "transformer_encoder.transformer_encoder.layers.16.self_attn.in_proj_weight", "transformer_encoder.transformer_encoder.layers.16.self_attn.in_proj_bias", "transformer_encoder.transformer_encoder.layers.16.self_attn.out_proj.weight", "transformer_encoder.transformer_encoder.layers.16.self_attn.out_proj.bias", "transformer_encoder.transformer_encoder.layers.16.linear1.weight", "transformer_encoder.transformer_encoder.layers.16.linear1.bias", "transformer_encoder.transformer_encoder.layers.16.linear2.weight", "transformer_encoder.transformer_encoder.layers.16.linear2.bias", "transformer_encoder.transformer_encoder.layers.16.norm1.weight", "transformer_encoder.transformer_encoder.layers.16.norm1.bias", "transformer_encoder.transformer_encoder.layers.16.norm2.weight", "transformer_encoder.transformer_encoder.layers.16.norm2.bias", "transformer_encoder.transformer_encoder.layers.17.self_attn.in_proj_weight", "transformer_encoder.transformer_encoder.layers.17.self_attn.in_proj_bias", "transformer_encoder.transformer_encoder.layers.17.self_attn.out_proj.weight", "transformer_encoder.transformer_encoder.layers.17.self_attn.out_proj.bias", "transformer_encoder.transformer_encoder.layers.17.linear1.weight", "transformer_encoder.transformer_encoder.layers.17.linear1.bias", "transformer_encoder.transformer_encoder.layers.17.linear2.weight", "transformer_encoder.transformer_encoder.layers.17.linear2.bias", "transformer_encoder.transformer_encoder.layers.17.norm1.weight", "transformer_encoder.transformer_encoder.layers.17.norm1.bias", "transformer_encoder.transformer_encoder.layers.17.norm2.weight", "transformer_encoder.transformer_encoder.layers.17.norm2.bias", "transformer_encoder.transformer_encoder.layers.18.self_attn.in_proj_weight", "transformer_encoder.transformer_encoder.layers.18.self_attn.in_proj_bias", "transformer_encoder.transformer_encoder.layers.18.self_attn.out_proj.weight", "transformer_encoder.transformer_encoder.layers.18.self_attn.out_proj.bias", "transformer_encoder.transformer_encoder.layers.18.linear1.weight", "transformer_encoder.transformer_encoder.layers.18.linear1.bias", "transformer_encoder.transformer_encoder.layers.18.linear2.weight", "transformer_encoder.transformer_encoder.layers.18.linear2.bias", "transformer_encoder.transformer_encoder.layers.18.norm1.weight", "transformer_encoder.transformer_encoder.layers.18.norm1.bias", "transformer_encoder.transformer_encoder.layers.18.norm2.weight", "transformer_encoder.transformer_encoder.layers.18.norm2.bias", "transformer_encoder.transformer_encoder.layers.19.self_attn.in_proj_weight", "transformer_encoder.transformer_encoder.layers.19.self_attn.in_proj_bias", "transformer_encoder.transformer_encoder.layers.19.self_attn.out_proj.weight", "transformer_encoder.transformer_encoder.layers.19.self_attn.out_proj.bias", "transformer_encoder.transformer_encoder.layers.19.linear1.weight", "transformer_encoder.transformer_encoder.layers.19.linear1.bias", "transformer_encoder.transformer_encoder.layers.19.linear2.weight", "transformer_encoder.transformer_encoder.layers.19.linear2.bias", "transformer_encoder.transformer_encoder.layers.19.norm1.weight", "transformer_encoder.transformer_encoder.layers.19.norm1.bias", "transformer_encoder.transformer_encoder.layers.19.norm2.weight", "transformer_encoder.transformer_encoder.layers.19.norm2.bias", "transformer_encoder.transformer_encoder.layers.20.self_attn.in_proj_weight", "transformer_encoder.transformer_encoder.layers.20.self_attn.in_proj_bias", "transformer_encoder.transformer_encoder.layers.20.self_attn.out_proj.weight", "transformer_encoder.transformer_encoder.layers.20.self_attn.out_proj.bias", "transformer_encoder.transformer_encoder.layers.20.linear1.weight", "transformer_encoder.transformer_encoder.layers.20.linear1.bias", "transformer_encoder.transformer_encoder.layers.20.linear2.weight", "transformer_encoder.transformer_encoder.layers.20.linear2.bias", "transformer_encoder.transformer_encoder.layers.20.norm1.weight", "transformer_encoder.transformer_encoder.layers.20.norm1.bias", "transformer_encoder.transformer_encoder.layers.20.norm2.weight", "transformer_encoder.transformer_encoder.layers.20.norm2.bias", "transformer_encoder.transformer_encoder.layers.21.self_attn.in_proj_weight", "transformer_encoder.transformer_encoder.layers.21.self_attn.in_proj_bias", "transformer_encoder.transformer_encoder.layers.21.self_attn.out_proj.weight", "transformer_encoder.transformer_encoder.layers.21.self_attn.out_proj.bias", "transformer_encoder.transformer_encoder.layers.21.linear1.weight", "transformer_encoder.transformer_encoder.layers.21.linear1.bias", "transformer_encoder.transformer_encoder.layers.21.linear2.weight", "transformer_encoder.transformer_encoder.layers.21.linear2.bias", "transformer_encoder.transformer_encoder.layers.21.norm1.weight", "transformer_encoder.transformer_encoder.layers.21.norm1.bias", "transformer_encoder.transformer_encoder.layers.21.norm2.weight", "transformer_encoder.transformer_encoder.layers.21.norm2.bias", "transformer_encoder.transformer_encoder.layers.22.self_attn.in_proj_weight", "transformer_encoder.transformer_encoder.layers.22.self_attn.in_proj_bias", "transformer_encoder.transformer_encoder.layers.22.self_attn.out_proj.weight", "transformer_encoder.transformer_encoder.layers.22.self_attn.out_proj.bias", "transformer_encoder.transformer_encoder.layers.22.linear1.weight", "transformer_encoder.transformer_encoder.layers.22.linear1.bias", "transformer_encoder.transformer_encoder.layers.22.linear2.weight", "transformer_encoder.transformer_encoder.layers.22.linear2.bias", "transformer_encoder.transformer_encoder.layers.22.norm1.weight", "transformer_encoder.transformer_encoder.layers.22.norm1.bias", "transformer_encoder.transformer_encoder.layers.22.norm2.weight", "transformer_encoder.transformer_encoder.layers.22.norm2.bias", "transformer_encoder.transformer_encoder.layers.23.self_attn.in_proj_weight", "transformer_encoder.transformer_encoder.layers.23.self_attn.in_proj_bias", "transformer_encoder.transformer_encoder.layers.23.self_attn.out_proj.weight", "transformer_encoder.transformer_encoder.layers.23.self_attn.out_proj.bias", "transformer_encoder.transformer_encoder.layers.23.linear1.weight", "transformer_encoder.transformer_encoder.layers.23.linear1.bias", "transformer_encoder.transformer_encoder.layers.23.linear2.weight", "transformer_encoder.transformer_encoder.layers.23.linear2.bias", "transformer_encoder.transformer_encoder.layers.23.norm1.weight", "transformer_encoder.transformer_encoder.layers.23.norm1.bias", "transformer_encoder.transformer_encoder.layers.23.norm2.weight", "transformer_encoder.transformer_encoder.layers.23.norm2.bias", "transformer_encoder.transformer_encoder.layers.24.self_attn.in_proj_weight", "transformer_encoder.transformer_encoder.layers.24.self_attn.in_proj_bias", "transformer_encoder.transformer_encoder.layers.24.self_attn.out_proj.weight", "transformer_encoder.transformer_encoder.layers.24.self_attn.out_proj.bias", "transformer_encoder.transformer_encoder.layers.24.linear1.weight", "transformer_encoder.transformer_encoder.layers.24.linear1.bias", "transformer_encoder.transformer_encoder.layers.24.linear2.weight", "transformer_encoder.transformer_encoder.layers.24.linear2.bias", "transformer_encoder.transformer_encoder.layers.24.norm1.weight", "transformer_encoder.transformer_encoder.layers.24.norm1.bias", "transformer_encoder.transformer_encoder.layers.24.norm2.weight", "transformer_encoder.transformer_encoder.layers.24.norm2.bias", "transformer_encoder.transformer_encoder.layers.25.self_attn.in_proj_weight", "transformer_encoder.transformer_encoder.layers.25.self_attn.in_proj_bias", "transformer_encoder.transformer_encoder.layers.25.self_attn.out_proj.weight", "transformer_encoder.transformer_encoder.layers.25.self_attn.out_proj.bias", "transformer_encoder.transformer_encoder.layers.25.linear1.weight", "transformer_encoder.transformer_encoder.layers.25.linear1.bias", "transformer_encoder.transformer_encoder.layers.25.linear2.weight", "transformer_encoder.transformer_encoder.layers.25.linear2.bias", "transformer_encoder.transformer_encoder.layers.25.norm1.weight", "transformer_encoder.transformer_encoder.layers.25.norm1.bias", "transformer_encoder.transformer_encoder.layers.25.norm2.weight", "transformer_encoder.transformer_encoder.layers.25.norm2.bias", "transformer_encoder.transformer_encoder.layers.26.self_attn.in_proj_weight", "transformer_encoder.transformer_encoder.layers.26.self_attn.in_proj_bias", "transformer_encoder.transformer_encoder.layers.26.self_attn.out_proj.weight", "transformer_encoder.transformer_encoder.layers.26.self_attn.out_proj.bias", "transformer_encoder.transformer_encoder.layers.26.linear1.weight", "transformer_encoder.transformer_encoder.layers.26.linear1.bias", "transformer_encoder.transformer_encoder.layers.26.linear2.weight", "transformer_encoder.transformer_encoder.layers.26.linear2.bias", "transformer_encoder.transformer_encoder.layers.26.norm1.weight", "transformer_encoder.transformer_encoder.layers.26.norm1.bias", "transformer_encoder.transformer_encoder.layers.26.norm2.weight", "transformer_encoder.transformer_encoder.layers.26.norm2.bias", "transformer_encoder.transformer_encoder.layers.27.self_attn.in_proj_weight", "transformer_encoder.transformer_encoder.layers.27.self_attn.in_proj_bias", "transformer_encoder.transformer_encoder.layers.27.self_attn.out_proj.weight", "transformer_encoder.transformer_encoder.layers.27.self_attn.out_proj.bias", "transformer_encoder.transformer_encoder.layers.27.linear1.weight", "transformer_encoder.transformer_encoder.layers.27.linear1.bias", "transformer_encoder.transformer_encoder.layers.27.linear2.weight", "transformer_encoder.transformer_encoder.layers.27.linear2.bias", "transformer_encoder.transformer_encoder.layers.27.norm1.weight", "transformer_encoder.transformer_encoder.layers.27.norm1.bias", "transformer_encoder.transformer_encoder.layers.27.norm2.weight", "transformer_encoder.transformer_encoder.layers.27.norm2.bias", "transformer_encoder.transformer_encoder.layers.28.self_attn.in_proj_weight", "transformer_encoder.transformer_encoder.layers.28.self_attn.in_proj_bias", "transformer_encoder.transformer_encoder.layers.28.self_attn.out_proj.weight", "transformer_encoder.transformer_encoder.layers.28.self_attn.out_proj.bias", "transformer_encoder.transformer_encoder.layers.28.linear1.weight", "transformer_encoder.transformer_encoder.layers.28.linear1.bias", "transformer_encoder.transformer_encoder.layers.28.linear2.weight", "transformer_encoder.transformer_encoder.layers.28.linear2.bias", "transformer_encoder.transformer_encoder.layers.28.norm1.weight", "transformer_encoder.transformer_encoder.layers.28.norm1.bias", "transformer_encoder.transformer_encoder.layers.28.norm2.weight", "transformer_encoder.transformer_encoder.layers.28.norm2.bias", "transformer_encoder.transformer_encoder.layers.29.self_attn.in_proj_weight", "transformer_encoder.transformer_encoder.layers.29.self_attn.in_proj_bias", "transformer_encoder.transformer_encoder.layers.29.self_attn.out_proj.weight", "transformer_encoder.transformer_encoder.layers.29.self_attn.out_proj.bias", "transformer_encoder.transformer_encoder.layers.29.linear1.weight", "transformer_encoder.transformer_encoder.layers.29.linear1.bias", "transformer_encoder.transformer_encoder.layers.29.linear2.weight", "transformer_encoder.transformer_encoder.layers.29.linear2.bias", "transformer_encoder.transformer_encoder.layers.29.norm1.weight", "transformer_encoder.transformer_encoder.layers.29.norm1.bias", "transformer_encoder.transformer_encoder.layers.29.norm2.weight", "transformer_encoder.transformer_encoder.layers.29.norm2.bias", "transformer_encoder.transformer_encoder.layers.30.self_attn.in_proj_weight", "transformer_encoder.transformer_encoder.layers.30.self_attn.in_proj_bias", "transformer_encoder.transformer_encoder.layers.30.self_attn.out_proj.weight", "transformer_encoder.transformer_encoder.layers.30.self_attn.out_proj.bias", "transformer_encoder.transformer_encoder.layers.30.linear1.weight", "transformer_encoder.transformer_encoder.layers.30.linear1.bias", "transformer_encoder.transformer_encoder.layers.30.linear2.weight", "transformer_encoder.transformer_encoder.layers.30.linear2.bias", "transformer_encoder.transformer_encoder.layers.30.norm1.weight", "transformer_encoder.transformer_encoder.layers.30.norm1.bias", "transformer_encoder.transformer_encoder.layers.30.norm2.weight", "transformer_encoder.transformer_encoder.layers.30.norm2.bias", "transformer_encoder.transformer_encoder.layers.31.self_attn.in_proj_weight", "transformer_encoder.transformer_encoder.layers.31.self_attn.in_proj_bias", "transformer_encoder.transformer_encoder.layers.31.self_attn.out_proj.weight", "transformer_encoder.transformer_encoder.layers.31.self_attn.out_proj.bias", "transformer_encoder.transformer_encoder.layers.31.linear1.weight", "transformer_encoder.transformer_encoder.layers.31.linear1.bias", "transformer_encoder.transformer_encoder.layers.31.linear2.weight", "transformer_encoder.transformer_encoder.layers.31.linear2.bias", "transformer_encoder.transformer_encoder.layers.31.norm1.weight", "transformer_encoder.transformer_encoder.layers.31.norm1.bias", "transformer_encoder.transformer_encoder.layers.31.norm2.weight", "transformer_encoder.transformer_encoder.layers.31.norm2.bias", "transformer_encoder.transformer_encoder.layers.32.self_attn.in_proj_weight", "transformer_encoder.transformer_encoder.layers.32.self_attn.in_proj_bias", "transformer_encoder.transformer_encoder.layers.32.self_attn.out_proj.weight", "transformer_encoder.transformer_encoder.layers.32.self_attn.out_proj.bias", "transformer_encoder.transformer_encoder.layers.32.linear1.weight", "transformer_encoder.transformer_encoder.layers.32.linear1.bias", "transformer_encoder.transformer_encoder.layers.32.linear2.weight", "transformer_encoder.transformer_encoder.layers.32.linear2.bias", "transformer_encoder.transformer_encoder.layers.32.norm1.weight", "transformer_encoder.transformer_encoder.layers.32.norm1.bias", "transformer_encoder.transformer_encoder.layers.32.norm2.weight", "transformer_encoder.transformer_encoder.layers.32.norm2.bias", "transformer_encoder.transformer_encoder.layers.33.self_attn.in_proj_weight", "transformer_encoder.transformer_encoder.layers.33.self_attn.in_proj_bias", "transformer_encoder.transformer_encoder.layers.33.self_attn.out_proj.weight", "transformer_encoder.transformer_encoder.layers.33.self_attn.out_proj.bias", "transformer_encoder.transformer_encoder.layers.33.linear1.weight", "transformer_encoder.transformer_encoder.layers.33.linear1.bias", "transformer_encoder.transformer_encoder.layers.33.linear2.weight", "transformer_encoder.transformer_encoder.layers.33.linear2.bias", "transformer_encoder.transformer_encoder.layers.33.norm1.weight", "transformer_encoder.transformer_encoder.layers.33.norm1.bias", "transformer_encoder.transformer_encoder.layers.33.norm2.weight", "transformer_encoder.transformer_encoder.layers.33.norm2.bias", "transformer_encoder.transformer_encoder.layers.34.self_attn.in_proj_weight", "transformer_encoder.transformer_encoder.layers.34.self_attn.in_proj_bias", "transformer_encoder.transformer_encoder.layers.34.self_attn.out_proj.weight", "transformer_encoder.transformer_encoder.layers.34.self_attn.out_proj.bias", "transformer_encoder.transformer_encoder.layers.34.linear1.weight", "transformer_encoder.transformer_encoder.layers.34.linear1.bias", "transformer_encoder.transformer_encoder.layers.34.linear2.weight", "transformer_encoder.transformer_encoder.layers.34.linear2.bias", "transformer_encoder.transformer_encoder.layers.34.norm1.weight", "transformer_encoder.transformer_encoder.layers.34.norm1.bias", "transformer_encoder.transformer_encoder.layers.34.norm2.weight", "transformer_encoder.transformer_encoder.layers.34.norm2.bias", "transformer_encoder.transformer_encoder.layers.35.self_attn.in_proj_weight", "transformer_encoder.transformer_encoder.layers.35.self_attn.in_proj_bias", "transformer_encoder.transformer_encoder.layers.35.self_attn.out_proj.weight", "transformer_encoder.transformer_encoder.layers.35.self_attn.out_proj.bias", "transformer_encoder.transformer_encoder.layers.35.linear1.weight", "transformer_encoder.transformer_encoder.layers.35.linear1.bias", "transformer_encoder.transformer_encoder.layers.35.linear2.weight", "transformer_encoder.transformer_encoder.layers.35.linear2.bias", "transformer_encoder.transformer_encoder.layers.35.norm1.weight", "transformer_encoder.transformer_encoder.layers.35.norm1.bias", "transformer_encoder.transformer_encoder.layers.35.norm2.weight", "transformer_encoder.transformer_encoder.layers.35.norm2.bias", "transformer_encoder.transformer_encoder.layers.36.self_attn.in_proj_weight", "transformer_encoder.transformer_encoder.layers.36.self_attn.in_proj_bias", "transformer_encoder.transformer_encoder.layers.36.self_attn.out_proj.weight", "transformer_encoder.transformer_encoder.layers.36.self_attn.out_proj.bias", "transformer_encoder.transformer_encoder.layers.36.linear1.weight", "transformer_encoder.transformer_encoder.layers.36.linear1.bias", "transformer_encoder.transformer_encoder.layers.36.linear2.weight", "transformer_encoder.transformer_encoder.layers.36.linear2.bias", "transformer_encoder.transformer_encoder.layers.36.norm1.weight", "transformer_encoder.transformer_encoder.layers.36.norm1.bias", "transformer_encoder.transformer_encoder.layers.36.norm2.weight", "transformer_encoder.transformer_encoder.layers.36.norm2.bias", "transformer_encoder.transformer_encoder.layers.37.self_attn.in_proj_weight", "transformer_encoder.transformer_encoder.layers.37.self_attn.in_proj_bias", "transformer_encoder.transformer_encoder.layers.37.self_attn.out_proj.weight", "transformer_encoder.transformer_encoder.layers.37.self_attn.out_proj.bias", "transformer_encoder.transformer_encoder.layers.37.linear1.weight", "transformer_encoder.transformer_encoder.layers.37.linear1.bias", "transformer_encoder.transformer_encoder.layers.37.linear2.weight", "transformer_encoder.transformer_encoder.layers.37.linear2.bias", "transformer_encoder.transformer_encoder.layers.37.norm1.weight", "transformer_encoder.transformer_encoder.layers.37.norm1.bias", "transformer_encoder.transformer_encoder.layers.37.norm2.weight", "transformer_encoder.transformer_encoder.layers.37.norm2.bias", "transformer_encoder.transformer_encoder.layers.38.self_attn.in_proj_weight", "transformer_encoder.transformer_encoder.layers.38.self_attn.in_proj_bias", "transformer_encoder.transformer_encoder.layers.38.self_attn.out_proj.weight", "transformer_encoder.transformer_encoder.layers.38.self_attn.out_proj.bias", "transformer_encoder.transformer_encoder.layers.38.linear1.weight", "transformer_encoder.transformer_encoder.layers.38.linear1.bias", "transformer_encoder.transformer_encoder.layers.38.linear2.weight", "transformer_encoder.transformer_encoder.layers.38.linear2.bias", "transformer_encoder.transformer_encoder.layers.38.norm1.weight", "transformer_encoder.transformer_encoder.layers.38.norm1.bias", "transformer_encoder.transformer_encoder.layers.38.norm2.weight", "transformer_encoder.transformer_encoder.layers.38.norm2.bias", "transformer_encoder.transformer_encoder.layers.39.self_attn.in_proj_weight", "transformer_encoder.transformer_encoder.layers.39.self_attn.in_proj_bias", "transformer_encoder.transformer_encoder.layers.39.self_attn.out_proj.weight", "transformer_encoder.transformer_encoder.layers.39.self_attn.out_proj.bias", "transformer_encoder.transformer_encoder.layers.39.linear1.weight", "transformer_encoder.transformer_encoder.layers.39.linear1.bias", "transformer_encoder.transformer_encoder.layers.39.linear2.weight", "transformer_encoder.transformer_encoder.layers.39.linear2.bias", "transformer_encoder.transformer_encoder.layers.39.norm1.weight", "transformer_encoder.transformer_encoder.layers.39.norm1.bias", "transformer_encoder.transformer_encoder.layers.39.norm2.weight", "transformer_encoder.transformer_encoder.layers.39.norm2.bias", "transformer_encoder.transformer_encoder.layers.40.self_attn.in_proj_weight", "transformer_encoder.transformer_encoder.layers.40.self_attn.in_proj_bias", "transformer_encoder.transformer_encoder.layers.40.self_attn.out_proj.weight", "transformer_encoder.transformer_encoder.layers.40.self_attn.out_proj.bias", "transformer_encoder.transformer_encoder.layers.40.linear1.weight", "transformer_encoder.transformer_encoder.layers.40.linear1.bias", "transformer_encoder.transformer_encoder.layers.40.linear2.weight", "transformer_encoder.transformer_encoder.layers.40.linear2.bias", "transformer_encoder.transformer_encoder.layers.40.norm1.weight", "transformer_encoder.transformer_encoder.layers.40.norm1.bias", "transformer_encoder.transformer_encoder.layers.40.norm2.weight", "transformer_encoder.transformer_encoder.layers.40.norm2.bias", "transformer_encoder.transformer_encoder.layers.41.self_attn.in_proj_weight", "transformer_encoder.transformer_encoder.layers.41.self_attn.in_proj_bias", "transformer_encoder.transformer_encoder.layers.41.self_attn.out_proj.weight", "transformer_encoder.transformer_encoder.layers.41.self_attn.out_proj.bias", "transformer_encoder.transformer_encoder.layers.41.linear1.weight", "transformer_encoder.transformer_encoder.layers.41.linear1.bias", "transformer_encoder.transformer_encoder.layers.41.linear2.weight", "transformer_encoder.transformer_encoder.layers.41.linear2.bias", "transformer_encoder.transformer_encoder.layers.41.norm1.weight", "transformer_encoder.transformer_encoder.layers.41.norm1.bias", "transformer_encoder.transformer_encoder.layers.41.norm2.weight", "transformer_encoder.transformer_encoder.layers.41.norm2.bias", "transformer_encoder.transformer_encoder.layers.42.self_attn.in_proj_weight", "transformer_encoder.transformer_encoder.layers.42.self_attn.in_proj_bias", "transformer_encoder.transformer_encoder.layers.42.self_attn.out_proj.weight", "transformer_encoder.transformer_encoder.layers.42.self_attn.out_proj.bias", "transformer_encoder.transformer_encoder.layers.42.linear1.weight", "transformer_encoder.transformer_encoder.layers.42.linear1.bias", "transformer_encoder.transformer_encoder.layers.42.linear2.weight", "transformer_encoder.transformer_encoder.layers.42.linear2.bias", "transformer_encoder.transformer_encoder.layers.42.norm1.weight", "transformer_encoder.transformer_encoder.layers.42.norm1.bias", "transformer_encoder.transformer_encoder.layers.42.norm2.weight", "transformer_encoder.transformer_encoder.layers.42.norm2.bias", "transformer_encoder.transformer_encoder.layers.43.self_attn.in_proj_weight", "transformer_encoder.transformer_encoder.layers.43.self_attn.in_proj_bias", "transformer_encoder.transformer_encoder.layers.43.self_attn.out_proj.weight", "transformer_encoder.transformer_encoder.layers.43.self_attn.out_proj.bias", "transformer_encoder.transformer_encoder.layers.43.linear1.weight", "transformer_encoder.transformer_encoder.layers.43.linear1.bias", "transformer_encoder.transformer_encoder.layers.43.linear2.weight", "transformer_encoder.transformer_encoder.layers.43.linear2.bias", "transformer_encoder.transformer_encoder.layers.43.norm1.weight", "transformer_encoder.transformer_encoder.layers.43.norm1.bias", "transformer_encoder.transformer_encoder.layers.43.norm2.weight", "transformer_encoder.transformer_encoder.layers.43.norm2.bias", "transformer_encoder.transformer_encoder.layers.44.self_attn.in_proj_weight", "transformer_encoder.transformer_encoder.layers.44.self_attn.in_proj_bias", "transformer_encoder.transformer_encoder.layers.44.self_attn.out_proj.weight", "transformer_encoder.transformer_encoder.layers.44.self_attn.out_proj.bias", "transformer_encoder.transformer_encoder.layers.44.linear1.weight", "transformer_encoder.transformer_encoder.layers.44.linear1.bias", "transformer_encoder.transformer_encoder.layers.44.linear2.weight", "transformer_encoder.transformer_encoder.layers.44.linear2.bias", "transformer_encoder.transformer_encoder.layers.44.norm1.weight", "transformer_encoder.transformer_encoder.layers.44.norm1.bias", "transformer_encoder.transformer_encoder.layers.44.norm2.weight", "transformer_encoder.transformer_encoder.layers.44.norm2.bias", "transformer_encoder.transformer_encoder.layers.45.self_attn.in_proj_weight", "transformer_encoder.transformer_encoder.layers.45.self_attn.in_proj_bias", "transformer_encoder.transformer_encoder.layers.45.self_attn.out_proj.weight", "transformer_encoder.transformer_encoder.layers.45.self_attn.out_proj.bias", "transformer_encoder.transformer_encoder.layers.45.linear1.weight", "transformer_encoder.transformer_encoder.layers.45.linear1.bias", "transformer_encoder.transformer_encoder.layers.45.linear2.weight", "transformer_encoder.transformer_encoder.layers.45.linear2.bias", "transformer_encoder.transformer_encoder.layers.45.norm1.weight", "transformer_encoder.transformer_encoder.layers.45.norm1.bias", "transformer_encoder.transformer_encoder.layers.45.norm2.weight", "transformer_encoder.transformer_encoder.layers.45.norm2.bias", "transformer_encoder.transformer_encoder.layers.46.self_attn.in_proj_weight", "transformer_encoder.transformer_encoder.layers.46.self_attn.in_proj_bias", "transformer_encoder.transformer_encoder.layers.46.self_attn.out_proj.weight", "transformer_encoder.transformer_encoder.layers.46.self_attn.out_proj.bias", "transformer_encoder.transformer_encoder.layers.46.linear1.weight", "transformer_encoder.transformer_encoder.layers.46.linear1.bias", "transformer_encoder.transformer_encoder.layers.46.linear2.weight", "transformer_encoder.transformer_encoder.layers.46.linear2.bias", "transformer_encoder.transformer_encoder.layers.46.norm1.weight", "transformer_encoder.transformer_encoder.layers.46.norm1.bias", "transformer_encoder.transformer_encoder.layers.46.norm2.weight", "transformer_encoder.transformer_encoder.layers.46.norm2.bias", "transformer_encoder.transformer_encoder.layers.47.self_attn.in_proj_weight", "transformer_encoder.transformer_encoder.layers.47.self_attn.in_proj_bias", "transformer_encoder.transformer_encoder.layers.47.self_attn.out_proj.weight", "transformer_encoder.transformer_encoder.layers.47.self_attn.out_proj.bias", "transformer_encoder.transformer_encoder.layers.47.linear1.weight", "transformer_encoder.transformer_encoder.layers.47.linear1.bias", "transformer_encoder.transformer_encoder.layers.47.linear2.weight", "transformer_encoder.transformer_encoder.layers.47.linear2.bias", "transformer_encoder.transformer_encoder.layers.47.norm1.weight", "transformer_encoder.transformer_encoder.layers.47.norm1.bias", "transformer_encoder.transformer_encoder.layers.47.norm2.weight", "transformer_encoder.transformer_encoder.layers.47.norm2.bias", "transformer_encoder.transformer_encoder.layers.48.self_attn.in_proj_weight", "transformer_encoder.transformer_encoder.layers.48.self_attn.in_proj_bias", "transformer_encoder.transformer_encoder.layers.48.self_attn.out_proj.weight", "transformer_encoder.transformer_encoder.layers.48.self_attn.out_proj.bias", "transformer_encoder.transformer_encoder.layers.48.linear1.weight", "transformer_encoder.transformer_encoder.layers.48.linear1.bias", "transformer_encoder.transformer_encoder.layers.48.linear2.weight", "transformer_encoder.transformer_encoder.layers.48.linear2.bias", "transformer_encoder.transformer_encoder.layers.48.norm1.weight", "transformer_encoder.transformer_encoder.layers.48.norm1.bias", "transformer_encoder.transformer_encoder.layers.48.norm2.weight", "transformer_encoder.transformer_encoder.layers.48.norm2.bias", "transformer_encoder.transformer_encoder.layers.49.self_attn.in_proj_weight", "transformer_encoder.transformer_encoder.layers.49.self_attn.in_proj_bias", "transformer_encoder.transformer_encoder.layers.49.self_attn.out_proj.weight", "transformer_encoder.transformer_encoder.layers.49.self_attn.out_proj.bias", "transformer_encoder.transformer_encoder.layers.49.linear1.weight", "transformer_encoder.transformer_encoder.layers.49.linear1.bias", "transformer_encoder.transformer_encoder.layers.49.linear2.weight", "transformer_encoder.transformer_encoder.layers.49.linear2.bias", "transformer_encoder.transformer_encoder.layers.49.norm1.weight", "transformer_encoder.transformer_encoder.layers.49.norm1.bias", "transformer_encoder.transformer_encoder.layers.49.norm2.weight", "transformer_encoder.transformer_encoder.layers.49.norm2.bias". 
	size mismatch for transformer_encoder.transformer_encoder.layers.0.self_attn.in_proj_weight: copying a param with shape torch.Size([48, 16]) from checkpoint, the shape in current model is torch.Size([192, 64]).
	size mismatch for transformer_encoder.transformer_encoder.layers.0.self_attn.in_proj_bias: copying a param with shape torch.Size([48]) from checkpoint, the shape in current model is torch.Size([192]).
	size mismatch for transformer_encoder.transformer_encoder.layers.0.self_attn.out_proj.weight: copying a param with shape torch.Size([16, 16]) from checkpoint, the shape in current model is torch.Size([64, 64]).
	size mismatch for transformer_encoder.transformer_encoder.layers.0.self_attn.out_proj.bias: copying a param with shape torch.Size([16]) from checkpoint, the shape in current model is torch.Size([64]).
	size mismatch for transformer_encoder.transformer_encoder.layers.0.linear1.weight: copying a param with shape torch.Size([2, 16]) from checkpoint, the shape in current model is torch.Size([2, 64]).
	size mismatch for transformer_encoder.transformer_encoder.layers.0.linear2.weight: copying a param with shape torch.Size([16, 2]) from checkpoint, the shape in current model is torch.Size([64, 2]).
	size mismatch for transformer_encoder.transformer_encoder.layers.0.linear2.bias: copying a param with shape torch.Size([16]) from checkpoint, the shape in current model is torch.Size([64]).
	size mismatch for transformer_encoder.transformer_encoder.layers.0.norm1.weight: copying a param with shape torch.Size([16]) from checkpoint, the shape in current model is torch.Size([64]).
	size mismatch for transformer_encoder.transformer_encoder.layers.0.norm1.bias: copying a param with shape torch.Size([16]) from checkpoint, the shape in current model is torch.Size([64]).
	size mismatch for transformer_encoder.transformer_encoder.layers.0.norm2.weight: copying a param with shape torch.Size([16]) from checkpoint, the shape in current model is torch.Size([64]).
	size mismatch for transformer_encoder.transformer_encoder.layers.0.norm2.bias: copying a param with shape torch.Size([16]) from checkpoint, the shape in current model is torch.Size([64]).
	size mismatch for transformer_encoder.transformer_encoder.layers.1.self_attn.in_proj_weight: copying a param with shape torch.Size([48, 16]) from checkpoint, the shape in current model is torch.Size([192, 64]).
	size mismatch for transformer_encoder.transformer_encoder.layers.1.self_attn.in_proj_bias: copying a param with shape torch.Size([48]) from checkpoint, the shape in current model is torch.Size([192]).
	size mismatch for transformer_encoder.transformer_encoder.layers.1.self_attn.out_proj.weight: copying a param with shape torch.Size([16, 16]) from checkpoint, the shape in current model is torch.Size([64, 64]).
	size mismatch for transformer_encoder.transformer_encoder.layers.1.self_attn.out_proj.bias: copying a param with shape torch.Size([16]) from checkpoint, the shape in current model is torch.Size([64]).
	size mismatch for transformer_encoder.transformer_encoder.layers.1.linear1.weight: copying a param with shape torch.Size([2, 16]) from checkpoint, the shape in current model is torch.Size([2, 64]).
	size mismatch for transformer_encoder.transformer_encoder.layers.1.linear2.weight: copying a param with shape torch.Size([16, 2]) from checkpoint, the shape in current model is torch.Size([64, 2]).
	size mismatch for transformer_encoder.transformer_encoder.layers.1.linear2.bias: copying a param with shape torch.Size([16]) from checkpoint, the shape in current model is torch.Size([64]).
	size mismatch for transformer_encoder.transformer_encoder.layers.1.norm1.weight: copying a param with shape torch.Size([16]) from checkpoint, the shape in current model is torch.Size([64]).
	size mismatch for transformer_encoder.transformer_encoder.layers.1.norm1.bias: copying a param with shape torch.Size([16]) from checkpoint, the shape in current model is torch.Size([64]).
	size mismatch for transformer_encoder.transformer_encoder.layers.1.norm2.weight: copying a param with shape torch.Size([16]) from checkpoint, the shape in current model is torch.Size([64]).
	size mismatch for transformer_encoder.transformer_encoder.layers.1.norm2.bias: copying a param with shape torch.Size([16]) from checkpoint, the shape in current model is torch.Size([64]).
	size mismatch for transformer_encoder.transformer_encoder.layers.2.self_attn.in_proj_weight: copying a param with shape torch.Size([48, 16]) from checkpoint, the shape in current model is torch.Size([192, 64]).
	size mismatch for transformer_encoder.transformer_encoder.layers.2.self_attn.in_proj_bias: copying a param with shape torch.Size([48]) from checkpoint, the shape in current model is torch.Size([192]).
	size mismatch for transformer_encoder.transformer_encoder.layers.2.self_attn.out_proj.weight: copying a param with shape torch.Size([16, 16]) from checkpoint, the shape in current model is torch.Size([64, 64]).
	size mismatch for transformer_encoder.transformer_encoder.layers.2.self_attn.out_proj.bias: copying a param with shape torch.Size([16]) from checkpoint, the shape in current model is torch.Size([64]).
	size mismatch for transformer_encoder.transformer_encoder.layers.2.linear1.weight: copying a param with shape torch.Size([2, 16]) from checkpoint, the shape in current model is torch.Size([2, 64]).
	size mismatch for transformer_encoder.transformer_encoder.layers.2.linear2.weight: copying a param with shape torch.Size([16, 2]) from checkpoint, the shape in current model is torch.Size([64, 2]).
	size mismatch for transformer_encoder.transformer_encoder.layers.2.linear2.bias: copying a param with shape torch.Size([16]) from checkpoint, the shape in current model is torch.Size([64]).
	size mismatch for transformer_encoder.transformer_encoder.layers.2.norm1.weight: copying a param with shape torch.Size([16]) from checkpoint, the shape in current model is torch.Size([64]).
	size mismatch for transformer_encoder.transformer_encoder.layers.2.norm1.bias: copying a param with shape torch.Size([16]) from checkpoint, the shape in current model is torch.Size([64]).
	size mismatch for transformer_encoder.transformer_encoder.layers.2.norm2.weight: copying a param with shape torch.Size([16]) from checkpoint, the shape in current model is torch.Size([64]).
	size mismatch for transformer_encoder.transformer_encoder.layers.2.norm2.bias: copying a param with shape torch.Size([16]) from checkpoint, the shape in current model is torch.Size([64]).

In [92]:
# Compute embeddings for xgboost training data
with torch.no_grad():
    model.to(device)
    model.eval()
    training_embeddings = model(torch.tensor(X_train_xgboost, dtype=torch.float32).to(device))
    test_embeddings = model(torch.tensor(X_test, dtype=torch.float32).to(device))
    test_embeddings = test_embeddings.cpu().numpy()
    training_embeddings = training_embeddings.cpu().numpy()

In [93]:
training_embeddings.shape

(320, 248, 16)

In [None]:


params = {
    'objective': 'binary:logistic',  # For binary classification; use 'multi:softmax' for multi-class
    'eval_metric': 'logloss',        # Evaluation metric (logarithmic loss)
    'learning_rate': 0.1,            # Step size shrinkage
    'max_depth': 6,                  # Maximum tree depth
    'subsample': 0.8,                # Percentage of samples to use per tree
    'colsample_bytree': 0.8,         # Percentage of features to use per tree
    'lambda': 1,                     # L2 regularization term
    'alpha': 0                       # L1 regularization term
}

# Train the model
model_xgb = MultiOutputClassifier(xgb.XGBClassifier(**params))


In [81]:
y_train_xgboost.shape

(80, 5)

In [94]:
print(model_xgb.classes_)

[array([0, 1]), array([0, 1]), array([0, 1]), array([0, 1]), array([0, 1])]


In [95]:
training_embeddings = training_embeddings.reshape(training_embeddings.shape[0], -1)


In [96]:
model_xgb.fit(training_embeddings, y_train_xgboost)

In [None]:
# save model
pickle.dump(model_xgb, open("eeg2vec/data/saved_models/xgboost_2_smaller_400windows.pkl", "wb"))

In [11]:
import pickle
model_xgb = pickle.load(open("eeg2vec/data/saved_models/xgboost_1_400windows.pkl", "rb"))

In [None]:
# Evaluate the model
test_embeddings = test_embeddings.reshape(test_embeddings.shape[0], -1)
predictions = model_xgb.predict(test_embeddings)

accuracy = accuracy_score(y_test, predictions)
print(f'Accuracy: {accuracy:.2f}')
# F1 score
f1 = f1_score(y_test, predictions, average='weighted')
print(f'F1 Score: {f1:.2f}')

Accuracy: 0.94
F1 Score: 0.99


In [99]:
test_data = all_data[7000:50000]
test_targets = all_targets[7000:50000]

In [100]:
# empty cache
torch.cuda.empty_cache()

In [None]:
model =  model.to(device)
model.eval()
with torch.no_grad():
    model.eval()
    for i in range(0, len(test_data), 1000):
        embeddings = model(torch.tensor(test_data[i:i+1000], dtype=torch.float32).to(device))
        if i == 0:
            all_embeddings = embeddings
        else:
            all_embeddings = torch.cat((all_embeddings, embeddings), dim=0)
    embeddings = all_embeddings.reshape(all_embeddings.shape[0], -1).cpu().numpy()
predictions = model_xgb.predict(embeddings)


accuracy = accuracy_score(test_targets, predictions)
print(f'Accuracy: {accuracy:.2f}')
# F1 score
f1 = f1_score(test_targets, predictions, average='weighted')
print(f'F1 Score: {f1:.2f}')

Accuracy: 0.50
F1 Score: 0.72


# Total Training

In [108]:
def butter_bandpass(lowcut, highcut, fs, order=5):
    return butter(order, [lowcut, highcut], fs=fs, btype='band')

def butter_bandpass_filter(data, lowcut, highcut, fs, order=5):
    b, a = butter_bandpass(lowcut, highcut, fs, order=order)
    y = lfilter(b, a, data)
    return y

In [109]:
# First we need to get the point that maps to a label

def reshape_array_into_windows(x, sample_rate, window_duration_in_seconds):
    """
    Reshape the data into an array of shape (C, T, window) where 'window' contains
    the points corresponding to 'window_duration' seconds of data.

    Parameters:
    x (numpy array): The input data array.
    sample_rate (int): The number of samples per second.
    window_duration_in_seconds (float): The duration of each window in seconds.

    Returns:
    reshaped_x (numpy array): The reshaped array with shape (C, T, window).
    """
    # Calculate the number of samples in one window
    window_size = int(window_duration_in_seconds * sample_rate)
    
    # Ensure the total length of x is a multiple of window_size
    total_samples = x.shape[-1]
    if total_samples % window_size != 0:
        # Truncate or pad x to make it divisible by window_size
        x = x[..., :total_samples - (total_samples % window_size)]
    # Reshape x into (C, T, window)
    reshaped_x = x.reshape(x.shape[0], -1, window_size)

    return reshaped_x

In [115]:
# We first load and reshape all the data
all_data = []
all_targets = []
# We need to have
# data of Shape: [num_samples, num_channels (5), sequence_length]
# labels of Shape: [num_samples, 5]

for data, target in training_data:
    reshaped_data = reshape_array_into_windows(data, 250, 2)
    reshaped_data = reshaped_data.transpose(1, 0, 2)
    target = target.reshape(-1, 5)
    all_data.append(reshaped_data)
    all_targets.append(target)

all_data = np.concatenate(all_data, axis=0)
all_targets = np.concatenate(all_targets, axis=0)


In [116]:
all_data.shape

(52351, 5, 500)

In [117]:
# Split the data into training and test sets
X_train_full, X_test, y_train_full, y_test = train_test_split(all_data, all_targets, test_size=0.2, random_state=42)

# Further split training data for embeddings and XGBoost
X_train_embeddings, X_train_xgboost, y_train_embeddings, y_train_xgboost = train_test_split(X_train_full, y_train_full, test_size=0.5, random_state=42)

In [118]:
# Use GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

## EEG2VEC training

In [122]:
model_eeg2vec = EEG2Vec(16, 2, 3, 2)
model_eeg2vec = model_eeg2vec.to(device)
for session in range(4):
    print("..... Session " + str(session) + " .....")
    session_length = int(len(X_train_embeddings)/5)
    data_loader = get_dataloader(X_train_embeddings[session*session_length:(session+1)*session_length], y_train_embeddings[session*session_length:(session+1)*session_length], batch_size=100, shuffle=True)
    train(model_eeg2vec, data_loader, 20, device)
    # Save the model
    torch.save(model_eeg2vec.state_dict(), "eeg2vec/data/saved_models/eeg2vec_3_final.pth")
data_loader =get_dataloader(X_train_embeddings[(session+1)*session_length:], y_train_embeddings[(session+1)*session_length:], batch_size=100, shuffle=True)
train(model_eeg2vec, data_loader, 20, device)
torch.save(model_eeg2vec.state_dict(), "eeg2vec/data/saved_models/eeg2vec_3_final.pth")

..... Session 0 .....
Epoch 1/20 completed.
Epoch 2/20 completed.
Epoch 3/20 completed.
Epoch 4/20 completed.
Epoch 5/20 completed.
Epoch 6/20 completed.
Epoch 7/20 completed.
Epoch 8/20 completed.
Epoch 9/20 completed.
Epoch 10/20 completed.
Epoch 11/20 completed.
Epoch 12/20 completed.
Epoch 13/20 completed.
Epoch 14/20 completed.
Epoch 15/20 completed.
Epoch 16/20 completed.
Epoch 17/20 completed.
Epoch 18/20 completed.
Epoch 19/20 completed.
Epoch 20/20 completed.
..... Session 1 .....
Epoch 1/20 completed.
Epoch 2/20 completed.
Epoch 3/20 completed.
Epoch 4/20 completed.
Epoch 5/20 completed.
Epoch 6/20 completed.
Epoch 7/20 completed.
Epoch 8/20 completed.
Epoch 9/20 completed.
Epoch 10/20 completed.
Epoch 11/20 completed.
Epoch 12/20 completed.
Epoch 13/20 completed.
Epoch 14/20 completed.
Epoch 15/20 completed.
Epoch 16/20 completed.
Epoch 17/20 completed.
Epoch 18/20 completed.
Epoch 19/20 completed.
Epoch 20/20 completed.
..... Session 2 .....
Epoch 1/20 completed.
Epoch 2/20

## Train XGBOOST

In [124]:
# Compute embeddings for xgboost training data
torch.cuda.empty_cache()
with torch.no_grad():
    model.to(device)
    model.eval()
    training_embeddings = model(torch.tensor(X_train_xgboost, dtype=torch.float32).to(device))
    test_embeddings = model(torch.tensor(X_test, dtype=torch.float32).to(device))
    test_embeddings = test_embeddings.cpu().numpy()
    training_embeddings = training_embeddings.cpu().numpy()
training_embeddings = training_embeddings.reshape(training_embeddings.shape[0], -1)

RuntimeError: CUDA out of memory. Tried to allocate 810.21 GiB (GPU 0; 11.73 GiB total capacity; 5.87 GiB already allocated; 3.68 GiB free; 6.40 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

In [None]:
params = {
    'objective': 'binary:logistic',  # For binary classification; use 'multi:softmax' for multi-class
    'eval_metric': 'logloss',        # Evaluation metric (logarithmic loss)
    'learning_rate': 0.1,            # Step size shrinkage
    'max_depth': 6,                  # Maximum tree depth
    'subsample': 0.8,                # Percentage of samples to use per tree
    'colsample_bytree': 0.8,         # Percentage of features to use per tree
    'lambda': 1,                     # L2 regularization term
    'alpha': 0                       # L1 regularization term
}

# Train the model
model_xgb_eeg2vec = MultiOutputClassifier(xgb.XGBClassifier(**params))

In [None]:
model_xgb_eeg2vec.fit(training_embeddings, y_train_xgboost)
pickle.dump(model_xgb, open("eeg2vec/data/saved_models/xgboost_3_smaller_final.pkl", "wb"))

In [None]:
# Evaluate the model
test_embeddings = test_embeddings.reshape(test_embeddings.shape[0], -1)
predictions = model_xgb.predict(test_embeddings)

accuracy = accuracy_score(y_test, predictions)
print(f'Accuracy: {accuracy:.2f}')
# F1 score
f1 = f1_score(y_test, predictions, average='weighted')
print(f'F1 Score: {f1:.2f}')

# Generating submission

In [None]:
ROOT_TEST_PATH = Path("test/")
test_data = {i:np.load(ROOT_TEST_PATH / f"data_{i}.npy") for i in [4,5]}
# We process each record independantly

def compute_predictions_on_record(data,model,model_xgb):
    filtered_data =  butter_bandpass_filter(data,0.1,18,250,4)
    reshaped_data = reshape_array_into_windows(filtered_data,250,2)

    reshaped_data = reshaped_data.transpose(1, 0, 2)
    model =  model.to(device)
    model.eval()
    embeddings = model(torch.tensor(reshaped_data, dtype=torch.float32).to(device)).cpu().numpy()
    embeddings = embeddings.reshape(embeddings.shape[0], -1)
    predictions = model_xgb.predict(embeddings)
    return predictions

def format_array_to_target_format(array, record_number):
    assert isinstance(record_number, int)
    assert isinstance(array, np.ndarray)
    assert len(array.shape) == 2
    assert array.shape[0] == 5
    assert set(np.unique(array)) == {0, 1}
    formatted_target = []
    for i in range(array.shape[0]):
        channel_encoding = (i + 1) * 100000
        record_number_encoding = record_number * 1000000
        for j in range(array.shape[1]):
            formatted_target.append(
                {
                    "identifier": record_number_encoding + channel_encoding + j,
                    "target": array[i, j],
                }
            )
    return formatted_target
    


In [43]:
test_data[4].shape

(5, 6602015)

In [None]:
results = []
for record_number, data in test_data.items():
    with torch.no_grad():
        preds = compute_predictions_on_record(data,model,model_xgb)
    formatted_preds = format_array_to_target_format(preds,record_number)
    results.extend(formatted_preds)
df = pd.DataFrame(results)
df.to_csv("submission.csv",index = False)

: 