In [1]:
import torch
from torch import nn

from einops import rearrange, repeat
from einops.layers.torch import Rearrange
import os
import cv2
import numpy as np
import math
import mediapipe as mp
from matplotlib import pyplot as plt
import glob
from util.img2bone import HandDetector
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms
from torch.utils.data import Dataset,DataLoader
from PIL import Image
import glob
from tqdm.auto import tqdm
from einops import rearrange, repeat
from einops.layers.torch import Rearrange
from loader.dataloader import SkeletonAndEMGData

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# helpers
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

def pair(t):
    return t if isinstance(t, tuple) else (t, t)

# classes

class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim, dropout = 0.):
        super().__init__()
        self.net = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, dim),
            nn.Dropout(dropout)
        )

    def forward(self, x):
        return self.net(x)

class Attention(nn.Module):
    def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
        super().__init__()
        inner_dim = dim_head *  heads
        project_out = not (heads == 1 and dim_head == dim)

        self.heads = heads
        self.scale = dim_head ** -0.5

        self.norm = nn.LayerNorm(dim)

        self.attend = nn.Softmax(dim = -1)
        self.dropout = nn.Dropout(dropout)

        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)

        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim),
            nn.Dropout(dropout)
        ) if project_out else nn.Identity()

    def forward(self, x):
        x = self.norm(x)

        qkv = self.to_qkv(x).chunk(3, dim = -1)
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)

        dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale

        attn = self.attend(dots)
        attn = self.dropout(attn)

        out = torch.matmul(attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')
        return self.to_out(out)

class Transformer(nn.Module):
    def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout),
                FeedForward(dim, mlp_dim, dropout = dropout)
            ]))

    def forward(self, x):
        for attn, ff in self.layers:
            x = attn(x) + x
            x = ff(x) + x

        return self.norm(x)

class ViT(nn.Module):
    def __init__(self, *, emg_size, patch_height, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', dim_head = 64, dropout = 0., emb_dropout = 0.):
        super().__init__()
        emg_height, emg_width = pair(emg_size)


        num_patches = int(emg_height//patch_height)
        patch_dim = int(emg_width * patch_height )
        
        self.to_patch_embedding = nn.Sequential(
            Rearrange('b (h p1) c -> b h (p1 c)', h = num_patches, c = emg_width),
            nn.LayerNorm(patch_dim),
            nn.Linear(patch_dim, dim),
            nn.LayerNorm(dim),
        )

        self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
        self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
        self.dropout = nn.Dropout(emb_dropout)

        self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)

        self.pool = pool
        self.to_latent = nn.Identity()

        self.mlp_head = nn.Linear(dim, num_classes)

    def forward(self, x):
        x = self.to_patch_embedding(x)
        
        b, n, _ = x.shape

        cls_tokens = repeat(self.cls_token, '1 1 d -> b 1 d', b = b)
        x = torch.cat((cls_tokens, x), dim=1)
        x += self.pos_embedding[:, :(n + 1)]
        x = self.dropout(x)

        x = self.transformer(x)

        x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]

        x = self.to_latent(x)
        return self.mlp_head(x)
        
        

In [3]:

def train(train_loader,model,criterion,optimizer,device):
    running_loss = 0
    model.train()
   
    for videos,labels,emgs in tqdm(train_loader):
        
        
        labels = labels.to(device)
        emgs = emgs.to(device).double()
        
        # forward
        outputs = model(emgs)
        loss = criterion(outputs,labels)
        running_loss += loss.item()
        
        #backward
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        
    
    epoch_loss = running_loss / (len(train_loader))
    return model,epoch_loss,optimizer

def validate(valid_loader,model,criterion,device):
    model.eval()
    running_loss = 0
   
    
    for videos,labels,emgs in tqdm(valid_loader):
       
       
        labels = labels.to(device)
        emgs = emgs.to(device).double()
        
        # forward
        
        outputs = model(emgs)
        loss = criterion(outputs,labels)
        running_loss += loss.item()

    epoch_loss = running_loss / (len(valid_loader))
    return model,epoch_loss

