In [1]:
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
from models import VGG16

import os
from tqdm import tqdm

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
model = VGG16(num_classes=88, use_pretrain=False)
model_path = os.path.join("checkpoints", "best_model.pth")
model.load_state_dict(torch.load(model_path))

<All keys matched successfully>

In [3]:
# 데이터셋 경로
dataset_path = os.path.join("mvtec_anomaly_detection_imagefolder", "test")

# 이미지 변환 설정
transform = transforms.Compose([transforms.Resize((112, 112)),
                                transforms.ToTensor(),
                                # transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # ImageNet 데이터셋의 평균과 표준편차 사용
                                ])

# 데이터셋 로드
dataset = torchvision.datasets.ImageFolder(root=dataset_path, transform=transform)

# 데이터 로더 설정
batch_size = 16
test_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size)

# gpu or cpu
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model.to(device)

from tqdm import tqdm

# 클래스 별 맞춘 수를 저장할 리스트
correct = [0] * len(dataset.classes)
total = 0
predictions = []

with tqdm(total=len(test_loader), desc='Testing', unit='batch') as pbar:
    with torch.no_grad():
        for inputs, labels in test_loader:
            inputs = inputs.to(device)
            labels = labels.to(device)
            outputs = model(inputs)
            _, predicted = torch.max(outputs, 1)
            predictions.extend(predicted.cpu().numpy())
            total += labels.size(0)
            for i in range(len(labels)):
                if predicted[i] == labels[i]:
                    correct[labels[i]] += 1
            pbar.update(1)
            pbar.set_postfix({'Accuracy': sum(correct) / total})

Testing: 100%|██████████| 84/84 [00:43<00:00,  1.94batch/s, Accuracy=0.768]

Accuracy on test set: 0.7677
Accuracy for class bottle-broken_large: 0.6000
Accuracy for class bottle-broken_small: 0.8000
Accuracy for class bottle-contamination: 0.2000
Accuracy for class bottle-good: 1.0000
Accuracy for class cable-bent_wire: 0.5000
Accuracy for class cable-cable_swap: 0.2500
Accuracy for class cable-combined: 0.0000
Accuracy for class cable-cut_inner_insulation: 0.5000
Accuracy for class cable-cut_outer_insulation: 0.5000
Accuracy for class cable-good: 0.9571
Accuracy for class cable-missing_cable: 1.0000
Accuracy for class cable-missing_wire: 0.0000
Accuracy for class cable-poke_insulation: 0.0000
Accuracy for class capsule-crack: 0.0000
Accuracy for class capsule-faulty_imprint: 0.0000
Accuracy for class capsule-good: 0.9000
Accuracy for class capsule-poke: 0.0000
Accuracy for class capsule-scratch: 0.0000
Accuracy for class capsule-squeeze: 0.2000
Accuracy for class carpet-color: 0.0000
Accuracy for class carpet-cut: 0.0000
Accuracy for class carpet-good: 0.9740




In [5]:
# 각 클래스 별 정확도
class_accuracy = [0] * len(dataset.classes)
for i in range(len(correct)):
    class_accuracy[i] = correct[i] / dataset.targets.count(i)

# 전체 정확도 계산
accuracy = sum(correct) / total
print(f"Accuracy on test set: {accuracy:.4f}")

# 각 클래스 별 정확도
for i, acc in enumerate(class_accuracy):
    print(f"Accuracy for class {dataset.classes[i]}: {acc:.4f}")

Accuracy on test set: 0.7677
Accuracy for class bottle-broken_large: 0.6000
Accuracy for class bottle-broken_small: 0.8000
Accuracy for class bottle-contamination: 0.2000
Accuracy for class bottle-good: 1.0000
Accuracy for class cable-bent_wire: 0.5000
Accuracy for class cable-cable_swap: 0.2500
Accuracy for class cable-combined: 0.0000
Accuracy for class cable-cut_inner_insulation: 0.5000
Accuracy for class cable-cut_outer_insulation: 0.5000
Accuracy for class cable-good: 0.9571
Accuracy for class cable-missing_cable: 1.0000
Accuracy for class cable-missing_wire: 0.0000
Accuracy for class cable-poke_insulation: 0.0000
Accuracy for class capsule-crack: 0.0000
Accuracy for class capsule-faulty_imprint: 0.0000
Accuracy for class capsule-good: 0.9000
Accuracy for class capsule-poke: 0.0000
Accuracy for class capsule-scratch: 0.0000
Accuracy for class capsule-squeeze: 0.2000
Accuracy for class carpet-color: 0.0000
Accuracy for class carpet-cut: 0.0000
Accuracy for class carpet-good: 0.9740

In [4]:
import csv

csv_file_path = "class_accuracy.csv"

data = [["Label", "Correct", "Total_Sample", "Accuracy"]]
for i, label in enumerate(dataset.classes):
    correct_count = correct[i]
    total_sample_count = dataset.targets.count(i)
    accuracy = correct_count / total_sample_count if total_sample_count > 0 else 0
    data.append([label, correct_count, total_sample_count, accuracy])

with open(csv_file_path, mode='w', newline='') as file:
    writer = csv.writer(file)
    writer.writerows(data)

print(f"Class accuracy data saved to {csv_file_path}")


Class accuracy data saved to class_accuracy.csv
