In [None]:
from torchvision.transforms import transforms
from torch.utils.data import Dataset
import pickle
import os
import torch
import numpy as np
from PIL import Image
import natsort
#set device
device = torch.device("mps" if torch.cuda.is_available() else "cpu")
#constants
VIDEO = 40
PARTICIPANT = 32
CHANNEL = 32
SESSION = 63
STAT_FEATURE_MAX_LEN = 58
IMAGE_PIXEL_NUM = 496*369
BAND_TYPE_NUM = 4


#Make Datasets
class Arousal_Dataset(Dataset):
    def __init__(self, image_dir_path, stat_data_path, train = None):
        super().__init__
        self.image_dir_path = image_dir_path
        self.stat_data_path  = stat_data_path
        image_name_list = os.listdir(image_dir_path)
        self.image_name_list = natsort.natsorted(image_name_list)

    def __getitem__(self, index):
        #Tensor Transformer
        to_tensor = transforms.ToTensor()
        ##image data
        #get image of theta to gamma
        #dictionary "band type" ->  tensor image
        band_images = {}
        band_types = ['alpha', 'beta', 'gamma', 'theta']
        for i in range(4):
            image_path = os.path.join(self.image_dir_path, self.image_name_list[4*index + i])
            image = np.array(Image.open(image_path))
            tensor_image = to_tensor(image)
            band_images[band_types[i]] = tensor_image

        image_name_split_list = self.image_name_list[4*index].split('_')

        #locate information of files
        par, vid, cha, ses = int(image_name_split_list[0]), int(
            image_name_split_list[1]), int(image_name_split_list[2]), int(image_name_split_list[3])
        
        # statistic data,[Par, Vid, Cha, Ses] (32,40,32,15)
        with open('./Data/Statistic_Features_np.pkl', 'rb') as f:
            statistic_features = pickle.load(f)

        statistic_feature = statistic_features[par-1][vid-1][cha-1][ses-1]
        #elong the legnth to (,58)
        statistic_feature = np.pad(statistic_feature, (0, STAT_FEATURE_MAX_LEN - len(statistic_feature)), 'constant', constant_values= 0)
        tensor_statistic_feature = to_tensor(statistic_feature)
        
        #band images : dictionary, tensor images
        #tensor statistic feature : tensor 1 dim array
        x_data = band_images, tensor_statistic_feature

        ##truth ground data form file name
        Truth_label = image_name_split_list[5]

        if Truth_label[0 :2] == 'HA':
            y_data = 1
        else:
            y_data = 0

        return x_data, y_data

    def __len__(self):
        return int(len(self.image_name_list)/4)

# #데이터 전처리
# transform = transforms.Compose(
#     [transforms.ToTensor(),transforms.Normalize((0.5 , 0.5, 0.5), (0.5, 0.5, 0.5))]
# )


#총 데이터 num :32 * 40 * 32 * 15 * 4(band), batch_size는 이의 약수, 논문에 명시됨
batch_size = 16
Data_Path_Train = "./Data/Train_Data/"
Data_Path_Validation = "./Data/Validation_Data/"
Data_Path_Test = "./Data/Test_Data/"



#test_데이터, train_데이터 불러오고 저장
train_set = Arousal_Dataset(data_path = Data_Path_Train, train = True)

train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers = 0)

test_set = Arousal_Dataset(data_path = Data_Path_Test, train = False)

test_loader = torch.utils.data.DataLoader(test_set, batch_size=batch_size, shuffle = False, num_workers = 0)

#output clasees
classes =  (0, 1)

In [None]:
import torch.nn as nn
import torch.nn.functional as F

