### 1DCNN Module

In [15]:
import torch
import torch.nn as nn

In [31]:
class ODCM(nn.Module):
    def __init__(self, input_channels = 19, kernel_size = 10, dtype=torch.float32):
        super(ODCM, self).__init__()
        self.input_channels = input_channels
        self.kernel_size = kernel_size  # 1 X 10
        self.ncf = 120  # The number of the depth-wise convolutional filter used in the three layers is set to 120
        self.dtype = dtype
        self.cvf1 = nn.Conv1d(in_channels=self.input_channels, out_channels=self.input_channels, kernel_size=self.kernel_size, padding='valid', stride=1, groups=self.input_channels, dtype=self.dtype)
        self.cvf2 = nn.Conv1d(in_channels=self.cvf1.out_channels, out_channels=self.cvf1.out_channels, kernel_size=self.kernel_size, padding='valid', stride=1, groups=self.cvf1.out_channels, dtype=self.dtype)
        # For each channel (19 channels), has 120 independent features due to the depthwise convolution in third layer
        # 채널 한 개당 120개의 filter 를 할당해서 feature 를 뽑아낸다
        self.cvf3 = nn.Conv1d(in_channels=self.cvf2.out_channels, out_channels=self.ncf * self.cvf2.out_channels, kernel_size=self.kernel_size, padding='valid', stride=1, groups=self.cvf2.out_channels, dtype=self.dtype)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.cvf1(x)
        x = self.relu(x)
        x = self.cvf2(x)
        x = self.relu(x)
        x = self.cvf3(x)
        x = self.relu(x)
        # x = torch.reshape(x, ((int)(x.shape[0] / self.ncf), self.ncf, (int)(x.shape[1])))
        # Dimension: (Batch, 2280, L3) → (Batch, 120, 19, L3) → (Batch, 19, 120, L3)
        # x = x.view(x.shape[0], self.ncf, self.input_channels, -1)  # (Batch, 120, 19, L3)
        # x = x.permute(0, 2, 1, 3)  # (Batch, 19, 120, L3)
        
        return x

In [17]:
import mne
import torch

file_path = '/Users/hwangjeongho/Desktop/EEG Transformer/model-data/train/sub-001_eeg_chunk_1.set'

raw = mne.io.read_raw_eeglab(file_path, preload=True)
data = raw.get_data()  # shape: (19, 1425)

eeg_tensor = torch.tensor(data, dtype=torch.float32).unsqueeze(0)  # shape: (1, 19, 1425)


In [18]:
eeg_tensor.shape

torch.Size([1, 19, 1425])

In [10]:
model = ODCM(input_channels=19, kernel_size=10)
output = model(eeg_tensor)

print("Output shape:", output.shape)  # Expected: (1, 19, 120, Le)

Output shape: torch.Size([1, 19, 120, 1398])


### Truncated Normal Distribution Function
- For Positional Embedding (Initialization)
- To initialize the Weight or Embedding Vector
- To avoid too big weights or too small weights
- To avoid exploding gradient / vanishing gradient

- 모델 안의 weight나 embedding vector를 초기화하는 함수
- 훈련 시작 전에 좋은 초기값을 주기 위해 사용함.
- 평균이 0이고 표준편차가 1인 정규분포에서 특정 범위 (-2, +2) Sigma 사이에서만 값이 나오게 하고 싶다
- PyTorch 에서는 Truncated Normal 이 없기 때문에 돌아서 작업해야한다.
- 정규분포를 직접 뽑기 어려우니, 먼저 Uniform 분포 (0~1)에서 뽑은 다음, 그걸 정규분포의 형태로 바꿔서 쓰자


Step 1] 원하는 정규분포 범위 [a, b] → 확률 구간 [l, u]

Step 2] [l, u] → [-1, 1] 매핑 → uniform 값 생성

Step 3] uniform → 정규분포로 변환 (inverse CDF = erfinv)

Step 4] 평균/표준편차 맞춤 → 정규분포 값 완성

Step 5] 너무 큰 값은 clamp

In [22]:
import os
import torch
import torch.nn as nn
import math

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


