### Import Libraries

In [1]:
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import datasets, models, transforms
from sklearn.metrics import f1_score
from PIL import Image
from tqdm import tqdm
from transformers import (ViTModel, 
                          ViTConfig, 
                          ViTFeatureExtractor, 
                          ViTForImageClassification)
from config import CFG
from sklearn.model_selection import StratifiedKFold

2024-01-07 19:50:53.161562: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


### Load data

In [2]:
label_data = pd.read_csv(os.path.join(CFG.train_data_dir, '文件标签汇总数据.csv'))
train_csv_folder = os.path.join(CFG.train_data_dir, 'csv文件')
test_A_csv_folder = os.path.join(CFG.test_A_data_dir, 'csv文件')
train_image_folder = os.path.join(CFG.project_name, 'project/image/训练集数据')
test_A_image_folder = os.path.join(CFG.project_name, 'project/image/A榜测试集数据')

In [3]:
class TrainDataset(Dataset):
    def __init__(self, image_folder, transform=None):
        self.image_folder = image_folder
        self.transform = transform
        self.image_files = [f for f in os.listdir(
            image_folder) if f.endswith('.png')]
        # 假设标签以某种方式存储或从文件名中获取
        self.labels = [CFG.label2id[f[0]] for f in self.image_files]

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        img_name = os.path.join(self.image_folder, self.image_files[idx])
        image = Image.open(img_name)
        # if image.mode != 'RGB':
        #     image = image.convert('RGB')

        label = CFG.label2id[self.image_files[idx][0]]

        if self.transform:
            image = self.transform(image)

        return image, label

In [4]:
class TestDataset(Dataset):
    def __init__(self, image_folder, transform=None):
        self.image_folder = image_folder
        self.transform = transform
        self.image_files = [f for f in os.listdir(
            image_folder) if f.endswith('.png')]

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        img_name = os.path.join(self.image_folder, self.image_files[idx])
        image = Image.open(img_name)
        # if image.mode != 'RGB':
        #     image = image.convert('RGB')

        if self.transform:
            image = self.transform(image)

        return image, self.image_files[idx]

In [5]:
data_transforms = {
    'train': transforms.Compose([  
        transforms.RandomHorizontalFlip(),  # 随机水平翻转
        transforms.RandomVerticalFlip(),    # 随机垂直翻转
        transforms.RandomRotation(10),      # 在[-10, 10]范围内随机旋转
        transforms.RandomHorizontalFlip(),
        transforms.RandomResizedCrop(224),
        transforms.ToTensor(),
        transforms.Lambda(lambda x: x.repeat(3, 1, 1)),
        transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5))
    ]),
    'test': transforms.Compose([
        transforms.Resize(224),
        transforms.ToTensor(),
        transforms.Lambda(lambda x: x.repeat(3, 1, 1)),
        transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5))
    ]),
}

In [6]:
train_dataset = TrainDataset(
    image_folder=train_image_folder, transform=data_transforms['train'])
test_dataset = TestDataset(
    image_folder=test_A_image_folder, transform=data_transforms['test'])

train_loader = DataLoader(train_dataset, batch_size=4,
                          shuffle=True, num_workers=4)
test_loader = DataLoader(test_dataset, batch_size=4,
                         shuffle=False, num_workers=4)

dataset_sizes = {'train': len(train_dataset), 'test': len(test_dataset)}

In [7]:
model = ViTForImageClassification.from_pretrained(CFG.vit_model,
                                                    num_labels=5,
                                                    ignore_mismatched_sizes=True)
model.to(CFG.device)

