In [1]:
import os
import timm
import torch
from tqdm import tqdm
import cv2
from torch.utils.data import Dataset, DataLoader, random_split
import torch
from torch.utils.data import DataLoader, Dataset
from torchvision.transforms import transforms
import torch.nn.functional as F
# from sklearn.metrics import accuracy_score
from torchmetrics.classification import MulticlassAccuracy

## dataset

In [2]:
classes = os.listdir('/kaggle/input/half-binary/half_binary_COD/half_binary_COD')

In [3]:
class half_binary_dataset(Dataset):
    """
        半二值图像数据集
    """

    def __init__(self, add_imagenet=False, train=True, transform=None):
        super().__init__()

        ## set classes
        self.classes_ImageNet = os.listdir('/kaggle/input/half-binary/half_binary_ImageNet_15/half_binary_ImageNet_15')
        self.classes_COD = os.listdir('/kaggle/input/half-binary/half_binary_COD/half_binary_COD')
        self.Train = train
        self.transform = transform

        ## load dataset
        self.train_data = []
        self.labels = []

        # load COD train
        if self.Train:
            for label, class_name in enumerate(self.classes_COD):
                image_path = os.path.join('/kaggle/input/half-binary/half_binary_COD/half_binary_COD', class_name, 'train')
                images = os.listdir(image_path)
                for img in images:
                    self.train_data.append(
                        os.path.join(image_path, img)
                    )
                    self.labels.append(classes.index(class_name))
                    
        # load COD test
        else:
            for label, class_name in enumerate(self.classes_COD):
                image_path = os.path.join('/kaggle/input/half-binary/half_binary_COD/half_binary_COD', class_name, 'train')
                images = os.listdir(image_path)
                for img in images:
                    self.train_data.append(
                        os.path.join(image_path, img)
                    )
                    self.labels.append(classes.index(class_name))

        # load imageNet
        if add_imagenet:
                for label, class_name in enumerate(self.classes_ImageNet):
                    image_path = os.path.join('/kaggle/input/half-binary/half_binary_ImageNet_15/half_binary_ImageNet_15', class_name)
                    images = os.listdir(image_path)
                    for img in images:
                        self.train_data.append(
                            os.path.join(image_path, img)
                        )
                        self.labels.append(classes.index(class_name))

        
    def __len__(self):
        return len(self.train_data)

    def __getitem__(self, idx):
        img = cv2.imread(
            self.train_data[idx]
        )
        if self.transform:
            img = self.transform(img)
        return img, self.labels[idx]

_dataset = half_binary_dataset(add_imagenet=True)
len(_dataset.labels)

27711

## Train

In [4]:
## 参数
lr = 0.002
betas = (0.9, 0.999)
batch_size = 32
epoches = 10
T_max = 10
weight_decay=1e-3
class_num = len(classes)

In [5]:
def create_lr_scheduler(optimizer,
                        num_step: int,
                        epochs: int,
                        warmup=True,
                        warmup_epochs=1,
                        warmup_factor=1e-3):
    """
    学习率调度器， 先warm up，然后再decay
    :param optimizer: 优化器
    :param num_step: 每个epoch的step数，即每个epoch中batch的数量，len(train_dataset) / batch_size
    :param epochs: 总epoch数
    :param warmup: 是否warm up
    :param warmup_epochs: warm up 的epoch数
    :param warmup_factor: warm up 的学习率因子
    :return:
    """
    assert num_step > 0 and epochs > 0
    if warmup is False:
        warmup_epochs = 0

    def f(x):
        """
        根据step数返回一个学习率因子，
        注意在训练开始之前，pytorch会提前调用一次lr_schedule.step()方法

        """
        if warmup is True and x <= (warmup_epochs * num_step):
            alpha = float(x) / (warmup_epochs * num_step)
            # 在warm up 的过程中，学习率因子：warmup factor -> 1

            return warmup_factor * (1 - alpha) + alpha

        else:
            # 在warm up 之后，学习率因子：1 -> 0
            return (1 - (x - warmup_epochs * num_step) / ((epochs - warmup_epochs) * num_step)) ** 2

    return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=f)

