## 数据预处理

In [11]:
from torchvision.datasets import ImageFolder
from torchvision.transforms import transforms


# 数据增强和预处理
data_augmentation_transform = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(15),
])

# 数据集根目录
data_root = './data/dataset'

dataset = ImageFolder(root=data_root, transform=data_augmentation_transform)
print(len(dataset))
print(dataset.classes)
print(dataset.class_to_idx)
print(dataset.imgs)
print(dataset[0])

352
['an_hong', 'bai_tai', 'bao_tai', 'bo_luo', 'chi_hen', 'dan_bai', 'dan_hong', 'dian_ci', 'duan_suo', 'fu_tai', 'hou_tai', 'hua_tai', 'huang_tai', 'hui_hei', 'jiang_she', 'lie_wen', 'ni_tai', 'pang_da', 'qing_zi', 'shao_tai', 'she_jian_hong', 'wai_xie', 'wei_ruan', 'wu_tai']
{'an_hong': 0, 'bai_tai': 1, 'bao_tai': 2, 'bo_luo': 3, 'chi_hen': 4, 'dan_bai': 5, 'dan_hong': 6, 'dian_ci': 7, 'duan_suo': 8, 'fu_tai': 9, 'hou_tai': 10, 'hua_tai': 11, 'huang_tai': 12, 'hui_hei': 13, 'jiang_she': 14, 'lie_wen': 15, 'ni_tai': 16, 'pang_da': 17, 'qing_zi': 18, 'shao_tai': 19, 'she_jian_hong': 20, 'wai_xie': 21, 'wei_ruan': 22, 'wu_tai': 23}
[('./data/dataset\\an_hong\\tongue_front_300110899004_2023-11-08-15-36-06.jpg', 0), ('./data/dataset\\an_hong\\tongue_front_300136928001_2023-11-07-10-46-40.jpg', 0), ('./data/dataset\\bai_tai\\tongue_front_0000351212_2023-10-21-11-21-13.jpg', 1), ('./data/dataset\\bai_tai\\tongue_front_200011111001_2023-10-24-09-58-09.jpg', 1), ('./data/dataset\\bai_tai\\to

In [12]:
from transformers import ViTImageProcessor


# 加载 Vision Transformer 的特征提取器
model_name = 'google/vit-base-patch16-224-in21k'
feature_extractor = ViTImageProcessor.from_pretrained(model_name)
feature_extractor

ViTImageProcessor {
  "_valid_processor_keys": [
    "images",
    "do_resize",
    "size",
    "resample",
    "do_rescale",
    "rescale_factor",
    "do_normalize",
    "image_mean",
    "image_std",
    "return_tensors",
    "data_format",
    "input_data_format"
  ],
  "do_normalize": true,
  "do_rescale": true,
  "do_resize": true,
  "image_mean": [
    0.5,
    0.5,
    0.5
  ],
  "image_processor_type": "ViTImageProcessor",
  "image_std": [
    0.5,
    0.5,
    0.5
  ],
  "resample": 2,
  "rescale_factor": 0.00392156862745098,
  "size": {
    "height": 224,
    "width": 224
  }
}

In [13]:
import torch
from torch.utils.data import random_split
from torch.utils.data.sampler import WeightedRandomSampler


# 计算每个类别的样本数量
class_counts = torch.bincount(torch.tensor(dataset.targets))

# 计算每个类别的权重
class_weights = 1 / class_counts.float()

# 设置随机数生成器种子
seed = 42

# 划分数据集
train_size = int(0.8 * len(dataset))
val_size = int(0.1 * len(dataset))
test_size = len(dataset) - train_size - val_size
remaining_size = len(dataset) - train_size
train_dataset, remaining_dataset = random_split(
    dataset, [train_size, remaining_size])
val_dataset, test_dataset = random_split(
    remaining_dataset, [val_size, test_size])

id2label = {id: label for id, label in enumerate(
    train_dataset.dataset.classes)}
label2id = {label: id for id,label in id2label.items()}

In [14]:
from torch.utils.data import DataLoader
from torch.utils.data.dataset import Dataset


class MyDataset(torch.utils.data.Dataset):
    def __init__(self, dataset):
        self.dataset = dataset

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        return self.dataset[idx]


def custom_collate(batch):



    # 将样本数据列表拆分为输入和标签

    images, labels = list(zip(*batch))

    inputs = feature_extractor(images, return_tensors='pt')

    inputs['labels'] = torch.tensor(labels)
    inputs['pixel_values'] = torch.stack(
        [input for input in inputs["pixel_values"]])
    return inputs


train_loader = DataLoader(train_dataset, batch_size=4, collate_fn=custom_collate, shuffle=True, drop_last=True)
train_data = MyDataset(train_dataset)
val_loader = DataLoader(val_dataset, batch_size=4, collate_fn=custom_collate, shuffle=False, drop_last=True)
val_data=MyDataset(val_dataset)
test_loader = DataLoader(test_dataset, batch_size=4, collate_fn=custom_collate, shuffle=False, drop_last=True)
test_data=MyDataset(test_dataset)
print(len(train_loader))
batch = next(iter(train_loader))
for k, v in batch.items():
  if isinstance(v, torch.Tensor):
    print(k, v.shape)



70
pixel_values torch.Size([4, 3, 224, 224])
labels torch.Size([4])


In [15]:
from transformers import ViTForImageClassification

model = ViTForImageClassification.from_pretrained(model_name, num_labels=len(dataset.classes),id2label=id2label,label2id=label2id)

Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224-in21k and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [26]:
from transformers import TrainingArguments, Trainer

metric_name = "accuracy"

args = TrainingArguments(
    f"tongue-disease-classification",
    save_strategy="epoch",
    evaluation_strategy="epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    num_train_epochs=4,
    weight_decay=0.01,
    load_best_model_at_end=True,
    metric_for_best_model=metric_name,
    logging_dir='logs',
    remove_unused_columns=False,
    seed=seed,
)

In [36]:
from sklearn.metrics import accuracy_score
import numpy as np


def compute_metrics(eval_pred):
    # predictions, labels = eval_pred
    # predictions = np.argmax(predictions, axis=1)
    # return dict(accuracy=accuracy_score(predictions, labels))
    logits, labels = eval_pred
    predictions = logits.argsort(axis=-1)[:, -5:]  # 获取top-5预测结果
    # 将标签扩展为与预测结果相同的形状
    labels_expanded = labels.reshape(-1, 1).repeat(5, axis=1)
    correct_predictions = (predictions == labels_expanded)  # 判断预测结果是否包含正确标签
    top5_accuracy = correct_predictions.any(axis=-1).mean()  # 计算top-5准确率
    return {"accuracy": top5_accuracy}

In [37]:
import torch

trainer = Trainer(
    model,
    args,
    train_dataset=train_data,
    eval_dataset=val_data,
    data_collator=custom_collate,
    compute_metrics=compute_metrics,
    tokenizer=feature_extractor,
)

dataloader_config = DataLoaderConfiguration(dispatch_batches=None, split_batches=False, even_batches=True, use_seedable_sampler=True)


In [38]:
# Start tensorboard.
%load_ext tensorboard
%tensorboard --logdir logs/

The tensorboard extension is already loaded. To reload it, use:
  %reload_ext tensorboard


Reusing TensorBoard on port 6006 (pid 20456), started 3:50:42 ago. (Use '!kill 20456' to kill it.)

In [39]:
trainer.train()

  0%|          | 0/144 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

{'eval_loss': 2.7128612995147705, 'eval_accuracy': 0.6857142857142857, 'eval_runtime': 7.8931, 'eval_samples_per_second': 4.434, 'eval_steps_per_second': 0.633, 'epoch': 1.0}


  0%|          | 0/5 [00:00<?, ?it/s]

{'eval_loss': 2.7157540321350098, 'eval_accuracy': 0.6857142857142857, 'eval_runtime': 7.0109, 'eval_samples_per_second': 4.992, 'eval_steps_per_second': 0.713, 'epoch': 2.0}


  0%|          | 0/5 [00:00<?, ?it/s]

{'eval_loss': 2.730070114135742, 'eval_accuracy': 0.6285714285714286, 'eval_runtime': 6.9649, 'eval_samples_per_second': 5.025, 'eval_steps_per_second': 0.718, 'epoch': 3.0}


  0%|          | 0/5 [00:00<?, ?it/s]

{'eval_loss': 2.705512523651123, 'eval_accuracy': 0.6857142857142857, 'eval_runtime': 6.7098, 'eval_samples_per_second': 5.216, 'eval_steps_per_second': 0.745, 'epoch': 4.0}
{'train_runtime': 265.6085, 'train_samples_per_second': 4.232, 'train_steps_per_second': 0.542, 'train_loss': 2.363711886935764, 'epoch': 4.0}


TrainOutput(global_step=144, training_loss=2.363711886935764, metrics={'train_runtime': 265.6085, 'train_samples_per_second': 4.232, 'train_steps_per_second': 0.542, 'train_loss': 2.363711886935764, 'epoch': 4.0})

## 模型训练 (废弃)

import torch.optim as optim
import torch.nn as nn
from transformers import ViTForImageClassification
from torch.optim.lr_scheduler import StepLR
import matplotlib.pyplot as plt


class VisionTransformer:
    def __init__(self, model_name='google/vit-base-patch16-224-in21k', num_classes=24):
        self.model_name = model_name
        self.num_classes = num_classes
        self.model = ViTForImageClassification.from_pretrained(model_name)
        self.model.classifier = nn.Linear(
            self.model.config.hidden_size, num_classes)
        self.device = torch.device(
            "cuda" if torch.cuda.is_available() else "cpu")
        self.model.to(self.device)

    def train(self, train_loader, val_loader, num_epochs=5, lr=1e-4, checkpoint_path='checkpoint.pth'):
        criterion = nn.CrossEntropyLoss()
        optimizer = optim.Adam(self.model.parameters(), lr=lr)
        # 每 5 个 epoch 衰减学习率为当前的 0.1 倍
        scheduler = StepLR(optimizer, step_size=5, gamma=0.1)

        train_losses = []
        val_losses = []
        best_val_loss = float('inf')
        for epoch in range(num_epochs):
            self.model.train()
            running_loss = 0.0
            for inputs in train_loader:
                inputs = {key: value.to(self.device)
                          for key, value in inputs.items()}
                print(inputs)
                optimizer.zero_grad()
                outputs = self.model(**inputs)
                loss = criterion(outputs.logits, inputs['labels'])
                loss.backward()
                optimizer.step()

                running_loss += loss.item() * inputs.size(0)

            scheduler.step()
            epoch_loss = running_loss / len(train_loader.dataset)
            train_losses.append(epoch_loss)

            val_loss = self.evaluate(val_loader)
            val_losses.append(val_loss)

            print(
                f'Epoch {epoch+1}/{num_epochs}, Train Loss: {epoch_loss:.4f}, Val Loss: {val_loss:.4f}')

            if val_loss < best_val_loss:
                best_val_loss = val_loss
                torch.save(self.model.state_dict(), checkpoint_path)

        return train_losses, val_losses

    def evaluate(self, data_loader):
        self.model.eval()
        criterion = nn.CrossEntropyLoss()
        running_loss = 0.0
        correct_top1 = 0
        correct_top5 = 0
        total = 0
        with torch.no_grad():
            for inputs in train_loader:
                inputs = {key: value.to(self.device)
                               for key, value in inputs.items()}

                outputs = self.model(**inputs)
                loss = criterion(outputs.logits, inputs['labels'])
                running_loss += loss.item() * inputs.size(0)


                _, predicted = torch.topk(
                    outputs.logits, k=5, dim=1)  # 获取前五个预测结果的索引
                total += inputs['labels'].size(0)
                # 计算 Top-1 和 Top-5 准确率
                correct_top1 += (predicted[:, 0] ==
                                 inputs['labels']).sum().item()
                correct_top5 += torch.sum(torch.any(predicted == inputs['labels'].view(-1, 1),dim=1)).item()
        accuracy_top1 = correct_top1 / total
        accuracy_top5 = correct_top5 / total
        print(f'Top1 Evaluate Accuracy: {accuracy_top1:.4f}')
        print(f'Top5 Evaluate Accuracy: {accuracy_top5:.4f}')

        return running_loss / len(data_loader.dataset)

    def test(self, test_loader):
        self.model.eval()
        correct_top1 = 0
        correct_top5 = 0
        total = 0
        with torch.no_grad():
            for inputs in test_loader:
                inputs = {key: value.to(self.device)
                          for key, value in inputs.items()}

                outputs = self.model(**inputs)
                _, predicted = torch.topk(
                    outputs.logits, k=5, dim=1)  # 获取前五个预测结果的索引
                total += inputs['labels'].size(0)
                # 计算 Top-1 和 Top-5 准确率
                correct_top1 += (predicted[:, 0] ==
                                 inputs['labels']).sum().item()
                correct_top5 += torch.sum(torch.any(predicted ==
                                          inputs['labels'].view(-1, 1), dim=1)).item()
        accuracy_top1 = correct_top1 / total
        accuracy_top5 = correct_top5 / total

        print(f'Top1 Evaluate Accuracy: {accuracy_top1:.4f}')
        print(f'Top5 Evaluate Accuracy: {accuracy_top5:.4f}')
        return accuracy_top1, accuracy_top5

    def plot_losses(self, train_losses, val_losses):
        plt.plot(train_losses, label='Train Loss')
        plt.plot(val_losses, label='Val Loss')
        plt.xlabel('Epoch')
        plt.ylabel('Loss')
        plt.title('Training and Validation Loss')
        plt.legend()
        plt.show()

    def export_model(self, filename='vit_model.pt'):
        self.model.eval().to('cpu')
        torch.save(self.model, filename)

# 初始化模型
model = VisionTransformer()

# 训练模型
train_losses, val_losses = model.train(train_loader, val_loader, checkpoint_path='./model/checkpoint.pth')

# 可视化训练过程
model.plot_losses(train_losses, val_losses)

# 加载最佳模型参数
model.model.load_state_dict(torch.load('./model/checkpoint.pth'))

# 测试模型
model.test(test_loader)

# 导出模型
model.export_model('vit_model.pt')

def collate_fn(batch):
    return {
        'pixel_values': torch.stack([x['pixel_values'] for x in batch]),
        'labels': torch.tensor([x['labels'] for x in batch])
    }

import numpy as np
import evaluate

metric = evaluate.load("accuracy")


def compute_metrics(p):
    return metric.compute(predictions=np.argmax(p.predictions, axis=1), references=p.label_ids)

from transformers import ViTForImageClassification

labels = dataset.classes

model = ViTForImageClassification.from_pretrained(
    model,
    num_labels=len(labels),
    id2label={str(i): c for i, c in enumerate(labels)},
    label2id={c: str(i) for i, c in enumerate(labels)}
)

from transformers import TrainingArguments


training_args = TrainingArguments(
    output_dir="./vit-base-tongue",
    per_device_train_batch_size=32,
    evaluation_strategy="steps",
    num_train_epochs=4,
    fp16=True,
    save_steps=100,
    eval_steps=100,
    logging_steps=10,
    learning_rate=2e-4,
    save_total_limit=2,
    remove_unused_columns=False,
    push_to_hub=False,
    report_to='tensorboard',
    load_best_model_at_end=True,
)

from transformers import Trainer
from accelerate import DataLoaderConfiguration


# 定义数据加载器配置
dataloader_config = DataLoaderConfiguration(
    dispatch_batches=None,
    split_batches=False,
    even_batches=True,
    use_seedable_sampler=True
)


trainer = Trainer(
    model=model,

    args=training_args,
    data_collator=collate_fn,
    compute_metrics=compute_metrics,

    train_dataset=train_loader,

    eval_dataset=val_loader,

    tokenizer=feature_extractor,

)

train_results = trainer.train()
trainer.save_model()
trainer.log_metrics("train", train_results.metrics)
trainer.save_metrics("train", train_results.metrics)
trainer.save_state()

metrics = trainer.evaluate(prepared_ds['validation'])
trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics)

In [None]:
# classes = os.listdir(data_root)  # 获取数据集类别
# id2label={i: c for i, c in enumerate(classes)},
# label2id={c: i for i, c in enumerate(classes)}

# # 定义函数加载数据集并进行批量处理


# def process_dataset(data_root, data_transform, feature_extractor):
#     images = []
#     labels = []

#     # 遍历数据集文件夹
#     for class_dir in os.listdir(data_root):
#         class_path = os.path.join(data_root, class_dir)
#         if not os.path.isdir(class_path):
#             continue

#         # 遍历每个类别文件夹
#         for image_file in os.listdir(class_path):
#             image_path = os.path.join(class_path, image_file)
#             image = Image.open(image_path)
#             image = data_transform(image)  # 数据预处理
#             images.append(image)
#             labels.append(label2id[class_dir])

#     # 使用 ViTFeatureExtractor 进行批量处理
#     inputs = feature_extractor(images=images, return_tensors="pt")

#     return inputs, torch.tensor(labels)


# # 调用函数加载数据集并进行批量处理
# inputs, labels = process_dataset(data_root, data_transform, feature_extractor)

# # 输出处理后的张量
# inputs
# labels

tensor([ 0,  0,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,
         1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,
         1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,
         1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  2,  2,  2,  2,  2,  2,  2,  2,
         2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  3,  3,  4,
         4,  4,  4,  4,  4,  4,  4,  4,  4,  4,  4,  4,  4,  4,  4,  4,  4,  4,
         4,  4,  4,  4,  4,  4,  4,  4,  4,  5,  5,  5,  5,  5,  5,  5,  5,  5,
         5,  5,  5,  5,  5,  5,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6,
         6,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6,
         6,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6,
         6,  6,  6,  6,  6,  6,  7,  8,  8,  8,  8,  8,  8,  9,  9,  9,  9,  9,
        10, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11,
        11, 11, 11, 11, 11, 11, 11, 11, 