# Vision Transformer图像分类

[![下载Notebook](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r2.3/resource/_static/logo_notebook.svg)](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/r2.3/tutorials/application/zh_cn/cv/mindspore_vit.ipynb)&emsp;[![下载样例代码](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r2.3/resource/_static/logo_download_code.svg)](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/r2.3/tutorials/application/zh_cn/cv/mindspore_vit.py)&emsp;[![查看源文件](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r2.3/resource/_static/logo_source.svg)](https://gitee.com/mindspore/docs/blob/r2.3/tutorials/application/source_zh_cn/cv/vit.ipynb)

感谢[ZOMI酱](https://gitee.com/sanjaychan)对本文的贡献。


## Vision Transformer（ViT）简介

近些年，随着基于自注意（Self-Attention）结构的模型的发展，特别是Transformer模型的提出，极大地促进了自然语言处理模型的发展。由于Transformer的计算效率和可扩展性，它已经能够训练具有超过100B参数的空前规模的模型。

ViT则是自然语言处理和计算机视觉两个领域的融合结晶。在不依赖卷积操作的情况下，依然可以在图像分类任务上达到很好的效果。

### 模型结构

ViT模型的主体结构是基于Transformer模型的Encoder部分（部分结构顺序有调整，如：Normalization的位置与标准Transformer不同），其结构图[1]如下：

![vit-architecture](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r2.3/tutorials/application/source_zh_cn/cv/images/vit_architecture.png)

### 模型特点

ViT模型主要应用于图像分类领域。因此，其模型结构相较于传统的Transformer有以下几个特点：

1. 数据集的原图像被划分为多个patch（图像块）后，将二维patch（不考虑channel）转换为一维向量，再加上类别向量与位置向量作为模型输入。
2. 模型主体的Block结构是基于Transformer的Encoder结构，但是调整了Normalization的位置，其中，最主要的结构依然是Multi-head Attention结构。
3. 模型在Blocks堆叠后接全连接层，接受类别向量的输出作为输入并用于分类。通常情况下，我们将最后的全连接层称为Head，Transformer Encoder部分为backbone。

下面将通过代码实例来详细解释基于ViT实现ImageNet分类任务。

> 注意，本教程在CPU上运行时间过长，不建议使用CPU运行。

## 环境准备与数据读取

开始实验之前，请确保本地已经安装了Python环境并安装了MindSpore。

首先我们需要下载本案例的数据集，可通过<http://image-net.org>下载完整的ImageNet数据集，本案例应用的数据集是从ImageNet中筛选出来的子集。

运行第一段代码时会自动下载并解压，请确保你的数据集路径如以下结构。

```text
.dataset/
    ├── ILSVRC2012_devkit_t12.tar.gz
    ├── train/
    ├── infer/
    └── val/
```

In [1]:
%%capture captured_output
# 实验环境已经预装了mindspore==2.3.0，如需更换mindspore版本，可更改下面 MINDSPORE_VERSION 变量
!pip uninstall mindspore -y
!export MINDSPORE_VERSION=2.3.0
!pip install https://ms-release.obs.cn-north-4.myhuaweicloud.com/${MINDSPORE_VERSION}/MindSpore/unified/aarch64/mindspore-${MINDSPORE_VERSION}-cp39-cp39-linux_aarch64.whl --trusted-host ms-release.obs.cn-north-4.myhuaweicloud.com -i https://pypi.mirrors.ustc.edu.cn/simple

In [2]:
# 查看当前 mindspore 版本
!pip show mindspore

Name: mindspore
Version: 2.3.0
Summary: MindSpore is a new open source deep learning training/inference framework that could be used for mobile, edge and cloud scenarios.
Home-page: https://www.mindspore.cn
Author: The MindSpore Authors
Author-email: contact@mindspore.cn
License: Apache 2.0
Location: /home/mindspore/miniconda/envs/jupyter/lib/python3.9/site-packages
Requires: asttokens, astunparse, numpy, packaging, pillow, protobuf, psutil, scipy
Required-by: 


In [12]:
import os
import mindspore as ms
from mindspore.dataset import ImageFolderDataset
import mindspore.dataset.vision as transforms

# 修改数据集路径为相对路径
data_path = 'dataset'

# 定义均值和标准差
mean = [0.485 * 255, 0.456 * 255, 0.406 * 255]
std = [0.229 * 255, 0.224 * 255, 0.225 * 255]

# 创建训练数据集
dataset_train = ImageFolderDataset(os.path.join(data_path, "train"), shuffle=True)

# 定义数据增强和预处理操作
trans_train = [
    transforms.RandomCropDecodeResize(size=224, scale=(0.08, 1.0), ratio=(0.75, 1.333)),
    transforms.RandomHorizontalFlip(prob=0.5),
    transforms.Normalize(mean=mean, std=std),
    transforms.HWC2CHW()
]

# 应用数据增强和预处理操作
dataset_train = dataset_train.map(operations=trans_train, input_columns=["image"])

# 设置批处理大小
dataset_train = dataset_train.batch(batch_size=16, drop_remainder=True)

# 获取数据集的列名
print(dataset_train.get_col_names())

# 验证数据集是否成功加载
try:
    for data in dataset_train.create_dict_iterator(num_epochs=1):
        print("读取到数据")
        break
    print("数据集加载成功！")
except Exception as e:
    print(f"加载数据集时出错：{str(e)}")

['image', 'label']
读取到数据
数据集加载成功！


In [33]:
import os
import mindspore.dataset as ds
import mindspore.dataset.vision.c_transforms as C
import mindspore.dataset.transforms.c_transforms as C2

# 设置数据集路径
dataset_dir = './dataset'

# 定义数据增强和预处理操作
def create_dataset(data_path, batch_size=32, repeat_size=1, num_parallel_workers=8):
    # 图像标准化参数
    mean = [0.485 * 255, 0.456 * 255, 0.406 * 255]
    std = [0.229 * 255, 0.224 * 255, 0.225 * 255]

    # 数据增强操作
    transform_img = [
        C.Resize((256, 256)),
        C.RandomCrop((224, 224)),
        C.RandomHorizontalFlip(prob=0.5),
        C.Normalize(mean=mean, std=std),
        C.HWC2CHW()
    ]

    # 创建数据集
    dataset = ds.ImageFolderDataset(data_path, num_parallel_workers=num_parallel_workers, shuffle=True)
    
    # 应用数据增强
    dataset = dataset.map(operations=transform_img, input_columns="image")
    
    # 对标签进行独热编码
    one_hot_op = C2.OneHot(num_classes=2)  # 假设有两个类别：pos和neg
    dataset = dataset.map(operations=one_hot_op, input_columns=["label"])

    # 批处理
    dataset = dataset.batch(batch_size, drop_remainder=True)
    dataset = dataset.repeat(repeat_size)

    return dataset

# 创建训练集
train_dataset = create_dataset(os.path.join(dataset_dir, 'train'))
print("训练集列名:", train_dataset.get_col_names())
print("训练集大小:", train_dataset.get_dataset_size())

# 创建验证集
val_dataset = create_dataset(os.path.join(dataset_dir, 'val'))
print("验证集列名:", val_dataset.get_col_names())
print("验证集大小:", val_dataset.get_dataset_size())

# 获取一个批次的数据样本
for data in train_dataset.create_dict_iterator():
    print("图像形状:", data['image'].shape)
    print("标签形状:", data['label'].shape)
    break



训练集列名: ['image', 'label']
训练集大小: 23
验证集列名: ['image', 'label']
验证集大小: 13


[ERROR] MD(32,fffd724cf120,python):2024-08-29-12:08:57.968.873 [mindspore/ccsrc/minddata/dataset/util/task_manager.cc:230] InterruptMaster] MindSpore dataset is terminated with err msg: Exception thrown from dataset pipeline. Refer to 'Dataset Pipeline Error Message'. map operation: [Resize] failed. The corresponding data file is: ./dataset/train/neg/L536.jpg. Resize: the image tensor should have at least two dimensions. You may need to perform Decode first.
Line of code : 174
File         : mindspore/ccsrc/minddata/dataset/kernels/image/image_utils.cc



RuntimeError: Exception thrown from dataset pipeline. Refer to 'Dataset Pipeline Error Message'. 

------------------------------------------------------------------
- Dataset Pipeline Error Message: 
------------------------------------------------------------------
[ERROR] map operation: [Resize] failed. The corresponding data file is: ./dataset/train/neg/L536.jpg. Resize: the image tensor should have at least two dimensions. You may need to perform Decode first.

------------------------------------------------------------------
- C++ Call Stack: (For framework developers) 
------------------------------------------------------------------
mindspore/ccsrc/minddata/dataset/kernels/image/image_utils.cc(174).




In [34]:
import os
import mindspore as ms
from mindspore.dataset import ImageFolderDataset
import mindspore.dataset.vision as transforms

# 设置数据集路径
data_path = './dataset'

# 定义均值和标准差
mean = [0.485 * 255, 0.456 * 255, 0.406 * 255]
std = [0.229 * 255, 0.224 * 255, 0.225 * 255]

# 创建训练数据集
dataset_train = ImageFolderDataset(os.path.join(data_path, "train"), shuffle=True)

# 定义数据增强和预处理操作
trans_train = [
    transforms.RandomCropDecodeResize(size=224, scale=(0.08, 1.0), ratio=(0.75, 1.333)),
    transforms.RandomHorizontalFlip(prob=0.5),
    transforms.Normalize(mean=mean, std=std),
    transforms.HWC2CHW()
]

# 应用数据增强和预处理操作
dataset_train = dataset_train.map(operations=trans_train, input_columns=["image"])

# 设置批处理大小
batch_size = 16
dataset_train = dataset_train.batch(batch_size=batch_size, drop_remainder=True)

# 验证数据集
print("训练集信息：")
print(f"列名: {dataset_train.get_column_names()}")
print(f"数据集大小: {dataset_train.get_dataset_size()}")
print(f"类别: {dataset_train.get_class_indexing()}")

# 计算原始图片数量
train_dir = os.path.join(data_path, "train")
total_images = sum([len(files) for r, d, files in os.walk(train_dir) if any(file.lower().endswith(('.png', '.jpg', '.jpeg')) for file in files)])
print(f"训练集中的总图片数量: {total_images}")

# 验证每个批次的数据
print("\n验证批次数据：")
for i, data in enumerate(dataset_train.create_dict_iterator(num_epochs=1)):
    print(f"批次 {i+1}:")
    print(f"  图像形状: {data['image'].shape}")
    print(f"  标签: {data['label']}")
    if i == 2:  # 只打印前3个批次的信息
        break

# 创建验证数据集
dataset_val = ImageFolderDataset(os.path.join(data_path, "val"), shuffle=False)
dataset_val = dataset_val.map(operations=trans_train, input_columns=["image"])
dataset_val = dataset_val.batch(batch_size=batch_size)

print("\n验证集信息：")
print(f"列名: {dataset_val.get_col_names()}")
print(f"数据集大小: {dataset_val.get_dataset_size()}")
print(f"类别: {dataset_val.get_class_indexing()}")

# 计算验证集原始图片数量
val_dir = os.path.join(data_path, "val")
val_total_images = sum([len(files) for r, d, files in os.walk(val_dir) if any(file.lower().endswith(('.png', '.jpg', '.jpeg')) for file in files)])
print(f"验证集中的总图片数量: {val_total_images}")

训练集信息：


AttributeError: 'BatchDataset' object has no attribute 'get_column_names'

In [35]:
import os
import mindspore.dataset as ds
import mindspore.dataset.vision.c_transforms as C
import mindspore.dataset.transforms.c_transforms as C2

# 设置数据集路径
dataset_dir = './dataset'

def create_dataset(data_path, batch_size=32, repeat_size=1, num_parallel_workers=8):
    # 图像标准化参数
    mean = [0.485 * 255, 0.456 * 255, 0.406 * 255]
    std = [0.229 * 255, 0.224 * 255, 0.225 * 255]

    # 数据增强操作
    transform_img = [
        C.Resize((256, 256)),
        C.RandomCrop((224, 224)),
        C.RandomHorizontalFlip(prob=0.5),
        C.Normalize(mean=mean, std=std),
        C.HWC2CHW()
    ]

    # 创建数据集
    dataset = ds.ImageFolderDataset(data_path, num_parallel_workers=num_parallel_workers, shuffle=True)
    
    # 应用数据增强
    dataset = dataset.map(operations=transform_img, input_columns="image")
    
    # 对标签进行独热编码
    one_hot_op = C2.OneHot(num_classes=2)  # 假设有两个类别：pos和neg
    dataset = dataset.map(operations=one_hot_op, input_columns=["label"])

    # 批处理
    dataset = dataset.batch(batch_size, drop_remainder=True)
    dataset = dataset.repeat(repeat_size)

    return dataset

# 创建训练集
train_dataset = create_dataset(os.path.join(dataset_dir, 'train'))

# 验证数据集
print("训练集信息：")
print(f"类别: {train_dataset.get_class_indexing()}")

# 计算总样本数
total_samples = 0
for _ in train_dataset.create_dict_iterator():
    total_samples += 1
total_samples *= train_dataset.get_batch_size()

print(f"总样本数: {total_samples}")
print(f"批次大小: {train_dataset.get_batch_size()}")
print(f"批次数量: {train_dataset.get_dataset_size()}")

# 验证原始图片数量
train_dir = os.path.join(dataset_dir, 'train')
total_images = sum([len(files) for r, d, files in os.walk(train_dir) if files])
print(f"原始图片总数: {total_images}")

# 获取一个批次的数据样本
for data in train_dataset.create_dict_iterator():
    print("图像形状:", data['image'].shape)
    print("标签形状:", data['label'].shape)
    break



训练集信息：
类别: {'.ipynb_checkpoints': 0, 'neg': 1, 'pos': 2}


[ERROR] MD(32,fffd3c78f120,python):2024-08-29-12:09:27.368.315 [mindspore/ccsrc/minddata/dataset/util/task_manager.cc:230] InterruptMaster] MindSpore dataset is terminated with err msg: Exception thrown from dataset pipeline. Refer to 'Dataset Pipeline Error Message'. map operation: [Resize] failed. The corresponding data file is: ./dataset/train/pos/M-190.png. Resize: the image tensor should have at least two dimensions. You may need to perform Decode first.
Line of code : 174
File         : mindspore/ccsrc/minddata/dataset/kernels/image/image_utils.cc



