<a href="https://colab.research.google.com/github/nRknpy/lab-work/blob/main/asl_vit_finetuning.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# ViTを用いた手話の認識
事前学習されたVisionTransformerのモデルを，手話のデータセットを用いてファインチューニングする．

# モジュールをインポート

In [None]:
!pip install transformers datasets

## データセットの準備
### ダウンロード
データセットは，ASL Fingerspelling Images（ https://empslocal.ex.ac.uk/people/staff/np331/index.php?section=FingerSpellingDataset ） を用いる．
次のコマンドでダウンロード，解凍を行う．

In [None]:
!wget http://www.cvssp.org/FingerSpellingKinect2011/fingerspelling5.tar.bz2
!tar -jxvf fingerspelling5.tar.bz2

### Dataset，collate_fnを作成
データセットの中にはRGB画像と深度画像が含まれているが，ここでは前者のみ使う．

pytorchのImageFolderを用いてDatasetを作成するために次のようなディレクトリを作成する．

    asl
    ├── a
        ├── Acolor_0_0002.png
        ├── Acolor_0_0003.png
        ├── Acolor_0_0004.png
        ︙
    ├── b
    ├── c
    ├── d
    ├── e
    ︙

次の関数を用いてDataset用のディレクトリを作成する．

In [None]:
import os
import shutil

def prepare_asl_dataset(source, destination="asl"):
    cnt = 0
    for person in os.listdir(source):
        for label in os.listdir(source+'/'+person):
            for image in os.listdir(source+'/'+person+'/'+label):
                if image[0]=='c':
                    image_path = source+'/'+person+'/'+label+'/'+image
                    os.makedirs(destination+'/'+label, exist_ok=True)
                    shutil.copyfile(image_path, destination+'/'+label+'/'+person+image)
                    cnt += 1
    print("image count:", cnt)

source_dir = 'asl'
prepare_asl_dataset("dataset5", source_dir)

In [1]:
import torch
torch.cuda.is_available()

True


モデルの事前学習で行われた画像の前処理の情報をもとに，オリジナルのデータに前処理を施す．前処理のクラスはhuggingfaceからロードできる．

In [2]:
from transformers import ViTFeatureExtractor

feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224-in21k')

  from .autonotebook import tqdm as notebook_tqdm


上で作成したディレクトリから，Datasetを作成する．

In [3]:
import torch
import torchvision
from torchvision.datasets import ImageFolder
from torchvision import transforms

all_dataset = ImageFolder(root='asl')

label2id = all_dataset.class_to_idx
id2label = {label:id for id,label in label2id.items()}
label2id

{'a': 0,
 'b': 1,
 'c': 2,
 'd': 3,
 'e': 4,
 'f': 5,
 'g': 6,
 'h': 7,
 'i': 8,
 'k': 9,
 'l': 10,
 'm': 11,
 'n': 12,
 'o': 13,
 'p': 14,
 'q': 15,
 'r': 16,
 's': 17,
 't': 18,
 'u': 19,
 'v': 20,
 'w': 21,
 'x': 22,
 'y': 23}

### 前処理を定義
データに対して次のような前処理を施す．torchvisionのtransformを用いて行う．

1.   224x224にリサイズ
2.   事前学習データの平均及び分散を用いて画像を正規化

訓練用データセットには，加えてランダムな左右反転を施す（左右両方の手を学習させるため）．

In [4]:
from torchvision.transforms import (CenterCrop, 
                                    Compose, 
                                    Normalize, 
                                    RandomHorizontalFlip,
                                    RandomResizedCrop, 
                                    Resize, 
                                    ToTensor)

normalize = Normalize(mean=feature_extractor.image_mean, std=feature_extractor.image_std)
_train_transforms = Compose(
        [
            RandomResizedCrop(tuple(feature_extractor.size.values())),
            RandomHorizontalFlip(),
            ToTensor(),
            normalize,
        ]
    )

_val_transforms = Compose(
        [
            Resize(tuple(feature_extractor.size.values())),
            CenterCrop(tuple(feature_extractor.size.values())),
            ToTensor(),
            normalize,
        ]
    )