def trunc_normal(tensor, mean=0., std=1., a=-2., b=2.):  # for positional embedding - borrowed from Meta
    def norm_cdf(x):  # Computes standard normal cumulative distribution function
        return (1. + math.erf(x / math.sqrt(2.))) / 2.

    if (mean < a - 2 * std) or (mean > b + 2 * std):
        print("mean is more than 2 std from [a, b] in nn.init.trunc_normal\nThe distribution of values may be incorrect.")

    with torch.no_grad():  # Values are generated by using a truncated uniform distribution and then using the inverse CDF for the normal distribution.
        # Get upper and lower cdf values
        l = norm_cdf((a - mean) / std)
        u = norm_cdf((b - mean) / std)

        # Uniformly fill tensor with values from [l, u], then translate to [2l-1, 2u-1].
        tensor.uniform_(2 * l - 1, 2 * u - 1)

        # Use inverse cdf transform for normal distribution to get truncated standard normal
        tensor.erfinv_()

        # Transform to proper mean, std
        tensor.mul_(std * math.sqrt(2.))
        tensor.add_(mean)

        # Clamp to ensure it's in the proper range
        tensor.clamp_(min=a, max=b)
        return tensor

In [23]:
class Mlp(nn.Module): # Multilayer perceptron
    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0., dtype=torch.float32):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        self.fc1 = nn.Linear(in_features, hidden_features, dtype=dtype)
        self.act = act_layer()
        self.fc2 = nn.Linear(hidden_features, out_features, dtype=dtype)
        self.drop = nn.Dropout(drop)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x


### Multi-Head Self-Attention + FeedForward Network (= Generic Transformer Block)

In [24]:
class GenericTFB(nn.Module):
    def __init__(self, emb_size, num_heads, dtype):
        super(GenericTFB, self).__init__()

        self.M_size1 = emb_size  # -> D (Dimension of Feature Vector)
        self.dtype = dtype
        self.hA = num_heads  # number of multi-head self-attention units (A is the number of units in a block)
        self.Dh = int(self.M_size1 / self.hA)  # Dh is the quotient computed by D/A and denotes the dimension number of three vectors.
        self.Wqkv = nn.Parameter(torch.randn((3, self.hA, self.Dh, self.M_size1), dtype=self.dtype)) # 3 (=Q,K,V), # of Head, Dimension of each Head, Total embedding dimension  
        self.Wo = nn.Parameter(torch.randn(self.M_size1, self.M_size1, dtype=self.dtype))

        self.lnorm = nn.LayerNorm(self.M_size1, dtype=self.dtype)  # LayerNorm operation for dimension D
        self.lnormz = nn.LayerNorm(self.M_size1, dtype=self.dtype)  # LayerNorm operation for z
        self.mlp = Mlp(in_features=self.M_size1, hidden_features=int(self.M_size1 * 4), act_layer=nn.GELU, dtype=self.dtype)  # mlp_ratio=4

    def forward(self, x, savespace):
        # x.shape[2] = S 
        # x.shape[0] + 1 = C + Classification Token (=1)
        qkvspace = torch.zeros(3, x.shape[2], x.shape[0] + 1, self.hA, self.Dh, dtype=self.dtype).to(device)  # To store Q, K, V space
        atspace = torch.zeros(x.shape[2], self.hA, x.shape[0] + 1, x.shape[0] + 1, dtype=self.dtype).to(device) # To store Attention Score Matrix
        imv = torch.zeros(x.shape[2], x.shape[0] + 1, self.hA, self.Dh, dtype=self.dtype).to(device) # To store Attention Output Vector

        # Calculation; shape =  (3, S, C+1, hA, Dh)
        qkvspace = torch.einsum('xhdm,ijm -> xijhd', self.Wqkv, self.lnorm(savespace))

        # - Attention score
        # - Query: (S, C+1, hA, Dh) -> (S, hA, C+1, Dh)
        # - Key: (S, C+1, hA, Dh) -> (S, hA, Dh, C+1)
        # - QK^T: (S, hA, C+1, C+1)
        atspace = (qkvspace[0].clone().transpose(1, 2) / math.sqrt(self.Dh)) @ qkvspace[1].clone().transpose(1,2).transpose(-2, -1)

        # - Intermediate vectors
        # - Value: (S, C+1, hA, Dh) -> (S, hA, C+1, Dh)
        # - (QK^T)V: (S, hA, C+1, Dh)
        # - After Transpose: (S, C+1, hA, Dh)
        imv = (atspace.clone() @ qkvspace[2].clone().transpose(1, 2)).transpose(1, 2)

        # - NOW SAY HELLO TO NEW Z!
        # - imv reshape: (S, C+1, hA, Dh) -> (S, C+1, D)
        # - W0 shape: (D, D)
        # - Linear Projection to W0 to organize each information from heads
        # - Result: (S, C+1, D)
        # - Add Residual Connection
        savespace = torch.einsum('nm,ijm -> ijn', self.Wo, imv.clone().reshape(x.shape[2], x.shape[0] + 1, self.M_size1)) + savespace  # z'

        # - normalized by LN() and passed through a multilayer perceptron (MLP)
        savespace = self.mlp(self.lnormz(savespace)) + savespace  # new z

        return savespace