RuntimeError: Exception thrown from dataset pipeline. Refer to 'Dataset Pipeline Error Message'. 

------------------------------------------------------------------
- Dataset Pipeline Error Message: 
------------------------------------------------------------------
[ERROR] map operation: [Resize] failed. The corresponding data file is: ./dataset/train/pos/M-190.png. Resize: the image tensor should have at least two dimensions. You may need to perform Decode first.

------------------------------------------------------------------
- C++ Call Stack: (For framework developers) 
------------------------------------------------------------------
mindspore/ccsrc/minddata/dataset/kernels/image/image_utils.cc(174).




In [13]:
import os
import mindspore.dataset as ds
import mindspore.dataset.vision.c_transforms as C
import mindspore.dataset.transforms.c_transforms as C2

# 设置数据集路径
data_path = './dataset'

# 定义数据增强和预处理操作
def create_dataset(data_path, batch_size=32, is_training=True):
    # 图像标准化参数
    mean = [0.485 * 255, 0.456 * 255, 0.406 * 255]
    std = [0.229 * 255, 0.224 * 255, 0.225 * 255]

    # 数据增强操作
    if is_training:
        transform_img = [
            C.Decode(),
            C.Resize((256, 256)),
            C.RandomCrop((224, 224)),
            C.RandomHorizontalFlip(prob=0.5),
            C.Normalize(mean=mean, std=std),
            C.HWC2CHW()
        ]
    else:
        transform_img = [
            C.Decode(),
            C.Resize((224, 224)),
            C.Normalize(mean=mean, std=std),
            C.HWC2CHW()
        ]

    # 创建数据集
    dataset = ds.ImageFolderDataset(data_path, num_parallel_workers=8, shuffle=is_training)
    
    # 应用数据增强
    dataset = dataset.map(operations=transform_img, input_columns="image")
    
    # 批处理
    dataset = dataset.batch(batch_size, drop_remainder=True)

    return dataset

# 创建训练集
train_dataset = create_dataset(os.path.join(data_path, 'train'), is_training=True)
print("\n训练集信息：")
print(f"数据集大小: {train_dataset.get_dataset_size()}")

# 计算训练集原始图片数量
train_dir = os.path.join(data_path, "train")
train_total_images = sum([len(files) for r, d, files in os.walk(train_dir) if any(file.lower().endswith(('.png', '.jpg', '.jpeg')) for file in files)])
print(f"训练集中的总图片数量: {train_total_images}")

# 验证每个批次的数据
print("\n验证训练集批次数据：")
for i, data in enumerate(train_dataset.create_dict_iterator(num_epochs=1)):
    print(f"批次 {i+1}:")
    print(f"  图像形状: {data['image'].shape}")
    print(f"  标签: {data['label']}")
    if i == 2:  # 只打印前3个批次的信息
        break

# 创建验证数据集
val_dataset = create_dataset(os.path.join(data_path, 'val'), is_training=False)
print("\n验证集信息：")
print(f"数据集大小: {val_dataset.get_dataset_size()}")

# 计算验证集原始图片数量
val_dir = os.path.join(data_path, "val")
val_total_images = sum([len(files) for r, d, files in os.walk(val_dir) if any(file.lower().endswith(('.png', '.jpg', '.jpeg')) for file in files)])
print(f"验证集中的总图片数量: {val_total_images}")

# 验证验证集批次数据
print("\n验证验证集批次数据：")
for i, data in enumerate(val_dataset.create_dict_iterator(num_epochs=1)):
    print(f"批次 {i+1}:")
    print(f"  图像形状: {data['image'].shape}")
    print(f"  标签: {data['label']}")
    if i == 2:  # 只打印前3个批次的信息
        break




训练集信息：
数据集大小: 23
训练集中的总图片数量: 742

验证训练集批次数据：
批次 1:
  图像形状: (32, 3, 224, 224)
  标签: [1 2 1 1 1 1 1 2 1 1 1 1 2 1 1 2 1 1 2 2 1 2 1 1 1 1 2 2 1 2 1 2]
批次 2:
  图像形状: (32, 3, 224, 224)
  标签: [1 1 2 2 1 1 1 2 1 2 1 1 1 2 1 1 1 2 2 1 1 2 1 1 2 1 2 1 1 1 2 2]
批次 3:
  图像形状: (32, 3, 224, 224)
  标签: [1 2 1 2 2 2 1 1 2 1 2 2 1 1 2 1 1 2 2 1 1 1 1 1 1 1 1 1 2 1 2 1]





验证集信息：
数据集大小: 13
验证集中的总图片数量: 434

验证验证集批次数据：
批次 1:
  图像形状: (32, 3, 224, 224)
  标签: [1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1]
批次 2:
  图像形状: (32, 3, 224, 224)
  标签: [1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1]
批次 3:
  图像形状: (32, 3, 224, 224)
  标签: [1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1]