Some weights of ViTForImageClassification were not initialized from the model checkpoint at /ai/users/bst/competition/model/vit-base-patch16-224 and are newly initialized because the shapes did not match:
- classifier.bias: found shape torch.Size([1000]) in the checkpoint and torch.Size([5]) in the model instantiated
- classifier.weight: found shape torch.Size([1000, 768]) in the checkpoint and torch.Size([5, 768]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


ViTForImageClassification(
  (vit): ViTModel(
    (embeddings): ViTEmbeddings(
      (patch_embeddings): ViTPatchEmbeddings(
        (projection): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
      )
      (dropout): Dropout(p=0.0, inplace=False)
    )
    (encoder): ViTEncoder(
      (layer): ModuleList(
        (0-11): 12 x ViTLayer(
          (attention): ViTAttention(
            (attention): ViTSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
            (output): ViTSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
          )
          (intermediate): ViTIntermediate(
            (dense): Linear(in_features=7

In [8]:
from config import CFG
# 用K折交叉验证，对train_dataset进行划分, 5折,并且评估指标是f1_score
skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=CFG.seed)
for fold, (train_index, valid_index) in enumerate(skf.split(train_dataset.image_files, train_dataset.labels)):
    train_dataset_fold = torch.utils.data.Subset(train_dataset, train_index)
    valid_dataset_fold = torch.utils.data.Subset(train_dataset, valid_index)
    train_loader_fold = DataLoader(train_dataset_fold, batch_size=8,
                                   shuffle=True, num_workers=4)
    valid_loader_fold = DataLoader(valid_dataset_fold, batch_size=8,
                                   shuffle=False, num_workers=4)
    dataset_sizes = {'train': len(train_dataset_fold), 'valid': len(valid_dataset_fold)}
    print(f'fold:{fold}')
    # print(f'train_index:{train_index}')
    # print(f'valid_index:{valid_index}')
    print('--------------------------------')

    # 用于训练的模型
    
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.0001)

    # 训练
    for epoch in range(CFG.epochs):
        print('Epoch {}/{}'.format(epoch+1, CFG.epochs))
        print('-' * 10)

        # 每个epoch都有一个训练和验证阶段
        for phase in ['train', 'valid']:
            if phase == 'train':
                model.train()  # 设置模型为训练模式
                dataloader = train_loader_fold
            else:
                model.eval()   # 设置模型为评估模式
                dataloader = valid_loader_fold

            running_loss = 0.0
            running_corrects = 0
            y_true = []
            y_pred = []
            # 迭代数据
            for inputs, labels in tqdm(dataloader):
                inputs = inputs.to(CFG.device)
                labels = labels.to(CFG.device)

                # 零参数梯度
                optimizer.zero_grad()

                # 前向
                # track history if only in train
                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs)
                    _, preds = torch.max(outputs.logits, 1)
                    loss = criterion(outputs.logits, labels)
                    # 后向+仅在训练阶段进行优化
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()
                # print(f'labels:{labels}')
                # print(f'preds:{preds}')
                # 统计f1_score
                y_true.extend(labels.cpu().numpy().tolist())
                y_pred.extend(preds.cpu().numpy().tolist())
                # 统计loss
                running_loss += loss.item() * inputs.size(0)
            # 分别计算5个label的f1_score,然后加权平均[0.1, 0.3, 0.2, 0.1, 0.3]
            score = f1_score(y_true, y_pred, average=None)
            score = np.average(score, weights=[0.1, 0.3, 0.2, 0.1, 0.3])
            # 计算loss
            epoch_loss = running_loss / dataset_sizes[phase]
            print('{} Loss: {:.4f} F1-score: {:.4f}'.format(
                phase, epoch_loss, score))
    # 保存模型
    torch.save(model.state_dict(), f'./model/vit_model_{fold}.pth')
    print(f'./model/vit_model_{fold}.pth saved')



fold:0
--------------------------------
Epoch 1/25
----------


100%|██████████| 12/12 [00:04<00:00,  2.89it/s]


train Loss: 2174309.4891 F1-score: 0.0441


100%|██████████| 3/3 [00:00<00:00,  4.82it/s]


valid Loss: 2108420.7717 F1-score: 0.0467
Epoch 2/25
----------


100%|██████████| 12/12 [00:02<00:00,  4.62it/s]


train Loss: 2174308.4891 F1-score: 0.0441


100%|██████████| 3/3 [00:00<00:00,  4.85it/s]


valid Loss: 2108419.7283 F1-score: 0.0467
Epoch 3/25
----------


100%|██████████| 12/12 [00:02<00:00,  4.56it/s]


train Loss: 2174307.3804 F1-score: 0.0441


100%|██████████| 3/3 [00:00<00:00,  4.95it/s]


valid Loss: 2108418.8152 F1-score: 0.0467
Epoch 4/25
----------


100%|██████████| 12/12 [00:02<00:00,  4.53it/s]


train Loss: 2174306.2283 F1-score: 0.0441


100%|██████████| 3/3 [00:00<00:00,  4.74it/s]


valid Loss: 2108417.5326 F1-score: 0.0467
Epoch 5/25
----------


100%|██████████| 12/12 [00:02<00:00,  4.56it/s]


train Loss: 2174305.0217 F1-score: 0.0441


100%|██████████| 3/3 [00:00<00:00,  4.79it/s]


valid Loss: 2108416.2826 F1-score: 0.0467
Epoch 6/25
----------


100%|██████████| 12/12 [00:02<00:00,  4.56it/s]


train Loss: 2174303.8043 F1-score: 0.0441


100%|██████████| 3/3 [00:00<00:00,  4.89it/s]


valid Loss: 2108415.2826 F1-score: 0.0467
Epoch 7/25
----------


100%|██████████| 12/12 [00:02<00:00,  4.60it/s]


train Loss: 2174302.6196 F1-score: 0.0441


100%|██████████| 3/3 [00:00<00:00,  4.82it/s]


valid Loss: 2108414.0326 F1-score: 0.0467
Epoch 8/25
----------


100%|██████████| 12/12 [00:02<00:00,  4.56it/s]


train Loss: 2174301.3587 F1-score: 0.0441


100%|██████████| 3/3 [00:00<00:00,  4.87it/s]


valid Loss: 2108412.6630 F1-score: 0.0467
Epoch 9/25
----------


100%|██████████| 12/12 [00:02<00:00,  4.55it/s]


train Loss: 2174300.1848 F1-score: 0.0441


100%|██████████| 3/3 [00:00<00:00,  4.78it/s]


valid Loss: 2108411.5435 F1-score: 0.0467
Epoch 10/25
----------


100%|██████████| 12/12 [00:02<00:00,  4.54it/s]


train Loss: 2174298.8587 F1-score: 0.0441


100%|██████████| 3/3 [00:00<00:00,  4.86it/s]


valid Loss: 2108410.5435 F1-score: 0.0467
Epoch 11/25
----------


100%|██████████| 12/12 [00:02<00:00,  4.57it/s]


train Loss: 2174297.6630 F1-score: 0.0441


100%|██████████| 3/3 [00:00<00:00,  4.81it/s]


valid Loss: 2108409.3804 F1-score: 0.0467
Epoch 12/25
----------


100%|██████████| 12/12 [00:02<00:00,  4.57it/s]


train Loss: 2174296.3696 F1-score: 0.0441


100%|██████████| 3/3 [00:00<00:00,  4.80it/s]


valid Loss: 2108407.9674 F1-score: 0.0467
Epoch 13/25
----------


100%|██████████| 12/12 [00:02<00:00,  4.54it/s]


train Loss: 2174295.1957 F1-score: 0.0441


100%|██████████| 3/3 [00:00<00:00,  4.63it/s]


valid Loss: 2108406.7609 F1-score: 0.0467
Epoch 14/25
----------


100%|██████████| 12/12 [00:02<00:00,  4.59it/s]


train Loss: 2174293.9130 F1-score: 0.0441


100%|██████████| 3/3 [00:00<00:00,  4.90it/s]


valid Loss: 2108405.5543 F1-score: 0.0467
Epoch 15/25
----------


100%|██████████| 12/12 [00:02<00:00,  4.58it/s]


train Loss: 2174292.6304 F1-score: 0.0441


100%|██████████| 3/3 [00:00<00:00,  4.94it/s]


valid Loss: 2108404.1848 F1-score: 0.0467
Epoch 16/25
----------


100%|██████████| 12/12 [00:02<00:00,  4.54it/s]


train Loss: 2174291.3370 F1-score: 0.0441


100%|██████████| 3/3 [00:00<00:00,  4.92it/s]


valid Loss: 2108403.0543 F1-score: 0.0467
Epoch 17/25
----------


100%|██████████| 12/12 [00:02<00:00,  4.58it/s]


train Loss: 2174290.0435 F1-score: 0.0441


100%|██████████| 3/3 [00:00<00:00,  4.99it/s]


valid Loss: 2108401.8043 F1-score: 0.0467
Epoch 18/25
----------


100%|██████████| 12/12 [00:02<00:00,  4.58it/s]


train Loss: 2174288.7283 F1-score: 0.0441


100%|██████████| 3/3 [00:00<00:00,  4.91it/s]


valid Loss: 2108400.6413 F1-score: 0.0467
Epoch 19/25
----------


100%|██████████| 12/12 [00:02<00:00,  4.57it/s]


train Loss: 2174287.4239 F1-score: 0.0441


100%|██████████| 3/3 [00:00<00:00,  4.72it/s]


valid Loss: 2108399.4783 F1-score: 0.0467
Epoch 20/25
----------


100%|██████████| 12/12 [00:02<00:00,  4.56it/s]


train Loss: 2174286.0870 F1-score: 0.0441


100%|██████████| 3/3 [00:00<00:00,  4.84it/s]


valid Loss: 2108398.1522 F1-score: 0.0467
Epoch 21/25
----------


100%|██████████| 12/12 [00:02<00:00,  4.50it/s]


train Loss: 2174284.8370 F1-score: 0.0441


100%|██████████| 3/3 [00:00<00:00,  4.55it/s]


valid Loss: 2108396.8152 F1-score: 0.0467
Epoch 22/25
----------


100%|██████████| 12/12 [00:02<00:00,  4.55it/s]


train Loss: 2174283.3696 F1-score: 0.0441


100%|██████████| 3/3 [00:00<00:00,  4.89it/s]


valid Loss: 2108395.4457 F1-score: 0.0467
Epoch 23/25
----------


100%|██████████| 12/12 [00:02<00:00,  4.51it/s]


train Loss: 2174282.1630 F1-score: 0.0441


100%|██████████| 3/3 [00:00<00:00,  4.73it/s]


valid Loss: 2108394.0761 F1-score: 0.0467
Epoch 24/25
----------


100%|██████████| 12/12 [00:02<00:00,  4.57it/s]


train Loss: 2174280.7500 F1-score: 0.0441


100%|██████████| 3/3 [00:00<00:00,  5.00it/s]


valid Loss: 2108392.7500 F1-score: 0.0467
Epoch 25/25
----------


100%|██████████| 12/12 [00:02<00:00,  4.53it/s]


train Loss: 2174279.4130 F1-score: 0.0441


100%|██████████| 3/3 [00:00<00:00,  4.93it/s]


valid Loss: 2108391.4565 F1-score: 0.0467
./model/vit_model_0.pth saved
fold:1
--------------------------------
Epoch 1/25
----------


100%|██████████| 12/12 [00:02<00:00,  4.52it/s]


train Loss: 2174277.9674 F1-score: 0.0441


100%|██████████| 3/3 [00:00<00:00,  4.78it/s]


valid Loss: 2108390.2717 F1-score: 0.0467
Epoch 2/25
----------


100%|██████████| 12/12 [00:02<00:00,  4.50it/s]


train Loss: 2174276.7065 F1-score: 0.0441


100%|██████████| 3/3 [00:00<00:00,  5.00it/s]


valid Loss: 2108388.9239 F1-score: 0.0467
Epoch 3/25
----------


100%|██████████| 12/12 [00:02<00:00,  4.57it/s]


train Loss: 2174275.3804 F1-score: 0.0441


100%|██████████| 3/3 [00:00<00:00,  4.68it/s]


valid Loss: 2108387.8152 F1-score: 0.0467
Epoch 4/25
----------


100%|██████████| 12/12 [00:02<00:00,  4.51it/s]


train Loss: 2174274.1196 F1-score: 0.0441


100%|██████████| 3/3 [00:00<00:00,  4.90it/s]


valid Loss: 2108386.4239 F1-score: 0.0467
Epoch 5/25
----------


100%|██████████| 12/12 [00:02<00:00,  4.54it/s]


train Loss: 2174272.8370 F1-score: 0.0441


100%|██████████| 3/3 [00:00<00:00,  4.86it/s]


valid Loss: 2108385.2391 F1-score: 0.0467
Epoch 6/25
----------


100%|██████████| 12/12 [00:02<00:00,  4.55it/s]


train Loss: 2174271.6413 F1-score: 0.0441


100%|██████████| 3/3 [00:00<00:00,  4.82it/s]


valid Loss: 2108383.9674 F1-score: 0.0467
Epoch 7/25
----------


100%|██████████| 12/12 [00:02<00:00,  4.59it/s]


train Loss: 2174270.2337 F1-score: 0.0441


100%|██████████| 3/3 [00:00<00:00,  4.85it/s]


valid Loss: 2108382.8152 F1-score: 0.0467
Epoch 8/25
----------


100%|██████████| 12/12 [00:02<00:00,  4.52it/s]


train Loss: 2174268.9565 F1-score: 0.0441


100%|██████████| 3/3 [00:00<00:00,  4.87it/s]


valid Loss: 2108381.6304 F1-score: 0.0467
Epoch 9/25
----------


100%|██████████| 12/12 [00:02<00:00,  4.55it/s]


train Loss: 2174267.6196 F1-score: 0.0441


100%|██████████| 3/3 [00:00<00:00,  4.71it/s]


valid Loss: 2108380.2826 F1-score: 0.0467
Epoch 10/25
----------


100%|██████████| 12/12 [00:02<00:00,  4.51it/s]


train Loss: 2174266.2935 F1-score: 0.0441


100%|██████████| 3/3 [00:00<00:00,  4.60it/s]


valid Loss: 2108378.8913 F1-score: 0.0467
Epoch 11/25
----------


100%|██████████| 12/12 [00:02<00:00,  4.55it/s]


train Loss: 2174265.0326 F1-score: 0.0441


100%|██████████| 3/3 [00:00<00:00,  4.84it/s]


valid Loss: 2108377.5761 F1-score: 0.0467
Epoch 12/25
----------


100%|██████████| 12/12 [00:02<00:00,  4.55it/s]


train Loss: 2174263.6196 F1-score: 0.0441


100%|██████████| 3/3 [00:00<00:00,  4.10it/s]


valid Loss: 2108376.2283 F1-score: 0.0467
Epoch 13/25
----------


100%|██████████| 12/12 [00:02<00:00,  4.56it/s]


train Loss: 2174262.2283 F1-score: 0.0441


100%|██████████| 3/3 [00:00<00:00,  4.84it/s]


valid Loss: 2108374.8804 F1-score: 0.0467
Epoch 14/25
----------


100%|██████████| 12/12 [00:02<00:00,  4.56it/s]


train Loss: 2174260.8587 F1-score: 0.0441


100%|██████████| 3/3 [00:00<00:00,  4.95it/s]


valid Loss: 2108373.6522 F1-score: 0.0467
Epoch 15/25
----------


100%|██████████| 12/12 [00:02<00:00,  4.52it/s]


train Loss: 2174259.4130 F1-score: 0.0441


100%|██████████| 3/3 [00:00<00:00,  4.21it/s]


valid Loss: 2108372.3370 F1-score: 0.0467
Epoch 16/25
----------


100%|██████████| 12/12 [00:02<00:00,  4.56it/s]


train Loss: 2174258.1087 F1-score: 0.0441


100%|██████████| 3/3 [00:00<00:00,  4.93it/s]


valid Loss: 2108370.9457 F1-score: 0.0467
Epoch 17/25
----------


100%|██████████| 12/12 [00:02<00:00,  4.53it/s]


train Loss: 2174256.6196 F1-score: 0.0441


100%|██████████| 3/3 [00:00<00:00,  4.57it/s]


valid Loss: 2108369.5543 F1-score: 0.0467
Epoch 18/25
----------


100%|██████████| 12/12 [00:03<00:00,  3.27it/s]


train Loss: 2174255.2228 F1-score: 0.0441


100%|██████████| 3/3 [00:00<00:00,  4.76it/s]


valid Loss: 2108368.1630 F1-score: 0.0467
Epoch 19/25
----------


100%|██████████| 12/12 [00:02<00:00,  4.52it/s]


train Loss: 2174253.8043 F1-score: 0.0441


100%|██████████| 3/3 [00:00<00:00,  4.71it/s]


valid Loss: 2108366.7717 F1-score: 0.0467
Epoch 20/25
----------


100%|██████████| 12/12 [00:02<00:00,  4.56it/s]


train Loss: 2174252.3478 F1-score: 0.0441


100%|██████████| 3/3 [00:00<00:00,  4.90it/s]


valid Loss: 2108365.5000 F1-score: 0.0467
Epoch 21/25
----------


100%|██████████| 12/12 [00:02<00:00,  4.54it/s]


train Loss: 2174250.9592 F1-score: 0.0441


100%|██████████| 3/3 [00:00<00:00,  4.87it/s]


valid Loss: 2108363.9891 F1-score: 0.0467
Epoch 22/25
----------


100%|██████████| 12/12 [00:02<00:00,  4.57it/s]


train Loss: 2174249.4891 F1-score: 0.0441


100%|██████████| 3/3 [00:00<00:00,  4.65it/s]


valid Loss: 2108362.6413 F1-score: 0.0467
Epoch 23/25
----------


100%|██████████| 12/12 [00:02<00:00,  4.59it/s]


train Loss: 2174248.0000 F1-score: 0.0441


100%|██████████| 3/3 [00:00<00:00,  4.81it/s]


valid Loss: 2108361.3261 F1-score: 0.0467
Epoch 24/25
----------


100%|██████████| 12/12 [00:02<00:00,  4.60it/s]


train Loss: 2174246.5109 F1-score: 0.0441


100%|██████████| 3/3 [00:00<00:00,  4.71it/s]


valid Loss: 2108359.7717 F1-score: 0.0467
Epoch 25/25
----------


100%|██████████| 12/12 [00:02<00:00,  4.58it/s]


train Loss: 2174245.0543 F1-score: 0.0441


100%|██████████| 3/3 [00:00<00:00,  4.76it/s]


valid Loss: 2108358.1739 F1-score: 0.0467
./model/vit_model_1.pth saved
fold:2
--------------------------------
Epoch 1/25
----------


100%|██████████| 12/12 [00:02<00:00,  4.49it/s]


train Loss: 2141300.5435 F1-score: 0.0454


100%|██████████| 3/3 [00:00<00:00,  4.78it/s]


valid Loss: 2240129.1087 F1-score: 0.0414
Epoch 2/25
----------


100%|██████████| 12/12 [00:02<00:00,  4.57it/s]


train Loss: 2141299.0543 F1-score: 0.0454


100%|██████████| 3/3 [00:00<00:00,  4.72it/s]


valid Loss: 2240127.6739 F1-score: 0.0414
Epoch 3/25
----------


100%|██████████| 12/12 [00:02<00:00,  4.58it/s]


train Loss: 2141297.6957 F1-score: 0.0454


100%|██████████| 3/3 [00:00<00:00,  4.77it/s]


valid Loss: 2240126.1522 F1-score: 0.0414
Epoch 4/25
----------


100%|██████████| 12/12 [00:02<00:00,  4.53it/s]


train Loss: 2141296.2500 F1-score: 0.0454


100%|██████████| 3/3 [00:00<00:00,  4.81it/s]


valid Loss: 2240124.5870 F1-score: 0.0414
Epoch 5/25
----------


100%|██████████| 12/12 [00:02<00:00,  4.51it/s]


train Loss: 2141294.7826 F1-score: 0.0454


100%|██████████| 3/3 [00:00<00:00,  4.75it/s]


valid Loss: 2240123.0217 F1-score: 0.0414
Epoch 6/25
----------


100%|██████████| 12/12 [00:02<00:00,  4.58it/s]


train Loss: 2141293.3913 F1-score: 0.0454


100%|██████████| 3/3 [00:00<00:00,  4.79it/s]


valid Loss: 2240121.5435 F1-score: 0.0414
Epoch 7/25
----------


100%|██████████| 12/12 [00:02<00:00,  4.57it/s]


train Loss: 2141291.8261 F1-score: 0.0454


100%|██████████| 3/3 [00:00<00:00,  4.58it/s]


valid Loss: 2240119.9457 F1-score: 0.0414
Epoch 8/25
----------


100%|██████████| 12/12 [00:02<00:00,  4.56it/s]


train Loss: 2141290.4457 F1-score: 0.0454


100%|██████████| 3/3 [00:00<00:00,  4.78it/s]


valid Loss: 2240118.5109 F1-score: 0.0414
Epoch 9/25
----------


100%|██████████| 12/12 [00:02<00:00,  4.56it/s]


train Loss: 2141288.9239 F1-score: 0.0454


100%|██████████| 3/3 [00:00<00:00,  4.77it/s]


valid Loss: 2240116.9022 F1-score: 0.0414
Epoch 10/25
----------


100%|██████████| 12/12 [00:02<00:00,  4.58it/s]


train Loss: 2141287.4946 F1-score: 0.0454


100%|██████████| 3/3 [00:00<00:00,  4.79it/s]


valid Loss: 2240115.4239 F1-score: 0.0414
Epoch 11/25
----------


100%|██████████| 12/12 [00:02<00:00,  4.58it/s]


train Loss: 2141285.9973 F1-score: 0.0454


100%|██████████| 3/3 [00:00<00:00,  4.80it/s]


valid Loss: 2240113.8696 F1-score: 0.0414
Epoch 12/25
----------


100%|██████████| 12/12 [00:02<00:00,  4.43it/s]


train Loss: 2141284.5978 F1-score: 0.0454


100%|██████████| 3/3 [00:00<00:00,  4.86it/s]


valid Loss: 2240112.3913 F1-score: 0.0414
Epoch 13/25
----------


100%|██████████| 12/12 [00:02<00:00,  4.57it/s]


train Loss: 2141283.0761 F1-score: 0.0454


100%|██████████| 3/3 [00:00<00:00,  4.93it/s]


valid Loss: 2240110.8696 F1-score: 0.0414
Epoch 14/25
----------


100%|██████████| 12/12 [00:02<00:00,  4.56it/s]


train Loss: 2141281.6087 F1-score: 0.0454


100%|██████████| 3/3 [00:00<00:00,  4.72it/s]


valid Loss: 2240109.0652 F1-score: 0.0414
Epoch 15/25
----------


100%|██████████| 12/12 [00:02<00:00,  4.56it/s]


train Loss: 2141280.0435 F1-score: 0.0454


100%|██████████| 3/3 [00:00<00:00,  5.00it/s]


valid Loss: 2240107.6739 F1-score: 0.0414
Epoch 16/25
----------


100%|██████████| 12/12 [00:02<00:00,  4.54it/s]


train Loss: 2141278.5598 F1-score: 0.0454


100%|██████████| 3/3 [00:00<00:00,  4.88it/s]


valid Loss: 2240106.0217 F1-score: 0.0414
Epoch 17/25
----------


100%|██████████| 12/12 [00:02<00:00,  4.56it/s]


train Loss: 2141277.0543 F1-score: 0.0454


100%|██████████| 3/3 [00:00<00:00,  4.80it/s]


valid Loss: 2240104.3370 F1-score: 0.0414
Epoch 18/25
----------


100%|██████████| 12/12 [00:02<00:00,  4.56it/s]


train Loss: 2141275.4239 F1-score: 0.0454


100%|██████████| 3/3 [00:00<00:00,  4.76it/s]


valid Loss: 2240102.8152 F1-score: 0.0414
Epoch 19/25
----------


100%|██████████| 12/12 [00:02<00:00,  4.56it/s]


train Loss: 2141273.9239 F1-score: 0.0454


100%|██████████| 3/3 [00:00<00:00,  4.72it/s]


valid Loss: 2240101.1848 F1-score: 0.0414
Epoch 20/25
----------


100%|██████████| 12/12 [00:02<00:00,  4.56it/s]


train Loss: 2141272.3478 F1-score: 0.0454


100%|██████████| 3/3 [00:00<00:00,  4.66it/s]


valid Loss: 2240099.5435 F1-score: 0.0414
Epoch 21/25
----------


100%|██████████| 12/12 [00:02<00:00,  4.41it/s]


train Loss: 2141270.7826 F1-score: 0.0454


100%|██████████| 3/3 [00:00<00:00,  4.93it/s]


valid Loss: 2240097.8152 F1-score: 0.0414
Epoch 22/25
----------


100%|██████████| 12/12 [00:02<00:00,  4.55it/s]


train Loss: 2141269.2174 F1-score: 0.0454


100%|██████████| 3/3 [00:00<00:00,  4.68it/s]


valid Loss: 2240096.1630 F1-score: 0.0414
Epoch 23/25
----------


100%|██████████| 12/12 [00:02<00:00,  4.55it/s]


train Loss: 2141267.5870 F1-score: 0.0454


100%|██████████| 3/3 [00:00<00:00,  4.98it/s]


valid Loss: 2240094.4022 F1-score: 0.0414
Epoch 24/25
----------


100%|██████████| 12/12 [00:02<00:00,  4.50it/s]


train Loss: 2141266.0761 F1-score: 0.0454


100%|██████████| 3/3 [00:00<00:00,  4.56it/s]


valid Loss: 2240092.8804 F1-score: 0.0414
Epoch 25/25
----------


100%|██████████| 12/12 [00:02<00:00,  4.50it/s]


train Loss: 2141264.3804 F1-score: 0.0454


100%|██████████| 3/3 [00:00<00:00,  4.76it/s]


valid Loss: 2240091.2391 F1-score: 0.0414
./model/vit_model_2.pth saved
fold:3
--------------------------------
Epoch 1/25
----------


100%|██████████| 12/12 [00:02<00:00,  4.57it/s]


train Loss: 2141263.0652 F1-score: 0.0454


100%|██████████| 3/3 [00:00<00:00,  4.64it/s]


valid Loss: 2240089.2391 F1-score: 0.0414
Epoch 2/25
----------


100%|██████████| 12/12 [00:02<00:00,  4.54it/s]


train Loss: 2141261.4565 F1-score: 0.0454


100%|██████████| 3/3 [00:00<00:00,  4.97it/s]


valid Loss: 2240087.5978 F1-score: 0.0414
Epoch 3/25
----------


100%|██████████| 12/12 [00:02<00:00,  4.53it/s]


train Loss: 2141259.7989 F1-score: 0.0454


100%|██████████| 3/3 [00:00<00:00,  4.77it/s]


valid Loss: 2240085.6739 F1-score: 0.0414
Epoch 4/25
----------


100%|██████████| 12/12 [00:02<00:00,  4.58it/s]


train Loss: 2141258.2717 F1-score: 0.0454


100%|██████████| 3/3 [00:00<00:00,  4.74it/s]


valid Loss: 2240084.1196 F1-score: 0.0414
Epoch 5/25
----------


100%|██████████| 12/12 [00:02<00:00,  4.58it/s]


train Loss: 2141256.7065 F1-score: 0.0454


100%|██████████| 3/3 [00:00<00:00,  4.85it/s]


valid Loss: 2240082.5543 F1-score: 0.0414
Epoch 6/25
----------


100%|██████████| 12/12 [00:02<00:00,  4.57it/s]


train Loss: 2141255.1848 F1-score: 0.0454


100%|██████████| 3/3 [00:00<00:00,  4.86it/s]


valid Loss: 2240080.7065 F1-score: 0.0414
Epoch 7/25
----------


100%|██████████| 12/12 [00:02<00:00,  4.52it/s]


train Loss: 2141253.5870 F1-score: 0.0454


100%|██████████| 3/3 [00:00<00:00,  4.81it/s]


valid Loss: 2240079.3043 F1-score: 0.0414
Epoch 8/25
----------


100%|██████████| 12/12 [00:02<00:00,  4.59it/s]


train Loss: 2141251.9783 F1-score: 0.0454


100%|██████████| 3/3 [00:00<00:00,  4.77it/s]


valid Loss: 2240077.7500 F1-score: 0.0414
Epoch 9/25
----------


100%|██████████| 12/12 [00:02<00:00,  4.55it/s]


train Loss: 2141250.3478 F1-score: 0.0454


100%|██████████| 3/3 [00:00<00:00,  4.81it/s]


valid Loss: 2240075.9348 F1-score: 0.0414
Epoch 10/25
----------


100%|██████████| 12/12 [00:02<00:00,  4.56it/s]


train Loss: 2141248.7500 F1-score: 0.0454


100%|██████████| 3/3 [00:00<00:00,  4.92it/s]


valid Loss: 2240074.4565 F1-score: 0.0414
Epoch 11/25
----------


100%|██████████| 12/12 [00:02<00:00,  4.56it/s]


train Loss: 2141247.1087 F1-score: 0.0454


100%|██████████| 3/3 [00:00<00:00,  4.98it/s]


valid Loss: 2240072.5326 F1-score: 0.0414
Epoch 12/25
----------


100%|██████████| 12/12 [00:02<00:00,  4.55it/s]


train Loss: 2141245.5054 F1-score: 0.0454


100%|██████████| 3/3 [00:00<00:00,  4.54it/s]


valid Loss: 2240070.6957 F1-score: 0.0414
Epoch 13/25
----------


100%|██████████| 12/12 [00:02<00:00,  4.58it/s]


train Loss: 2141243.8804 F1-score: 0.0454


100%|██████████| 3/3 [00:00<00:00,  4.77it/s]


valid Loss: 2240068.8804 F1-score: 0.0414
Epoch 14/25
----------


100%|██████████| 12/12 [00:02<00:00,  4.50it/s]


train Loss: 2141242.2391 F1-score: 0.0454


100%|██████████| 3/3 [00:00<00:00,  4.83it/s]


valid Loss: 2240067.2717 F1-score: 0.0414
Epoch 15/25
----------


100%|██████████| 12/12 [00:02<00:00,  4.57it/s]


train Loss: 2141240.5109 F1-score: 0.0454


100%|██████████| 3/3 [00:00<00:00,  4.69it/s]


valid Loss: 2240065.5870 F1-score: 0.0414
Epoch 16/25
----------


100%|██████████| 12/12 [00:02<00:00,  4.55it/s]


train Loss: 2141238.8696 F1-score: 0.0454


100%|██████████| 3/3 [00:00<00:00,  4.66it/s]


valid Loss: 2240063.9891 F1-score: 0.0414
Epoch 17/25
----------


100%|██████████| 12/12 [00:02<00:00,  4.50it/s]


train Loss: 2141237.2446 F1-score: 0.0454


100%|██████████| 3/3 [00:00<00:00,  4.96it/s]


valid Loss: 2240062.0217 F1-score: 0.0414
Epoch 18/25
----------


100%|██████████| 12/12 [00:02<00:00,  4.56it/s]


train Loss: 2141235.5217 F1-score: 0.0454


100%|██████████| 3/3 [00:00<00:00,  4.82it/s]


valid Loss: 2240060.1739 F1-score: 0.0414
Epoch 19/25
----------


100%|██████████| 12/12 [00:02<00:00,  4.56it/s]


train Loss: 2141233.8478 F1-score: 0.0454


100%|██████████| 3/3 [00:00<00:00,  4.94it/s]


valid Loss: 2240058.5217 F1-score: 0.0414
Epoch 20/25
----------


100%|██████████| 12/12 [00:02<00:00,  4.49it/s]


train Loss: 2141232.1087 F1-score: 0.0454


100%|██████████| 3/3 [00:00<00:00,  4.71it/s]


valid Loss: 2240056.7283 F1-score: 0.0414
Epoch 21/25
----------


100%|██████████| 12/12 [00:02<00:00,  4.55it/s]


train Loss: 2141230.4402 F1-score: 0.0454


100%|██████████| 3/3 [00:00<00:00,  4.80it/s]


valid Loss: 2240054.9239 F1-score: 0.0414
Epoch 22/25
----------


100%|██████████| 12/12 [00:02<00:00,  4.57it/s]


train Loss: 2141228.7283 F1-score: 0.0454


100%|██████████| 3/3 [00:00<00:00,  4.37it/s]


valid Loss: 2240053.0326 F1-score: 0.0414
Epoch 23/25
----------


100%|██████████| 12/12 [00:02<00:00,  4.54it/s]


train Loss: 2141227.0109 F1-score: 0.0454


100%|██████████| 3/3 [00:00<00:00,  4.72it/s]


valid Loss: 2240051.3043 F1-score: 0.0414
Epoch 24/25
----------


100%|██████████| 12/12 [00:02<00:00,  4.57it/s]


train Loss: 2141225.2717 F1-score: 0.0454


100%|██████████| 3/3 [00:00<00:00,  4.85it/s]


valid Loss: 2240049.6196 F1-score: 0.0414
Epoch 25/25
----------


100%|██████████| 12/12 [00:02<00:00,  4.55it/s]


train Loss: 2141223.5326 F1-score: 0.0454


100%|██████████| 3/3 [00:00<00:00,  4.63it/s]


valid Loss: 2240047.6848 F1-score: 0.0414
./model/vit_model_3.pth saved
fold:4
--------------------------------
Epoch 1/25
----------


100%|██████████| 12/12 [00:02<00:00,  4.53it/s]


train Loss: 2174163.8043 F1-score: 0.0441


100%|██████████| 3/3 [00:00<00:00,  4.85it/s]


valid Loss: 2108278.0652 F1-score: 0.0467
Epoch 2/25
----------


100%|██████████| 12/12 [00:02<00:00,  4.55it/s]


train Loss: 2174162.0870 F1-score: 0.0441


100%|██████████| 3/3 [00:00<00:00,  4.84it/s]


valid Loss: 2108276.5435 F1-score: 0.0467
Epoch 3/25
----------


100%|██████████| 12/12 [00:02<00:00,  4.50it/s]


train Loss: 2174160.3587 F1-score: 0.0441


100%|██████████| 3/3 [00:00<00:00,  4.76it/s]


valid Loss: 2108274.8043 F1-score: 0.0467
Epoch 4/25
----------


100%|██████████| 12/12 [00:02<00:00,  4.57it/s]


train Loss: 2174158.5435 F1-score: 0.0441


100%|██████████| 3/3 [00:00<00:00,  4.92it/s]


valid Loss: 2108273.0326 F1-score: 0.0467
Epoch 5/25
----------


100%|██████████| 12/12 [00:02<00:00,  4.56it/s]


train Loss: 2174156.8804 F1-score: 0.0441


100%|██████████| 3/3 [00:00<00:00,  4.77it/s]


valid Loss: 2108271.2717 F1-score: 0.0467
Epoch 6/25
----------


100%|██████████| 12/12 [00:02<00:00,  4.58it/s]


train Loss: 2174155.0978 F1-score: 0.0441


100%|██████████| 3/3 [00:00<00:00,  4.78it/s]


valid Loss: 2108269.6196 F1-score: 0.0467
Epoch 7/25
----------


100%|██████████| 12/12 [00:02<00:00,  4.55it/s]


train Loss: 2174153.4457 F1-score: 0.0441


100%|██████████| 3/3 [00:00<00:00,  4.53it/s]


valid Loss: 2108267.8804 F1-score: 0.0467
Epoch 8/25
----------


100%|██████████| 12/12 [00:02<00:00,  4.57it/s]


train Loss: 2174151.5761 F1-score: 0.0441


100%|██████████| 3/3 [00:00<00:00,  4.73it/s]


valid Loss: 2108266.1087 F1-score: 0.0467
Epoch 9/25
----------


100%|██████████| 12/12 [00:02<00:00,  4.49it/s]


train Loss: 2174149.8804 F1-score: 0.0441


100%|██████████| 3/3 [00:00<00:00,  4.95it/s]


valid Loss: 2108264.4239 F1-score: 0.0467
Epoch 10/25
----------


100%|██████████| 12/12 [00:02<00:00,  4.57it/s]


train Loss: 2174148.0489 F1-score: 0.0441


100%|██████████| 3/3 [00:00<00:00,  4.63it/s]


valid Loss: 2108262.6848 F1-score: 0.0467
Epoch 11/25
----------


100%|██████████| 12/12 [00:02<00:00,  4.59it/s]


train Loss: 2174146.2391 F1-score: 0.0441


100%|██████████| 3/3 [00:00<00:00,  4.82it/s]


valid Loss: 2108261.1087 F1-score: 0.0467
Epoch 12/25
----------


100%|██████████| 12/12 [00:02<00:00,  4.58it/s]


train Loss: 2174144.4783 F1-score: 0.0441


100%|██████████| 3/3 [00:00<00:00,  4.90it/s]


valid Loss: 2108259.1739 F1-score: 0.0467
Epoch 13/25
----------


100%|██████████| 12/12 [00:02<00:00,  4.55it/s]


train Loss: 2174142.7228 F1-score: 0.0441


100%|██████████| 3/3 [00:00<00:00,  4.85it/s]


valid Loss: 2108257.4783 F1-score: 0.0467
Epoch 14/25
----------


100%|██████████| 12/12 [00:02<00:00,  4.58it/s]


train Loss: 2174140.8750 F1-score: 0.0441


100%|██████████| 3/3 [00:00<00:00,  4.87it/s]


valid Loss: 2108255.7065 F1-score: 0.0467
Epoch 15/25
----------


100%|██████████| 12/12 [00:02<00:00,  4.53it/s]


train Loss: 2174139.0761 F1-score: 0.0441


100%|██████████| 3/3 [00:00<00:00,  4.81it/s]


valid Loss: 2108253.9348 F1-score: 0.0467
Epoch 16/25
----------


100%|██████████| 12/12 [00:02<00:00,  4.55it/s]


train Loss: 2174137.3288 F1-score: 0.0441


100%|██████████| 3/3 [00:00<00:00,  4.82it/s]


valid Loss: 2108252.1522 F1-score: 0.0467
Epoch 17/25
----------


100%|██████████| 12/12 [00:02<00:00,  4.54it/s]


train Loss: 2174135.4783 F1-score: 0.0441


100%|██████████| 3/3 [00:00<00:00,  4.51it/s]


valid Loss: 2108250.4674 F1-score: 0.0467
Epoch 18/25
----------


100%|██████████| 12/12 [00:02<00:00,  4.55it/s]


train Loss: 2174133.6087 F1-score: 0.0441


100%|██████████| 3/3 [00:00<00:00,  4.93it/s]


valid Loss: 2108248.4891 F1-score: 0.0467
Epoch 19/25
----------


100%|██████████| 12/12 [00:02<00:00,  4.55it/s]


train Loss: 2174131.8152 F1-score: 0.0441


100%|██████████| 3/3 [00:00<00:00,  4.92it/s]


valid Loss: 2108246.7609 F1-score: 0.0467
Epoch 20/25
----------


100%|██████████| 12/12 [00:02<00:00,  4.58it/s]


train Loss: 2174129.9022 F1-score: 0.0441


100%|██████████| 3/3 [00:00<00:00,  4.85it/s]


valid Loss: 2108244.9457 F1-score: 0.0467
Epoch 21/25
----------


100%|██████████| 12/12 [00:02<00:00,  4.61it/s]


train Loss: 2174128.0870 F1-score: 0.0441


100%|██████████| 3/3 [00:00<00:00,  4.75it/s]


valid Loss: 2108243.0000 F1-score: 0.0467
Epoch 22/25
----------


100%|██████████| 12/12 [00:02<00:00,  4.46it/s]


train Loss: 2174126.1087 F1-score: 0.0441


100%|██████████| 3/3 [00:00<00:00,  4.73it/s]


valid Loss: 2108241.1413 F1-score: 0.0467
Epoch 23/25
----------


100%|██████████| 12/12 [00:02<00:00,  4.60it/s]


train Loss: 2174124.2065 F1-score: 0.0441


100%|██████████| 3/3 [00:00<00:00,  4.81it/s]


valid Loss: 2108239.5326 F1-score: 0.0467
Epoch 24/25
----------


100%|██████████| 12/12 [00:02<00:00,  4.55it/s]


train Loss: 2174122.2500 F1-score: 0.0441


100%|██████████| 3/3 [00:00<00:00,  4.93it/s]


valid Loss: 2108237.5109 F1-score: 0.0467
Epoch 25/25
----------


100%|██████████| 12/12 [00:02<00:00,  4.54it/s]


train Loss: 2174120.4022 F1-score: 0.0441


100%|██████████| 3/3 [00:00<00:00,  4.72it/s]


valid Loss: 2108235.6196 F1-score: 0.0467
./model/vit_model_4.pth saved


In [9]:
model.eval()
predictions = []
file_name = []
for images,img_name in test_loader:
    with torch.no_grad():
        images = images.to(CFG.device)  
        outputs = model(images)
        _, predicted =torch.max(outputs.logits, 1)
        predictions.extend(predicted.cpu().numpy())
        file_name.extend(list(img_name))

RuntimeError: Caught RuntimeError in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "/ai/anaconda3/envs/kaggle/lib/python3.8/site-packages/torch/utils/data/_utils/worker.py", line 308, in _worker_loop
    data = fetcher.fetch(index)
  File "/ai/anaconda3/envs/kaggle/lib/python3.8/site-packages/torch/utils/data/_utils/fetch.py", line 54, in fetch
    return self.collate_fn(data)
  File "/ai/anaconda3/envs/kaggle/lib/python3.8/site-packages/torch/utils/data/_utils/collate.py", line 265, in default_collate
    return collate(batch, collate_fn_map=default_collate_fn_map)
  File "/ai/anaconda3/envs/kaggle/lib/python3.8/site-packages/torch/utils/data/_utils/collate.py", line 142, in collate
    return [collate(samples, collate_fn_map=collate_fn_map) for samples in transposed]  # Backwards compatibility.
  File "/ai/anaconda3/envs/kaggle/lib/python3.8/site-packages/torch/utils/data/_utils/collate.py", line 142, in <listcomp>
    return [collate(samples, collate_fn_map=collate_fn_map) for samples in transposed]  # Backwards compatibility.
  File "/ai/anaconda3/envs/kaggle/lib/python3.8/site-packages/torch/utils/data/_utils/collate.py", line 119, in collate
    return collate_fn_map[elem_type](batch, collate_fn_map=collate_fn_map)
  File "/ai/anaconda3/envs/kaggle/lib/python3.8/site-packages/torch/utils/data/_utils/collate.py", line 161, in collate_tensor_fn
    out = elem.new(storage).resize_(len(batch), *list(elem.size()))
RuntimeError: Trying to resize storage that is not resizable


In [None]:
csv_file_name = []
test_A_csv_file_name = os.listdir(test_A_csv_folder)
for file in file_name:
    for j in test_A_csv_file_name:
        if file[:-4] in j:
            csv_file_name.append(j)
            break

In [None]:
label = [CFG.id2label[i] for i in predictions]

In [None]:
results_df = pd.DataFrame({'defectType': label,'fileName': csv_file_name})

In [None]:
results_df.to_csv('submission.csv', index=False)

In [None]:
train_dataset[0][0].shape

torch.Size([1, 224, 224])