In [1]:
from transformers import (
    ViTForImageClassification,
    ViTFeatureExtractor,
    ViTImageProcessor,
)

from torchvision.transforms import (
    CenterCrop,
    Compose,
    Normalize,
    RandomHorizontalFlip,
    RandomResizedCrop,
    Resize,
    ToTensor,
)

import torch

from PIL import Image

In [2]:
from sklearn.metrics import accuracy_score, precision_score, recall_score, jaccard_score, confusion_matrix

In [3]:
import os
import shutil

In [4]:
# Create the image processor
model_name_or_path: str = 'google/vit-base-patch16-224-in21k'
cache_dir: str = None
model_revision: str = 'main'
use_auth_token: bool = False

image_processor = ViTImageProcessor.from_pretrained(
    model_name_or_path,
    cache_dir=cache_dir,
    revision=model_revision,
    use_auth_token=use_auth_token,
)

# Define torchvision transforms to be applied to each image.
if "shortest_edge" in image_processor.size:
    size = image_processor.size["shortest_edge"]
else:
    size = (image_processor.size["height"], image_processor.size["width"])

normalize = Normalize(mean=image_processor.image_mean, std=image_processor.image_std)

_test_transforms = Compose(
    [
        Resize(size),
        CenterCrop(size),
        ToTensor(),
        normalize,
    ]
)

In [5]:
def get_model(model_path):
    model = ViTForImageClassification.from_pretrained(model_path, local_files_only=True)
    return model

In [6]:
def make_prediction(model, filename):
    image = Image.open(filename)

    processed = image_processor(image)
    processed.pixel_values = _test_transforms(image.convert('RGB'))
    outputs = model(torch.reshape(processed.pixel_values, (1, 3, 224, 224)))

    # obtain the class
    logits = outputs.logits

    prediction = logits.argmax(-1)
    
    return model.config.id2label[prediction.item()]

In [7]:
def compute_model_performance(model_dir):
    model_results: list[str,str,str] = [
        #path, predicted_class, expected_class
    ]
    model = get_model(model_dir)

    for dirpath, _, filenames in os.walk('test/bread'):
        for filename in filenames:
            path = os.path.join(dirpath, filename)
            prediction = make_prediction(model, path)
            model_results.append((path, prediction, 'bread'))

    for dirpath, _, filenames in os.walk('test/not_bread'):
        for filename in filenames:
            path = os.path.join(dirpath, filename)
            prediction = make_prediction(model, path)
            model_results.append((path, prediction, 'not_bread'))
            
    true_labels = [res[2] for res in model_results]
    pred_labels = [res[1] for res in model_results]

    # Compute accuracy, precision, recall, and Jaccard score
    accuracy = accuracy_score(true_labels, pred_labels)
    precision = precision_score(true_labels, pred_labels, pos_label='bread')
    recall = recall_score(true_labels, pred_labels, pos_label='bread')
    jaccard = jaccard_score(true_labels, pred_labels, pos_label='bread')

    # Print the results
    print("Accuracy:", accuracy)
    print("Precision:", precision)
    print("Recall:", recall)
    print("Jaccard score:", jaccard)
    
    print(f'Examples of bread:     {len([e for e in true_labels if e == "bread"])}')
    print(f'Examples of not bread: {len([e for e in true_labels if e == "not_bread"])}')
    
    cm = confusion_matrix(true_labels, pred_labels)
    print(cm)
    
    return model_results

## Access the original model

In [8]:
original_model_results = compute_model_performance('outputs')

ValueError: Unable to infer channel dimension format

## Access Kesley's Model
#### The data was just as impure, but had a few more examples of bread

In [None]:
kesley_model_results = compute_model_performance('kesley_2070_output1')

## Access Ray's Model
#### For this model he cleaned up the data.  A few additional bread examples.

In [None]:
ray_model_results = compute_model_performance('ray_output1')

In [None]:
false_negatives = [f for f, p, a in ray_model_results if p != a]

In [None]:
false_negatives

In [None]:
dirname = 'false_negatives'

if not os.path.isdir(dirname):
    os.mkdir(dirname)

for file in false_negatives:
    new_path = os.path.join(dirname, os.path.basename(file))
    shutil.copy(file, new_path)

In [None]:
dirname = 'positives'

bread = [f for f, p, a in ray_model_results if p == a and p == 'bread']

if not os.path.isdir(dirname):
    os.mkdir(dirname)

for file in bread:
    new_path = os.path.join(dirname, os.path.basename(file))
    shutil.copy(file, new_path)