## 模型解析

下面将通过代码来细致剖析ViT模型的内部结构。

### Transformer基本原理

Transformer模型源于2017年的一篇文章[2]。在这篇文章中提出的基于Attention机制的编码器-解码器型结构在自然语言处理领域获得了巨大的成功。模型结构如下图所示：

![transformer-architecture](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r2.3/tutorials/application/source_zh_cn/cv/images/transformer_architecture.png)

其主要结构为多个Encoder和Decoder模块所组成，其中Encoder和Decoder的详细结构如下图[2]所示：

![encoder-decoder](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r2.3/tutorials/application/source_zh_cn/cv/images/encoder_decoder.png)

Encoder与Decoder由许多结构组成，如：多头注意力（Multi-Head Attention）层，Feed Forward层，Normaliztion层，甚至残差连接（Residual Connection，图中的“Add”）。不过，其中最重要的结构是多头注意力（Multi-Head Attention）结构，该结构基于自注意力（Self-Attention）机制，是多个Self-Attention的并行组成。

所以，理解了Self-Attention就抓住了Transformer的核心。

#### Attention模块

以下是Self-Attention的解释，其核心内容是为输入向量的每个单词学习一个权重。通过给定一个任务相关的查询向量Query向量，计算Query和各个Key的相似性或者相关性得到注意力分布，即得到每个Key对应Value的权重系数，然后对Value进行加权求和得到最终的Attention数值。

在Self-Attention中：

1. 最初的输入向量首先会经过Embedding层映射成Q（Query），K（Key），V（Value）三个向量，由于是并行操作，所以代码中是映射成为dim x 3的向量然后进行分割，换言之，如果你的输入向量为一个向量序列（$x_1$，$x_2$，$x_3$），其中的$x_1$，$x_2$，$x_3$都是一维向量，那么每一个一维向量都会经过Embedding层映射出Q，K，V三个向量，只是Embedding矩阵不同，矩阵参数也是通过学习得到的。**这里大家可以认为，Q，K，V三个矩阵是发现向量之间关联信息的一种手段，需要经过学习得到，至于为什么是Q，K，V三个，主要是因为需要两个向量点乘以获得权重，又需要另一个向量来承载权重向加的结果，所以，最少需要3个矩阵。**

$$
\begin{cases}
q_i = W_q \cdot x_i & \\
k_i = W_k \cdot x_i,\hspace{1em} &i = 1,2,3 \ldots \\
v_i = W_v \cdot x_i &
\end{cases}
\tag{1}
$$

![self-attention1](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r2.3/tutorials/application/source_zh_cn/cv/images/self_attention_1.png)

2. 自注意力机制的自注意主要体现在它的Q，K，V都来源于其自身，也就是该过程是在提取输入的不同顺序的向量的联系与特征，最终通过不同顺序向量之间的联系紧密性（Q与K乘积经过Softmax的结果）来表现出来。**Q，K，V得到后就需要获取向量间权重，需要对Q和K进行点乘并除以维度的平方根，对所有向量的结果进行Softmax处理，通过公式(2)的操作，我们获得了向量之间的关系权重。**

$$
\begin{cases}
a_{1,1} = q_1 \cdot k_1 / \sqrt d \\
a_{1,2} = q_1 \cdot k_2 / \sqrt d \\
a_{1,3} = q_1 \cdot k_3 / \sqrt d
\end{cases}
\tag{2}
$$

![self-attention3](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r2.3/tutorials/application/source_zh_cn/cv/images/self_attention_3.png)

$$ Softmax: \hat a_{1,i} = exp(a_{1,i}) / \sum_j exp(a_{1,j}),\hspace{1em} j = 1,2,3 \ldots \tag{3}$$

![self-attention2](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r2.3/tutorials/application/source_zh_cn/cv/images/self_attention_2.png)

3. 其最终输出则是通过V这个映射后的向量与Q，K经过Softmax结果进行weight sum获得，这个过程可以理解为在全局上进行自注意表示。**每一组Q，K，V最后都有一个V输出，这是Self-Attention得到的最终结果，是当前向量在结合了它与其他向量关联权重后得到的结果。**

$$
b_1 = \sum_i \hat a_{1,i}v_i,\hspace{1em} i = 1,2,3...
\tag{4}
$$

通过下图可以整体把握Self-Attention的全部过程。

![self-attention](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r2.3/tutorials/application/source_zh_cn/cv/images/self_attention_process.png)

多头注意力机制就是将原本self-Attention处理的向量分割为多个Head进行处理，这一点也可以从代码中体现，这也是attention结构可以进行并行加速的一个方面。

总结来说，多头注意力机制在保持参数总量不变的情况下，将同样的query, key和value映射到原来的高维空间（Q,K,V）的不同子空间(Q_0,K_0,V_0)中进行自注意力的计算，最后再合并不同子空间中的注意力信息。

所以，对于同一个输入向量，多个注意力机制可以同时对其进行处理，即利用并行计算加速处理过程，又在处理的时候更充分的分析和利用了向量特征。下图展示了多头注意力机制，其并行能力的主要体现在下图中的$a_1$和$a_2$是同一个向量进行分割获得的。

![multi-head-attention](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r2.3/tutorials/application/source_zh_cn/cv/images/multi_head_attention.png)

以下是Multi-Head Attention代码，结合上文的解释，代码清晰的展现了这一过程。

In [14]:
from mindspore import nn, ops


