# ShallowConvNet

2024.04.18 Written by @Chahyunee (Chaehyun Lee)

In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable
import torch.nn.functional as F
import datetime
import scipy.io
from torch.utils.data import DataLoader, random_split
from torchvision import transforms
from torch.utils.data import Dataset
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix, roc_auc_score

# Cross Validation
from sklearn.model_selection import KFold

# Plot
import matplotlib.pyplot as plt

dataset_dir = ''
model_dir = ''
class1_data, class2_data, class3_data, class4_data, class5_data= np.array, np.array, np.array, np.array, np.array


## data load

In [None]:
class_data = {}

for subj_num in range(1, 6):
    load_spec_dir = f'class{subj_num}.mat'
    data = scipy.io.loadmat(dataset_dir + load_spec_dir)
    data = data['data']
    data = np.array(data)
    data = data.reshape(600, order='F')
    
    # Add to dictionary
    class_data[f'class{subj_num}_data'] = np.array(data)

## Make a Dataset

In [None]:
# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

### Make a dataset class
EEGDataset class


In [None]:

transform = transforms.Compose([
                transforms.ToTensor(),])

class EEGDataset(Dataset):
    def __init__(self, inputs, labels, transform=None):
        self.inputs = inputs
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.inputs)

    def __getitem__(self, idx):
        
        eeg_sample = (torch.tensor(self.inputs[idx], dtype=torch.float32),torch.tensor(self.labels[idx], dtype=torch.int8))
        
        return eeg_sample


Extract Data from Dictionary and chane to PyTorch Tensor
- Change this part!!

In [None]:
X_data = np.concatenate([class_data[f'class{i}_data'] for i in range(1, 6)], axis=0)

for d in X_data:
    d = torch.tensor(d)

subj_num = 5
y_labels = torch.tensor(np.concatenate([np.full(600, i) for i in range(5)]), dtype=torch.int64)
y_labels_one_hot = torch.nn.functional.one_hot(y_labels, subj_num)


### Split dataset to train, validation, test
 0.8 train, 0.2 validation

In [None]:
dataset = EEGDataset(X_data, y_labels_one_hot)
train_size = int(0.8 * len(dataset))
val_size = int(0.1 * len(dataset))
test_size = int(0.1 * len(dataset))
train_dataset, val_dataset, test_dataset = random_split(dataset, [train_size, val_size, test_size])

### Data Loader

In [None]:

# DataLoader로 데이터 로딩
batch_size = 10
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True)
train_size, val_size, test_size


## ShallowNet Model Class

        
Note:

    Class & def. for ShallowConvNet layer
    - Class : LinearWithConstraint
    - Class : Conv2dWithConstraint
    - Def. : initialize_weight

    class ShallowConvNet


In [None]:

class LinearWithConstraint(nn.Linear):
    def __init__(self, *config, max_norm=1, **kwconfig):
        self.max_norm = max_norm
        super(LinearWithConstraint, self).__init__(*config, **kwconfig)

    def forward(self, x):
        self.weight.data = torch.renorm(
            self.weight.data, p=2, dim=0, maxnorm=self.max_norm
        )
        return super(LinearWithConstraint, self).forward(x)


class Conv2dWithConstraint(nn.Conv2d):
    def __init__(self, *config, max_norm=1, **kwconfig):
        self.max_norm = max_norm
        super(Conv2dWithConstraint, self).__init__(*config, **kwconfig)

    def forward(self, x):
        self.weight.data = torch.renorm(
            self.weight.data, p=2, dim=0, maxnorm=self.max_norm
        )
        return super(Conv2dWithConstraint, self).forward(x)
def initialize_weight(model, method):
    method = dict(normal=['normal_', dict(mean=0, std=0.01)],
                  xavier_uni=['xavier_uniform_', dict()],
                  xavier_normal=['xavier_normal_', dict()],
                  he_uni=['kaiming_uniform_', dict()],
                  he_normal=['kaiming_normal_', dict()]).get(method)
    if method is None:
        return None

    for module in model.modules():
        # LSTM
        if module.__class__.__name__ in ['LSTM']:
            for param in module._all_weights[0]:
                if param.startswith('weight'):
                    getattr(nn.init, method[0])(getattr(module, param), **method[1])
                elif param.startswith('bias'):
                    nn.init.constant_(getattr(module, param), 0)
        else:
            if hasattr(module, "weight"):
                # Not BN
                if not ("BatchNorm" in module.__class__.__name__):
                    getattr(nn.init, method[0])(module.weight, **method[1])
                # BN
                else:
                    nn.init.constant_(module.weight, 1)
                if hasattr(module, "bias"):
                    if module.bias is not None:
                        nn.init.constant_(module.bias, 0)

