In [1]:
import torch

# If a GPU is available, use it
if torch.cuda.is_available():
    device = torch.device("cuda")
    use_cuda = True
    print('Using cuda !')
else:
    device = torch.device("cpu")
    use_cuda = False
    print('GPU not available !')

Using cuda !


In [2]:
import numpy as np
from torch.utils.data.dataset import Dataset
from torch.utils.data import DataLoader

class CustomDataset(Dataset):
    
    def __init__(self, data, label, transform=None):
        self.data = data
        self.data_len = len(self.data)
        self.label_arr = label

    def __getitem__(self, index):
        label = int(self.label_arr[index])
        sample = (self.data[index]).reshape(1, 44, 1126)
        sample = torch.from_numpy(sample)
        return sample, label

    def __len__(self):
        return self.data_len

def EEGdata_loader():

    x_train = np.load(r'.\MIEEG data\x_train.npy')
    x_test = np.load(r'.\MIEEG data\x_test.npy')
    x_valid = np.load(r'.\MIEEG data\x_valid.npy')
    y_train = np.load(r'.\MIEEG data\y_train.npy')
    y_test = np.load(r'.\MIEEG data\y_test.npy')
    y_valid = np.load(r'.\MIEEG data\y_valid.npy')

    train_data = CustomDataset(x_train, y_train)
    valid_data = CustomDataset(x_valid, y_valid)
    test_data = CustomDataset(x_test, y_test)

    train_loader = DataLoader(train_data, batch_size=batch_size, pin_memory=use_cuda, shuffle=True)
    valid_loader = DataLoader(valid_data, batch_size=batch_size_eval, pin_memory=use_cuda)
    test_loader = DataLoader(test_data, batch_size=batch_size_eval, pin_memory=use_cuda)
    
    return train_loader, valid_loader, test_loader

In [3]:
from tcacnet_utils.attention import inference

def train(model_global, model_local, model_top, optimizer, loss_fn_local_top, epoch, only_global_model):
    
    model_global.train()
    model_local.train()
    model_top.train()

    for batch_idx, (inputs, target) in enumerate(train_loader):
        inputs, target = inputs.to(device), target.to(device)
        inputs = inputs.float()
        target = target.long()
        optimizer.zero_grad()
        
        wpser = inputs[:,:,:,-1]    # WPSER corresponding to each channel
        inputs = inputs[:,:,:,0:inputs.shape[3]-1]    # raw EEG signal
        output_merged, hint_loss, channel_loss = inference(inputs, wpser, model_global, model_local, model_top, n_slices, device,
                                                           only_global_model, is_training=True)
        
        loss_local_and_top = loss_fn_local_top(output_merged, target)
        loss_global_model = loss_local_and_top + hint_loss + channel_loss

        for param in model_local.parameters():
            param.requires_grad = False
        for param in model_top.parameters():
            param.requires_grad = False
        loss_global_model.backward(retain_graph=True)
        for param in model_local.parameters():
            param.requires_grad = True
        for param in model_top.parameters():
            param.requires_grad = True
        for param in model_global.parameters():
            param.requires_grad = False
        loss_local_and_top.backward()
        for param in model_global.parameters():
            param.requires_grad = True

        optimizer.step()
    
    if epoch % 10 == 0:
        print('\rTrain Epoch: {}'   
              '  Total_Loss: {:.4f} (CrossEntropy: {:.2f} Hint: {:.2f} Ch: {:.2f})'
              ''.format(epoch, loss_local_and_top.item()+hint_loss.item(), loss_local_and_top.item(), hint_loss.item(), channel_loss.item()),
              end='')
            
    return loss_local_and_top.item()+hint_loss.item()+channel_loss.item(), loss_local_and_top.item()