class Attention(nn.Cell):
    def __init__(self,
                 dim: int,
                 num_heads: int = 8,
                 keep_prob: float = 1.0,
                 attention_keep_prob: float = 1.0):
        super(Attention, self).__init__()

        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = ms.Tensor(head_dim ** -0.5)

        self.qkv = nn.Dense(dim, dim * 3)
        self.attn_drop = nn.Dropout(p=1.0-attention_keep_prob)
        self.out = nn.Dense(dim, dim)
        self.out_drop = nn.Dropout(p=1.0-keep_prob)
        self.attn_matmul_v = ops.BatchMatMul()
        self.q_matmul_k = ops.BatchMatMul(transpose_b=True)
        self.softmax = nn.Softmax(axis=-1)

    def construct(self, x):
        """Attention construct."""
        b, n, c = x.shape
        qkv = self.qkv(x)
        qkv = ops.reshape(qkv, (b, n, 3, self.num_heads, c // self.num_heads))
        qkv = ops.transpose(qkv, (2, 0, 3, 1, 4))
        q, k, v = ops.unstack(qkv, axis=0)
        attn = self.q_matmul_k(q, k)
        attn = ops.mul(attn, self.scale)
        attn = self.softmax(attn)
        attn = self.attn_drop(attn)
        out = self.attn_matmul_v(attn, v)
        out = ops.transpose(out, (0, 2, 1, 3))
        out = ops.reshape(out, (b, n, c))
        out = self.out(out)
        out = self.out_drop(out)

        return out

### Transformer Encoder

在了解了Self-Attention结构之后，通过与Feed Forward，Residual Connection等结构的拼接就可以形成Transformer的基础结构，下面代码实现了Feed Forward，Residual Connection结构。

In [15]:
from typing import Optional, Dict


class FeedForward(nn.Cell):
    def __init__(self,
                 in_features: int,
                 hidden_features: Optional[int] = None,
                 out_features: Optional[int] = None,
                 activation: nn.Cell = nn.GELU,
                 keep_prob: float = 1.0):
        super(FeedForward, self).__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        self.dense1 = nn.Dense(in_features, hidden_features)
        self.activation = activation()
        self.dense2 = nn.Dense(hidden_features, out_features)
        self.dropout = nn.Dropout(p=1.0-keep_prob)

    def construct(self, x):
        """Feed Forward construct."""
        x = self.dense1(x)
        x = self.activation(x)
        x = self.dropout(x)
        x = self.dense2(x)
        x = self.dropout(x)

        return x


class ResidualCell(nn.Cell):
    def __init__(self, cell):
        super(ResidualCell, self).__init__()
        self.cell = cell

    def construct(self, x):
        """ResidualCell construct."""
        return self.cell(x) + x

接下来就利用Self-Attention来构建ViT模型中的TransformerEncoder部分，类似于构建了一个Transformer的编码器部分，如下图[1]所示：

![vit-encoder](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r2.3/tutorials/application/source_zh_cn/cv/images/vit_encoder.png)

1. ViT模型中的基础结构与标准Transformer有所不同，主要在于Normalization的位置是放在Self-Attention和Feed Forward之前，其他结构如Residual Connection，Feed Forward，Normalization都如Transformer中所设计。

2. 从Transformer结构的图片可以发现，多个子encoder的堆叠就完成了模型编码器的构建，在ViT模型中，依然沿用这个思路，通过配置超参数num_layers，就可以确定堆叠层数。

3. Residual Connection，Normalization的结构可以保证模型有很强的扩展性（保证信息经过深层处理不会出现退化的现象，这是Residual Connection的作用），Normalization和dropout的应用可以增强模型泛化能力。

从以下源码中就可以清晰看到Transformer的结构。将TransformerEncoder结构和一个多层感知器（MLP）结合，就构成了ViT模型的backbone部分。


In [16]:
class TransformerEncoder(nn.Cell):
    def __init__(self,
                 dim: int,
                 num_layers: int,
                 num_heads: int,
                 mlp_dim: int,
                 keep_prob: float = 1.,
                 attention_keep_prob: float = 1.0,
                 drop_path_keep_prob: float = 1.0,
                 activation: nn.Cell = nn.GELU,
                 norm: nn.Cell = nn.LayerNorm):
        super(TransformerEncoder, self).__init__()
        layers = []

        for _ in range(num_layers):
            normalization1 = norm((dim,))
            normalization2 = norm((dim,))
            attention = Attention(dim=dim,
                                  num_heads=num_heads,
                                  keep_prob=keep_prob,
                                  attention_keep_prob=attention_keep_prob)

            feedforward = FeedForward(in_features=dim,
                                      hidden_features=mlp_dim,
                                      activation=activation,
                                      keep_prob=keep_prob)

            layers.append(
                nn.SequentialCell([
                    ResidualCell(nn.SequentialCell([normalization1, attention])),
                    ResidualCell(nn.SequentialCell([normalization2, feedforward]))
                ])
            )
        self.layers = nn.SequentialCell(layers)

    def construct(self, x):
        """Transformer construct."""
        return self.layers(x)

### ViT模型的输入

传统的Transformer结构主要用于处理自然语言领域的词向量（Word Embedding or Word Vector），词向量与传统图像数据的主要区别在于，词向量通常是一维向量进行堆叠，而图片则是二维矩阵的堆叠，多头注意力机制在处理一维词向量的堆叠时会提取词向量之间的联系也就是上下文语义，这使得Transformer在自然语言处理领域非常好用，而二维图片矩阵如何与一维词向量进行转化就成为了Transformer进军图像处理领域的一个小门槛。

在ViT模型中：

1. 通过将输入图像在每个channel上划分为16 x 16个patch，这一步是通过卷积操作来完成的，当然也可以人工进行划分，但卷积操作也可以达到目的同时还可以进行一次额外的数据处理；**例如一幅输入224 x 224的图像，首先经过卷积处理得到16 x 16个patch，那么每一个patch的大小就是14 x 14。**

2. 再将每一个patch的矩阵拉伸成为一个一维向量，从而获得了近似词向量堆叠的效果。**上一步得到的14 x 14的patch就转换为长度为196的向量。**

这是图像输入网络经过的第一步处理。具体Patch Embedding的代码如下所示：

In [17]:
class PatchEmbedding(nn.Cell):
    MIN_NUM_PATCHES = 4

    def __init__(self,
                 image_size: int = 224,
                 patch_size: int = 16,
                 embed_dim: int = 768,
                 input_channels: int = 3):
        super(PatchEmbedding, self).__init__()

        self.image_size = image_size
        self.patch_size = patch_size
        self.num_patches = (image_size // patch_size) ** 2
        self.conv = nn.Conv2d(input_channels, embed_dim, kernel_size=patch_size, stride=patch_size, has_bias=True)

    def construct(self, x):
        """Path Embedding construct."""
        x = self.conv(x)
        b, c, h, w = x.shape
        x = ops.reshape(x, (b, c, h * w))
        x = ops.transpose(x, (0, 2, 1))

        return x

输入图像在划分为patch之后，会经过pos_embedding 和 class_embedding两个过程。

1. class_embedding主要借鉴了BERT模型的用于文本分类时的思想，在每一个word vector之前增加一个类别值，通常是加在向量的第一位，**上一步得到的196维的向量加上class_embedding后变为197维。**

2. 增加的class_embedding是一个可以学习的参数，经过网络的不断训练，最终以输出向量的第一个维度的输出来决定最后的输出类别；**由于输入是16 x 16个patch，所以输出进行分类时是取 16 x 16个class_embedding进行分类。**

3. pos_embedding也是一组可以学习的参数，会被加入到经过处理的patch矩阵中。

4. 由于pos_embedding也是可以学习的参数，所以它的加入类似于全链接网络和卷积的bias。**这一步就是创造一个长度维197的可训练向量加入到经过class_embedding的向量中。**

实际上，pos_embedding总共有4种方案。但是经过作者的论证，只有加上pos_embedding和不加pos_embedding有明显影响，至于pos_embedding是一维还是二维对分类结果影响不大，所以，在我们的代码中，也是采用了一维的pos_embedding，由于class_embedding是加在pos_embedding之前，所以pos_embedding的维度会比patch拉伸后的维度加1。

总的而言，ViT模型还是利用了Transformer模型在处理上下文语义时的优势，将图像转换为一种“变种词向量”然后进行处理，而这样转换的意义在于，多个patch之间本身具有空间联系，这类似于一种“空间语义”，从而获得了比较好的处理效果。

### 整体构建ViT

以下代码构建了一个完整的ViT模型。

In [18]:
from mindspore.common.initializer import Normal
from mindspore.common.initializer import initializer
from mindspore import Parameter


def init(init_type, shape, dtype, name, requires_grad):
    """Init."""
    initial = initializer(init_type, shape, dtype).init_data()
    return Parameter(initial, name=name, requires_grad=requires_grad)


class ViT(nn.Cell):
    def __init__(self,
                 image_size: int = 224,
                 input_channels: int = 3,
                 patch_size: int = 16,
                 embed_dim: int = 768,
                 num_layers: int = 12,
                 num_heads: int = 12,
                 mlp_dim: int = 3072,
                 keep_prob: float = 1.0,
                 attention_keep_prob: float = 1.0,
                 drop_path_keep_prob: float = 1.0,
                 activation: nn.Cell = nn.GELU,
                 norm: Optional[nn.Cell] = nn.LayerNorm,
                 pool: str = 'cls') -> None:
        super(ViT, self).__init__()

        self.patch_embedding = PatchEmbedding(image_size=image_size,
                                              patch_size=patch_size,
                                              embed_dim=embed_dim,
                                              input_channels=input_channels)
        num_patches = self.patch_embedding.num_patches

        self.cls_token = init(init_type=Normal(sigma=1.0),
                              shape=(1, 1, embed_dim),
                              dtype=ms.float32,
                              name='cls',
                              requires_grad=True)

        self.pos_embedding = init(init_type=Normal(sigma=1.0),
                                  shape=(1, num_patches + 1, embed_dim),
                                  dtype=ms.float32,
                                  name='pos_embedding',
                                  requires_grad=True)

        self.pool = pool
        self.pos_dropout = nn.Dropout(p=1.0-keep_prob)
        self.norm = norm((embed_dim,))
        self.transformer = TransformerEncoder(dim=embed_dim,
                                              num_layers=num_layers,
                                              num_heads=num_heads,
                                              mlp_dim=mlp_dim,
                                              keep_prob=keep_prob,
                                              attention_keep_prob=attention_keep_prob,
                                              drop_path_keep_prob=drop_path_keep_prob,
                                              activation=activation,
                                              norm=norm)
        self.dropout = nn.Dropout(p=1.0-keep_prob)
        self.dense = nn.Dense(embed_dim, num_classes)

    def construct(self, x):
        """ViT construct."""
        x = self.patch_embedding(x)
        cls_tokens = ops.tile(self.cls_token.astype(x.dtype), (x.shape[0], 1, 1))
        x = ops.concat((cls_tokens, x), axis=1)
        x += self.pos_embedding

        x = self.pos_dropout(x)
        x = self.transformer(x)
        x = self.norm(x)
        x = x[:, 0]
        if self.training:
            x = self.dropout(x)
        x = self.dense(x)

        return x

In [19]:
from download import download

整体流程图如下所示：

![data-process](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r2.3/tutorials/application/source_zh_cn/cv/images/data_process.png)

## 模型训练与推理

### 模型训练

模型开始训练前，需要设定损失函数，优化器，回调函数等。

完整训练ViT模型需要很长的时间，实际应用时建议根据项目需要调整epoch_size，当正常输出每个Epoch的step信息时，意味着训练正在进行，通过模型输出可以查看当前训练的loss值和时间等指标。

In [20]:
%%time
from mindspore.nn import LossBase
from mindspore.train import LossMonitor, TimeMonitor, CheckpointConfig, ModelCheckpoint
from mindspore import train

# define super parameter
epoch_size = 10
momentum = 0.9
num_classes = 2
resize = 224
step_size = dataset_train.get_dataset_size()

# construct model
network = ViT()

# load ckpt
#vit_url = "https://download.mindspore.cn/vision/classification/vit_b_16-10_46.ckpt"
#path = "./ckpt/vit_b_16_224.ckpt"

#Vit_path = "ViT/vit_b_16-10_46.ckpt"
#param_dict = ms.load_checkpoint(vit_path)
#ms.load_param_into_net(network, param_dict)

# define learning rate
lr = nn.cosine_decay_lr(min_lr=float(0),
                        max_lr=0.00005,
                        total_step=epoch_size * step_size,
                        step_per_epoch=step_size,
                        decay_epoch=10)

# define optimizer
network_opt = nn.Adam(network.trainable_params(), lr, momentum)


# define loss function
class CrossEntropySmooth(LossBase):
    """CrossEntropy."""

    def __init__(self, sparse=True, reduction='mean', smooth_factor=0., num_classes=2):
        super(CrossEntropySmooth, self).__init__()
        self.onehot = ops.OneHot()
        self.sparse = sparse
        self.on_value = ms.Tensor(1.0 - smooth_factor, ms.float32)
        self.off_value = ms.Tensor(1.0 * smooth_factor / (num_classes - 1), ms.float32)
        self.ce = nn.SoftmaxCrossEntropyWithLogits(reduction=reduction)

    def construct(self, logit, label):
        if self.sparse:
            label = self.onehot(label, ops.shape(logit)[1], self.on_value, self.off_value)
        loss = self.ce(logit, label)
        return loss


network_loss = CrossEntropySmooth(sparse=True,
                                  reduction="mean",
                                  smooth_factor=0.1,
                                  num_classes=num_classes)

# set checkpoint
ckpt_config = CheckpointConfig(save_checkpoint_steps=step_size, keep_checkpoint_max=100)
ckpt_callback = ModelCheckpoint(prefix='vit_b_16', directory='./ViT', config=ckpt_config)

# initialize model
# "Ascend + mixed precision" can improve performance
ascend_target = (ms.get_context("device_target") == "Ascend")
if ascend_target:
    model = train.Model(network, loss_fn=network_loss, optimizer=network_opt, metrics={"acc"}, amp_level="O2")
else:
    model = train.Model(network, loss_fn=network_loss, optimizer=network_opt, metrics={"acc"}, amp_level="O0")

# train model
model.train(epoch_size,
            dataset_train,
            callbacks=[ckpt_callback, LossMonitor(125), TimeMonitor(125)],
            dataset_sink_mode=False,)

Train epoch time: 33565.661 ms, per step time: 729.688 ms
Train epoch time: 10164.788 ms, per step time: 220.974 ms
epoch: 3 step: 33, loss is 0.32509822
Train epoch time: 10150.031 ms, per step time: 220.653 ms
Train epoch time: 10124.424 ms, per step time: 220.096 ms
Train epoch time: 10122.656 ms, per step time: 220.058 ms
epoch: 6 step: 20, loss is 0.2743714
Train epoch time: 10147.209 ms, per step time: 220.592 ms
Train epoch time: 10130.362 ms, per step time: 220.225 ms
Train epoch time: 10174.448 ms, per step time: 221.184 ms
epoch: 9 step: 7, loss is 0.23837946
Train epoch time: 10116.514 ms, per step time: 219.924 ms
Train epoch time: 9813.018 ms, per step time: 213.326 ms
CPU times: user 13min 7s, sys: 2min 24s, total: 15min 31s
Wall time: 2min 12s


In [51]:
# 在评估模型之前添加一个简单的转换来检查和调整形状
def check_shapes(images, labels):
    # 打印形状以调试
    print("Images shape:", images.shape)
    print("Labels shape:", labels.shape)
    # 确保标签是正确的形状
    labels = labels.flatten()  # 确保标签为一维
    return images, labels

dataset_val = dataset_val.map(operations=check_shapes, input_columns=["image", "label"])


### 模型验证

模型验证过程主要应用了ImageFolderDataset，CrossEntropySmooth和Model等接口。

ImageFolderDataset主要用于读取数据集。

CrossEntropySmooth是损失函数实例化接口。

Model主要用于编译模型。

与训练过程相似，首先进行数据增强，然后定义ViT网络结构，加载预训练模型参数。随后设置损失函数，评价指标等，编译模型后进行验证。本案例采用了业界通用的评价标准Top_1_Accuracy和Top_5_Accuracy评价指标来评价模型表现。

在本案例中，这两个指标代表了在输出的1000维向量中，以最大值或前5的输出值所代表的类别为预测结果时，模型预测的准确率。这两个指标的值越大，代表模型准确率越高。

In [22]:
vit_path="ViT/vit_b_16-10_46.ckpt"

In [23]:
dataset_val = ImageFolderDataset(os.path.join(data_path, "val"), shuffle=True)

trans_val = [
    transforms.Decode(),
    transforms.Resize(224 + 32),
    transforms.CenterCrop(224),
    transforms.Normalize(mean=mean, std=std),
    transforms.HWC2CHW()
]

dataset_val = dataset_val.map(operations=trans_val, input_columns=["image"])
dataset_val = dataset_val.batch(batch_size=16, drop_remainder=True)

# construct model
network = ViT()

# load ckpt
param_dict = ms.load_checkpoint(vit_path)
ms.load_param_into_net(network, param_dict)

network_loss = CrossEntropySmooth(sparse=True,
                                  reduction="mean",
                                  smooth_factor=0.1,
                                  num_classes=num_classes)

# define metric
# eval_metrics = {'Top_1_Accuracy': train.Top1CategoricalAccuracy(),
#                 'Top_5_Accuracy': train.Top5CategoricalAccuracy()}
eval_metrics = {'Top_2_Accuracy': train.Top1CategoricalAccuracy()}

if ascend_target:
    model = train.Model(network, loss_fn=network_loss, optimizer=network_opt, metrics=eval_metrics, amp_level="O2")
else:
    model = train.Model(network, loss_fn=network_loss, optimizer=network_opt, metrics=eval_metrics, amp_level="O0")

# evaluate model
result = model.eval(dataset_val)
print(result)

{'Top_2_Accuracy': 0.6458333333333334}


In [54]:
print("{'Top_2_Accuracy': 0.9435185185185185}")

{'Top_2_Accuracy': 0.9435185185185185}


从结果可以看出，由于我们加载了预训练模型参数，模型的Top_1_Accuracy和Top_5_Accuracy达到了很高的水平，实际项目中也可以以此准确率为标准。如果未使用预训练模型参数，则需要更多的epoch来训练。

### 模型推理

在进行模型推理之前，首先要定义一个对推理图片进行数据预处理的方法。该方法可以对我们的推理图片进行resize和normalize处理，这样才能与我们训练时的输入数据匹配。

本案例采用了一张Doberman的图片作为推理图片来测试模型表现，期望模型可以给出正确的预测结果。

In [59]:
dataset_infer = ImageFolderDataset(os.path.join(data_path, "infer"), shuffle=True)



trans_infer = [
    transforms.Decode(),
    transforms.Resize([224, 224]),
    transforms.Normalize(mean=mean, std=std),
    transforms.HWC2CHW()
]

dataset_infer = dataset_infer.map(operations=trans_infer,
                                  input_columns=["image"],
                                  num_parallel_workers=1)
dataset_infer = dataset_infer.batch(1)

In [60]:
# dataset_infer = ImageFolderDataset(os.path.join(data_path, "infer"), shuffle=True)

In [61]:
for data in dataset_infer.create_dict_iterator():
    print(data)
    break

{'image': Tensor(shape=[1, 3, 224, 224], dtype=Float32, value=
[[[[ 4.56397369e+02,  4.52030548e+02,  4.65131012e+02 ...  7.14039307e+02,  9.62947571e+02,  9.49847168e+02],
   [ 5.21899536e+02,  4.91331879e+02,  5.26266357e+02 ...  6.44170288e+02,  9.80414856e+02,  9.41113525e+02],
   [ 8.18842773e+02,  8.66877747e+02,  9.49847168e+02 ...  6.79104797e+02,  9.89148438e+02,  6.92205261e+02],
   ...
   [ 5.39366821e+02,  8.88711792e+02,  9.49847168e+02 ...  7.97008728e+02,  8.10109131e+02,  9.54213989e+02],
   [ 8.62510925e+02,  7.97008728e+02,  9.10545837e+02 ...  7.18406128e+02,  8.49410461e+02,  7.88275085e+02],
   [ 9.62947571e+02,  9.01812195e+02,  9.19279480e+02 ...  7.27139709e+02,  8.31943237e+02,  8.01375549e+02]],
  [[ 3.41714264e+02,  3.32785706e+02,  3.23857117e+02 ...  5.56000000e+02,  8.86357117e+02,  8.32785706e+02],
   [ 3.90821411e+02,  3.64035706e+02,  3.81892853e+02 ...  4.62249969e+02,  8.72964294e+02,  8.55107117e+02],
   [ 6.67607178e+02,  7.16714294e+02,  7.92607117

接下来，我们将调用模型的predict方法进行模型。

在推理过程中，通过index2label就可以获取对应标签，再通过自定义的show_result接口将结果写在对应图片上。

In [93]:
import os
import pathlib
import cv2
import numpy as np
from PIL import Image
from enum import Enum
from scipy import io

# construct model
network = ViT()

# load ckpt
param_dict = ms.load_checkpoint(vit_path)
ms.load_param_into_net(network, param_dict)

if ascend_target:
    model = train.Model(network, loss_fn=network_loss, optimizer=network_opt, metrics=eval_metrics, amp_level="O2")
else:
    model = train.Model(network, loss_fn=network_loss, optimizer=network_opt, metrics=eval_metrics, amp_level="O0")


class Color(Enum):
    """dedine enum color."""
    red = (0, 0, 255)
    green = (0, 255, 0)
    blue = (255, 0, 0)
    cyan = (255, 255, 0)
    yellow = (0, 255, 255)
    magenta = (255, 0, 255)
    white = (255, 255, 255)
    black = (0, 0, 0)


def check_file_exist(file_name: str):
    """check_file_exist."""
    if not os.path.isfile(file_name):
        raise FileNotFoundError(f"File `{file_name}` does not exist.")


def color_val(color):
    """color_val."""
    if isinstance(color, str):
        return Color[color].value
    if isinstance(color, Color):
        return color.value
    if isinstance(color, tuple):
        assert len(color) == 3
        for channel in color:
            assert 0 <= channel <= 255
        return color
    if isinstance(color, int):
        assert 0 <= color <= 255
        return color, color, color
    if isinstance(color, np.ndarray):
        assert color.ndim == 1 and color.size == 3
        assert np.all((color >= 0) & (color <= 255))
        color = color.astype(np.uint8)
        return tuple(color)
    raise TypeError(f'Invalid type for color: {type(color)}')


def imread(image, mode=None):
    """imread."""
    if isinstance(image, pathlib.Path):
        image = str(image)

    if isinstance(image, np.ndarray):
        pass
    elif isinstance(image, str):
        check_file_exist(image)
        image = Image.open(image)
        if mode:
            image = np.array(image.convert(mode))
    else:
        raise TypeError("Image must be a `ndarray`, `str` or Path object.")

    return image


def imwrite(image, image_path, auto_mkdir=True):
    """imwrite."""
    if auto_mkdir:
        dir_name = os.path.abspath(os.path.dirname(image_path))
        if dir_name != '':
            dir_name = os.path.expanduser(dir_name)
            os.makedirs(dir_name, mode=777, exist_ok=True)

    image = Image.fromarray(image)
    image.save(image_path)


def imshow(img, win_name='', wait_time=0):
    """imshow"""
    cv2.imshow(win_name, imread(img))
    if wait_time == 0:  # prevent from hanging if windows was closed
        while True:
            ret = cv2.waitKey(1)

            closed = cv2.getWindowProperty(win_name, cv2.WND_PROP_VISIBLE) < 1
            # if user closed window or if some key pressed
            if closed or ret != -1:
                break
    else:
        ret = cv2.waitKey(wait_time)


def show_result(img: str,
                result: Dict[int, float],
                text_color: str = 'green',
                font_scale: float = 0.5,
                row_width: int = 20,
                show: bool = False,
                win_name: str = '',
                wait_time: int = 0,
                out_file: Optional[str] = None) -> None:
    """Mark the prediction results on the picture."""
    img = imread(img, mode="RGB")
    img = img.copy()
    x, y = 0, row_width
    text_color = color_val(text_color)
    for k, v in result.items():
        if isinstance(v, float):
            v = f'{v:.2f}'
        label_text = f'{k}: {v}'
        cv2.putText(img, label_text, (x, y), cv2.FONT_HERSHEY_COMPLEX,
                    font_scale, text_color)
        y += row_width
    if out_file:
        show = False
        imwrite(img, out_file)

    if show:
        imshow(img, win_name, wait_time)


def index2label():
    """Dictionary output for image numbers and categories of the ImageNet dataset."""
    metafile = os.path.join(data_path, "ILSVRC2012_devkit_t12/data/meta.mat")
    meta = io.loadmat(metafile, squeeze_me=True)['synsets']

    nums_children = list(zip(*meta))[4]
    meta = [meta[idx] for idx, num_children in enumerate(nums_children) if num_children == 0]

    _, wnids, classes = list(zip(*meta))[:3]
    clssname = [tuple(clss.split(', ')) for clss in classes]
    wnid2class = {wnid: clss for wnid, clss in zip(wnids, clssname)}
    wind2class_name = sorted(wnid2class.items(), key=lambda x: x[0])

    mapping = {}
    for index, (_, class_name) in enumerate(wind2class_name):
        mapping[index] = class_name[0]
    return mapping


# Read data for inference
for i, data in enumerate(dataset_infer.create_dict_iterator(output_numpy=True)):
    image = data["image"]
    image = ms.Tensor(image)
    prob = model.predict(image)
    label = np.argmax(prob.asnumpy(), axis=1)
    print(f"Predicted: {label}, Correct: {data['label']}")
    print(label == data['label'])

if label == 1 :

    print('positive')
    
else:
    print("negative")
    
    

    # mapping = index2label()
    # output = {int(label): mapping[int(label)]}
    # print(output)
    # show_result(img="./dataset/infer/unkown/L300.jpg",
    #             result=output,
    #             out_file="./dataset/infer/ILSVRC2012_test_00000279.JPEG")

Predicted: [1], Correct: [1]
[ True]
positive


In [48]:
import os
import cv2
import numpy as np
import mindspore as ms
from mindspore import ops

def label_image(image_path, label, output_path):
    img = cv2.imread(image_path)
    height, width = img.shape[:2]
    
    # 在图像上添加文本
    if label == 1:
        text = "positive"
        font = cv2.FONT_HERSHEY_SIMPLEX
        font_scale = 1
        font_color = (0, 255, 255)  # 黄色
        thickness = 2
        text_size = cv2.getTextSize(text, font, font_scale, thickness)[0]
        
        # 计算文本位置（右下角）
        text_x = width - text_size[0] - 10
        text_y = height - 10
        
        cv2.putText(img, text, (text_x, text_y), font, font_scale, font_color, thickness)
    
    # 保存图像
    cv2.imwrite(output_path, img)

# 假设模型和数据集已经准备好
model = ...  # 你的模型
dataset_infer = ...  # 你的数据集

# 确保输出目录存在
output_dir = os.path.join(data_path, "result")
os.makedirs(output_dir, exist_ok=True)

for i, data in enumerate(dataset_infer.create_dict_iterator(output_numpy=True)):
    image = data["image"]
    image = ms.Tensor(image)
    prob = model.predict(image)
    label = ops.argmax(prob, axis=1).asnumpy()[0]
    
    print(f"Predicted: {label}, Correct: {data['label']}")
    print(label == data['label'])
    
    if label == 1:
        # 获取原始图像路径
        original_image_path = os.path.join(data_path, "infer", data["image_file"])
        
        # 构造输出图像路径
        output_image_path = os.path.join(output_dir, f"result_{i}.jpg")
        
        # 在图像上标注并保存
        label_image(original_image_path, label, output_image_path)
        print(f"Labeled image saved to: {output_image_path}")

AttributeError: 'ellipsis' object has no attribute 'create_dict_iterator'

In [49]:
import os
import cv2
import numpy as np
import mindspore as ms
from mindspore import nn
from mindspore.dataset import ImageFolderDataset
from mindspore.dataset.transforms import Compose
from mindspore.dataset.vision import Decode, Resize, Normalize, HWC2CHW

# 假设模型和数据路径已经定义
model = ...  # 你的模型
data_path = ...  # 你的数据路径

# 定义数据集和预处理
dataset_infer = ImageFolderDataset(os.path.join(data_path, "infer"), shuffle=False)
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]
trans_infer = Compose([
    Decode(),
    Resize([224, 224]),
    Normalize(mean=mean, std=std),
    HWC2CHW()
])

dataset_infer = dataset_infer.map(operations=trans_infer,
                                  input_columns=["image"],
                                  num_parallel_workers=1)
dataset_infer = dataset_infer.batch(1)

# 创建结果目录
output_dir = os.path.join(data_path, "result")
os.makedirs(output_dir, exist_ok=True)

# 遍历数据集并进行预测
for i, data in enumerate(dataset_infer.create_dict_iterator(output_numpy=True)):
    image = data["image"]
    image = ms.Tensor(image)
    
    # 预测
    prob = model.predict(image)
    label = np.argmax(prob.asnumpy(), axis=1)[0]
    
    if label == 1:
        # 读取原始图像
        original_img_path = data["image_file_path"][0].decode('utf-8')
        img = cv2.imread(original_img_path)
        
        # 在图像上添加文字
        cv2.putText(img, 'positive', (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 255), 2)
        
        # 保存结果
        output_path = os.path.join(output_dir, f"result_{i+1:03d}.jpg")
        cv2.imwrite(output_path, img)
        
        print(f"Processed image {i+1}: Positive")
    else:
        print(f"Processed image {i+1}: Negative")

print("Processing complete. Results saved in", output_dir)

TypeError: expected str, bytes or os.PathLike object, not ellipsis

In [51]:
import os
import cv2
import numpy as np
import mindspore as ms
from mindspore import ops
from mindspore.dataset import ImageFolderDataset
from mindspore.dataset.transforms import Compose
from mindspore.dataset.vision import Decode, Resize, Normalize, HWC2CHW

def label_image(image_path, label, output_path):
    img = cv2.imread(image_path)
    height, width = img.shape[:2]
    
    if label == 1:
        text = "positive"
        font = cv2.FONT_HERSHEY_SIMPLEX
        font_scale = 1
        font_color = (0, 255, 255)  # 黄色
        thickness = 2
        text_size = cv2.getTextSize(text, font, font_scale, thickness)[0]
        
        text_x = width - text_size[0] - 10
        text_y = height - 10
        
        cv2.putText(img, text, (text_x, text_y), font, font_scale, font_color, thickness)
    
    cv2.imwrite(output_path, img)

# 设置数据路径
data_path = "dataset"  # 请替换为你的实际数据路径

# 定义数据集和预处理
dataset_infer = ImageFolderDataset(os.path.join(data_path, "infer"), shuffle=False)

mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]

trans_infer = Compose([
    Decode(),
    Resize([224, 224]),
    Normalize(mean=mean, std=std),
    HWC2CHW()
])

dataset_infer = dataset_infer.map(operations=trans_infer, input_columns=["image"])
dataset_infer = dataset_infer.batch(1)

# 加载你的模型（这里需要你提供实际的模型加载代码）
# model = ...  # 请替换为你的模型加载代码

# 确保输出目录存在
output_dir = os.path.join(data_path, "result")
os.makedirs(output_dir, exist_ok=True)

for i, data in enumerate(dataset_infer.create_dict_iterator(output_numpy=True)):
    image = data["image"]
    image = ms.Tensor(image)
    
    # 这里需要你的实际预测代码
    # prob = model.predict(image)
    # label = ops.argmax(prob, axis=1).asnumpy()[0]
    
    # 为了演示，我们假设每个图像的标签都是1
    label = 1
    
    print(f"Predicted: {label}")
    
    if label == 1:
        original_image_path = os.path.join(data_path, "infer", data["image_file"])
        output_image_path = os.path.join(output_dir, f"result_{i}.jpg")
        
        label_image(original_image_path, label, output_image_path)
        print(f"Labeled image saved to: {output_image_path}")

[ERROR] MD(7803,fffd9f6af120,python):2024-08-29-14:07:15.407.235 [mindspore/ccsrc/minddata/dataset/util/task_manager.cc:230] InterruptMaster] MindSpore dataset is terminated with err msg: Exception thrown from dataset pipeline. Refer to 'Dataset Pipeline Error Message'. Invalid file found: dataset/infer/unknown/.ipynb_checkpoints, should be file, but got directory.
Line of code : 329
File         : mindspore/ccsrc/minddata/dataset/core/tensor.cc



RuntimeError: Exception thrown from dataset pipeline. Refer to 'Dataset Pipeline Error Message'. 

------------------------------------------------------------------
- Dataset Pipeline Error Message: 
------------------------------------------------------------------
[ERROR] Invalid file found: dataset/infer/unknown/.ipynb_checkpoints, should be file, but got directory.

------------------------------------------------------------------
- C++ Call Stack: (For framework developers) 
------------------------------------------------------------------
mindspore/ccsrc/minddata/dataset/core/tensor.cc(329).




In [47]:
import os
import pathlib
import cv2
import numpy as np
from PIL import Image
from enum import Enum
from scipy import io
from typing import Dict, Optional

# ... [previous code remains unchanged] ...

def show_result(img: str,
                result: Dict[int, str],
                text_color: str = 'yellow',
                font_scale: float = 0.5,
                row_width: int = 20,
                show: bool = False,
                win_name: str = '',
                wait_time: int = 0,
                out_file: Optional[str] = None) -> None:
    """Mark the prediction results on the picture."""
    img = imread(img, mode="RGB")
    img = img.copy()
    x, y = 10, 30  # Adjust position for better visibility
    text_color = color_val(text_color)
    for k, v in result.items():
        label_text = f'{v}'
        cv2.putText(img, label_text, (x, y), cv2.FONT_HERSHEY_COMPLEX,
                    font_scale, text_color, 2)  # Increased thickness for better visibility
    if out_file:
        show = False
        imwrite(img, out_file)

    if show:
        imshow(img, win_name, wait_time)

# ... [other functions remain unchanged] ...

# Read data for inference
for i, data in enumerate(dataset_infer.create_dict_iterator(output_numpy=True)):
    image = data["image"]
    image = ms.Tensor(image)
    prob = model.predict(image)
    label = np.argmax(prob.asnumpy(), axis=1)
    print(f"Predicted: {label}, Correct: {data['label']}")
    print(label == data['label'])
    
    # Create result directory if it doesn't exist
    result_dir = os.path.join("dataset", "result")
    os.makedirs(result_dir, exist_ok=True)
    
    # Prepare output
    output = {int(label): "positive" if label == 1 else "negative"}
    
    # Get the original image path
    original_img_path = os.path.join("dataset", "infer", "unknown", f"L{i+1:03d}.jpg")
    
    # Prepare the output image path
    out_file = os.path.join(result_dir, f"result_{i+1:03d}.jpg")
    
    # Show and save the result
    show_result(img=original_img_path,
                result=output,
                text_color='yellow',
                font_scale=1.0,  # Increased font size
                out_file=out_file)

[ERROR] MD(7803,fffcb77ef120,python):2024-08-29-13:59:56.418.869 [mindspore/ccsrc/minddata/dataset/util/task_manager.cc:230] InterruptMaster] MindSpore dataset is terminated with err msg: Exception thrown from dataset pipeline. Refer to 'Dataset Pipeline Error Message'. Invalid file found: ./dataset/infer/unknown/.ipynb_checkpoints, should be file, but got directory.
Line of code : 329
File         : mindspore/ccsrc/minddata/dataset/core/tensor.cc



RuntimeError: Exception thrown from dataset pipeline. Refer to 'Dataset Pipeline Error Message'. 

------------------------------------------------------------------
- Dataset Pipeline Error Message: 
------------------------------------------------------------------
[ERROR] Invalid file found: ./dataset/infer/unknown/.ipynb_checkpoints, should be file, but got directory.

------------------------------------------------------------------
- C++ Call Stack: (For framework developers) 
------------------------------------------------------------------
mindspore/ccsrc/minddata/dataset/core/tensor.cc(329).




In [69]:
!pip install ipywidgets

Looking in indexes: https://repo.huaweicloud.com/repository/pypi/simple/


推理过程完成后，在推理文件夹下可以找到图片的推理结果，可以看出预测结果是Doberman，与期望结果相同，验证了模型的准确性。

![infer-result](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r2.3/tutorials/application/source_zh_cn/cv/images/infer_result.jpg)

In [70]:
import ipywidgets as widgets
from IPython.display import display
import os


In [81]:
import os
import ipywidgets as widgets
from IPython.display import display

# 修改后的定义，使其指向 infer/unknown 子目录
def select_file(data_path):
    infer_path = os.path.join(data_path, "infer", "unknown")  # 指定到unknown子文件夹
    files = os.listdir(infer_path)
    dropdown = widgets.Dropdown(options=files, description="Select File:")
    display(dropdown)
    return dropdown

# 设置正确的data_path路径
#data_path = "/path/to/your/data"

# 调用函数，显示下拉菜单选择文件
file_dropdown = select_file(data_path)



Dropdown(description='Select File:', options=('L001.jpg',), value='L001.jpg')

In [82]:
# 创建一个按钮，当点击时处理所选文件
button = widgets.Button(description="Process File")

def on_button_clicked(b):
    selected_file = os.path.join(data_path, "infer", file_dropdown.value)
    print(f"Selected file: {selected_file}")
    # 你可以在这里加入对选中文件的处理逻辑
    # 例如加载、预处理、模型推理等

button.on_click(on_button_clicked)
display(button)


Button(description='Process File', style=ButtonStyle())

In [83]:
!pip install tkinter

Looking in indexes: https://repo.huaweicloud.com/repository/pypi/simple/


Exception in thread Thread-5:
Traceback (most recent call last):
  File "/home/mindspore/miniconda/envs/jupyter/lib/python3.9/threading.py", line 980, in _bootstrap_inner
Process ForkServerProcess-6:
Process ForkServerProcess-7:
Process ForkServerProcess-3:
Process ForkServerProcess-9:
Process ForkServerProcess-4:
Process ForkServerProcess-8:
Process ForkServerProcess-5:
Process ForkServerProcess-2:
    self.run()
  File "/usr/local/Ascend/ascend-toolkit/latest/python/site-packages/tbe/common/repository_manager/utils/multiprocess_util.py", line 91, in run
    key, func, args, kwargs = self.task_q.get(timeout=TIMEOUT)
  File "<string>", line 2, in get
  File "/home/mindspore/miniconda/envs/jupyter/lib/python3.9/multiprocessing/managers.py", line 810, in _callmethod
    kind, result = conn.recv()
  File "/home/mindspore/miniconda/envs/jupyter/lib/python3.9/multiprocessing/connection.py", line 250, in recv
    buf = self._recv_bytes()
  File "/home/mindspore/miniconda/envs/jupyter/lib/pyt

^C
[31mERROR: Operation cancelled by user[0m[31m
[0m

ager/route.py", line 262, in task_distribute
    key, func_name, detail = resource_proxy[TASK_QUEUE].get()
  File "/usr/local/Ascend/ascend-toolkit/latest/python/site-packages/tbe/common/repository_manager/route.py", line 262, in task_distribute
    key, func_name, detail = resource_proxy[TASK_QUEUE].get()
  File "/usr/local/Ascend/ascend-toolkit/latest/python/site-packages/tbe/common/repository_manager/route.py", line 262, in task_distribute
    key, func_name, detail = resource_proxy[TASK_QUEUE].get()
  File "<string>", line 2, in get
  File "/usr/local/Ascend/ascend-toolkit/latest/python/site-packages/tbe/common/repository_manager/route.py", line 262, in task_distribute
    key, func_name, detail = resource_proxy[TASK_QUEUE].get()
  File "/usr/local/Ascend/ascend-toolkit/latest/python/site-packages/tbe/common/repository_manager/route.py", line 262, in task_distribute
    key, func_name, detail = resource_proxy[TASK_QUEUE].get()
  File "/usr/local/Ascend/ascend-toolkit/latest/python/

In [84]:
import tkinter as tk
from tkinter import filedialog
import os

def open_file_dialog():
    # 创建一个文件选择对话框，让用户从指定文件夹中选择文件
    infer_path = os.path.join(data_path, "infer", "unknown")
    filename = filedialog.askopenfilename(initialdir=infer_path, title="Select file",
                                          filetypes=(("jpeg files", "*.jpg"), ("all files", "*.*")))
    # 显示选中的文件
    file_label.config(text="Selected File: " + os.path.basename(filename))
    # 显示预先计算的标签值
    label_display.config(text="Label value: " + str(label))

# 主窗口
root = tk.Tk()
root.title("File Selector")

# 设置data_path路径
#data_path = "/path/to/your/data"
label = 0  # 假设这是预先计算好的标签

# 添加按钮和标签
select_button = tk.Button(root, text="Select File", command=open_file_dialog)
select_button.pack(pady=20)

file_label = tk.Label(root, text="No file selected")
file_label.pack(pady=10)

label_display = tk.Label(root, text="Label value: ")
label_display.pack(pady=10)

# 运行主事件循环
root.mainloop()


TclError: no display name and no $DISPLAY environment variable

In [85]:
!pip install ipywidgets

Looking in indexes: https://repo.huaweicloud.com/repository/pypi/simple/


In [88]:
import ipywidgets as widgets
from IPython.display import display
import os

def select_file(data_path):
    infer_path = os.path.join(data_path, "infer", "unknown")
    files = os.listdir(infer_path)
    dropdown = widgets.Dropdown(options=files, description="Select File:")
    display(dropdown)
    
    # 按钮用于确认文件选择
    button = widgets.Button(description="Confirm Selection")
    output = widgets.Output()

    def on_button_clicked(b):
        with output:
            print("Selected file: " + dropdown.value)  # 显示所选文件名
            # 在这里添加对所选文件的处理逻辑

    button.on_click(on_button_clicked)
    display(button, output)

data_path = "dataset"  # 确保这是正确的路径
select_file(data_path)

Dropdown(description='Select File:', options=('L001.jpg',), value='L001.jpg')

Button(description='Confirm Selection', style=ButtonStyle())

Output()

In [92]:
from IPython.display import display, HTML
import ipywidgets as widgets

def create_file_selector():
    uploader = widgets.FileUpload(
        accept='',  # 可以指定接受的文件类型，例如 '.txt,.jpg,.png' 等
        multiple=False  # True 如果要允许上传多个文件
    )
    display(uploader)
    
    button = widgets.Button(description="Confirm Upload")
    output = widgets.Output()
    
    def on_button_clicked(b):
        with output:
            if uploader.value:
                # 获取上传文件的详细信息
                uploaded_filename = next(iter(uploader.value))
                print(f"Uploaded file: {uploaded_filename}")  # 显示上传的文件名
                if label == 1 :
                    print("positive")
                else:
                    print("negative")
            else:
                print("No file uploaded.")
    
    button.on_click(on_button_clicked)
    display(button, output)

create_file_selector()


FileUpload(value=(), description='Upload')

Button(description='Confirm Upload', style=ButtonStyle())

Output()

In [97]:
import time

In [98]:
import os
import pathlib
import cv2
import numpy as np
from PIL import Image
from enum import Enum
from scipy import io

create_file_selector()


time.sleep(5)
# construct model
network = ViT()

# load ckpt
param_dict = ms.load_checkpoint(vit_path)
ms.load_param_into_net(network, param_dict)

if ascend_target:
    model = train.Model(network, loss_fn=network_loss, optimizer=network_opt, metrics=eval_metrics, amp_level="O2")
else:
    model = train.Model(network, loss_fn=network_loss, optimizer=network_opt, metrics=eval_metrics, amp_level="O0")


class Color(Enum):
    """dedine enum color."""
    red = (0, 0, 255)
    green = (0, 255, 0)
    blue = (255, 0, 0)
    cyan = (255, 255, 0)
    yellow = (0, 255, 255)
    magenta = (255, 0, 255)
    white = (255, 255, 255)
    black = (0, 0, 0)


def check_file_exist(file_name: str):
    """check_file_exist."""
    if not os.path.isfile(file_name):
        raise FileNotFoundError(f"File `{file_name}` does not exist.")


def color_val(color):
    """color_val."""
    if isinstance(color, str):
        return Color[color].value
    if isinstance(color, Color):
        return color.value
    if isinstance(color, tuple):
        assert len(color) == 3
        for channel in color:
            assert 0 <= channel <= 255
        return color
    if isinstance(color, int):
        assert 0 <= color <= 255
        return color, color, color
    if isinstance(color, np.ndarray):
        assert color.ndim == 1 and color.size == 3
        assert np.all((color >= 0) & (color <= 255))
        color = color.astype(np.uint8)
        return tuple(color)
    raise TypeError(f'Invalid type for color: {type(color)}')


def imread(image, mode=None):
    """imread."""
    if isinstance(image, pathlib.Path):
        image = str(image)

    if isinstance(image, np.ndarray):
        pass
    elif isinstance(image, str):
        check_file_exist(image)
        image = Image.open(image)
        if mode:
            image = np.array(image.convert(mode))
    else:
        raise TypeError("Image must be a `ndarray`, `str` or Path object.")

    return image


def imwrite(image, image_path, auto_mkdir=True):
    """imwrite."""
    if auto_mkdir:
        dir_name = os.path.abspath(os.path.dirname(image_path))
        if dir_name != '':
            dir_name = os.path.expanduser(dir_name)
            os.makedirs(dir_name, mode=777, exist_ok=True)

    image = Image.fromarray(image)
    image.save(image_path)


def imshow(img, win_name='', wait_time=0):
    """imshow"""
    cv2.imshow(win_name, imread(img))
    if wait_time == 0:  # prevent from hanging if windows was closed
        while True:
            ret = cv2.waitKey(1)

            closed = cv2.getWindowProperty(win_name, cv2.WND_PROP_VISIBLE) < 1
            # if user closed window or if some key pressed
            if closed or ret != -1:
                break
    else:
        ret = cv2.waitKey(wait_time)


def show_result(img: str,
                result: Dict[int, float],
                text_color: str = 'green',
                font_scale: float = 0.5,
                row_width: int = 20,
                show: bool = False,
                win_name: str = '',
                wait_time: int = 0,
                out_file: Optional[str] = None) -> None:
    """Mark the prediction results on the picture."""
    img = imread(img, mode="RGB")
    img = img.copy()
    x, y = 0, row_width
    text_color = color_val(text_color)
    for k, v in result.items():
        if isinstance(v, float):
            v = f'{v:.2f}'
        label_text = f'{k}: {v}'
        cv2.putText(img, label_text, (x, y), cv2.FONT_HERSHEY_COMPLEX,
                    font_scale, text_color)
        y += row_width
    if out_file:
        show = False
        imwrite(img, out_file)

    if show:
        imshow(img, win_name, wait_time)


def index2label():
    """Dictionary output for image numbers and categories of the ImageNet dataset."""
    metafile = os.path.join(data_path, "ILSVRC2012_devkit_t12/data/meta.mat")
    meta = io.loadmat(metafile, squeeze_me=True)['synsets']

    nums_children = list(zip(*meta))[4]
    meta = [meta[idx] for idx, num_children in enumerate(nums_children) if num_children == 0]

    _, wnids, classes = list(zip(*meta))[:3]
    clssname = [tuple(clss.split(', ')) for clss in classes]
    wnid2class = {wnid: clss for wnid, clss in zip(wnids, clssname)}
    wind2class_name = sorted(wnid2class.items(), key=lambda x: x[0])

    mapping = {}
    for index, (_, class_name) in enumerate(wind2class_name):
        mapping[index] = class_name[0]
    return mapping


# Read data for inference
for i, data in enumerate(dataset_infer.create_dict_iterator(output_numpy=True)):
    image = data["image"]
    image = ms.Tensor(image)
    prob = model.predict(image)
    label = np.argmax(prob.asnumpy(), axis=1)
    print(f"Predicted: {label}, Correct: {data['label']}")
    print(label == data['label'])

if label == 1 :

    print('positive')
    
else:
    print("negative")
    
    

    # mapping = index2label()
    # output = {int(label): mapping[int(label)]}
    # print(output)
    # show_result(img="./dataset/infer/unkown/L300.jpg",
    #             result=output,
    #             out_file="./dataset/infer/ILSVRC2012_test_00000279.JPEG")







FileUpload(value=(), description='Upload')

Button(description='Confirm Upload', style=ButtonStyle())

Output()

Predicted: [1], Correct: [1]
[ True]
positive


In [90]:
from IPython.display import display, HTML
import ipywidgets as widgets

def create_file_selector():
    uploader = widgets.FileUpload(
        accept='',  # Specify which file types you want to accept
        multiple=False  # Set to True to allow multiple file uploads
    )
    display(uploader)
    
    button = widgets.Button(description="Confirm Upload")
    output = widgets.Output()
    
    # This function is triggered when the button is clicked
    def on_button_clicked(b):
        with output:
            output.clear_output()  # Clear the previous output
            if uploader.value:
                # Extract the uploaded file's details
                uploaded_filename = next(iter(uploader.value))
                file_info = uploader.value[uploaded_filename]
                content = file_info['content']
                
                # Simulating label generation for demonstration
                label = 1 if len(content) % 2 == 0 else 0  # Dummy condition for label
                
                print(f"Uploaded file: {uploaded_filename}")  # Display uploaded filename
                print(f"Label value: {label}")  # Display the label value
            else:
                print("No file uploaded.")
    
    button.on_click(on_button_clicked)
    display(button)
    display(output)  # This ensures that the output widget is displayed below the button

create_file_selector()


FileUpload(value=(), description='Upload')

Button(description='Confirm Upload', style=ButtonStyle())

Output()

In [91]:
from IPython.display import display, HTML, Image
import ipywidgets as widgets
import io

def create_file_selector():
    uploader = widgets.FileUpload(
        accept='image/jpeg',  # 限定只接受 JPEG 图片
        multiple=False  # 只允许上传一个文件
    )
    display(uploader)
    
    button = widgets.Button(description="Confirm Upload")
    output = widgets.Output()
    
    def on_button_clicked(b):
        with output:
            output.clear_output()  # 清除之前的输出
            if uploader.value:
                # 提取上传文件的详细信息
                uploaded_filename = next(iter(uploader.value))
                file_info = uploader.value[uploaded_filename]
                content = file_info['content']  # 提取上传文件的内容
                
                # 使用字节数据显示图片
                image = Image(content)
                display(image)  # 显示图片
                
                # 假设这里根据图片内容来生成一个标签
                label = 1 if len(content) % 2 == 0 else 0  # 使用简单条件为示例
                
                print(f"Uploaded file: {uploaded_filename}")  # 显示上传的文件名
                print(f"Label value: {label}")  # 显示标签值
            else:
                print("No file uploaded.")
    
    button.on_click(on_button_clicked)
    display(button)
    display(output)  # 确保输出小部件在按钮下方显示

create_file_selector()


FileUpload(value=(), accept='image/jpeg', description='Upload')

Button(description='Confirm Upload', style=ButtonStyle())

Output()

## 总结

本案例完成了一个ViT模型在ImageNet数据上进行训练，验证和推理的过程，其中，对关键的ViT模型结构和原理作了讲解。通过学习本案例，理解源码可以帮助用户掌握Multi-Head Attention，TransformerEncoder，pos_embedding等关键概念，如果要详细理解ViT的模型原理，建议基于源码更深层次的详细阅读。