torch.set_printoptions(linewidth=1000)


In [None]:

class ActSquare(nn.Module):
    def __init__(self):
        super(ActSquare, self).__init__()
        pass

    def forward(self, x):
        return torch.square(x)


class ActLog(nn.Module):
    def __init__(self, eps=1e-06):
        super(ActLog, self).__init__()
        self.eps = eps

    def forward(self, x):
        return torch.log(torch.clamp(x, min=self.eps))
import warnings

In [None]:
class ShallowConvNet(nn.Module):
    
    """
    
    Parameters
    
    < requisiment>
    n_classes : number of target classses
    input_shape : s, t (e.g. [32, 1048])
        s: number of channels
        t: number of timepoints
    
    
    < option >
    
    F1, F2 : filter size of the first and second layer for temporal information
                Default: F1 = 5, F2 = 10
    T1 : number of time points in one trial (e.g. sec x sampling rate)
                Default: T1 = 25
    
    P1_T : pooling layer-temporal
    P1_S : pooling layer-spatial
    
    dropout : the rate of dropout. Default: 0.5            
    pool_mode : mode of the pooling. (mean, max) Default: 'mean'
    
    F1, F2  : number of temporal filters (F1) and number of pointwise
            filters (F2) to learn. Default: F1 = 8, F2 = F1 * D. 
    D       : number of spatial filters to learn within each temporal
            convolution. Default: D = 2           
    
    """
    
    
    def __init__(
            self,
            n_classes,
            input_shape,
            F1=5,
            T1=25,
            F2=10,
            P1_T=75, # pooling layer-temporal
            P1_S=15, # pooling layer-spatial
            drop_out=0.5,
            pool_mode='mean',
            weight_init_method=None,
            last_dim= 3072 #F2*spatial size*temporal size 
    ):
        super(ShallowConvNet, self).__init__()
        s, t = input_shape
        
        pooling_layer = dict(max=nn.MaxPool2d, mean=nn.AvgPool2d)[pool_mode]
        
        # chaehyunee edited ver.
        self.constConv2d1 = Conv2dWithConstraint(1, F1, (1, T1), max_norm=2)
        self.constConv2d2 = Conv2dWithConstraint(F1, F2, (s, 1), bias=False, max_norm=2)
        self.bn1 = nn.BatchNorm2d(F2)
        self.pool1 = pooling_layer((1, P1_T), (1, P1_S))
        self.dropout1 = nn.Dropout(drop_out)
        self.flatten1 = nn.Flatten()
        self.linear1 = nn.Linear(1990, last_dim) # TODO 1990 자리에 맞는 param 넣기
        self.linear2 = nn.Linear(last_dim, n_classes)
        self.linearConst1 = LinearWithConstraint(last_dim, n_classes, max_norm=1)
        
        initialize_weight(self, weight_init_method)

    def forward(self, x):
        
        x = self.constConv2d1(x)
        x = self.constConv2d2(x)
        x = self.bn1(x)
        ActSquare().forward(x)
        x = self.pool1(x)
        ActLog().forward(x)
        x = self.dropout1(x)
        x = self.flatten1(x)
        x = self.linear1(x)
        out = self.linear2(x)
        # out = self.linearConst1(x) # TODO constraint 적용한 version으로 수정 필요
        
        
        return out

### model evaluation

In [None]:
from pprint import pprint as pp

