In [None]:
!pip install facenet_pytorch > /dev/null 2>&1

In [None]:
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
from matplotlib import pyplot as plt
from tqdm import tqdm

from torch.utils.data import DataLoader, Dataset

import zipfile
from pathlib import Path
from PIL import Image, ImageDraw
import cv2

from torch import nn
import torch
from torch.nn import functional as F
import torchvision
from torchvision import transforms
from torchvision.models import resnet18
from torchsummary import summary

from facenet_pytorch import MTCNN

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
device

'cuda'

Для обучения модели создавался собственный датасет с рядом жестов. За основу взята архитектура `ResNet18`

In [None]:
PATH = Path('/content/drive/MyDrive/Colab Notebooks/Torch/HW')
DATA_PATH = PATH / 'custom_dataset'
MODELS_PATH = PATH / 'models'

In [None]:
while not (PATH / 'custom_dataset.zip').is_file():
    pass
with zipfile.ZipFile(PATH / 'custom_dataset.zip', 'r') as zip_ref:
    zip_ref.extractall(PATH)

In [None]:
sign_dict = dict()

for cls in (DATA_PATH).glob('*'):
    sign_dict[int(cls.name[:2])] = cls.name[3:]

sign_dict

{0: 'Without',
 1: 'Minus',
 2: 'Hi',
 3: 'Ok',
 4: 'Cool',
 5: 'Fist',
 6: 'Index',
 7: 'Two'}

Названия классов здесь даны больше для наглядности, в дальнейшем они могут (и будут) гибко изменяться

In [None]:
columns = ['Path', 'Img', 'Class', 'Class_int']

paths_dict = dict()
for clm in columns:
    paths_dict[clm] = []

for cls in DATA_PATH.glob('*'):
        for f in cls.glob('*'):
            img = Image.open(f)
            paths_dict['Path'].append(f)
            paths_dict['Img'].append(img)
            paths_dict['Class'].append(cls.name)
            paths_dict['Class_int'].append(int(cls.name[:2]))

df_paths = pd.DataFrame(paths_dict)

paths_train, paths_test = train_test_split(df_paths,
                                           test_size=0.2,
                                           shuffle=True,
                                           stratify=df_paths['Class_int'])

In [None]:
paths_train.head(5)

Unnamed: 0,Path,Img,Class,Class_int
1736,/content/drive/MyDrive/Colab Notebooks/Torch/H...,<PIL.PngImagePlugin.PngImageFile image mode=RG...,01_Minus,1
229,/content/drive/MyDrive/Colab Notebooks/Torch/H...,<PIL.PngImagePlugin.PngImageFile image mode=RG...,00_Without,0
10123,/content/drive/MyDrive/Colab Notebooks/Torch/H...,<PIL.PngImagePlugin.PngImageFile image mode=RG...,07_Two,7
3092,/content/drive/MyDrive/Colab Notebooks/Torch/H...,<PIL.PngImagePlugin.PngImageFile image mode=RG...,02_Hi,2
7185,/content/drive/MyDrive/Colab Notebooks/Torch/H...,<PIL.PngImagePlugin.PngImageFile image mode=RG...,05_Fist,5


In [None]:
paths_train.groupby('Class', as_index=False).count()

Unnamed: 0,Class,Path,Img,Class_int
0,00_Without,1228,1228,1228
1,01_Minus,1120,1120,1120
2,02_Hi,1120,1120,1120
3,03_Ok,1120,1120,1120
4,04_Cool,1120,1120,1120
5,05_Fist,1120,1120,1120
6,06_Index,1120,1120,1120
7,07_Two,1120,1120,1120


In [None]:
paths_test.groupby('Class').count()

Unnamed: 0_level_0,Path,Img,Class_int
Class,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
00_Without,307,307,307
01_Minus,280,280,280
02_Hi,280,280,280
03_Ok,280,280,280
04_Cool,280,280,280
05_Fist,280,280,280
06_Index,280,280,280
07_Two,280,280,280


In [None]:
class SignDataset(Dataset):
    def __init__(self,
                 df,
                 transforms=None,
                 path_name='Img',
                 class_name='Class_int',
                 ):
        self.df = df.reset_index(drop=True)
        self._transforms = transforms
        self._path_name = path_name
        self._class_name = class_name

    def __getitem__(self, idx):
        img = self.df[self._path_name][idx]
        label = torch.tensor(self.df[self._class_name][idx])
        if self._transforms is not None:
            img = self._transforms(img)
        else:
            img = transforms.ToTensor()(img)
        return img, label
    
    def __len__(self):
        return len(self.df)

In [None]:
# TODO: create augmentations
ds_transforms = transforms.Compose([transforms.Resize((240, 320)),
                                    transforms.ToTensor()
                                    ]
                                   )

train_dataset = SignDataset(paths_train, ds_transforms)
valid_dataset = SignDataset(paths_test, ds_transforms)

train_loader = DataLoader(train_dataset,
                          batch_size=32,
                          drop_last=True
                          )

valid_loader = DataLoader(valid_dataset,
                          batch_size=32,
                          drop_last=True
                          )

Основа модели на `ResNet18` с заменой выходных слоев:

In [None]:
class MyResNet(nn.Module):
    def __init__(self, out_features, *argw, **kwargs):
        super().__init__(*argw, **kwargs)

        self.model = resnet18()
        self.model.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
                                     bias=False)
        self.model.fc = nn.Linear(in_features=512, out_features=out_features, bias=True)
    
    def forward(self, x):
        return self.model(x)

In [None]:
sign_detection = MyResNet(out_features=8).to(device)