In [5]:
from torch.utils.data import Dataset

# random_splitを用いてデータセットをtrain,val,testに分割すると別々のtransformを割り当てられないため，
# 自作のSubsetを定義してtransformを後から割り当てる．
class SetTransform(Dataset):
  def __init__(self, dataset, transform=None):
    self.dataset = dataset
    self.transform = transform
  
  def __getitem__(self, idx):
    img, label = self.dataset[idx]
    if self.transform:
      img = self.transform(img)
    return img, label
  
  def __len__(self):
    return len(self.dataset)

データセットをtrain,validation,test用に分割し，各々にtransformを割り当てる．

ここでは，val,test用のデータをそれぞれ1000個ずつとし，残りを訓練に使う．

In [6]:
val_size = 1000
test_size = 1000
train_size = len(all_dataset) - val_size - test_size

test_dataset, trainval_dataset = torch.utils.data.random_split(all_dataset, [test_size, train_size + val_size])
train_dataset, val_dataset = torch.utils.data.random_split(trainval_dataset, [train_size, val_size])

train_dataset = SetTransform(train_dataset, _train_transforms)
val_dataset = SetTransform(val_dataset, _val_transforms)
test_dataset = SetTransform(test_dataset, _val_transforms)

print('train:', len(train_dataset))
print('validation:', len(val_dataset))
print('test:', len(test_dataset))

train: 63774
validation: 1000
test: 1000


バッチを`pixel_values`と`labels`をキーに持つ辞書にする，collate_fnを作成する．これは，Trainer内でモデルに入力する際にアンパックするためである．

In [7]:
def collate_fn(examples):
    imgs, labels = zip(*examples)
    pixel_values = torch.stack(imgs)
    labels = torch.tensor(labels)
    return {"pixel_values": pixel_values, "labels": labels}

挙動確認．バッチが辞書型になって返される．

In [8]:
from torch.utils.data import DataLoader

train_dataloader = DataLoader(train_dataset, collate_fn=collate_fn, batch_size=4)

batch = next(iter(train_dataloader))
for k,v in batch.items():
  if isinstance(v, torch.Tensor):
    print(k, v.shape)

pixel_values torch.Size([4, 3, 224, 224])
labels torch.Size([4])


## モデルの定義
huggingfaceから，事前学習されたViTのモデルをロードする．

In [9]:
from transformers import ViTForImageClassification

model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224-in21k',
                                                  num_labels=len(label2id),
                                                  label2id=label2id,
                                                  id2label=id2label)