def evaluate(true_labels, predicted_labels, subj_num = 5, mode='train'):
    
    result = dict(recall_per_class = [], f1_per_class = [], acc_per_class = [], precision_per_class = [])
    
    if mode == 'train':
        for class_idx in range(subj_num):
            recall_class = recall_score(true_labels[:, class_idx].cpu().detach().numpy(), predicted_labels[:, class_idx].cpu().detach().numpy())
            f1_class = f1_score(true_labels[:, class_idx].cpu().detach().numpy(), predicted_labels[:, class_idx].cpu().detach().numpy())
            acc_class = accuracy_score(true_labels[:, class_idx].cpu().detach().numpy(), predicted_labels[:, class_idx].cpu().detach().numpy())
            precision_class = precision_score(true_labels[:, class_idx].cpu().detach().numpy(), predicted_labels[:, class_idx].cpu().detach().numpy(), zero_division=1)
            
            result['recall_per_class'].append(recall_class)
            result['f1_per_class'].append(f1_class)
            result['acc_per_class'].append(acc_class)
            result['precision_per_class'].append(precision_class)
        
    else:    
        for class_idx in range(subj_num):
            recall_class = recall_score(true_labels[:, class_idx].cpu().numpy(), predicted_labels[:, class_idx].cpu().numpy())
            f1_class = f1_score(true_labels[:, class_idx].cpu().numpy(), predicted_labels[:, class_idx].cpu().numpy())
            acc_class = accuracy_score(true_labels[:, class_idx].cpu().numpy(), predicted_labels[:, class_idx].cpu().numpy())
            precision_class = precision_score(true_labels[:, class_idx].cpu().numpy(), predicted_labels[:, class_idx].cpu().numpy(), zero_division=1)
            
            result['recall_per_class'].append(recall_class)
            result['f1_per_class'].append(f1_class)
            result['acc_per_class'].append(acc_class)
            result['precision_per_class'].append(precision_class)
                    
    result['average_recall'] = sum(result['recall_per_class']) / len(result['recall_per_class'])
    result['average_f1'] = sum(result['f1_per_class']) / len(result['f1_per_class'])
    result['average_acc'] = sum(result['acc_per_class']) / len(result['acc_per_class'])
    result['average_prec'] = sum(result['precision_per_class']) / len(result['precision_per_class'])          
    
    return result


### Prepare Training

In [None]:
# Change it!
epochs = 300
train_losses, val_losses = [], []

shallow_input_shape = [32, 1024 ]

model = ShallowConvNet(n_classes=5, input_shape= shallow_input_shape)
model = model.to(device)
criterion = nn.MSELoss() 
optimizer = optim.Adam(model.parameters(), lr=1e-3)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer=optimizer, mode='min', factor=0.6, min_lr=1e-10)

## Train

In [None]:

# Set fixed random number seed
seed_n = np.random.randint(500)
print('seed is ' + str(seed_n))
np.random.seed(seed_n)
torch.manual_seed(seed_n)
torch.cuda.manual_seed(seed_n)
torch.cuda.manual_seed_all(seed_n)

In [None]:

# 학습
epochs = 200
train_losses, val_losses = [], []

# 학습 진행 상황 데이터를 저장할 변수 설정
history = {'val_loss': [], 'val_acc': [], 'val_f1': [], 'val_prec': [], 'val_rec': [],
            'train_loss': [], 'train_acc' : [], 'train_f1': [], 'train_prec': [], 'train_rec': []}


for epoch in range(epochs):
    
    for batch_idx, (inputs, labels) in enumerate(train_loader):
        model.train()
        
        optimizer.zero_grad()
        
        inputs = inputs.to(device)
        labels = labels.to(device)

    
        outputs = model(inputs.unsqueeze(1))
        
        # Assign 1 to the position of the largest value and 0 to the rest.
        predicted_labels = torch.eye(outputs.shape[1])[torch.argmax(outputs.cpu(), dim=1)]

        true_labels = labels.float()
        
        train_loss = criterion(outputs, labels.float())
        
        # Calculating evaluation metrics for each class in a multi-label scenario.
        train_result = evaluate(true_labels, predicted_labels) # dictionary return

        train_loss.backward()
        optimizer.step()

    # 검증
    model.eval()
    with torch.no_grad():
        val_loss = 0.0
        
        for inputs, labels in val_loader:
            inputs = inputs.to(device)
            labels = labels.to(device)
            
            outputs = model(inputs.unsqueeze(1))
            predicted_labels = torch.eye(outputs.shape[1])[torch.argmax(outputs.cpu(), dim=1)]
            true_labels = labels.float()
            
            val_loss += criterion(outputs, true_labels)
            
            valid_result = evaluate(true_labels, predicted_labels, mode='valid') # dictionary return
        
        scheduler.step(val_loss)  
            
    

    print(f'\nEpoch {epoch + 1}/{epochs} \n\
        train loss: {train_loss}, valid loss: {val_loss / len(val_loader)}')
    pp(train_result)
    pp(valid_result)
    
    history['train_loss'].append(train_loss)
    history['val_loss'].append(val_loss/ len(val_loader))
    history['train_acc'].append(train_result['average_acc'])
    history['val_acc'].append(valid_result['average_acc'])
    history['train_f1'].append(train_result['average_f1'])
    history['val_f1'].append(valid_result['average_f1'])
    history['train_prec'].append(train_result['average_prec'])
    history['val_prec'].append(valid_result['average_prec'])
    history['train_rec'].append(train_result['average_recall'])
    history['val_rec'].append(valid_result['average_recall'])

    

