In [1]:
import torch
from torch.utils.data import DataLoader
import torch.nn as nn
import os
import torchvision

# 画像のサイズを定義
image_size = 256
# バッチサイズを定義
batch_size = 32
# データのルートディレクトリを指定
root_dir = "./data"

# 画像変換の設定
transform = torchvision.transforms.Compose([
    torchvision.transforms.Resize(image_size),  # 画像サイズをリサイズ
    torchvision.transforms.ToTensor()  # 画像をテンソルに変換
])

# 訓練データセットをロード
train_datasets = torchvision.datasets.CIFAR10(
    root=root_dir, train=True, transform=transform, download=True
)

# 検証データセットをロード
val_datasets = torchvision.datasets.CIFAR10(
    root=root_dir, train=False, transform=transform, download=True
)

# 訓練データ用のデータローダーを作成
train_dataloader = DataLoader(
    train_datasets, batch_size=batch_size, shuffle=True, drop_last=True
)

# 検証データ用のデータローダーを作成
val_dataloader = DataLoader(
    val_datasets, batch_size=batch_size, shuffle=False, drop_last=True
)



Files already downloaded and verified
Files already downloaded and verified


In [2]:
for batch, (images, labels) in enumerate(val_dataloader):
    print(batch + 1, images.shape)

1 torch.Size([32, 3, 256, 256])
2 torch.Size([32, 3, 256, 256])
3 torch.Size([32, 3, 256, 256])
4 torch.Size([32, 3, 256, 256])
5 torch.Size([32, 3, 256, 256])
6 torch.Size([32, 3, 256, 256])
7 torch.Size([32, 3, 256, 256])
8 torch.Size([32, 3, 256, 256])
9 torch.Size([32, 3, 256, 256])
10 torch.Size([32, 3, 256, 256])
11 torch.Size([32, 3, 256, 256])
12 torch.Size([32, 3, 256, 256])
13 torch.Size([32, 3, 256, 256])
14 torch.Size([32, 3, 256, 256])
15 torch.Size([32, 3, 256, 256])
16 torch.Size([32, 3, 256, 256])
17 torch.Size([32, 3, 256, 256])
18 torch.Size([32, 3, 256, 256])
19 torch.Size([32, 3, 256, 256])
20 torch.Size([32, 3, 256, 256])
21 torch.Size([32, 3, 256, 256])
22 torch.Size([32, 3, 256, 256])
23 torch.Size([32, 3, 256, 256])
24 torch.Size([32, 3, 256, 256])
25 torch.Size([32, 3, 256, 256])
26 torch.Size([32, 3, 256, 256])
27 torch.Size([32, 3, 256, 256])
28 torch.Size([32, 3, 256, 256])
29 torch.Size([32, 3, 256, 256])
30 torch.Size([32, 3, 256, 256])
31 torch.Size([32, 

In [3]:
class VisonTransformer(nn.Module):
    def __init__(self, num_classes, batch_size, image_size, num_channel, patch_size, embed_hidden_size, num_layer, num_head, device, MultiHeadAttention, Encoder):
        # ビジョン・トランスフォーマーのコンストラクタ
        super().__init__()
        self.device = device
        self.batch_size = batch_size
        self.patch_size = patch_size
        self.image_size = image_size
        self.num_channel = num_channel
        self.num_patch = int((image_size / patch_size) * (image_size / patch_size))
        self.num_token = self.num_patch + 1
        self.num_layer = num_layer
        self.num_head = num_head
        self.cls_id = torch.tensor(0, dtype=torch.long).to(device)
        self.cls_embedding = nn.Embedding(1, embed_hidden_size)
        self.embed_hidden_size = embed_hidden_size
        self.positional_embedding = nn.Embedding(self.num_token, embed_hidden_size)
        self.image_embedding = nn.Linear(patch_size * patch_size * self.num_patch, embed_hidden_size)
        self.layer_norm = nn.LayerNorm((batch_size, embed_hidden_size))
        self.dropout = nn.Dropout(p=0.9)
        self.fc = nn.Linear(embed_hidden_size, num_classes)
        args = (batch_size, image_size, num_channel, patch_size, embed_hidden_size, num_layer, num_head, device, MultiHeadAttention)
        self.setup_layer(num_layer, Encoder, args)

    def image_to_token(self, images, patch_size):
        # 画像をトークンに変換する関数
        batch_size, num_channel, width, height = self.batch_size, self.num_channel, self.image_size, self.image_size

        patch_window = torch.ones((patch_size, patch_size), dtype=torch.long)
        patch_window = patch_window.unsqueeze(0).expand(num_channel, patch_size, patch_size)\
            .unsqueeze(0).expand(batch_size, num_channel, patch_size, patch_size)

        token_list = []
        for row_idx in range(0, width, patch_size):
            for col_idx in range(0, height, patch_size):
                patch = images[:, :, row_idx: row_idx + patch_size, col_idx:col_idx + patch_size]
                token_list.append(patch)

        token_list = torch.stack(token_list, dim=0).transpose(0, 1).view(batch_size, 256, -1)

        return token_list

    def positional_encoding(self, num_token):
        # 位置エンコーディングを計算する関数
        position_ids = torch.tensor(list(range(num_token)), dtype=torch.long).expand(self.batch_size, -1).to(self.device)
        positional_embeds = self.positional_embedding(position_ids)

        return positional_embeds

    def setup_layer(self, num_layer, encoder, args):
        # エンコーダーレイヤーを設定する関数
        layer_list = []
        for _ in range(num_layer):
            layer_list.append(encoder(*args).to(self.device))

        module_list = nn.ModuleList(layer_list)
        self.layer_list = nn.Sequential(*module_list)

    def forward(self, images):
        # フォワードパスを定義する関数
        cls_tokens = self.cls_embedding(self.cls_id).unsqueeze(0).expand(self.batch_size, 1, -1)
        image_tokens = self.image_to_token(images, self.patch_size)
        tokens = torch.cat([cls_tokens, image_tokens], dim=1)  # torch.concat -> torch.catに修正
        positional_embeds = self.positional_encoding(self.num_token)
        embed_tokens = (tokens + positional_embeds)
        encoded_outputs = self.layer_list(embed_tokens)
        cls = encoded_outputs[:, 0, :]
        layer_norm = self.layer_norm(cls)
        dropout = self.dropout(layer_norm)
        outputs = self.fc(dropout)

        return outputs

class Encoder(nn.Module):
    def __init__(self, batch_size, image_size, num_channel, patch_size, embed_hidden_size, num_layer, num_head, device, MultiHeadAttention):
        # エンコーダーのコンストラクタ
        super().__init()

        num_patch = int((image_size / patch_size) * (image_size / patch_size))
        num_token = num_patch + 1
        self.layer_norm1 = nn.LayerNorm((batch_size, num_token, embed_hidden_size))
        self.layer_norm2 = nn.LayerNorm((batch_size, num_token, embed_hidden_size))
        self.dropout = nn.Dropout(p=0.9)
        self.mlp = nn.Linear(embed_hidden_size, embed_hidden_size)
        self.attention_layer = MultiHeadAttention(embed_hidden_size, num_head, device)
        self.gelu = nn.GELU().to(device)

    def forward(self, tokens):
        # エンコーダーのフォワードパスを定義する関数
        layer_norm1 = self.layer_norm1(tokens)
        dropout1 = self.dropout(layer_norm1)
        skip1 = tokens
        concat_attention = self.attention_layer(dropout1)

        outputs_tmp1 = concat_attention + skip1
        skip2 = outputs_tmp1
        layer_norm2 = self.layer_norm2(outputs_tmp1)
        dropout2 = self.dropout(layer_norm2)
        mlp = self.gelu(self.mlp(dropout2))
        outputs = mlp + skip2

        return outputs

class MultiHeadAttention(nn.Module):
    def __init__(self, embed_hidden_size, num_head, device):
        # マルチヘッドアテンションのコンストラクタ
        super().__init__()
        self.attention_layers = []
        self.query_layers = []
        self.key_layers = []
        self.value_layers = []
        self.num_head = num_head
        self.embed_hidden_size = embed_hidden_size
        self.multi_embed_hidden_size = int(embed_hidden_size / num_head)

        self.device = device
        self.setup_attention()

    def setup_attention(self):
        # アテンションのセットアップを行う関数
        for number in range(self.num_head):
            self.query_layers.append(nn.Linear(self.embed_hidden_size, self.multi_embed_hidden_size).to(self.device))
            self.key_layers.append(nn.Linear(self.embed_hidden_size, self.multi_embed_hidden_size).to(self.device))
            self.value_layers.append(nn.Linear(self.embed_hidden_size, self.multi_embed_hidden_size).to(self.device))

        self.query_layers = nn.ModuleList(self.query_layers)
        self.key_layers = nn.ModuleList(self.key_layers)
        self.value_layers = nn.ModuleList(self.value_layers)

    def output_attention(self, tokens):
        # アテンションの計算を行う関数
        for number in range(self.num_head):
            query = self.query_layers[number](tokens)
            key = self.key_layers[number](tokens)
            value = self.value_layers[number](tokens)
            attention = nn.Softmax(dim=-1)((query @ torch.transpose(key, 1, 2)) / torch.sqrt(torch.tensor(self.multi_embed_hidden_size))) @ value

            if number > 0:
                concat_attention = torch.cat([concat_attention, attention], dim=-1)  # torch.concat -> torch.catに修正
            else:
                concat_attention = attention

        return concat_attention

    def forward(self, tokens):
        # フォワードパスを定義する関数
        concat_attention = self.output_attention(tokens)

        return concat_attention


In [4]:
import torch
from torcheval.metrics.functional import (multiclass_accuracy, 
                                          multiclass_precision,
                                          multiclass_recall,
                                          multiclass_f1_score
                                          )

# デバイスを選択 (GPUが利用可能な場合はGPUを使用し、そうでなければCPUを使用)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# モデルの設定とハイパーパラメータ
kwargs = {
    "num_classes": 10, 
    "batch_size": 32, 
    "image_size": 256, 
    "num_channel": 3, 
    "patch_size": 16, 
    "embed_hidden_size": 768,
    "num_layer": 12, 
    "num_head": 8,
    "device": device,
    "MultiHeadAttention": MultiHeadAttention, 
    "Encoder": Encoder
}

# モデルを初期化し、デバイスに配置
model = VisonTransformer(**kwargs).to(device)

# 損失関数 (CrossEntropyLoss) の設定
criterion = torch.nn.CrossEntropyLoss()

# オプティマイザ (Adam) の設定
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

# モデルの出力を保存するディレクトリ
output_dir = "./output"

# 訓練中の損失と正解率、精度、再現率、F1スコアを記録するためのリスト
train_loss_list = []
val_loss_list = []
train_correct_list = []
val_correct_list = []
presision_list = []
recall_list = []
f1_list = []

# エポック数
epochs = 1

# ミニバッチのサイズ
batch_size = kwargs["batch_size"]

# 最小の損失を初期化
min_loss = float('inf')

# エポックのループ
for epoch in range(epochs):
    # "train"と"val"の2つのステップでループ
    for step in ["train", "val"]:
        if step == "train":
            model.train()
            dataloader = train_dataloader
        else:
            model.eval()
            dataloader = val_dataloader

        running_loss = 0.0
        running_correct = 0.0

        # ミニバッチのループ
        for batch, (images, labels) in enumerate(dataloader):
            images = images.to(device)
            labels = labels.to(device)

            with torch.set_grad_enabled(model.training):
                outputs = model(images)
                pred = torch.argmax(outputs, dim=-1)
                loss = criterion(outputs, labels).sum()
                correct = multiclass_accuracy(
                    input=outputs,
                    target=labels,
                    num_classes=kwargs["num_classes"],
                    average="micro"
                )

                if model.training:
                    optimizer.zero_grad()
                    loss.backward()
                    optimizer.step()

            running_loss += loss.item() / batch_size
            running_correct += correct.item() / batch_size

            if model.training:
                train_loss_list.append(running_loss)
                train_correct_list.append(running_correct)
                print(f"Step: Train Epoch: {epoch + 1}/{epochs} train_loss: {running_loss / (batch + 1)} train_correct: {running_correct / (batch + 1)}")
            else:
                print(f"Step: Train Epoch: {epoch + 1}/{epochs} val_loss: {running_loss / (batch + 1)} val_correct: {running_correct / (batch + 1)}")
                val_loss_list.append(running_loss)
                val_correct_list.append(running_correct)
                
                # 精度、再現率、F1スコアの計算
                precision = multiclass_precision(
                    input=outputs,
                    target=labels,
                    num_classes=kwargs["num_classes"],
                    average="micro"
                ).item()

                recall = multiclass_recall(
                    input=outputs,
                    target=labels,
                    num_classes=kwargs["num_classes"],
                    average="micro"
                ).item()

                f1_score = multiclass_f1_score(
                    input=outputs,
                    target=labels,
                    num_classes=kwargs["num_classes"],
                    average="micro"
                ).item()

                recall_list.append(recall)
                presision_list.append(precision)
                f1_list.append(f1_score)

    if running_loss < min_loss:
        print("Model Save!")
        min_loss = running_loss
        if not os.path.exists(output_dir):
            os.mkdir(output_dir) 
        torch.save(model.state_dict(), os.path.join(output_dir, f"{batch + 1}.pth"))  # モデルを指定したディレクトリに保存



Step: Train Epoch: 1/1 train_loss: 0.09938900172710419 train_correct: 0.0078125
Step: Train Epoch: 1/1 train_loss: 0.09826721251010895 train_correct: 0.005859375
Step: Train Epoch: 1/1 train_loss: 0.10128515958786011 train_correct: 0.005533854166666667
Step: Train Epoch: 1/1 train_loss: 0.09919302351772785 train_correct: 0.005126953125
Step: Train Epoch: 1/1 train_loss: 0.0985305905342102 train_correct: 0.0046875
Step: Train Epoch: 1/1 train_loss: 0.09787710011005402 train_correct: 0.004557291666666667
Step: Train Epoch: 1/1 train_loss: 0.09641868727547782 train_correct: 0.004603794642857143
Step: Train Epoch: 1/1 train_loss: 0.09696859028190374 train_correct: 0.004638671875
Step: Train Epoch: 1/1 train_loss: 0.09915352198812696 train_correct: 0.004231770833333333
Step: Train Epoch: 1/1 train_loss: 0.09970039427280426 train_correct: 0.0041015625
Step: Train Epoch: 1/1 train_loss: 0.10100206190889532 train_correct: 0.003995028409090909
Step: Train Epoch: 1/1 train_loss: 0.10130457083384