In [None]:
summary(sign_detection, (3, 240, 320))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 64, 120, 160]           9,408
       BatchNorm2d-2         [-1, 64, 120, 160]             128
              ReLU-3         [-1, 64, 120, 160]               0
         MaxPool2d-4           [-1, 64, 60, 80]               0
            Conv2d-5           [-1, 64, 60, 80]          36,864
       BatchNorm2d-6           [-1, 64, 60, 80]             128
              ReLU-7           [-1, 64, 60, 80]               0
            Conv2d-8           [-1, 64, 60, 80]          36,864
       BatchNorm2d-9           [-1, 64, 60, 80]             128
             ReLU-10           [-1, 64, 60, 80]               0
       BasicBlock-11           [-1, 64, 60, 80]               0
           Conv2d-12           [-1, 64, 60, 80]          36,864
      BatchNorm2d-13           [-1, 64, 60, 80]             128
             ReLU-14           [-1, 64,

In [None]:
epochs = 60
lr = 0.001
optimizer = torch.optim.Adam(sign_detection.parameters(), betas=(0.9, 0.95), lr=lr)

In [None]:
epoch_losses = []
test_epoch_losses = []
epoch_acc_test = []
epoch_acc_train = []
min_train_loss, min_valid_loss, max_valid_acc = 1000.0, 1000.0, 0.0

for epoch in tqdm(range(epochs)):
    
    running_loss = 0.0
    epoch_loss = []
    train_results = torch.tensor([], dtype=torch.int).to(device)

    sign_detection.train()
    for batch_idx, (data, labels) in enumerate(train_loader):
        data = data.to(device)
        labels = labels.to(device)
        
        optimizer.zero_grad()
        
        outputs = sign_detection(data)
        loss = F.cross_entropy(outputs, labels)
        loss.backward()
        optimizer.step()
        
        preds = outputs.argmax(axis=1)
        train_results = torch.concat([train_results, preds == labels])

        running_loss += loss.item()
        epoch_loss.append(loss.item())
    
    sign_detection.eval()
    test_running_loss = 0
    test_epoch_loss = []
    test_results = torch.tensor([], dtype=torch.int).to(device)
    for batch_idx, (data, labels) in enumerate(valid_loader):
        sign_detection.eval()
        data = data.to(device)
        labels = labels.to(device)

        outputs = sign_detection(data)
        loss = F.cross_entropy(outputs, labels)

        preds = outputs.argmax(axis=1)
        test_results = torch.concat([test_results, preds == labels])
        
        test_running_loss += loss.item()
        test_epoch_loss.append(loss.item())
    
    train_acc = float(train_results.sum() / train_results.shape[0])
    test_acc = float(test_results.sum() / test_results.shape[0])
    if epoch%2:
        print(f'Epoch {epoch+1}, loss:', round(np.mean(epoch_loss), 4),
            'test loss:', round(np.mean(test_epoch_loss), 4), end='. ')
        print(f'Train acc is {round(train_acc, 4)}', end='. ')
        print(f'Test acc is {round(test_acc, 4)}')

    min_train_loss = min(min_train_loss, np.mean(epoch_loss))
    min_valid_loss = min(min_valid_loss, np.mean(test_epoch_loss))
    max_valid_acc = max(max_valid_acc, test_acc)
    if max_valid_acc == test_acc:
        torch.save(sign_detection, MODELS_PATH / f'sign_d_ep:{epoch+1}_acc:{round(test_acc, 3)}')
    epoch_losses.append(epoch_loss)
    test_epoch_losses.append(test_epoch_loss)
    epoch_acc_train.append(train_acc)
    epoch_acc_test.append(test_acc)


  3%|▎         | 2/60 [03:42<1:47:14, 110.95s/it]

Epoch 2, loss: 0.1534 test loss: 0.4297. Train acc is 0.9501. Test acc is 0.8527


  7%|▋         | 4/60 [07:24<1:43:31, 110.92s/it]

Epoch 4, loss: 0.035 test loss: 0.1867. Train acc is 0.9886. Test acc is 0.9357


 10%|█         | 6/60 [11:06<1:39:48, 110.90s/it]

Epoch 6, loss: 0.018 test loss: 0.146. Train acc is 0.9948. Test acc is 0.9576


 13%|█▎        | 8/60 [14:47<1:35:57, 110.72s/it]

Epoch 8, loss: 0.0137 test loss: 0.024. Train acc is 0.996. Test acc is 0.992


 17%|█▋        | 10/60 [18:28<1:32:09, 110.58s/it]

Epoch 10, loss: 0.0146 test loss: 0.0665. Train acc is 0.9945. Test acc is 0.9879


 20%|██        | 12/60 [22:09<1:28:32, 110.69s/it]

Epoch 12, loss: 0.0068 test loss: 0.0747. Train acc is 0.9978. Test acc is 0.9754


 23%|██▎       | 14/60 [25:52<1:25:01, 110.91s/it]

Epoch 14, loss: 0.0061 test loss: 0.1843. Train acc is 0.9978. Test acc is 0.958


 27%|██▋       | 16/60 [29:34<1:21:24, 111.00s/it]

Epoch 16, loss: 0.0057 test loss: 0.0382. Train acc is 0.9986. Test acc is 0.9933


 30%|███       | 18/60 [33:17<1:17:49, 111.19s/it]

Epoch 18, loss: 0.0042 test loss: 0.0001. Train acc is 0.9987. Test acc is 1.0


 33%|███▎      | 20/60 [37:00<1:14:08, 111.21s/it]

Epoch 20, loss: 0.0054 test loss: 0.0059. Train acc is 0.999. Test acc is 0.9987


 35%|███▌      | 21/60 [39:30<1:13:22, 112.89s/it]


KeyboardInterrupt: ignored

Тестовые метрики выросли намного более быстро, чем предполагалось. Скорее связано с достаточно однотипным датасетом. Однако работает вроде неплохо.