### Temporal Transformer Block (Multi-Head Self-Attention + Feed Forward Network)

In [25]:
class TemporalTFB(nn.Module):
    def __init__(self, emb_size, num_heads, avgf, dtype):
        super(TemporalTFB, self).__init__()

        self.avgf = avgf  # average factor (M)
        self.M_size1 = emb_size  # Feature Vector -> D
        self.dtype = dtype
        self.hA = num_heads  # number of multi-head self-attention units (A is the number of units in a block)
        self.Dh = int(self.M_size1 / self.hA)  # Dh is the quotient computed by D/A and denotes the dimension number of three vectors.
        self.Wqkv = nn.Parameter(torch.randn((3, self.hA, self.Dh, self.M_size1), dtype=self.dtype)) # (3, hA, Dh, D)
        self.Wo = nn.Parameter(torch.randn(self.M_size1, self.M_size1, dtype=self.dtype)) # (D, D)

        self.lnorm = nn.LayerNorm(self.M_size1, dtype=self.dtype)  # LayerNorm operation for dimension D
        self.lnormz = nn.LayerNorm(self.M_size1, dtype=self.dtype)  # LayerNorm operation for z
        self.mlp = Mlp(in_features=self.M_size1, hidden_features=int(self.M_size1 * 4), act_layer=nn.GELU, dtype=self.dtype)  # mlp_ratio=4

    def forward(self, x, savespace):
        # Q, K, V = (3, M+1, hA, Dh)
        qkvspace = torch.zeros(3, self.avgf + 1, self.hA, self.Dh, dtype=self.dtype).to(device)  # Q, K, V
        # attention space = (hA, M+1, M+1)
        atspace = torch.zeros(self.hA, self.avgf + 1, self.avgf + 1, dtype=self.dtype).to(device)
        # intermediate vector space = (M+1, hA, Dh)
        imv = torch.zeros(self.avgf + 1, self.hA, self.Dh, dtype=self.dtype).to(device)

        # (3, hA, Dh, D) x (M+1, D) -> (3, M+1 , hA, Dh) 
        qkvspace = torch.einsum('xhdm,im -> xihd', self.Wqkv, self.lnorm(savespace))  # Q, K, V

        # - Attention score
        # - Query: (hA, M+1, Dh)
        # - Key: (hA, Dh, M+1)
        # - QK^T: (hA, M+1, M+1)
        atspace = (qkvspace[0].clone().transpose(0, 1) / math.sqrt(self.Dh)) @ qkvspace[1].clone().transpose(0, 1).transpose(-2, -1)

        # - Intermediate vectors
        # - Value: (M+1 , hA, Dh)
        # - (QK^T)V: (hA, M+1, M+1) x (hA, M+1, Dh) = (hA, M+1, Dh)
        # - Result: (M+1, hA, Dh)
        imv = (atspace.clone() @ qkvspace[2].clone().transpose(0, 1)).transpose(0, 1)

        # - NOW SAY HELLO TO NEW Z!
        # - (D, D) x (M+1, D) = (M+1, D)
        savespace = torch.einsum('nm,im -> in', self.Wo, imv.clone().reshape(self.avgf + 1, self.M_size1)) + savespace  # z'

        # - normalized by LN() and passed through a multilayer perceptron (MLP)
        savespace = self.mlp(self.lnormz(savespace)) + savespace  # new z

        return savespace

