<a href="https://colab.research.google.com/github/Taiga10969/Learn-the-basics/blob/main/timm/timm_ViT_finetuning/sample.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## 1．必要ライブラリのインポート

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torchvision
import torchvision.transforms as T
import torch.optim as optim

import time
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

## 2．データを用意
事前学習モデルを用いる為1辺の画像の長さを224pxに変更して用意する．

In [None]:
transforms = T.Compose([T.ToTensor(),
                        T.Resize(224),
                        T.CenterCrop(224),
                        T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
                        ])


batch_size = 64

train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transforms)

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size,
                                          shuffle=True, num_workers=2)

test_dataset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transforms)

test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size,
                                         shuffle=False, num_workers=2)

classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

### データの確認

In [None]:
plt.figure(figsize=(20, 10))
for i in range(10):
    image, label = train_dataset[i]
    #print(image)
    image = image / 2 + 0.5  # unnormalize
    image = image.numpy()
    image = np.transpose(image, (1, 2, 0))
    print(image.shape)
    ax = plt.subplot(1, 10, i+1)
    plt.imshow(image)
    ax.axis('off')
    ax.set_title('label : {}'.format(classes[int(label)]), fontsize=15)
plt.show()

## 3．GPUの使用確認

In [None]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print("device : ", device)

### GPUメモリの確認

In [None]:
!nvidia-smi

## 4．学習済みモデルの用意

### 学習･検証を行う関数の用意

In [7]:
def train(model, train_loader, criterion, optimizer, device):

    # ネットワークモデルを学習モードに設定
    model.train()

    sum_loss = 0.0
    count = 0

    for data, label in train_loader:
        count += len(label)
        data, label = data.to(device), label.to(device)
        optimizer.zero_grad()
        outputs = model(data)
        loss = criterion(outputs, label)
        loss.backward()
        optimizer.step()
        sum_loss += loss.item()

    return sum_loss/count

In [8]:
def val(model, val_loader, criterion, device):

    # ネットワークモデルを評価モードに設定
    model.eval()

    sum_loss = 0.0
    count = 0
    correct = 0

    with torch.no_grad():
        for data, label in val_loader:
            count += len(label)
            data, label = data.to(device), label.to(device)
            outputs = model(data)
            loss = criterion(outputs, label)
            sum_loss += loss.item()
            pred = torch.argmax(outputs, dim=1)
            correct += torch.sum(pred == label)

    accuracy_rate = (correct / count).cpu().detach()

    return sum_loss/count, accuracy_rate

### PyTorch Image Modelsライブラリのインストール

In [None]:
!pip install timm

In [10]:
import timm
from pprint import pprint
#model_names = timm.list_models(pretrained=True)
#pprint(model_names)

### モデルを生成

In [None]:
model = timm.create_model('vit_small_patch16_224', pretrained=True, num_classes=10)
model.to(device)

criterion = nn.CrossEntropyLoss()
criterion.to(device)

optimizer = optim.SGD(model.parameters(), lr=0.01)

print("model : ", model)
print("criterion : ", criterion)
print("optimizer : ", optimizer)

## 5．学習

In [None]:
num_epoch = 25
train_loss_list = []
val_loss_list = []
accuracy_rate_list = []

start = time.time()
for epoch in range(1, num_epoch+1, 1):
    train_loss = train(model, train_loader, criterion, optimizer, device)
    val_loss, accuracy_rate = val(model, test_loader, criterion, device)

    train_loss_list.append(train_loss)
    val_loss_list.append(val_loss)
    accuracy_rate_list.append(accuracy_rate)

    process_time = time.time() - start

    print("epoch : {}, train_loss : {}, test_loss : {}, accuracy_rate : {}, time : {}".format(epoch, train_loss, val_loss, accuracy_rate, process_time))
print("training_time : {}".format(time.time() - start))

## 6．学習の確認

### クラス別認識率の算出

In [None]:
model.eval()
class_count_list = [0,0,0,0,0,0,0,0,0,0]
class_accuracy_rate_list = [0,0,0,0,0,0,0,0,0,0]

for i in range(len(test_dataset)):
    data, label = test_dataset[i]
    data  = data.view(-1, 3, 224, 224).to(device)
    class_count_list[label] = class_count_list[label] + 1
    # 推論
    prediction_label = torch.argmax(model(data))
    if label == prediction_label:
        class_accuracy_rate_list[label] = class_accuracy_rate_list[label] + 1

for i in range(10):
    class_accuracy = class_accuracy_rate_list[i] / class_count_list[i]
    sum_accuracy = sum(class_accuracy_rate_list) / sum(class_count_list)
    print("class{} : {:.5f}  ( {} / {})".format(i, class_accuracy, class_accuracy_rate_list[i], class_count_list[i]))
print("sum_accuracy : {} ( {} / {})".format(sum_accuracy, sum(class_accuracy_rate_list), sum(class_count_list)))

### 学習曲線の可視化

In [None]:
plt.plot(range(1, len(train_loss_list)+1, 1), train_loss_list, c='b', label='train loss')
plt.plot(range(1, len(val_loss_list)+1, 1), val_loss_list, c='r', label='val loss')
plt.xlabel("epoch")
plt.ylabel("loss")
plt.legend()
plt.grid()
plt.savefig('Vit_CIFAR-10_finetuning_loss.svg')
plt.show()

## 学習済みモデルのパラメータを保存

In [None]:
torch.save(model.to('cpu').state_dict(),'model_vit_small_patch16_224_finetuning.pth')

In [None]:
!cp model_vit_small_patch16_224_finetuning.pth /content/drive/MyDrive/OLD/Research/ViT_Research