In [3]:
import numpy as np

train_data = np.load("/kaggle/input/2024-1-1/train.npy", allow_pickle = "TRUE")
test_data = np.load("/kaggle/input/2024-1-1/test.npy", allow_pickle = "TRUE")

train_data_input = train_data.item().get('input')
train_data_label = train_data.item().get('label')

test_data_input = test_data.item().get('input')

train_input = train_data_input.squeeze()
test_input = test_data_input.squeeze()
print(train_input.shape)
print(test_input.shape)

(4608, 22, 1125)
(576, 22, 1125)


### Bandpass Filter 적용

In [4]:
import numpy as np
from scipy.signal import butter, filtfilt
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dense, Dropout, Input
from tensorflow.keras.utils import to_categorical
from mne import create_info
from mne.io import RawArray
from mne.decoding import CSP

# 혼합 정밀도 설정
from tensorflow.keras.mixed_precision import set_global_policy, Policy
policy = Policy('mixed_float16')
set_global_policy(policy)

# 밴드패스 필터 설정 및 적용 함수
def butter_bandpass(lowcut, highcut, fs, order=5):
    nyquist = 0.5 * fs
    low = lowcut / nyquist
    high = highcut / nyquist
    b, a = butter(order, [low, high], btype='band')
    return b, a

def bandpass_filter(data, lowcut, highcut, fs, order=5):
    b, a = butter_bandpass(lowcut, highcut, fs, order=order)
    y = np.zeros_like(data)
    for i in range(data.shape[0]):
        for j in range(data.shape[1]):
            y[i, j, :] = filtfilt(b, a, data[i, j, :], axis=-1)
    return y

# 데이터 다운샘플링 함수 (슬라이싱 사용)
def downsample(data, factor):
    return data[:, :, ::factor]

In [5]:
# CSP를 적용하기 위해 데이터를 reshape (샘플 수, 채널 수, 시간 포인트 수)
X_train = train_input.astype(np.float64)
# 원-핫 인코딩 레이블을 단일 클래스 레이블로 변환
y_train = np.argmax(train_data_label, axis=1)
X_test = test_input.astype(np.float64)

In [6]:
# 데이터 필터링
fs = 250.0  # 샘플링 주파수
lowcut = 8.0
highcut = 30.0
X_train = bandpass_filter(X_train, lowcut, highcut, fs)
X_test = bandpass_filter(X_test, lowcut, highcut, fs)

# 데이터 다운샘플링
downsample_factor = 4  # 다운샘플링 비율 (시간축을 4분의 1로 줄임)
X_train = downsample(X_train, downsample_factor)
X_test = downsample(X_test, downsample_factor)

# 데이터 형상 변환 (CNN 입력에 맞게 조정)
X_train = X_train[..., np.newaxis]  # (4608, 22, 282, 1)
X_test = X_test[..., np.newaxis]    # (576, 22, 282, 1)

### K-fold Val

In [7]:
# 데이터 분할 (학습 데이터와 검증 데이터)
X_train, X_val, y_train, y_val = train_test_split(X_train, y_train, test_size=0.2, random_state=42)

# One-Hot Encoding
y_train_cat = to_categorical(y_train, num_classes=4)
y_val_cat = to_categorical(y_val, num_classes=4)

### CNN 모델

In [8]:
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dense, Dropout, Input, BatchNormalization

# CNN 모델 정의
model = Sequential([
    Input(shape=(22, 282, 1)),
    Conv2D(32, (3, 3), activation='relu', padding='same'),
    BatchNormalization(),
    MaxPooling2D((2, 2)),
    Dropout(0.25),
    Conv2D(64, (3, 3), activation='relu', padding='same'),
    BatchNormalization(),
    MaxPooling2D((2, 2)),
    Dropout(0.25),
    Conv2D(128, (3, 3), activation='relu', padding='same'),
    BatchNormalization(),
    MaxPooling2D((2, 2)),
    Dropout(0.25),
    Flatten(),
    Dense(256, activation='relu'),
    BatchNormalization(),
    Dropout(0.5),
    Dense(4, activation='softmax', dtype='float32')
])

In [9]:
# 모델 컴파일
model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])

