In [1]:
from pathlib import Path
import numpy as np
import os, shutil
import matplotlib.pyplot as plt
import seaborn as sns

from PIL import Image

from tqdm.auto import tqdm

import torch
import torchvision
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
from torchvision.transforms import transforms
import torch.optim as optim

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
from torchvision.models.vision_transformer import vit_b_16, ViT_B_16_Weights

In [None]:
# Load ViT base model with pretrained weights
vit_model = vit_b_16(weights=ViT_B_16_Weights.IMAGENET1K_V1).cuda()
vit_model.eval()

# Freeze parameters
for param in vit_model.parameters():
    param.requires_grad = False

# Use the built-in preprocessing
transform = ViT_B_16_Weights.IMAGENET1K_V1.transforms()

# Storage
class_labels = []
y_true = []
vit_features = []

# Loop through test images (same as in your code)
for classes in ['color', 'good', 'cut', 'hole', 'metal_contamination', 'thread']:
    folder_path = base_path / "test" / classes

    for pth in tqdm(folder_path.iterdir(), leave=False):
        class_label = pth.parts[-2]

        with torch.no_grad():
            test_image = transform(Image.open(pth)).cuda().unsqueeze(0)  # [1, 3, 224, 224]

            # Extract CLS token
            x = vit_model._process_input(test_image)
            n = x.shape[0]
            cls_token = vit_model.class_token.expand(n, -1, -1)
            x = torch.cat((cls_token, x), dim=1)
            x = vit_model.encoder(x)
            cls_embedding = x[:, 0]  # [CLS] token → shape: [1, 768]

            vit_features.append(cls_embedding.squeeze().cpu().numpy())

        class_labels.append(class_label)
        y_true.append(0 if class_label == 'good' else 1)

vit_features = np.array(vit_features)  # shape: [N_images, 768]