In [6]:
transform = transforms.Compose([
    transforms.ToPILImage(),  # 将cv2读取的图像从NumPy数组转换为PIL图像
    transforms.Resize((224, 224)),  # 调整图像大小
    transforms.ToTensor(),  # 将图像转换为张量并归一化到[0, 1]
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # 标准化
])


train_dataset = half_binary_dataset(add_imagenet=True, transform=transform)
test_dataset = half_binary_dataset(add_imagenet=False, transform=transform, train=False)

train_loader = DataLoader(
    train_dataset,
    num_workers=min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8]),
    shuffle=True,
    batch_size=batch_size
)
test_loader = DataLoader(
    test_dataset,
    num_workers=min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8]),
    shuffle=False,
    batch_size=batch_size
)

## device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


## import model
model = timm.create_model('vit_small_patch16_224', pretrained=True)
model.head = torch.nn.Linear(model.head.in_features, class_num)  # 修改分类头为15类
model:torch.nn.Module = model.to(device)


# criterion and optimizer
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=lr, betas=betas, weight_decay=weight_decay)
lr_scheduler = create_lr_scheduler(
    optimizer=optimizer,
    num_step=len(train_loader),
    epochs=epoches,
    warmup_epochs=3
)
accuracy_score = MulticlassAccuracy(num_classes=class_num).to(device)


if __name__ == "__main__":
    for ep in range(epoches):
        print("epoch:", ep + 1)
        loss_one_epoch = []
        accuracy_one_epoch = []
        for images, labels in train_loader:
            optimizer.zero_grad()

            # pre
            images, labels = images.to(device), labels.to(device)
            # labels = F.one_hot(labels, class_num).to(torch.float32)
            pre: torch.Tensor = model(images)

            # back-propagation
            loss: torch.Tensor = criterion(pre, F.one_hot(labels, class_num).to(torch.float32))
            loss.backward()
            optimizer.step()
            

            # count
            loss_one_epoch.append(loss.detach().item())
            accuracy_one_epoch.append(
                accuracy_score(pre, labels)
            )

        print(f'loss: {sum(loss_one_epoch) / len(loss_one_epoch)}   accuracy train: {sum(accuracy_one_epoch) / len(accuracy_one_epoch) * 100.0:.2f}%')
        lr_scheduler.step()
        
        accuracy_one_epoch = []
        for images, labels in test_loader:
            with torch.no_grad():
                images, labels = images.to(device), labels.to(device)
                pre: torch.Tensor = model(images)

                accuracy_one_epoch.append(
                    accuracy_score(pre, labels)
                )
        print(f'accuracy test: {sum(accuracy_one_epoch) / len(accuracy_one_epoch) * 100.0:.2f}%\n')


    ## save model
    torch.save(model.state_dict(), 'vit_S16.pt')

model.safetensors:   0%|          | 0.00/88.2M [00:00<?, ?B/s]

epoch: 1
loss: 1.747079708077891   accuracy train: 47.16%
accuracy test: 1.47%

epoch: 2
loss: 0.793206268727917   accuracy train: 68.45%
accuracy test: 3.43%

epoch: 3
loss: 0.5825445634198932   accuracy train: 73.47%
accuracy test: 5.91%

epoch: 4
loss: 0.44960859032481154   accuracy train: 76.96%
accuracy test: 9.14%

epoch: 5
loss: 0.3389184835149668   accuracy train: 81.49%
accuracy test: 17.23%

epoch: 6
loss: 0.24585677504025433   accuracy train: 86.40%
accuracy test: 27.85%

epoch: 7
loss: 0.166462010290478   accuracy train: 90.92%
accuracy test: 46.17%

epoch: 8
loss: 0.10989672856353756   accuracy train: 95.13%
accuracy test: 69.18%

epoch: 9
loss: 0.07480344737102396   accuracy train: 97.30%
accuracy test: 77.40%

epoch: 10
loss: 0.05927620069355022   accuracy train: 98.12%
accuracy test: 63.63%

