In [1]:
import cv2
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import torch
import torch.nn.functional as F
import os
from torch.autograd import Variable
from skimage import io 
from skimage.transform import resize 
import transforms as transforms
from models import ResNet18 

In [2]:
cut_size = 44  

transform_test = transforms.Compose([
    transforms.TenCrop(cut_size),
    transforms.Lambda(lambda crops: torch.stack([transforms.ToTensor()(crop) for crop in crops])),
])

def rgb2gray(rgb):
    return np.dot(rgb[...,:3], [0.299, 0.587, 0.114])

In [3]:
class_names = ['Angry', 'Disgust', 'Fear', 'Happy', 'Sad', 'Surprise', 'Neutral']

net = ResNet18()  
checkpoint = torch.load(os.path.join('FER2013_ResNet18', 'Validation_model.t7'), map_location='cpu')
net.load_state_dict(checkpoint['net'])  
net.eval()

  checkpoint = torch.load(os.path.join('FER2013_ResNet18', 'Validation_model.t7'), map_location='cpu')


ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (shortcut): Sequential()
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=

In [4]:
def visualize_results(img_path):
    raw_img = io.imread(img_path)
    gray = rgb2gray(raw_img)
    gray = resize(gray, (48,48), mode='symmetric').astype(np.uint8)

    img = gray[:, :, np.newaxis]
    img = np.concatenate((img, img, img), axis=2)
    img = Image.fromarray(img)
    inputs = transform_test(img)

    class_names = ['Angry', 'Disgust', 'Fear', 'Happy', 'Sad', 'Surprise', 'Neutral']

    net = ResNet18()  # Re-initialize ResNet18 model
    checkpoint = torch.load(os.path.join('FER2013_ResNet18', 'Validation_model.t7'), map_location='cpu')
    net.load_state_dict(checkpoint['net'])  # Load the trained model weights
    net.eval()

    ncrops, c, h, w = np.shape(inputs)

    inputs = inputs.view(-1, c, h, w)
    inputs = Variable(inputs)
    with torch.no_grad():
        outputs = net(inputs)

    outputs_avg = outputs.view(ncrops, -1).mean(0)  # avg over crops

    score = F.softmax(outputs_avg, dim=0)
    _, predicted = torch.max(outputs_avg.data, 0)

    plt.rcParams['figure.figsize'] = (13.5, 5.5)
    plt.subplot(1, 3, 1)
    plt.imshow(raw_img)
    plt.xlabel('Input Image', fontsize=16)
    plt.xticks([])
    plt.yticks([])
    plt.tight_layout()

    plt.subplots_adjust(left=0.05, bottom=0.2, right=0.95, top=0.9, hspace=0.02, wspace=0.3)

    plt.subplot(1, 3, 2)
    ind = 0.1 + 0.6 * np.arange(len(class_names))
    width = 0.4
    color_list = ['red', 'orangered', 'darkorange', 'limegreen', 'darkgreen', 'royalblue', 'navy']
    for i in range(len(class_names)):
        plt.bar(ind[i], score.data.cpu().numpy()[i], width, color=color_list[i])
    plt.title("Classification results", fontsize=20)
    plt.xlabel("Expression Category", fontsize=16)
    plt.ylabel("Classification Score", fontsize=16)
    plt.xticks(ind, class_names, rotation=45, fontsize=14)

    plt.subplot(1, 3, 3)
    plt.text(0.5, 0.5, f'Expression: {class_names[int(predicted.cpu().numpy())]}', 
             ha='center', va='center', fontsize=16)
    plt.xticks([])
    plt.yticks([])
    plt.tight_layout()

    # Ensure directory exists
    output_dir = r'F:\FER-2013 Project\fer2013_vgg_resnet\demo'

    # Save the result image
    plt.savefig(os.path.join(output_dir, '5_resnet.png'))
    plt.close()

    return class_names[int(predicted.cpu().numpy())]

In [10]:
img_path = 'images/2.jpg'
expression = visualize_results(img_path)
print(f"The Expression is {expression}")

  checkpoint = torch.load(os.path.join('FER2013_ResNet18', 'Validation_model.t7'), map_location='cpu')


The Expression is Surprise


In [11]:
def visualize_webcam():
    class_names = ['Angry', 'Disgust', 'Fear', 'Happy', 'Sad', 'Surprise', 'Neutral']
    
    # Load the model
    net = ResNet18()
    checkpoint = torch.load(os.path.join('FER2013_ResNet18', 'Validation_model.t7'), map_location='cpu')
    net.load_state_dict(checkpoint['net'])
    net.eval()

    # Start webcam
    cap = cv2.VideoCapture(0)
    
    while True:
        ret, frame = cap.read()
        if not ret:
            break
        
        # Convert the frame to grayscale and resize to 48x48 for model input
        gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
        gray = cv2.resize(gray, (48, 48))
        img = gray[:, :, np.newaxis]
        img = np.concatenate((img, img, img), axis=2)
        img = Image.fromarray(img)
        inputs = transform_test(img)

        # Prepare input tensor
        ncrops, c, h, w = np.shape(inputs)
        inputs = inputs.view(-1, c, h, w)
        inputs = Variable(inputs)

        # Make prediction
        with torch.no_grad():
            outputs = net(inputs)

        outputs_avg = outputs.view(ncrops, -1).mean(0)  # avg over crops
        score = F.softmax(outputs_avg, dim=0)
        _, predicted = torch.max(outputs_avg.data, 0)
        predicted_expression = class_names[int(predicted.cpu().numpy())]

        # Display the resulting frame with prediction
        cv2.putText(frame, f'Expression: {predicted_expression}', (10, 30), 
                    cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2, cv2.LINE_AA)
        cv2.imshow('Facial Expression Recognition', frame)

        # Press 'q' to quit
        if cv2.waitKey(1) & 0xFF == ord('q'):
            break

    # Release the capture and close windows
    cap.release()
    cv2.destroyAllWindows()

# Call the function to start the webcam
visualize_webcam()

  checkpoint = torch.load(os.path.join('FER2013_ResNet18', 'Validation_model.t7'), map_location='cpu')