In [4]:
def test(model_global, model_local, model_top, test_loss_fn_local_top, epoch, loader, only_global_model):
    model_global.eval()
    model_local.eval()
    model_top.eval()

    avg_test_loss, avg_hint_loss, avg_channel_loss = 0, 0, 0
    correct = 0
    test_size = 0

    with torch.no_grad():
        for inputs, target in loader:
            inputs, target = inputs.to(device), target.to(device)
            
            inputs = inputs.float()
            target = target.long()

            wpser = inputs[:,:,:,-1]
            inputs = inputs[:,:,:,0:inputs.shape[3]-1]
            
            output_merged, hint_loss, channel_loss = inference(inputs, wpser, model_global, model_local, model_top, n_slices, device,
                                                               only_global_model, is_training=False)

            test_size += len(inputs)
            avg_test_loss += test_loss_fn_local_top(output_merged, target).item()
            avg_hint_loss += len(inputs) * hint_loss.item()
            avg_channel_loss += len(inputs) * channel_loss.item()
            pred = output_merged.max(1, keepdim=True)[1]
            correct += pred.eq(target.view_as(pred)).sum().item()

    avg_test_loss /= test_size
    avg_hint_loss /= test_size
    avg_channel_loss /= test_size
    accuracy = correct / test_size
    
    if epoch % 10 == 0:
        print('\nTest set: Avg_Total_Loss: {:.4f} (CrossEntropy: {:.2f} Hint: {:.2f} Ch: {:.2f})' 
              '  Accuracy: {}/{} ({:.0f}%)\n'
              .format(avg_test_loss + avg_hint_loss + avg_channel_loss, avg_test_loss, avg_hint_loss, avg_channel_loss,
                      correct, test_size, 100. * accuracy))

    return avg_test_loss+avg_hint_loss+avg_channel_loss, avg_test_loss, accuracy

In [5]:
import torch.nn as nn
import torch.optim as optim
from braindecode.torch_ext.optimizers import AdamW
from tcacnet_utils.network import globalnetwork, localnetwork, topnetwork

n_slices = 1    # number of time slices

n_epochs = 200
loss_fn_local_top = nn.NLLLoss()
test_loss_fn_local_top = nn.NLLLoss(reduction='sum')
learning_rate = 0.0625 * 0.01
batch_size = 16
batch_size_eval = 16

train_loader, valid_loader, test_loader = EEGdata_loader()

In [6]:
model_global = globalnetwork().to(device)
model_local = localnetwork().to(device)
model_top = topnetwork().to(device)

only_global_model = True    # only use global model

if only_global_model:
    optimizer = optim.Adam(list(model_global.parameters())
                           + list(model_top.parameters()), lr=learning_rate)
else:
    optimizer = optim.Adam(list(model_global.parameters())
                           + list(model_local.parameters())
                           + list(model_top.parameters()), lr=learning_rate)

min_cross_entropy = 100000

for ep in range(n_epochs):
    train_total_loss, train_cross_entropy = train(model_global, model_local, model_top, optimizer,
                                                  loss_fn_local_top, ep, only_global_model)
    valid_total_loss, valid_cross_entropy, valid_acc = test(model_global, model_local, model_top,
                                                            test_loss_fn_local_top, ep, valid_loader, only_global_model)
    if valid_cross_entropy < min_cross_entropy:
        min_cross_entropy = valid_cross_entropy
        torch.save(model_global.state_dict(),'model_global_cross_entropy.pth')
        torch.save(model_local.state_dict(),'model_local_cross_entropy.pth')
        torch.save(model_top.state_dict(),'model_top_cross_entropy.pth')

if only_global_model:
    print('\nUse global model:')
else:
    print('\nUse TCACNet:')       

model_global.load_state_dict(torch.load('model_global_cross_entropy.pth'))
model_local.load_state_dict(torch.load('model_local_cross_entropy.pth'))
model_top.load_state_dict(torch.load('model_top_cross_entropy.pth'))
valid_total_loss, valid_cross_entropy, valid_acc = test(model_global, model_local, model_top,
                                                        test_loss_fn_local_top, 0, test_loader, only_global_model)

Train Epoch: 0  Total_Loss: 0.9760 (CrossEntropy: 0.98 Hint: 0.00 Ch: 0.00)
Test set: Avg_Total_Loss: 0.6309 (CrossEntropy: 0.63 Hint: 0.00 Ch: 0.00)  Accuracy: 124/163 (76%)

Train Epoch: 10  Total_Loss: 0.0413 (CrossEntropy: 0.04 Hint: 0.00 Ch: 0.00)
Test set: Avg_Total_Loss: 0.2797 (CrossEntropy: 0.28 Hint: 0.00 Ch: 0.00)  Accuracy: 146/163 (90%)

Train Epoch: 20  Total_Loss: 0.0292 (CrossEntropy: 0.03 Hint: 0.00 Ch: 0.00)
Test set: Avg_Total_Loss: 0.2589 (CrossEntropy: 0.26 Hint: 0.00 Ch: 0.00)  Accuracy: 147/163 (90%)

