In [1]:
import torch
import torch.nn as nn
from torch.optim import Optimizer
from torch.utils.data import DataLoader


device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# Set the random seed
torch.manual_seed(0)

'''
assume that datasets are given, after preprocessing
'''
# Set the random seed
torch.manual_seed(0)

train_dataset = []
for i in range(10):
    if i <=5 :
        train_dataset += list(zip(torch.stack([torch.normal(0, 1, (22, 128, 8)) for _ in range(50)]), torch.zeros(50, dtype=torch.int)))
    else:
        train_dataset += list(zip(torch.stack([torch.normal(5, 1, (22, 128, 8)) for _ in range(50)]), torch.ones(50, dtype=torch.int)))

test_dataset = []
for i in range(10):
    if i <=5 :
        test_dataset += list(zip(torch.stack([torch.normal(0, 1, (22, 128, 8)) for _ in range(50)]), torch.zeros(50, dtype=torch.int)))
    else:
        test_dataset += list(zip(torch.stack([torch.normal(5, 1, (22, 128, 8)) for _ in range(50)]), torch.ones(50, dtype=torch.int)))


In [83]:
import os
import torch
from torch.utils.data import Dataset, DataLoader, TensorDataset

class CustomEEGDataset(Dataset):
    def __init__(self, dataset, transform=None, target_transform=None):
        self.raw = dataset
        self.EEG = torch.stack([t[0] for t in self.raw])
        self.labels = torch.stack([t[1] for t in self.raw])
        self.transform = transform
        self.target_transform = target_transform

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        image = self.EEG[idx]
        label = self.labels[idx]
        if self.transform:
            image = self.transform(image)
        if self.target_transform:
            label = self.target_transform(label)
        return image, label

In [80]:
from torch.utils.data import DataLoader, random_split, Subset
generator = torch.Generator().manual_seed(42)
chb01 = torch.load('/home/dhyun/project/FYP/chb01.pt')
train_dataset, test_dataset = random_split(chb01, [int(len(chb01)*0.8), len(chb01) - int(len(chb01)*0.8)], generator=generator)


In [117]:
train_dataset = CustomEEGDataset(train_dataset)
test_dataset = CustomEEGDataset(test_dataset)

def get_dataloaders(dataset, batch_size=50):
    class_inds = [torch.where(dataset.labels == class_idx)[0]
                for class_idx in torch.unique(dataset.labels)]

    dataloaders = [
        DataLoader(
            dataset=Subset(dataset, inds),
            batch_size=batch_size,
            shuffle=True,
            drop_last=False)
        for inds in class_inds]
    
    return dataloaders

train_dataloaders = get_dataloaders(train_dataset)
itrs = list(map(iter, train_dataloaders))
itrs

In [120]:
a = np.random.choice(iterators)


<torch.utils.data.dataloader._SingleProcessDataLoaderIter at 0x7f6d31a4a2d0>

In [121]:
len(a)

2329

In [122]:
# Set the random seed
from tqdm.notebook import tqdm
import numpy as np
torch.manual_seed(0)

class SciCNN(nn.Module):

    def __init__(self):
        super(SciCNN, self).__init__()        

        self.inception1 = Inception(8, 8, 16, 8, 8)
        self.maxpool1 = nn.MaxPool2d((1, 4), stride=(1, 4), ceil_mode=True)
        self.inception2 = Inception(16, 16, 8, 16, 4)
        self.maxpool2 = nn.MaxPool2d((1, 4), stride=(1, 4), ceil_mode=True)
        self.inception3 = Inception(32, 32, 4, 32, 2)
        self.maxpool3 = nn.MaxPool2d((1, 8), stride=(1, 8), ceil_mode=True)
        self.flatten = nn.Flatten()
        self.dropout = nn.Dropout(0.4)
        self.fc = nn.Linear(22*64, 16)
        self.npc = NPC()
        
        for m in self.modules():
            if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
                torch.nn.init.xavier_uniform_(m.weight)

    def forward(self, x):
        x = x.permute(0, 3, 1, 2)
        x = self.inception1(x)
        x = self.maxpool1(x)
        x = self.inception2(x)
        x = self.maxpool2(x)
        x = self.inception3(x)
        x = self.maxpool3(x)
        x = self.flatten(x)
        x = self.dropout(x)
        x = self.fc(x)
        return x

