ViT Model trained on FER2013 from https://huggingface.co/trpakov/vit-face-expression

In [1]:
# Load model directly
from transformers import AutoImageProcessor, AutoModelForImageClassification
import torch

processor = AutoImageProcessor.from_pretrained("trpakov/vit-face-expression")
model = AutoModelForImageClassification.from_pretrained("trpakov/vit-face-expression")

In [2]:
import numpy as np

In [3]:
processor.do_convert_rgb = True
processor.do_rgb_to_grayscale = True

In [None]:
processor

Download dataset from https://figshare.com/articles/figure/DVM-CAR_Dataset/19586296/1?file=34792480
Place under `data/images`

In [5]:
TEST_IMAGE_PATH = "data/images/confirmed_fronts/Lexus/2017/Lexus$$RX 450h$$2017$$Red$$48_24$$468$$image_1.jpg"
CLASS_INDEX = {
    0: "Angry",
    1: "Disgust",
    2: "Fear",
    3: "Happy",
    4: "Sad",
    5: "Surprise",
    6: "Neutral"
}

In [6]:
from PIL import Image

def predict_expression(image_path):
    # Open the image
    # image = Image.open(image_path)
    gray_image = Image.open(image_path).convert('L')
    image = gray_image.convert('RGB')

    # Preprocess the image
    inputs = processor(images=image, return_tensors="pt")
    
    # Run inference
    with torch.no_grad():
        outputs = model(**inputs)

    # Get the predicted class
    predicted_class_idx = outputs.logits.argmax(-1).item()
    predicted_prob = outputs.logits[0].softmax(-1).tolist()
    
    # Get the human-readable label
    label = CLASS_INDEX[predicted_class_idx]
    
    return image, label, predicted_prob

In [7]:
import matplotlib.pyplot as plt

def visualize_images(image_paths):
    num_images = len(image_paths)
    # Set up the plot
    fig, axes = plt.subplots(1, num_images, figsize=(15, 5))

    for idx, image_path in enumerate(image_paths):
        # Predict expression
        image, label, _ = predict_expression(image_path)
        
        # If there's only one image, `axes` is not a list.
        ax = axes if num_images == 1 else axes[idx]
        
        # Display the image
        ax.imshow(image)
        ax.set_title(label)
        ax.axis('off')  # Hide axis

    plt.tight_layout()
    plt.show()

In [None]:
visualize_images([TEST_IMAGE_PATH,])

In [9]:
image_paths = [
    "data/images/confirmed_fronts/Lexus/2017/Lexus$$RX 450h$$2017$$Red$$48_24$$468$$image_1.jpg",
    "data/images/confirmed_fronts/Audi/2017/Audi$$Q5$$2017$$Black$$7_20$$1219$$image_5.jpg",
    "data/images/confirmed_fronts/Tesla/2017/Tesla$$Model X$$2017$$White$$90_2$$8$$image_0.jpg",
    "data/images/confirmed_fronts/Toyota/2017/Toyota$$RAV4$$2017$$Silver$$92_34$$164$$image_0.jpg",
    "data/images/confirmed_fronts/Subaru/2017/Subaru$$Outback$$2017$$Black$$86_7$$89$$image_4.jpg",
    "data/images/confirmed_fronts/Bentley/2017/Bentley$$Continental$$2017$$Grey$$10_5$$700$$image_5.jpg"
]

In [None]:
visualize_images(image_paths)

In [None]:
import os
import pandas as pd
from tqdm import tqdm
from PIL import Image

root_dir = 'data/images/confirmed_fronts/'

# Get total number of files for the progress bar
total_files = sum([len(files) for r, d, files in os.walk(root_dir)])

# Create a tqdm object
pbar = tqdm(total=total_files, desc="Processing files")

# Initialize an empty list to store results
results = []
avg_probs = np.array([])

# Counter for visualizing every 100 images
image_counter = 0

for dirpath, dirnames, filenames in os.walk(root_dir):
    for filename in filenames:
        full_path = os.path.join(dirpath, filename)
        image, label, probs = predict_expression(full_path)
        # Store results
        result = {
            'image_path': full_path,
            'label': label
        }
        # Add probability columns
        for i, prob in enumerate(probs):
            result[f'prob_{CLASS_INDEX[i].lower()}'] = prob
        
        results.append(result)
        pbar.update(1)

pbar.close()
# Create DataFrame from results
df = pd.DataFrame(results)

# Display the first few rows of the DataFrame
print(df.head())

# Optionally, save the DataFrame to a CSV file
df.to_csv('expression_results.csv', index=False)