<a href="https://colab.research.google.com/github/nRknpy/lab-work/blob/main/ASL_ViT_finetuning.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# ViTを用いた手話の認識
事前学習されたVisionTransformerのモデルを，手話のデータセットを用いてファインチューニングする．

# モジュールをインポート

In [3]:
!pip install transformers

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting transformers
  Downloading transformers-4.25.1-py3-none-any.whl (5.8 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m5.8/5.8 MB[0m [31m31.9 MB/s[0m eta [36m0:00:00[0m
Collecting huggingface-hub<1.0,>=0.10.0
  Downloading huggingface_hub-0.11.1-py3-none-any.whl (182 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m182.4/182.4 KB[0m [31m19.2 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting tokenizers!=0.11.3,<0.14,>=0.11.1
  Downloading tokenizers-0.13.2-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (7.6 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.6/7.6 MB[0m [31m51.8 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: tokenizers, huggingface-hub, transformers
Successfully installed huggingface-hub-0.11.1 tokenizers-0.13.2 transformers-4.25.1


## データセットの準備
データセットは，ASL Fingerspelling Images（ https://empslocal.ex.ac.uk/people/staff/np331/index.php?section=FingerSpellingDataset ） を用いる．
次のコマンドでダウンロード，解凍を行う．

In [1]:
!wget http://www.cvssp.org/FingerSpellingKinect2011/fingerspelling5.tar.bz2
!tar -jxvf fingerspelling5.tar.bz2

[1;30;43mストリーミング出力は最後の 5000 行に切り捨てられました。[0m
dataset5/E/u/color_20_0250.png
dataset5/E/u/color_20_0251.png
dataset5/E/u/color_20_0252.png
dataset5/E/u/color_20_0254.png
dataset5/E/u/color_20_0255.png
dataset5/E/u/color_20_0256.png
dataset5/E/u/color_20_0257.png
dataset5/E/u/color_20_0258.png
dataset5/E/u/color_20_0259.png
dataset5/E/u/color_20_0260.png
dataset5/E/u/color_20_0261.png
dataset5/E/u/color_20_0262.png
dataset5/E/u/color_20_0263.png
dataset5/E/u/color_20_0264.png
dataset5/E/u/color_20_0265.png
dataset5/E/u/color_20_0266.png
dataset5/E/u/color_20_0267.png
dataset5/E/u/color_20_0268.png
dataset5/E/u/color_20_0269.png
dataset5/E/u/color_20_0270.png
dataset5/E/u/color_20_0272.png
dataset5/E/u/color_20_0273.png
dataset5/E/u/color_20_0274.png
dataset5/E/u/color_20_0275.png
dataset5/E/u/color_20_0276.png
dataset5/E/u/color_20_0277.png
dataset5/E/u/color_20_0278.png
dataset5/E/u/color_20_0279.png
dataset5/E/u/color_20_0280.png
dataset5/E/u/color_20_0281.png
dataset5/E/u/color_20_02


pytorchのDatasetを作成するためにディレクトリの構造を次のように変更する必要がある．

    asl
    ├── a
        ├── color_0_0002.png
        ├── color_0_0003.png
        ├── color_0_0004.png
        ︙
    ├── b
    ├── c
    ├── d
    ├── e
    ︙

In [2]:
import os
import shutil

def prepare_asl_dataset(source, destination="asl"):
    cnt = 0
    for person in os.listdir(source):
        for label in os.listdir(source+'/'+person):
            for image in os.listdir(source+'/'+person+'/'+label):
                if image[0]=='c':
                    image_path = source+'/'+person+'/'+label+'/'+image
                    os.makedirs(destination+'/'+label, exist_ok=True)
                    shutil.copyfile(image_path, destination+'/'+label+'/'+image)
                    cnt += 1
    print("image count:", cnt)

prepare_asl_dataset("dataset5")

image count: 65774


In [4]:
from transformers import ViTFeatureExtractor

feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224-in21k')

Downloading:   0%|          | 0.00/160 [00:00<?, ?B/s]

In [None]:
from torchvision.transforms import (CenterCrop, 
                                    Compose, 
                                    Normalize, 
                                    RandomHorizontalFlip,
                                    RandomResizedCrop, 
                                    Resize, 
                                    ToTensor)

normalize = Normalize(mean=feature_extractor.image_mean, std=feature_extractor.image_std)

train_transforms = Compose([RandomResizedCrop(tuple(feature_extractor.size.values())),
                            RandomHorizontalFlip(),
                            ToTensor(),
                            normalize])
val_transforms = Compose([Resize(tuple(feature_extra))])

In [None]:
import torch
import torchvision
from torchvision.datasets import ImageFolder

all_dataset = ImageFolder(root='asl')