### Regional Transformer Module

In [26]:
class RTM(nn.Module):  # Regional transformer module
    def __init__(self, input, num_blocks, num_heads, dtype):  # input -> S x C x D
        super(RTM, self).__init__()
        self.inputshape = input.transpose(0, 1).transpose(1, 2).shape  # C x D x S
        self.M_size1 = self.inputshape[1]  # -> D
        self.dtype = dtype

        self.tK = num_blocks  # number of transformer blocks - K in the paper
        self.hA = num_heads  # number of multi-head self-attention units (A is the number of units in a block)
        self.Dh = int(self.M_size1 / self.hA)  # Dh is the quotient computed by D/A and denotes the dimension number of three vectors.

        # Each head should have the same dimension of feature matrix
        if self.M_size1 % self.hA != 0 or int(self.M_size1 / self.hA) == 0:
            print(f"ERROR 1 - RTM : self.Dh = {int(self.M_size1 / self.hA)} != {self.M_size1}/{self.hA} \nTry with different num_heads")

        self.weight = nn.Parameter(torch.randn(self.M_size1, self.inputshape[1], dtype=self.dtype)) # (D, D)
        self.bias = nn.Parameter(torch.zeros(self.inputshape[2], self.inputshape[0] + 1, self.M_size1, dtype=self.dtype))  # S x C x D
        self.cls = nn.Parameter(torch.zeros(self.inputshape[2], 1, self.M_size1, dtype=self.dtype)) # (S, 1, D)
        trunc_normal(self.bias, std=.02)
        trunc_normal(self.cls, std=.02)
        self.tfb = nn.ModuleList([GenericTFB(self.M_size1, self.hA, self.dtype) for _ in range(self.tK)])

    def forward(self, x):
        x = x.transpose(0, 1).transpose(1, 2)  # C x D x S

        savespace = torch.zeros(x.shape[2], x.shape[0], self.M_size1, dtype=self.dtype).to(device)  # S x C x D
        savespace = torch.einsum('lm,jmi -> ijl', self.weight, x) # (D, D) x (C, D, S) = (S, C, D)
        savespace = torch.cat((self.cls, savespace), dim=1)  # ! -> S x (C+1) x D
        savespace = torch.add(savespace, self.bias)  # z -> S x C x D

        for tfb in self.tfb:
            savespace = tfb(x, savespace)

        return savespace  # S x C x D - z4 in the paper

### Synchronous Transformer Module

In [27]:
class STM(nn.Module):  # Synchronous transformer module
    def __init__(self, input, num_blocks, num_heads, dtype):  # input -> # S x C x D
        super(STM, self).__init__()
        self.inputshape = input.transpose(1, 2).shape  # S x D x C (S x Le x C in the paper)
        self.M_size1 = self.inputshape[1]  # -> D
        self.dtype = dtype

        self.tK = num_blocks  # number of transformer blocks - K in the paper
        self.hA = num_heads  # number of multi-head self-attention units (A is the number of units in a block)
        self.Dh = int(self.M_size1 / self.hA)  # Dh is the quotient computed by D/A and denotes the dimension number of three vectors.

        if self.M_size1 % self.hA != 0 or int(self.M_size1 / self.hA) == 0:
            print(f"ERROR 2 - STM : self.Dh = {int(self.M_size1 / self.hA)} != {self.M_size1}/{self.hA} \nTry with different num_heads")

        self.weight = nn.Parameter(torch.randn(self.M_size1, self.inputshape[1], dtype=self.dtype)) # (D x D)
        self.bias = nn.Parameter(torch.zeros(self.inputshape[2], self.inputshape[0] + 1, self.M_size1, dtype=self.dtype)) # (C, S+1, D)
        self.cls = nn.Parameter(torch.zeros(self.inputshape[2], 1, self.M_size1, dtype=self.dtype)) # (C, 1, D)
        trunc_normal(self.bias, std=.02)
        trunc_normal(self.cls, std=.02)
        self.tfb = nn.ModuleList([GenericTFB(self.M_size1, self.hA, self.dtype) for _ in range(self.tK)])

    def forward(self, x):  # S x C x D -> x
        x = x.transpose(1, 2)  # S x D x C

        savespace = torch.zeros(x.shape[2], x.shape[0] + 1, self.M_size1, dtype=self.dtype).to(device)  # (C x S+1 x D)
        savespace = torch.einsum('lm,jmi -> ijl', self.weight, x) # (D x D) x (S x D x C) = (C x S x D)
        savespace = torch.cat((self.cls, savespace), dim=1)  # (C x S+1 x D)
        savespace = torch.add(savespace, self.bias)  # z -> (C x S+1 x D)

        for tfb in self.tfb:
            savespace = tfb(x, savespace)

        return savespace  # C x S x D - z5 in the paper