print('Finished Training')


## model test

In [None]:
from sklearn.metrics import confusion_matrix
import seaborn as sn
import pandas as pd

y_pred = []
y_true = []

# iterate over test data
for inputs, labels in test_loader:
        inputs = inputs.to(device)
        labels = labels.to(device)
        outputs = model(inputs.unsqueeze(1)) # Feed Network

        # predicted_labels = torch.round(outputs)
        predicted_labels = torch.eye(outputs.shape[1])[torch.argmax(outputs.cpu(), dim=1)]
        # print('predicted_labels : ', predicted_labels)

        output = (torch.max(torch.exp(predicted_labels), 1)[1]).cpu().numpy()
        print('output : ', output)
        y_pred.extend(output) # Save Prediction
        
        labels = (torch.max(torch.exp(labels), 1)[1]).cpu().numpy()
        print('labels  : ', labels)
        y_true.extend(labels) # Save Truth

## Model test result


In [None]:
# constant for classes
classes = ('class a', 'calss b', 'class c', 'class d', 'class e')

print(y_true)
print(y_pred)


In [None]:

# Build confusion matrix
cf_matrix = confusion_matrix(y_true, y_pred)
df_cm = pd.DataFrame(cf_matrix / np.sum(cf_matrix, axis=1)[:, None], index = [i for i in classes],
                    columns = [i for i in classes])

df_test = pd.DataFrame(cf_matrix, index = [i for i in classes],
                    columns = [i for i in classes])

df_test
df_cm


In [None]:
import datetime

# 현재 날짜와 시간을 문자열로 변환
current_time = datetime.datetime.now().strftime("%Y%m%d_%H%M")
current_time
plt.figure(figsize = (12,7))
sn.heatmap(df_cm, annot=True, cmap="YlGnBu", vmin=0, vmax=1)

# Adding labels to x-axis and y-axis
plt.xlabel('Predicted')
plt.ylabel('True')

acc = valid_result['average_acc']
plt.savefig(f'Figure/class{subj_num}_ShallowNet_acc{acc}_{current_time}_output.png')


In [None]:
plt.figure(figsize = (12,7))
sn.heatmap(df_test, annot=True, cmap="YlGnBu")

# Adding labels to x-axis and y-axis
plt.xlabel('Predicted')
plt.ylabel('True')

acc = valid_result['average_acc']
plt.savefig(f'Figure/class{subj_num}_ShallowNet_acc{acc}_{current_time}_output.png')


Accuracy Loss graph

In [None]:
fig, axes = plt.subplots(nrows=2, ncols=1, figsize=(10, 8))

val_acc = []
history['val_acc']

axes[0].plot(range(epochs), history['val_acc'], label='Valid Accuracy', color='b')
axes[0].plot(range(epochs), history['train_acc'], label='Train Accuracy', color='r')  # Train 데이터 추가
axes[0].set_xlabel('epochs')
axes[0].set_ylabel('Accuracy(%)')
axes[0].grid(linestyle='--', color='lavender')
axes[0].legend()
axes[0].set_ylim(0, 1)

axes[1].plot(range(epochs), [loss.item() for loss in history['val_loss']], label='Valid Loss', color='g')
axes[1].plot(range(epochs), [loss.item() for loss in history['train_loss']], label='Train Loss', color='y')
axes[1].set_xlabel('epochs')
axes[1].set_ylabel('Loss(%)')
axes[1].legend()
axes[1]

## model save code

In [None]:
model_save_path = f'{model_dir}ShallowNet_acc{acc}_{current_time}.pth'
torch.save(model.state_dict(), model_save_path)

In [None]:
# model load
model_load_path = f'{model_dir} ... .pth'

loaded_model = ShallowConvNet(n_classes=5, input_shape=shallow_input_shape)
loaded_model.load_state_dict(torch.load(model_load_path))
loaded_model = loaded_model.to(device)