def get_accuracy(model,data_loader,device):
    correct = 0
    total = 0
    
    with torch.no_grad():
        model.eval()
        for videos,labels,emgs in data_loader:
            
            labels = labels.to(device)
            emgs = emgs.to(device).double()
            
            # forward
            outputs = model(emgs)
            predicted = torch.argmax(torch.softmax(outputs,1),1)
            total += labels.shape[0]
            correct += (predicted == labels).sum().item()
    return correct*100/total

def plot_losses(train_losses,valid_losses):
    train_losses = np.array(train_losses)
    valid_losses = np.array(valid_losses)
    
    fig, ax1 = plt.subplots(1, 1)
    ax1.plot(train_losses, color="blue", label="train_loss")
    ax1.plot(valid_losses, color="red", label="valid_loss")
    ax1.set(title="Loss over epochs",
            xlabel="Epoch",
            ylabel="Loss")
    ax1.legend()
    
def plot_accuracy(train_acc,valid_acc):
    train_acc = np.array(train_acc)
    valid_acc = np.array(valid_acc)
    
    fig, ax1 = plt.subplots(1, 1)
    ax1.plot(train_acc, color="blue", label="train_acc")
    ax1.plot(valid_acc, color="red", label="val_acc")
    ax1.set(title="Accuracy over epochs",
            xlabel="Epoch",
            ylabel="Accuracy")
    ax1.legend()

In [4]:
train_set = SkeletonAndEMGData("data/108_new/train.pkl")
val_set = SkeletonAndEMGData("data/108_new/val.pkl")
test_set = SkeletonAndEMGData("data/108_new/test.pkl")

train_loader = DataLoader(train_set,batch_size = 128,drop_last = False)
valid_loader = DataLoader(val_set,batch_size = 128,drop_last = False)
test_loader = DataLoader(test_set,batch_size = 128,drop_last = False)

tensor([33, 10, 10, 37])
tensor([33, 10, 37])
tensor([33, 10])


In [5]:
model = ViT(emg_size=(132300,8),patch_height=int(44100*0.2),num_classes=41,dim = 1024,depth=3,mlp_dim=2048,heads=8).to(device).double()
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters())
epochs = 15
train_losses = []
valid_losses = []
train_accuracy = []
val_accuracy = []
for epoch in range(epochs):
    # training
    model,train_loss,optimizer = train(train_loader,model,criterion,optimizer,device)
    
    # validation
    with torch.no_grad():
        model, valid_loss = validate(valid_loader, model, criterion, device)
    train_acc = get_accuracy(model,train_loader,device)
    val_acc = get_accuracy(model,valid_loader,device)
    print("Epoch {} --- Train loss = {} --- Valid loss = {} -- Train set accuracy = {} % Valid set Accuracy = {} %".format
          (epoch+1,train_loss,valid_loss,train_acc,val_acc))
    # save loss value
    train_losses.append(train_loss)
    valid_losses.append(valid_loss)
    
    #save accuracy
    train_accuracy.append(train_acc)
    val_accuracy.append(val_acc)

100%|██████████| 1/1 [00:01<00:00,  1.53s/it]
100%|██████████| 1/1 [00:00<00:00,  8.33it/s]


Epoch 1 --- Train loss = 4.370546994822786 --- Valid loss = 2.1023659858416264 -- Train set accuracy = 50.0 % Valid set Accuracy = 33.333333333333336 %


100%|██████████| 1/1 [00:00<00:00, 17.21it/s]
100%|██████████| 1/1 [00:00<00:00,  8.25it/s]


Epoch 2 --- Train loss = 1.577111110199355 --- Valid loss = 2.9300484084357454 -- Train set accuracy = 50.0 % Valid set Accuracy = 66.66666666666667 %


100%|██████████| 1/1 [00:00<00:00, 16.68it/s]
100%|██████████| 1/1 [00:00<00:00,  7.34it/s]


Epoch 3 --- Train loss = 4.394141396505674 --- Valid loss = 0.26233570632683406 -- Train set accuracy = 50.0 % Valid set Accuracy = 66.66666666666667 %