#### Temporal Transformer Module

In [29]:
class TTM(nn.Module):  # Temporal transformer module
    def __init__(self, input, num_submatrices, num_blocks, num_heads, dtype):  # input -> # C x S x D
        super(TTM, self).__init__()
        self.dtype = dtype
        self.avgf = num_submatrices  # average factor (M)
        self.input = input.transpose(0, 2)  # D x S x C
        self.seg = self.input.shape[0] / self.avgf

        if self.input.shape[0] % self.avgf != 0 or int(self.input.shape[0] / self.avgf) == 0:
            print(f"ERROR 3 - TTM : self.seg = {self.seg} != {self.input.shape[0]}/{self.avgf}")

        self.M_size1 = self.input.shape[1] * self.input.shape[2] # Submatrix is flattened to a 1D vector (= L1) = S x C = D
        self.tK = num_blocks  # number of transformer blocks - K in the paper
        self.hA = num_heads  # number of multi-head self-attention units (A is the number of units in a block)
        self.Dh = int(self.M_size1 / self.hA)

        if self.M_size1 % self.hA != 0 or int(self.M_size1 / self.hA) == 0:  # - Dh = 121*(S+1) / num_heads
            print(f"ERROR 4 - TTM : self.Dh = {int(self.M_size1 / self.hA)} != {self.M_size1}/{self.hA} \nTry with different num_heads")

        self.weight = nn.Parameter(torch.randn(self.M_size1, self.input.shape[1] * self.input.shape[2], dtype=self.dtype)) # (D, D)
        self.bias = nn.Parameter(torch.zeros(self.avgf + 1, self.M_size1, dtype=self.dtype)) # (M+1, D)
        self.cls = nn.Parameter(torch.zeros(1, self.M_size1, dtype=self.dtype)) # (1, D)
        trunc_normal(self.bias, std=.02)
        trunc_normal(self.cls, std=.02)
        self.tfb = nn.ModuleList([TemporalTFB(self.M_size1, self.hA, self.avgf, self.dtype) for _ in range(self.tK)])

        self.lnorm_extra = nn.LayerNorm(self.M_size1, dtype=self.dtype)  # EXPERIMENTAL

    def forward(self, x):
        input = x.transpose(0, 2)  # D x S x C
        inputc = torch.zeros(self.avgf, input.shape[1], input.shape[2], dtype=self.dtype).to(device)  # M x S x C
        for i in range(0, self.avgf):  # each i consists self.input.shape[0]/avgf
            for j in range(int(i * self.seg), int((i + 1) * self.seg)):  # int(i*self.seg), int((i+1)*self.seg)
                inputc[i, :, :] = inputc[i, :, :] + input[j, :, :]
            inputc[i, :, :] = inputc[i, :, :] / self.seg

        altx = inputc.reshape(self.avgf, input.shape[1] * input.shape[2]).to(device)  # M x L -> M x (S*C)

        savespace = torch.zeros(self.avgf, self.M_size1, dtype=self.dtype).to(device)  # M x D
        savespace = torch.einsum('lm,im -> il', self.weight, altx.clone()) # (D, D) x (M, D) = (M, D)
        savespace = torch.cat((self.cls, savespace), dim=0) # (M+1, D)
        savespace = torch.add(savespace, self.bias)  # z -> (M+1 x D)

        for tfb in self.tfb:
            savespace = tfb(x, savespace)

        savespace = self.lnorm_extra(savespace)  # EXPERIMENTAL
        return savespace.reshape(self.avgf + 1, input.shape[1], input.shape[2])

