In [1]:
import torch
from torch import nn
from vit_pytorch import ViT

torch.set_default_tensor_type(torch.DoubleTensor)



model = ViT(
    image_size = 512,    # 图像大小
    patch_size = 32,     # patch大小（分块的大小）
    num_classes = 4,  # imagenet数据集1000分类
    dim = 1024,          # position embedding的维度
    depth = 6,           # encoder和decoder中block层数是6
    heads = 16,          # multi-head中head的数量为16
    mlp_dim = 8,
    dropout = 0.1,       # 
    emb_dropout = 0.1
)

model = model.cuda()

# img = torch.randn(1, 3, 256, 1)

# preds = model(img) # (1, 1000)
                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                           
print(model)  # (16, 1000)


ViT(
  (to_patch_embedding): Sequential(
    (0): Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=32, p2=1)
    (1): Linear(in_features=96, out_features=1024, bias=True)
  )
  (dropout): Dropout(p=0.1, inplace=False)
  (transformer): Transformer(
    (layers): ModuleList(
      (0): ModuleList(
        (0): PreNorm(
          (norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
          (fn): Attention(
            (attend): Softmax(dim=-1)
            (dropout): Dropout(p=0.1, inplace=False)
            (to_qkv): Linear(in_features=1024, out_features=3072, bias=False)
            (to_out): Sequential(
              (0): Linear(in_features=1024, out_features=1024, bias=True)
              (1): Dropout(p=0.1, inplace=False)
            )
          )
        )
        (1): PreNorm(
          (norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
          (fn): FeedForward(
            (net): Sequential(
              (0): Linear(in_features=1024, out_features=8

In [2]:
import sys
sys.path.append('..')
from data_extractor import *

In [3]:
Extractor('A01T')
Extractor('A02T')

<data_extractor.Extractor at 0x7ff2a20f94c0>

In [4]:
raw_data = load_eeg('A01T')
labels, train_data = data_normalization(raw_data)
raw_data_eva = load_eeg('A02T')
labels_eva, eva_data = data_normalization(raw_data_eva)

  return torch.tensor(labels), torch.tensor(data)


In [5]:
print('shape of labels is:', labels.shape, 'shape of data is:', train_data.shape)
print('shape of labels is:', labels_eva.shape, 'shape of data is:', eva_data.shape)

shape of labels is: torch.Size([288, 4]) shape of data is: torch.Size([288, 3, 512, 1])
shape of labels is: torch.Size([288, 4]) shape of data is: torch.Size([288, 3, 512, 1])


In [6]:
pres = model(train_data.cuda())

In [7]:
import torch.nn.functional as F
from torch import optim
from tqdm import tqdm
import time
from torch.utils.data import Dataset, DataLoader

In [8]:
class MyDataset(Dataset):
    # 构造函数
    def __init__(self, data_tensor, target_tensor):
        self.data_tensor = data_tensor
        self.target_tensor = target_tensor
    # 返回数据集大小
    def __len__(self):
        return self.data_tensor.size(0)
    # 返回索引的数据与标签
    def __getitem__(self, index):
        return self.data_tensor[index], self.target_tensor[index]
    
dataset = MyDataset(train_data, labels)
BATCH_SIZE = 32
data_loader = DataLoader(dataset, batch_size = BATCH_SIZE, shuffle = True)

dataset_eva = MyDataset(eva_data, labels_eva)
data_loader_eva = DataLoader(dataset_eva, batch_size = BATCH_SIZE, shuffle = True)

In [9]:
LR = 0.001
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), LR)

for epoch in range(100):
    model.train()
    running_loss = 0.0
    running_len = 0.
    acc_num = 0
    with tqdm(data_loader, unit = 'batch', ncols = 0, total = len(data_loader)) as tepoch:
        for data, target in tepoch:
            tepoch.set_description(f"Epoch(Training) [{epoch + 1}/{100}]")
            data = data.cuda()
            target = target.cuda()
            optimizer.zero_grad()
            pres = model(data)
            loss = criterion(pres, target)
            loss = loss.cuda()
            
            running_loss += loss.item()
            running_len += len(pres)
            
            loss.backward()
            optimizer.step()

            predicted = (pres == pres.max(dim=1, keepdim=True)[0]).to(dtype=torch.float32)
            
            tepoch.set_postfix(loss = running_loss)
            
    model.eval()
    with tqdm(data_loader_eva, unit = 'batch', ncols = 0, total = len(data_loader_eva)) as tepoch:
        for data, target in tepoch:
            tepoch.set_description(f"Epoch(Evaluation) [{epoch + 1}/{100}]")
            data = data.cuda()
            target = target.cuda()
            pres = model(data)

            predicted = (pres == pres.max(dim=1, keepdim=True)[0]).to(dtype=torch.float32)
            
            acc_num = acc_num + torch.sum(torch.sum(predicted == target, dim = 1) / 4).item()
            
            accuracy = acc_num / running_len
            
            tepoch.set_postfix(acc = accuracy)

Epoch(Training) [1/100]: 100% 9/9 [00:01<00:00,  5.60batch/s, loss=2.73]
Epoch(Evaluation) [1/100]: 100% 9/9 [00:00<00:00, 13.16batch/s, acc=1]    
Epoch(Training) [2/100]: 100% 9/9 [00:01<00:00,  5.69batch/s, loss=0.528]
Epoch(Evaluation) [2/100]: 100% 9/9 [00:00<00:00, 13.03batch/s, acc=1]    
Epoch(Training) [3/100]: 100% 9/9 [00:01<00:00,  5.66batch/s, loss=0.3]   
Epoch(Evaluation) [3/100]: 100% 9/9 [00:00<00:00, 13.02batch/s, acc=1]    
Epoch(Training) [4/100]: 100% 9/9 [00:01<00:00,  5.68batch/s, loss=0.213] 
Epoch(Evaluation) [4/100]: 100% 9/9 [00:00<00:00, 12.94batch/s, acc=1]    
Epoch(Training) [5/100]: 100% 9/9 [00:01<00:00,  5.66batch/s, loss=0.162] 
Epoch(Evaluation) [5/100]: 100% 9/9 [00:00<00:00, 13.10batch/s, acc=1]    
Epoch(Training) [6/100]: 100% 9/9 [00:01<00:00,  5.67batch/s, loss=0.138] 
Epoch(Evaluation) [6/100]: 100% 9/9 [00:00<00:00, 13.14batch/s, acc=1]    
Epoch(Training) [7/100]: 100% 9/9 [00:01<00:00,  5.66batch/s, loss=0.114] 
Epoch(Evaluation) [7/100]: 1

Epoch(Evaluation) [54/100]: 100% 9/9 [00:00<00:00, 12.90batch/s, acc=1]    
Epoch(Training) [55/100]: 100% 9/9 [00:01<00:00,  5.59batch/s, loss=0.0136] 
Epoch(Evaluation) [55/100]: 100% 9/9 [00:00<00:00, 12.91batch/s, acc=1]    
Epoch(Training) [56/100]: 100% 9/9 [00:01<00:00,  5.59batch/s, loss=0.0133] 
Epoch(Evaluation) [56/100]: 100% 9/9 [00:00<00:00, 12.90batch/s, acc=1]    
Epoch(Training) [57/100]: 100% 9/9 [00:01<00:00,  5.60batch/s, loss=0.0132] 
Epoch(Evaluation) [57/100]: 100% 9/9 [00:00<00:00, 12.89batch/s, acc=1]    
Epoch(Training) [58/100]: 100% 9/9 [00:01<00:00,  5.60batch/s, loss=0.0133] 
Epoch(Evaluation) [58/100]: 100% 9/9 [00:00<00:00, 12.85batch/s, acc=1]    
Epoch(Training) [59/100]: 100% 9/9 [00:01<00:00,  5.59batch/s, loss=0.0128] 
Epoch(Evaluation) [59/100]: 100% 9/9 [00:00<00:00, 12.95batch/s, acc=1]    
Epoch(Training) [60/100]: 100% 9/9 [00:01<00:00,  5.61batch/s, loss=0.0125] 
Epoch(Evaluation) [60/100]: 100% 9/9 [00:00<00:00, 12.78batch/s, acc=1]    
Epoch(