In [1]:
import torch
import torchvision.transforms as T
import torchvision.models as models
import cv2
import numpy as np
import os
from sklearn.metrics.pairwise import cosine_similarity


In [2]:
# Load ViT-B/16 model from torchvision with ImageNet-1k weights
model = models.vit_l_16(weights=models.ViT_L_16_Weights.IMAGENET1K_SWAG_E2E_V1)
model.eval()

VisionTransformer(
  (conv_proj): Conv2d(3, 1024, kernel_size=(16, 16), stride=(16, 16))
  (encoder): Encoder(
    (dropout): Dropout(p=0.0, inplace=False)
    (layers): Sequential(
      (encoder_layer_0): EncoderBlock(
        (ln_1): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)
        (self_attention): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=1024, out_features=1024, bias=True)
        )
        (dropout): Dropout(p=0.0, inplace=False)
        (ln_2): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)
        (mlp): MLPBlock(
          (0): Linear(in_features=1024, out_features=4096, bias=True)
          (1): GELU(approximate='none')
          (2): Dropout(p=0.0, inplace=False)
          (3): Linear(in_features=4096, out_features=1024, bias=True)
          (4): Dropout(p=0.0, inplace=False)
        )
      )
      (encoder_layer_1): EncoderBlock(
        (ln_1): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)
       

In [3]:
# Preprocessing transform (use the model's recommended transforms)
# transform = models.ViT_B_16_Weights.IMAGENET1K_SWAG_E2E_V1.transforms()

transform = T.Compose([
    T.ToPILImage(),
    T.Resize((512, 512)),
    T.ToTensor(),
    T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

In [4]:
# def extract_features(img_path, model):
#     img = cv2.imread(img_path)
#     img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
#     img = transform(img).unsqueeze(0)
#     with torch.no_grad():
#         # Get the [CLS] token embedding (first token)
#         feats = model._process_input(img)
#         feats = model.encoder(feats)
#         cls_token = feats[:, 0, :].cpu().numpy().flatten()
#     return cls_token

# def extract_features(img_path, model):
#     img = cv2.imread(img_path)
#     img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
#     img = transform(img).unsqueeze(0)  # Add batch dimension
#     with torch.no_grad():
#         output = model(img)  # This is the CLS token representation
#     return output.squeeze().cpu().numpy()

def extract_features(img_path, model):
    img = cv2.imread(img_path)
    if img is None:
        print(f"Warning: Could not read image {img_path}")
        return None
    
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    img = transform(img).unsqueeze(0)
    
    with torch.no_grad():
        # Get intermediate features from encoder
        x = model._process_input(img)
        batch_size = x.shape[0]
        
        # Add class token
        cls_tokens = model.class_token.expand(batch_size, -1, -1)
        x = torch.cat((cls_tokens, x), dim=1)
        
        # Add position embedding
        x = x + model.encoder.pos_embedding
        
        # Pass through encoder
        x = model.encoder.dropout(x)
        for block in model.encoder.layers:
            x = block(x)
        x = model.encoder.ln(x)
        
        # Extract CLS token from the first position
        cls_token = x[:, 0]
        
        return cls_token.cpu().numpy().flatten()


In [5]:
# # Build reference features from "good" images
# good_folder = r'good'  # change to your folder
# good_feats = []
# for fname in os.listdir(good_folder):
#     if fname.lower().endswith(('.jpg', '.png', '.jpeg')):
#         fpath = os.path.join(good_folder, fname)
#         try:
#             f = extract_features(fpath, model)
#             if f is not None:
#                 good_feats.append(f)
#             else:
#                 print(f"Warning: Feature extraction failed for {fname}")
#         except Exception as e:
#             print(f"Error processing {fname}: {e}")
# if len(good_feats) == 0:
#     raise RuntimeError("No valid features extracted from good images.")
# good_feats = np.stack(good_feats)
# Build reference features from "good" images
good_folder = r'masks/good'  # change to your folder
good_feats = []
good_stats = {'mean': None, 'std': None}

# First pass: collect all valid features
for fname in os.listdir(good_folder):
    if fname.lower().endswith(('.jpg', '.png', '.jpeg')):
        fpath = os.path.join(good_folder, fname)
        try:
            features = extract_features(fpath, model)
            if features is not None:
                good_feats.append(features)
            else:
                print(f"Warning: Feature extraction failed for {fname}")
        except Exception as e:
            print(f"Error processing {fname}: {e}")

if len(good_feats) == 0:
    raise RuntimeError("No valid features extracted from good images.")

# Convert to numpy array and compute statistics
good_feats = np.stack(good_feats)
good_stats['mean'] = np.mean(good_feats, axis=0)
good_stats['std'] = np.std(good_feats, axis=0)

# Remove outliers (optional)
z_scores = np.abs((good_feats - good_stats['mean']) / good_stats['std'])
good_feats = good_feats[np.all(z_scores < 3, axis=1)]  # Keep features within 3 standard deviations

In [12]:
# Classify new/test image
def classify_image(test_img_path, good_feats, model, threshold=0.8):
    test_feat = extract_features(test_img_path, model)
    sims = cosine_similarity([test_feat], good_feats)
    max_sim = sims.max()
    if max_sim > threshold:
        return "good"
    else:
        return "defect"

In [15]:
result = classify_image(r'masks\test\33af61dd-Image__2025-05-02__10-36-24.jpg', good_feats, model,threshold=0.95)
print("Result:", result)

Result: defect


In [8]:
def classify_image(test_img_path, good_feats, good_stats, model, threshold=0.95):
    """
    Enhanced anomaly detection using multiple metrics
    """
    test_feat = extract_features(test_img_path, model)
    if test_feat is None:
        return "error", 0.0, {}
    
    # Compute multiple similarity metrics
    sims = cosine_similarity([test_feat], good_feats)
    max_sim = sims.max()
    mean_sim = sims.mean()
    
    # Compute statistical distance
    z_score = np.abs((test_feat - good_stats['mean']) / good_stats['std'])
    max_z_score = z_score.max()
    
    # Decision criteria
    metrics = {
        'max_similarity': max_sim,
        'mean_similarity': mean_sim,
        'max_z_score': max_z_score,
        'threshold': threshold
    }
    
    # More strict criteria for "good" classification
    if (max_sim > threshold and  # High similarity to at least one good sample
        mean_sim > threshold * 0.9 and  # Generally similar to good samples
        max_z_score < 3.0):  # Within statistical bounds
        return "good", max_sim, metrics
    else:
        return "defect", max_sim, metrics

In [10]:
# Test the improved classifier
result, similarity, metrics = classify_image(
    r'masks\good\Image__2025-05-05__10-07-50.jpg', 
    good_feats, 
    good_stats,
    model, 
    threshold=0.8
)

print(f"Result: {result}")
print(f"Similarity: {similarity:.3f}")
print("\nDetailed metrics:")
for key, value in metrics.items():
    print(f"{key}: {value:.3f}")

Result: good
Similarity: 1.000

Detailed metrics:
max_similarity: 1.000
mean_similarity: 0.915
max_z_score: 2.151
threshold: 0.800


In [39]:
metrics

{'max_similarity': np.float32(0.8431467),
 'mean_similarity': np.float32(0.67341876),
 'max_z_score': np.float32(4.300394),
 'threshold': 0.8}