Some weights of the model checkpoint at google/vit-base-patch16-224-in21k were not used when initializing ViTForImageClassification: ['pooler.dense.weight', 'pooler.dense.bias']
- This IS expected if you are initializing ViTForImageClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing ViTForImageClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224-in21k and are newly initialized: ['classifier.weight', 'classifier.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


TrainingArgumentsを設定する．詳細は以下の通り．

* モデルのチェックポイント，パラメータを`asl-vit`ディレクトリに保存
* 1エポックごとにチェックポイントを保存
* 1エポックごとに評価
* 学習率=$2.0×10^{-5}$
* 訓練データのバッチサイズ=10
* valデータのバッチサイズ=4
* 3エポック分訓練を回す
* 重み減衰=0.01
* 学習中に得られたベストなモデルを学習後にロードする


In [10]:
from transformers import TrainingArguments, Trainer

metric_name = "accuracy"

args = TrainingArguments(
    f"asl-vit-test",
    evaluation_strategy="epoch",
    save_strategy='epoch',
    learning_rate=1e-6,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    num_train_epochs=10,
    weight_decay=1e-3,
    load_best_model_at_end=True,
    metric_for_best_model=metric_name,
    logging_dir='logs',
    remove_unused_columns=False,
)

評価基準（ここでは正解率）を定義

In [11]:
from datasets import load_metric
import numpy as np

metric = load_metric("accuracy")

def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    predictions = np.argmax(predictions, axis=1)
    return metric.compute(predictions=predictions, references=labels)

  metric = load_metric("accuracy")


Trainerを定義．

In [12]:
from transformers import EarlyStoppingCallback

trainer = Trainer(
    model,
    args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    data_collator=collate_fn,
    compute_metrics=compute_metrics,
    tokenizer=feature_extractor,
    callbacks=[EarlyStoppingCallback(early_stopping_patience=2, early_stopping_threshold=0.005)]
)

## 訓練

In [13]:
trainer.train()

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mrkn[0m. Use [1m`wandb login --relogin`[0m to force relogin


  1%|          | 487/39860 [03:51<4:11:01,  2.61it/s] 

In [None]:



trainer.save_state()
trainer.save_model()

## 評価
ファインチューニングしたモデルの性能を評価する．

テストデータをモデルに入力し，推定ラベルを得る．

In [None]:
outputs = trainer.predict(test_dataset)

テストデータに対する損失，正解率等を計算．

In [None]:
print(outputs.metrics)

混同行列を作成．

In [None]:
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
import matplotlib.pyplot as plt

y_true = outputs.label_ids
y_pred = outputs.predictions.argmax(1)

labels = list(label2id.keys())
cm = confusion_matrix(y_true, y_pred)
disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=labels)
fig, ax = plt.subplots(figsize=(12,12))
disp.plot(ax=ax)

## 内部表現の比較

### fine-tuning前

In [None]:
from transformers import ViTForImageClassification

non_finetuned_model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224-in21k',
                                                                num_labels=len(label2id),
                                                                label2id=label2id,
                                                                id2label=id2label)

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

In [None]:
from tqdm import tqdm

def CLE_tokens(model, tokenizer, dataset, device):
    tokens = []
    labels = []
    for img, label in tqdm(dataset):
        feature = tokenizer(img, return_tensors='pt').pixel_values.to(device)
        with torch.no_grad():
            token = model(feature, output_hidden_states=True).hidden_states[-1][0,0,:]
        tokens.append(token.cpu())
        labels.append(label)
    return torch.stack(tokens).squeeze(), torch.tensor(labels)

In [None]:
from sklearn.manifold import TSNE
from sklearn.decomposition import PCA
from umap import UMAP

import matplotlib.pyplot as plt
import matplotlib.colors as colors
import matplotlib.cm as cm
import random

def plot_tokens(tokens, labels, n_neighbors):
    # tsne = TSNE(n_components=2)
    # zs = tsne.fit_transform(tokens.numpy())
    umap = UMAP(n_neighbors=n_neighbors)
    zs = umap.fit_transform(tokens.numpy())
    ys = labels.numpy()
    print(zs.shape)
    print(ys.shape)
    fig = plt.figure()
    ax = fig.add_subplot()
    ax.set_xlabel('feature-1')
    ax.set_ylabel('feature-2')
    cmap = cm.get_cmap('gist_ncar')
    
    label2point = {}
    for x, y in zip(zs, ys):
        mp = ax.scatter(x[0], x[1],
                        alpha=1,
                        label=id2label[y],
                        # c=label2color[y],
                        c=y,
                        cmap=cmap,
                        vmin=0,
                        vmax=len(set(ys)),
                        s=3,)
        label2point[id2label[y]] = mp
    labels, handles = zip(*sorted(label2point.items()))
    fig.legend(handles, labels, bbox_to_anchor=(0, -0.15), loc='lower left', ncol=10)
    plt.show()


In [None]:
tokens, labels = CLE_tokens(non_finetuned_model.to(device),
                            feature_extractor,
                            test_dataset,
                            device)
plot_tokens(tokens, labels, 75)

### fine-tuning後

In [None]:
finetuned_model = ViTForImageClassification.from_pretrained('asl-vit/',
                                                            num_labels=len(label2id),
                                                            label2id=label2id,
                                                            id2label=id2label).to(device)

In [None]:
tokens, labels = CLE_tokens(finetuned_model,
                            feature_extractor,
                            test_dataset,
                            device)
plot_tokens(tokens, labels, 75)