100%|██████████| 1/1 [00:00<00:00, 18.80it/s]
100%|██████████| 1/1 [00:00<00:00,  7.90it/s]


Epoch 4 --- Train loss = 0.3918963968410368 --- Valid loss = 0.0018763413061937443 -- Train set accuracy = 100.0 % Valid set Accuracy = 100.0 %


100%|██████████| 1/1 [00:00<00:00, 17.54it/s]
100%|██████████| 1/1 [00:00<00:00,  7.89it/s]


Epoch 5 --- Train loss = 0.0014668220366496193 --- Valid loss = 0.0020823634556857377 -- Train set accuracy = 100.0 % Valid set Accuracy = 100.0 %


100%|██████████| 1/1 [00:00<00:00, 13.63it/s]
100%|██████████| 1/1 [00:00<00:00,  7.24it/s]


Epoch 6 --- Train loss = 0.0015687185953429353 --- Valid loss = 0.0030232884359435845 -- Train set accuracy = 100.0 % Valid set Accuracy = 100.0 %


100%|██████████| 1/1 [00:00<00:00, 16.28it/s]
100%|██████████| 1/1 [00:00<00:00,  7.24it/s]


Epoch 7 --- Train loss = 0.002272392413744757 --- Valid loss = 0.004626578766714921 -- Train set accuracy = 100.0 % Valid set Accuracy = 100.0 %


100%|██████████| 1/1 [00:00<00:00, 15.79it/s]
100%|██████████| 1/1 [00:00<00:00,  7.26it/s]


Epoch 8 --- Train loss = 0.003478383144141635 --- Valid loss = 0.0066280052350626605 -- Train set accuracy = 100.0 % Valid set Accuracy = 100.0 %


100%|██████████| 1/1 [00:00<00:00, 16.43it/s]
100%|██████████| 1/1 [00:00<00:00,  7.48it/s]


Epoch 9 --- Train loss = 0.0049914130371106185 --- Valid loss = 0.008336941869298778 -- Train set accuracy = 100.0 % Valid set Accuracy = 100.0 %


100%|██████████| 1/1 [00:00<00:00, 11.83it/s]
100%|██████████| 1/1 [00:00<00:00,  6.09it/s]


Epoch 10 --- Train loss = 0.006303549866726076 --- Valid loss = 0.008907041366969104 -- Train set accuracy = 100.0 % Valid set Accuracy = 100.0 %


100%|██████████| 1/1 [00:00<00:00, 13.61it/s]
100%|██████████| 1/1 [00:00<00:00,  5.54it/s]


Epoch 11 --- Train loss = 0.006794578365188672 --- Valid loss = 0.008091704148737354 -- Train set accuracy = 100.0 % Valid set Accuracy = 100.0 %


100%|██████████| 1/1 [00:00<00:00, 18.32it/s]
100%|██████████| 1/1 [00:00<00:00,  7.42it/s]


Epoch 12 --- Train loss = 0.006290204073393884 --- Valid loss = 0.006468603682387348 -- Train set accuracy = 100.0 % Valid set Accuracy = 100.0 %


100%|██████████| 1/1 [00:00<00:00, 16.52it/s]
100%|██████████| 1/1 [00:00<00:00,  8.57it/s]


Epoch 13 --- Train loss = 0.005217533092361731 --- Valid loss = 0.004807469318540486 -- Train set accuracy = 100.0 % Valid set Accuracy = 100.0 %


100%|██████████| 1/1 [00:00<00:00, 15.48it/s]
100%|██████████| 1/1 [00:00<00:00,  6.25it/s]


Epoch 14 --- Train loss = 0.0041242142466768805 --- Valid loss = 0.003516410523458829 -- Train set accuracy = 100.0 % Valid set Accuracy = 100.0 %


100%|██████████| 1/1 [00:00<00:00, 17.42it/s]
100%|██████████| 1/1 [00:00<00:00,  6.21it/s]


Epoch 15 --- Train loss = 0.0032733527502023826 --- Valid loss = 0.0026229147595001508 -- Train set accuracy = 100.0 % Valid set Accuracy = 100.0 %


In [6]:
get_accuracy(model,test_loader,device) # test

100.0