class Inception(nn.Module):
    def __init__(self, in_channels, ch1, ch1_kernel, ch2, ch2_kernel):
        super(Inception, self).__init__()
        self.branch1 = BasicConv1d(in_channels, ch1, kernel=(1, ch1_kernel), padding=(0, (ch1_kernel-1)//2))
        self.branch2 = BasicConv1d(in_channels, ch2, kernel=(1, ch2_kernel), padding=(0, (ch2_kernel-1)//2))
        self.se = SEModule()

    def forward(self, x):
        branch1 = self.branch1(x)
        branch2 = self.branch2(x)
        concat = torch.cat([branch1, branch2], dim=1)
        return self.se(concat)
        
class SEModule(nn.Module):
    def __init__(self, channels=22, reduction=4):
        super(SEModule, self).__init__()
        self.globalAvgPool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Sequential(
            nn.Linear(channels, channels // reduction),
            nn.ReLU(),
            nn.Linear(channels // reduction, channels),
            nn.Sigmoid()
        )

    def forward(self, x):
        x = x.permute(0, 2, 3, 1)
        out = torch.squeeze(self.globalAvgPool(x))
        out = self.fc(out).view(x.size(0), x.size(1), 1, 1)
        out = x * out.expand_as(x)
        return out.permute(0, 3, 1, 2)
    
class BasicConv1d(nn.Module): 
    def __init__(self, in_channels, out_channels, kernel, padding):
        super(BasicConv1d, self).__init__()
        self.conv = nn.Sequential(
                            nn.Conv2d(in_channels, out_channels, kernel_size=kernel, padding=padding, bias=True),
                            nn.BatchNorm2d(out_channels),
                            nn.ReLU()
                            )
    def forward(self, x):
        return self.conv(x)
    
class NPC(nn.Module):
    def __init__(self, num_clusters=64):
        super(NPC, self).__init__()
        # 256 predefined positions of NPC clusters
        self.position = nn.Parameter(torch.from_numpy(np.random.uniform(0, 1, (num_clusters, 16, 1))).to(torch.float32), requires_grad=True)
        self.label = nn.Parameter(2*torch.ones(num_clusters), requires_grad=False)

model = SciCNN().to(device)

def npc_training_loss(output, label, model):
    # output: (batch_size, 16, 1)
    mean_output = torch.mean(output, dim=0)
    distances = torch.norm(mean_output.view(1, -1, 1) - model.npc.position.data, dim=1).squeeze()
    closest_position_index = torch.argmin(distances)
    closest_position = model.npc.position[closest_position_index]
    model.npc.label[closest_position_index] = torch.max(label)
    loss = torch.norm(mean_output - closest_position)
    return loss


model = SciCNN().to(device)
loss_function = npc_training_loss
optimizer = torch.optim.Adam(model.parameters(), lr=1e-1, weight_decay=1e-4)

model.train()

train_dataloaders = get_dataloaders(train_dataset)
train_loss_list = []
import time
start = time.time()
for epoch in tqdm(range(50)):
    loss=0
    iterators = list(map(iter, train_dataloaders))
    total_length = len(iterators[0]) + len(iterators[1])
    while iterators:
        try:
            iterator = np.random.choice(iterators)
            images, labels = next(iterator)
            images, labels = images.to(device), labels.to(device)
            optimizer.zero_grad()
            train_loss = loss_function(model(images), labels, model)
            loss += train_loss.item()
            train_loss.backward()
            optimizer.step()
        except StopIteration:
            iterators.remove(iterator)
    train_loss_list.append(loss/total_length)
    print ("Epoch [{}] Loss: {:.4f}".format(epoch+1, loss/total_length))

end = time.time()
print("Time ellapsed in training is: {}".format(end - start))

  0%|          | 0/50 [00:00<?, ?it/s]

Epoch [1] Loss: 0.7275
Epoch [2] Loss: 0.3151
Epoch [3] Loss: 0.3067


In [None]:
def npc_test_loss(output, model):
    mean_output = torch.mean(output, dim=0)
    distances = torch.norm(mean_output.view(1, -1, 1) - model.npc.position.data, dim=1).squeeze()
    closest_position_index = torch.argmin(distances)
    closest_position = model.npc.position[closest_position_index]
    closest_label = model.npc.label[closest_position_index]
    print("closest_label: ", closest_label, "index: ", closest_position_index)
    return torch.norm(mean_output - closest_position), closest_label

model.eval()
test_loss, correct, total = 0, 0, 0

loss_function = npc_test_loss
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=10, shuffle=False)
with torch.no_grad():  #using context manager
    for images, labels in test_loader :
        images, labels = images.to(device), labels.to(device)
        output = model(images)
        test_loss += loss_function(output, model)[0].item()
        pred = loss_function(output, model)[1].item()
        correct += (pred==labels).sum().item()
        total += labels.size(0)

print('[Test set] Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)\n'.format(
        test_loss /total, correct, total,
        100. * correct / total))

In [None]:
print(model.npc.label)
print(model.npc.position)
for i in range(64):
    if model.npc.label[i].item() != 2:
        print(model.npc.label[i].item(), model.npc.position[i].data)

In [None]:
from sklearn.manifold import TSNE
import seaborn as sns
import matplotlib.pyplot as plt
import pandas as pd

# Convert the model.npc.position to numpy array
positions = model.npc.position.detach().cpu().numpy().squeeze()

# Convert the model.npc.label to numpy array
labels = model.npc.label.detach().cpu().numpy()

# Apply t-SNE to reduce the dimensionality
tsne = TSNE(n_components=2, random_state=0)
tsne_positions = tsne.fit_transform(positions)

# Create a DataFrame with the t-SNE positions and labels
df_tsne = pd.DataFrame({'x': tsne_positions[:, 0], 'y': tsne_positions[:, 1], 'label': labels})

# Plot the t-SNE map with hue as the label
sns.scatterplot(data=df_tsne, x='x', y='y', hue='label', palette='Set2')
plt.xlabel('t-SNE 1')
plt.ylabel('t-SNE 2')
plt.title('t-SNE Map')
plt.show()

In [None]:
model(torch.normal(0, 1, (1, 22, 128, 8)).to(device))

In [None]:
positions.squeeze().shape

In [None]:
import seaborn as sns
import matplotlib.pyplot as plt
loss_plot = train_loss_list[:20]
sns.lineplot(x=range(len(loss_plot)), y=loss_plot)
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training Loss')


In [None]:
a = torch.tensor(1)
a.shape