# 모델 학습
history = model.fit(X_train, y_train_cat, epochs=30, batch_size=64, validation_data=(X_val, y_val_cat))

Epoch 1/30
[1m10/58[0m [32m━━━[0m[37m━━━━━━━━━━━━━━━━━[0m [1m0s[0m 13ms/step - accuracy: 0.2459 - loss: 2.4670 

I0000 00:00:1716601549.233135     111 device_compiler.h:186] Compiled cluster using XLA!  This line is logged at most once for the lifetime of the process.


[1m58/58[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m52s[0m 485ms/step - accuracy: 0.2664 - loss: 2.0903 - val_accuracy: 0.2950 - val_loss: 1.4764
Epoch 2/30
[1m58/58[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 9ms/step - accuracy: 0.3292 - loss: 1.5024 - val_accuracy: 0.3416 - val_loss: 1.3919
Epoch 3/30
[1m58/58[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 10ms/step - accuracy: 0.3824 - loss: 1.4113 - val_accuracy: 0.3275 - val_loss: 1.4738
Epoch 4/30
[1m58/58[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 9ms/step - accuracy: 0.4050 - loss: 1.3519 - val_accuracy: 0.3265 - val_loss: 1.3481
Epoch 5/30
[1m58/58[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 9ms/step - accuracy: 0.4581 - loss: 1.2463 - val_accuracy: 0.3623 - val_loss: 1.3055
Epoch 6/30
[1m58/58[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 9ms/step - accuracy: 0.5122 - loss: 1.1405 - val_accuracy: 0.3731 - val_loss: 1.3099
Epoch 7/30
[1m58/58[0m [32m━━━━━━━━━━━━━━━━━

### 결과 출력

In [10]:
# 검증 데이터 성능 평가
val_loss, val_accuracy = model.evaluate(X_val, y_val_cat)
print(f'Validation Accuracy: {val_accuracy}')

# 테스트 데이터 예측
y_test_pred = model.predict(X_test)
y_test_pred_classes = np.argmax(y_test_pred, axis=1)

# 예측 결과 출력
print(f'Test predictions: {y_test_pred_classes}')

[1m29/29[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m6s[0m 2ms/step - accuracy: 0.3841 - loss: 2.3760
Validation Accuracy: 0.37744033336639404
[1m18/18[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 2ms/step   
Test predictions: [0 0 0 0 0 2 1 3 0 2 1 2 1 0 3 3 3 3 3 3 1 0 0 0 1 1 0 0 2 0 0 2 0 1 0 0 2
 2 0 0 3 2 2 3 3 2 0 3 3 2 0 0 1 2 2 3 1 1 2 2 2 3 1 0 0 3 1 2 0 3 3 2 0 3
 0 0 0 3 1 0 1 3 0 0 2 0 2 2 1 1 1 1 3 2 0 0 0 2 3 2 1 0 0 3 3 3 0 2 3 2 2
 1 2 3 3 2 2 3 2 0 0 0 0 0 2 0 3 1 2 3 1 0 0 1 2 0 3 2 2 0 2 0 2 0 1 0 2 2
 3 3 0 0 0 1 0 2 2 0 3 0 0 0 0 0 1 0 3 3 2 2 2 0 3 0 0 0 1 0 2 0 1 0 3 0 3
 0 2 2 2 1 2 1 2 0 0 3 2 3 0 0 0 2 2 0 0 3 1 1 2 2 0 2 0 0 3 2 0 1 2 3 2 1
 0 2 3 0 2 3 2 3 2 1 0 0 0 2 0 0 0 0 2 2 2 1 3 0 3 1 3 0 2 0 2 2 2 3 2 3 1
 1 0 0 3 0 0 0 1 2 2 0 3 0 0 1 1 3 1 3 3 3 2 3 2 0 2 1 0 0 0 3 0 3 3 2 0 2
 1 0 3 2 1 0 3 2 2 2 2 3 1 0 0 3 1 0 2 1 2 2 1 2 0 2 1 0 1 2 2 0 0 2 2 3 1
 0 3 1 0 1 1 0 0 0 2 2 0 2 2 0 2 2 1 1 0 2 2 0 3 2 2 0 0 3 0 1 3 0 1 3 3 1
 0 0 0 0