Train Epoch: 30  Total_Loss: 0.0042 (CrossEntropy: 0.00 Hint: 0.00 Ch: 0.00)
Test set: Avg_Total_Loss: 0.2649 (CrossEntropy: 0.26 Hint: 0.00 Ch: 0.00)  Accuracy: 146/163 (90%)

Train Epoch: 40  Total_Loss: 0.0180 (CrossEntropy: 0.02 Hint: 0.00 Ch: 0.00)
Test set: Avg_Total_Loss: 0.3003 (CrossEntropy: 0.30 Hint: 0.00 Ch: 0.00)  Accuracy: 147/163 (90%)

Train Epoch: 50  Total_Loss: 0.0012 (CrossEntropy: 0.00 Hint: 0.00 Ch: 0.00)
Test set: Avg_Total_Loss: 0.2721 (Cross

In [7]:
model_global = globalnetwork().to(device)
model_local = localnetwork().to(device)
model_top = topnetwork().to(device)

only_global_model = False    # use TCACNet

if only_global_model:
    optimizer = optim.Adam(list(model_global.parameters())
                           + list(model_top.parameters()), lr=learning_rate)
else:
    optimizer = optim.Adam(list(model_global.parameters())
                           + list(model_local.parameters())
                           + list(model_top.parameters()), lr=learning_rate)

min_cross_entropy = 100000
for ep in range(n_epochs):
    train_total_loss, train_cross_entropy = train(model_global, model_local, model_top, optimizer,
                                                  loss_fn_local_top, ep, only_global_model)
    valid_total_loss, valid_cross_entropy, valid_acc = test(model_global, model_local, model_top,
                                                            test_loss_fn_local_top, ep, valid_loader, only_global_model)
    if valid_cross_entropy < min_cross_entropy:
        min_cross_entropy = valid_cross_entropy
        torch.save(model_global.state_dict(),'model_global_cross_entropy.pth')
        torch.save(model_local.state_dict(),'model_local_cross_entropy.pth')
        torch.save(model_top.state_dict(),'model_top_cross_entropy.pth')
        
if only_global_model:
    print('\nUse global model:')
else:
    print('\nUse TCACNet:')

model_global.load_state_dict(torch.load('model_global_cross_entropy.pth'))
model_local.load_state_dict(torch.load('model_local_cross_entropy.pth'))
model_top.load_state_dict(torch.load('model_top_cross_entropy.pth'))
valid_total_loss, valid_cross_entropy, valid_acc = test(model_global, model_local, model_top,
                                                        test_loss_fn_local_top, 0, test_loader, only_global_model)

  idx_x = indexes // feature_w


Train Epoch: 0  Total_Loss: 2.9238 (CrossEntropy: 0.59 Hint: 2.33 Ch: 7.51)
Test set: Avg_Total_Loss: 9.7704 (CrossEntropy: 0.90 Hint: 2.45 Ch: 6.42)  Accuracy: 97/163 (60%)

Train Epoch: 10  Total_Loss: 1.9977 (CrossEntropy: 0.29 Hint: 1.71 Ch: 0.78)
Test set: Avg_Total_Loss: 3.0804 (CrossEntropy: 0.37 Hint: 2.14 Ch: 0.57)  Accuracy: 147/163 (90%)

Train Epoch: 20  Total_Loss: 1.7589 (CrossEntropy: 0.12 Hint: 1.64 Ch: 0.51)
Test set: Avg_Total_Loss: 2.9569 (CrossEntropy: 0.31 Hint: 2.15 Ch: 0.50)  Accuracy: 147/163 (90%)

Train Epoch: 30  Total_Loss: 1.7397 (CrossEntropy: 0.13 Hint: 1.61 Ch: 0.44)
Test set: Avg_Total_Loss: 2.9098 (CrossEntropy: 0.34 Hint: 2.13 Ch: 0.44)  Accuracy: 147/163 (90%)

Train Epoch: 40  Total_Loss: 2.1001 (CrossEntropy: 0.13 Hint: 1.97 Ch: 0.50)
Test set: Avg_Total_Loss: 2.9528 (CrossEntropy: 0.34 Hint: 2.18 Ch: 0.43)  Accuracy: 144/163 (88%)

Train Epoch: 50  Total_Loss: 2.0455 (CrossEntropy: 0.20 Hint: 1.85 Ch: 0.81)
Test set: Avg_Total_Loss: 2.9072 (CrossE