### Decoder (CNN)

In [None]:
class CNNdecoder(nn.Module):  # EEGformer decoder
    def __init__(self, input, num_cls, CF_second, dtype):  # input -> # M x S x C
        super(CNNdecoder, self).__init__()
        self.input = input.transpose(0, 1).transpose(1, 2)  # S x C x M
        self.s = self.input.shape[0]  # S
        self.c = self.input.shape[1]  # C
        self.m = self.input.shape[2]  # M
        self.n = CF_second
        self.dtype = dtype
        self.cvd1 = nn.Conv1d(in_channels=self.c, out_channels=1, kernel_size=1, dtype=self.dtype)  # S x M
        self.cvd2 = nn.Conv1d(in_channels=self.s, out_channels=self.n, kernel_size=1, dtype=self.dtype)
        self.cvd3 = nn.Conv1d(in_channels=self.m, out_channels=int(self.m / 2), kernel_size=1, dtype=self.dtype)
        self.fc = nn.Linear(int(self.m / 2) * self.n, num_cls, dtype=self.dtype)
        self.relu = nn.ReLU()

    def forward(self, x):  # x -> M x S x C
        x = x.transpose(0, 1).transpose(1, 2)  # S x C x M
        x = self.cvd1(x)  # S x M
        x = self.relu(x)
        x = x[:, 0, :] # can be replaced with x.squeeze(x,1) in torch 2.0 or higher
        x = self.cvd2(x).transpose(0, 1)  # N x M transposed to M x N
        x = self.relu(x)
        x = self.cvd3(x)  # M/2 x N
        x = self.relu(x)
        x = self.fc(x.reshape(1, x.shape[0] * x.shape[1]))

        return x

#### EEGformer Module