class MSCB_CNN(nn.Module):
    def __init__(self):   
        super(SmallCNN, self).__init__()

        ###for MSCB
        self.conv_L = nn.Conv2d(in_channels = 3, out_channels = 14, kernel_size = (1,5) , stride = 1, padding = 'same')
        self.conv_M = nn.Conv2d(in_channels = 3, out_channels = 14, kernel_size = (1,3), stride = 1, padding = 'same')
        self.conv_S = nn.Conv2d(in_channels = 3, out_channels = 14, kernel_size = (1,1), stride = 1, padding = 'same')
        self.conv_raw = nn.Conv2d(in_channels = 3, out_channels = 14, kernel_size = (1,1), stride = 1, padding = 'same')
        
        #Stride = 5 -> 그렇기에 kernel size 또한 5, 그러나 데이터의 규격이 5의 배수가 아니여서 4로 조정
        self.pool = nn.MaxPool2d(kernel_size = 4, stride = 4, padding = 'same')

        # Convolutional layers
        self.conv = nn.Conv2d(in_channels = 56, out_channels = 112, kernel_size=(1,3), stride = 1, padding=3)

        #flatten layer
        self.flat = nn.Flatten()

        # Fully connected layers, 112 : final kernel num, /(4*4) striding number at pooling
        self.fc1 = nn.Linear(IMAGE_PIXEL_NUM * BAND_TYPE_NUM *112/(4*4)+ STAT_FEATURE_MAX_LEN, 400)
        self.fc2 = nn.Linear(400, 300)  
        # 2 output classes (0 and 1)«
        self.fc3 = nn.Linear(300, 2)

        # Dropout layer to prevent overfitting
        self.dropout = nn.Dropout(0.5)


        
    def forward(self, x):
        #MSCB Block
        def MSCB(self, x):
            x_L = self.pool((torch.relu(self.conv_L(x))))
            x_M = self.pool((torch.relu(self.conv_M(x))))
            x_S = self.pool((torch.relu(self.conv_S(x))))
            x_raw = self.conv_raw((torch.relu(self.pool(x))))

            ##2차원 이미지 합치기, kernerl수 증가
            x = torch.cat((x_L, x_M, x_S, x_raw), dim = 2)

            return x
        
        def Conv_to_FCL(self,x):
            x = self.pool(torch.relu(self.conv(x)))
            x = self.flat(x)

            return x
            
        # print(f"initial x {x.shape}")
        band_images, tensor_statistic_feature = x

        theta_MSCB = MSCB(self, band_images["theta"])
        alpha_MSCB = MSCB(self, band_images["alpha"])
        beta_MSCB = MSCB(self, band_images["beta"])
        gamma_MSCB = MSCB(self, band_images["gamma"])

        theta_features = Conv_to_FCL(theta_MSCB)
        alpha_features = Conv_to_FCL(alpha_MSCB)
        beta_features = Conv_to_FCL(beta_MSCB)
        gamma_features = Conv_to_FCL(gamma_MSCB)
        
        all_feature_map = torch.cat((theta_features,alpha_features,beta_features,gamma_features,tensor_statistic_feature), dim = 1)

        x = torch.relu(self.fc1(all_feature_map))
        x = self.dropout(x)
        x = torch.relu(self.fc2(x))
        x = self.fc3(x)

        return x

#initialize with MPS GPU
net = MSCB_CNN().to(device)


#Set Optimizer and loss function
import torch.optim as optim

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, weight_decay = 0.02)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size = 30, gamma = 0.5)

In [None]:
from tqdm.notebook import tqdm
#학습
start_flag = 1
for epoch in tqdm(range(500), desc= f"epoch"):   # 데이터셋을 수차례 반복합니다.

    running_loss = 0.0
    for i, data in enumerate(tqdm(train_loader, desc = "learning!!!"), 0):

        if start_flag == 1:
            print("start learning!!!")
            start_flag = 0

        # [inputs, labels]의 목록인 data로부터 입력을 받은 후;
        inputs, labels = data

        #검증
        # print(f"minibatch number : {i}, file's shape : {inputs.shape}")

        #move to GPU to calculate
        inputs, labels= inputs.to(device), labels.to(device)
        
        # 변화도(Gradient) 매개변수를 0으로 만들고
        optimizer.zero_grad()

        # 순전파 + 역전파 + 최적화를 한 후
        outputs = net(inputs)
        # print(outputs.shape)
        # print(labels)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        # 통계를 출력합니다.
        running_loss += loss.item()
        if i % 600 == 599:    # print every 600 mini-batches
            print(f'[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 2000:.3f}')
            running_loss = 0.0
            

print('Finished Training')

In [None]:
# ##Save Model
# PATH = "/content/drive/MyDrive/models/"
# torch.save(net.state_dict(), PATH)

In [13]:
#Accuracy Check
correct = 0
total = 0
# 학습 중이 아니므로, 출력에 대한 변화도를 계산할 필요가 없습니다
with torch.no_grad():
    for data in test_loader:
        images, labels = data
        images. labels = images.to(device), labels.to(device)
        # 신경망에 이미지를 통과시켜 출력을 계산합니다
        outputs = net(images).to(device)
        # 가장 높은 값(energy)를 갖는 분류(class)를 정답으로 선택하겠습니다
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print(f'Accuracy of the network on the 10000 test images: {100 * correct / total : .2f} %')

Accuracy of the network on the 10000 test images:  56.13 %


***결과 정리 노트***
None : 58.74
Laplace : 41.09
CAR : 56.13