In [None]:
class EEGformer(nn.Module):
    def __init__(self, input, num_cls, input_channels, kernel_size, num_blocks, num_heads_RTM, num_heads_STM, num_heads_TTM, num_submatrices, CF_second, dtype=torch.float32):
        super(EEGformer, self).__init__()
        self.dtype = dtype
        self.ncf = 120
        self.num_cls = num_cls
        self.input_channels = input_channels
        self.kernel_size = kernel_size
        self.tK = num_blocks
        self.hA_rtm = num_heads_RTM
        self.hA_stm = num_heads_STM
        self.hA_ttm = num_heads_TTM
        self.avgf = num_submatrices
        self.cfs = CF_second

        self.outshape1 = torch.zeros(self.input_channels, self.ncf, input.shape[0] - 3 * (self.kernel_size - 1)).to(device) # (S, C, Le) | 3 layer depthwise convolution
        self.outshape2 = torch.zeros(self.outshape1.shape[0], self.outshape1.shape[1] + 1, self.outshape1.shape[2]).to(device) # (S, C+1, Le)
        self.outshape3 = torch.zeros(self.outshape2.shape[1], self.outshape2.shape[0] + 1, self.outshape2.shape[2]).to(device) # (C+1, S+1, Le)
        self.outshape4 = torch.zeros(self.avgf + 1, self.outshape3.shape[1], self.outshape3.shape[0]).to(device) # (M+1, S+1, C+1)

        self.odcm = ODCM(input_channels, self.kernel_size, self.dtype) # (C, T) -> (S x C x Le)
        self.rtm = RTM(self.outshape1, self.tK, self.hA_rtm, self.dtype) # (C x Le x S) -> (S x C x D)
        self.stm = STM(self.outshape2, self.tK, self.hA_stm, self.dtype) # (S x C x D) -> (S x Le x C) -> (C x S x D)
        self.ttm = TTM(self.outshape3, self.avgf, self.tK, self.hA_ttm, self.dtype) # (C x S x D) -> (M x S x C) -> (M x L1)
        self.cnndecoder = CNNdecoder(self.outshape4, self.num_cls, self.cfs, self.dtype) # (M x L1) -> (M/2 x N)

    def forward(self, x):
        x = self.odcm(x.transpose(0, 1))
        x = self.rtm(x)
        x = self.stm(x)
        x = self.ttm(x)
        x = self.cnndecoder(x)

        return torch.softmax(x, dim=1)

    # Cross Entropy Loss (CE)
    # CE - uses one hot encoded label or similar(such as multi class probability label)
    def eegloss(self, xf, label, L1_reg_const):  # CE Loss with L1 regularization
        wt = self.sa(self.cnndecoder.fc.weight) + self.sa(self.cnndecoder.cvd1.weight) + self.sa(self.cnndecoder.cvd2.weight) + self.sa(self.cnndecoder.cvd3.weight)
        wt += self.sa(self.ttm.mlp.fc1.weight) + self.sa(self.ttm.mlp.fc2.weight) + self.sa(self.ttm.lnorm.weight) + self.sa(self.ttm.lnormz.weight) + self.sa(self.ttm.weight)
        wt += self.sa(self.stm.mlp.fc1.weight) + self.sa(self.stm.mlp.fc2.weight) + self.sa(self.stm.lnorm.weight) + self.sa(self.stm.lnormz.weight) + self.sa(self.stm.weight)
        wt += self.sa(self.rtm.mlp.fc1.weight) + self.sa(self.rtm.mlp.fc2.weight) + self.sa(self.rtm.lnorm.weight) + self.sa(self.rtm.lnormz.weight) + self.sa(self.rtm.weight)
        wt += self.sa(self.odcm.cvf1.weight) + self.sa(self.odcm.cvf2.weight) + self.sa(self.odcm.cvf3.weight)

        for tfb in self.rtm.tfb:
            wt += self.sa(tfb.Wo) + self.sa(tfb.Wqkv)
        for tfb in self.stm.tfb:
            wt += self.sa(tfb.Wo) + self.sa(tfb.Wqkv)
        for tfb in self.ttm.tfb:
            wt += self.sa(tfb.Wo) + self.sa(tfb.Wqkv)

        ls = -(label * torch.log(xf) + (1 - label) * torch.log(1 - xf)) # xf = softmax output
        ls = torch.mean(ls) + L1_reg_const * wt

        return ls

    # Apply L1 Regularization to CNN Decoder only
    def eegloss_light(self, xf, label, L1_reg_const):  # takes the weight sum of cnndecoder only
        wt = self.sa(self.cnndecoder.fc.weight) + self.sa(self.cnndecoder.cvd1.weight) + self.sa(self.cnndecoder.cvd2.weight) + self.sa(self.cnndecoder.cvd3.weight)
        ls = -(label * torch.log(xf) + (1 - label) * torch.log(1 - xf))
        ls = torch.mean(ls) + L1_reg_const * wt
        return ls

    def eegloss_wol1(self, xf, label):  # without L1
        ls = -(label * torch.log(xf) + (1 - label) * torch.log(1 - xf))
        ls = torch.mean(ls)
        return ls

    # BCE - does not need one hot encoding
    def bceloss(self, xf, label):  # BCE loss
        ls = -(label * torch.log(xf[:, 1]) + (1 - label) * torch.log(xf[:, 0]))
        ls = torch.mean(ls)
        return ls

    def bceloss_w(self, xf, label, numpos, numtot):  # Weighted BCE loss
        w0 = numtot / (2 * (numtot - numpos))
        w1 = numtot / (2 * numpos)
        ls = -(w1 * label * torch.log(xf[:, 1]) + w0 * (1 - label) * torch.log(xf[:, 0]))
        ls = torch.mean(ls)
        return ls

    # Sum of absolute values = L1 Norm
    def sa(self, t):
        return torch.sum(torch.abs(t))

In [32]:
import os

print("내 컴퓨터에서 사용할 수 있는 worker 수:", os.cpu_count())

내 컴퓨터에서 사용할 수 있는 worker 수: 8
