In [1]:
import os
import cv2
import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as tt

In [2]:
model_state = torch.load("./emotion_detection_model_state.pth", map_location ='cpu')

In [3]:
face_classifier = cv2.CascadeClassifier("./haarcascade_frontalface_default.xml")
class_labels = ["Angry", "Happy", "Neutral", "Sad", "Suprise"]

In [4]:
def conv_block(in_channels, out_channels, pool=False):
    layers = [
        nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
        nn.BatchNorm2d(out_channels),
        nn.ELU(inplace=True),
    ]
    if pool:
        layers.append(nn.MaxPool2d(2))
    return nn.Sequential(*layers)

In [5]:
def conv_block(in_channels, out_channels, pool=False):
    layers = [
        nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
        nn.BatchNorm2d(out_channels),
        nn.ReLU(inplace=True),
    ]
    if pool:
        layers.append(nn.MaxPool2d(2))
    return nn.Sequential(*layers)

In [6]:
class ResNet(nn.Module):
    def __init__(self, in_channels, num_classes):
        super().__init__()

        self.input = conv_block(in_channels, 64)

        self.conv1 = conv_block(64, 64, pool=True)
        self.res1 = nn.Sequential(conv_block(64, 32), conv_block(32, 64))
        self.drop1 = nn.Dropout(0.5)

        self.conv2 = conv_block(64, 64, pool=True)
        self.res2 = nn.Sequential(conv_block(64, 32), conv_block(32, 64))
        self.drop2 = nn.Dropout(0.5)

        self.conv3 = conv_block(64, 64, pool=True)
        self.res3 = nn.Sequential(conv_block(64, 32), conv_block(32, 64))
        self.drop3 = nn.Dropout(0.5)

        self.classifier = nn.Sequential(
            nn.MaxPool2d(6), nn.Flatten(), nn.Linear(64, num_classes)
        )

    def forward(self, xb):
        out = self.input(xb)

        out = self.conv1(out)
        out = self.res1(out) + out
        out = self.drop1(out)

        out = self.conv2(out)
        out = self.res2(out) + out
        out = self.drop2(out)

        out = self.conv3(out)
        out = self.res3(out) + out
        out = self.drop3(out)

        return self.classifier(out)


In [7]:
model = ResNet(1, len(class_labels))
model.load_state_dict(model_state)

<All keys matched successfully>

In [8]:
cap = cv2.VideoCapture(0)

while True:
    # Grab a single frame of video
    ret, frame = cap.read()
    frame = cv2.flip(frame, 1)
    labels = []
    gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
    faces = face_classifier.detectMultiScale(gray, 1.3, 5)

    for (x, y, w, h) in faces:
        cv2.rectangle(frame, (x, y), (x + w, y + h), (255, 0, 0), 2)
        roi_gray = gray[y : y + h, x : x + w]
        roi_gray = cv2.resize(roi_gray, (48, 48), interpolation=cv2.INTER_AREA)

        if np.sum([roi_gray]) != 0:
            roi = tt.functional.to_pil_image(roi_gray)
            roi = tt.functional.to_grayscale(roi)
            roi = tt.ToTensor()(roi).unsqueeze(0)

            # make a prediction on the ROI
            tensor = model(roi)
            #print(tensor.tolist())
            pred = torch.max(tensor, dim=1)[1].tolist()
            label = class_labels[pred[0]]

            label_position = (x, y)
            cv2.putText(
                frame,
                label,
                label_position,
                cv2.FONT_HERSHEY_COMPLEX,
                2,
                (0, 255, 0),
                3,
            )
        else:
            cv2.putText(
                frame,
                "No Face Found",
                (20, 60),
                cv2.FONT_HERSHEY_COMPLEX,
                2,
                (0, 255, 0),
                3,
            )

    cv2.imshow("Emotion Detector", frame)

    if cv2.waitKey(1) & 0xFF == ord("q"):
        break

cap.release()
cv2.destroyAllWindows()

[[-4.1685872077941895, -4.564293384552002, -3.6073617935180664, -3.1580841541290283, -5.063583850860596]]
[[-4.053432941436768, -2.956923484802246, -2.0543532371520996, -3.056497812271118, -5.4582343101501465]]
[[-2.8383259773254395, -3.0892415046691895, -3.005513906478882, -2.821845769882202, -4.626485824584961]]
[[-3.9935824871063232, -4.357008934020996, -2.6183087825775146, -2.097705841064453, -4.850430965423584]]
[[-3.601508140563965, -5.979517459869385, -2.239406108856201, -2.058777093887329, -5.05912971496582]]
[[-3.1958436965942383, -5.440171241760254, -3.1475911140441895, -2.5942044258117676, -6.002705097198486]]
[[-3.154974937438965, -4.714908599853516, -2.8684020042419434, -1.9397155046463013, -4.573980331420898]]
[[-3.528003692626953, -3.963716506958008, -2.3341588973999023, -2.0165841579437256, -4.434648036956787]]
[[-3.5471227169036865, -3.7799782752990723, -2.2808823585510254, -1.7799582481384277, -5.412362575531006]]
[[-3.9613242149353027, -4.221288204193115, -1.30645370

[[-4.618040561676025, -2.951165199279785, -3.169095516204834, -2.8437280654907227, -5.117485523223877]]
[[-3.6980807781219482, -4.3561692237854, -3.500450611114502, -2.4781813621520996, -3.893428325653076]]
[[-4.182177543640137, -4.983205318450928, -2.900052785873413, -2.129589319229126, -4.209403038024902]]
[[-4.574118137359619, -3.984955310821533, -3.815899610519409, -2.972611904144287, -5.355576992034912]]
[[-4.809389114379883, -5.671656131744385, -2.990262031555176, -2.693049192428589, -4.519289970397949]]
[[-4.204864025115967, -4.09714412689209, -3.3432042598724365, -3.3338444232940674, -4.066456317901611]]
[[-3.802105665206909, -4.628827095031738, -3.9456472396850586, -2.618626356124878, -4.412708282470703]]
[[-4.5082879066467285, -4.125554084777832, -2.864006996154785, -1.7908704280853271, -4.547720432281494]]
[[-3.8323962688446045, -4.6899309158325195, -2.6516897678375244, -2.6519196033477783, -4.425174236297607]]
[[-3.7506237030029297, -5.007550239562988, -3.0818285942077637, 

[[-3.9044361114501953, -4.486044406890869, -2.4362998008728027, -1.4975978136062622, -4.91843843460083]]
[[-3.6098837852478027, -4.207095146179199, -2.954740047454834, -2.2410645484924316, -4.858747959136963]]
[[-3.291938304901123, -2.9783620834350586, -2.908325672149658, -2.9307568073272705, -3.8370842933654785]]
[[-4.0318732261657715, -4.301527500152588, -2.986837863922119, -2.6788337230682373, -4.822628974914551]]
[[-4.3586907386779785, -4.593265056610107, -3.049363851547241, -2.437067747116089, -4.776833534240723]]
[[-3.481433153152466, -5.044958114624023, -2.3004884719848633, -1.4778482913970947, -5.1291399002075195]]
[[-4.2149739265441895, -4.820835113525391, -2.7626068592071533, -2.7115638256073, -4.292047023773193]]
[[-3.6480050086975098, -5.627954483032227, -3.0047640800476074, -3.3847336769104004, -5.169857501983643]]
[[-4.867081165313721, -3.4944887161254883, -2.5149879455566406, -3.7473866939544678, -5.311840057373047]]
[[-3.585775375366211, -4.242979049682617, -3.103597640

[[-2.615772247314453, -5.292549133300781, -2.488844633102417, -0.9405000805854797, -5.922170162200928]]
[[-3.6902976036071777, -5.487622261047363, -2.303640604019165, -1.4647477865219116, -5.979639053344727]]
[[-3.676485538482666, -4.528435230255127, -2.2775068283081055, -1.67377507686615, -4.987416744232178]]
[[-3.3817243576049805, -4.702972412109375, -2.4922261238098145, -1.001153588294983, -5.161862850189209]]
[[-4.583157062530518, -4.061932563781738, -2.8714098930358887, -3.0451605319976807, -5.0278706550598145]]
[[-4.670909404754639, -4.122353553771973, -2.7052934169769287, -0.5700129270553589, -5.848662853240967]]
[[-4.031097412109375, -4.178384780883789, -2.0840115547180176, -0.5175685286521912, -4.982253551483154]]
[[-4.353708744049072, -3.7112584114074707, -3.401383399963379, -2.326866388320923, -4.3474345207214355]]
[[-3.5315330028533936, -4.516988277435303, -2.5881221294403076, -0.5859200954437256, -5.12340784072876]]
[[-3.6948792934417725, -5.344332695007324, -2.48016262054

[[-4.529146671295166, -4.867465972900391, -4.303593158721924, -3.979978322982788, -2.9142837524414062]]
[[-4.218332767486572, -4.597044467926025, -3.979245185852051, -2.5747830867767334, -3.076146364212036]]
[[-3.1597580909729004, -3.9815239906311035, -3.802675247192383, -2.2142913341522217, -3.303389310836792]]
[[-4.7826924324035645, -3.3808183670043945, -4.752132892608643, -4.241610527038574, -3.0918831825256348]]
[[-5.113654613494873, -4.982944011688232, -3.966785430908203, -3.706805944442749, -2.5202465057373047]]
[[-4.719248294830322, -3.9412920475006104, -4.411561012268066, -1.9910022020339966, -2.6760621070861816]]
[[-4.902853012084961, -5.24464750289917, -3.9781856536865234, -3.8953640460968018, -3.802952289581299]]
[[-5.245761394500732, -3.247518301010132, -4.997799873352051, -4.863434791564941, -1.2172704935073853]]
[[-5.22240686416626, -3.6202785968780518, -5.483442306518555, -4.687103271484375, -2.939514398574829]]
[[-5.083861827850342, -4.382379531860352, -4.86442518234252

[[-4.4568190574646, -4.722151756286621, -4.587924480438232, -3.5002365112304688, -2.5268945693969727]]
[[-4.671376705169678, -5.285706996917725, -3.860994815826416, -3.532686710357666, -2.404507637023926]]
[[-3.842297077178955, -5.165245056152344, -4.0335540771484375, -3.60062575340271, -3.994699239730835]]
[[-3.851330280303955, -5.035499095916748, -3.5066866874694824, -3.3751490116119385, -4.184139728546143]]
[[-3.358123540878296, -4.314993381500244, -3.548039197921753, -3.2646048069000244, -3.475513458251953]]
[[-4.739797115325928, -4.856914520263672, -5.224826812744141, -4.13408899307251, -2.9651732444763184]]
[[-5.174964427947998, -6.159482955932617, -4.796938896179199, -3.5383989810943604, -3.148988962173462]]
[[-4.816011905670166, -4.354241371154785, -3.4415502548217773, -4.284322261810303, -3.2573068141937256]]
[[-4.7437310218811035, -4.987056255340576, -3.29262638092041, -3.7571496963500977, -4.533743381500244]]
[[-4.303832530975342, -4.427855014801025, -4.065188407897949, -4.4

[[-4.573625564575195, -2.8706705570220947, -3.360535144805908, -4.593873977661133, -6.814995765686035]]
[[-4.977965831756592, -2.4322314262390137, -2.545609712600708, -4.217416286468506, -5.033194065093994]]
[[-4.915923595428467, -1.751186490058899, -2.152088165283203, -4.259552955627441, -5.9000935554504395]]
[[-4.678994178771973, -1.840142846107483, -1.7011924982070923, -3.4008262157440186, -4.864196300506592]]
[[-5.430327415466309, -2.859445810317993, -2.5108022689819336, -4.585289001464844, -5.489516735076904]]
[[-5.953924179077148, -1.3374396562576294, -2.5893900394439697, -4.475591659545898, -5.324830055236816]]
[[-5.70791482925415, -2.3241262435913086, -2.6107237339019775, -4.611123085021973, -5.157947540283203]]
[[-5.796895503997803, -2.968656063079834, -2.240489959716797, -3.4674088954925537, -5.833924770355225]]
[[-3.878570556640625, -3.4022040367126465, -3.1714987754821777, -3.2481625080108643, -5.387538433074951]]
[[-4.524316310882568, -1.8999172449111938, -3.40444850921630

[[-3.402071952819824, -3.504519462585449, -1.825807809829712, -3.1217520236968994, -5.24923849105835]]
[[-3.350102663040161, -4.974280834197998, -3.1398234367370605, -1.8847779035568237, -5.836030006408691]]
[[-4.228591442108154, -5.698777198791504, -3.2848474979400635, -2.178910970687866, -4.899908065795898]]
[[-3.3322300910949707, -4.652418613433838, -3.452791213989258, -1.499277114868164, -5.428720951080322]]
[[-4.915748119354248, -5.232456207275391, -2.8216381072998047, -2.208383321762085, -5.607344150543213]]
[[-3.000356674194336, -5.045199394226074, -2.009694814682007, -2.024637222290039, -5.5900750160217285]]
[[-2.5438029766082764, -6.573352336883545, -4.437749862670898, -2.8801233768463135, -5.315150737762451]]
[[-3.7463924884796143, -4.941371917724609, -3.1322526931762695, -3.423729181289673, -5.355259895324707]]
[[-3.656216621398926, -4.584330081939697, -3.1924898624420166, -2.03586483001709, -4.812069416046143]]
[[-3.3632102012634277, -5.162949562072754, -3.489593029022217, 

[[-3.1386427879333496, -5.32071590423584, -3.0127570629119873, -1.8127740621566772, -5.35189962387085]]
[[-2.517038106918335, -4.754676342010498, -3.0257678031921387, -2.005847454071045, -5.817479133605957]]
[[-2.6311521530151367, -4.889752388000488, -3.722900867462158, -2.4356849193573, -5.680029392242432]]
[[-2.9066460132598877, -4.406500816345215, -3.1163816452026367, -2.0171337127685547, -6.025664806365967]]
[[-3.2032322883605957, -4.1491594314575195, -3.595277786254883, -3.6420462131500244, -5.744482517242432]]
[[-2.6111114025115967, -4.628686904907227, -1.8493444919586182, -2.215609550476074, -6.9197797775268555]]
[[-2.875730514526367, -4.285348892211914, -2.6469171047210693, -2.901796340942383, -5.299404144287109]]
[[-3.0715556144714355, -4.771430492401123, -2.662281036376953, -2.501223087310791, -5.424170970916748]]
[[-2.682860851287842, -5.1324238777160645, -3.0960307121276855, -1.6629197597503662, -7.04462194442749]]
[[-3.8144009113311768, -5.669995307922363, -3.9450001716613

[[-3.058256149291992, -2.2704052925109863, -3.99723482131958, -2.7851617336273193, -5.472224712371826]]
[[-2.195129632949829, -4.163115978240967, -3.495056390762329, -2.779209613800049, -6.100755214691162]]
[[-2.259369373321533, -2.606663942337036, -3.8455934524536133, -2.7590994834899902, -4.688252925872803]]
[[-2.069603681564331, -4.216067314147949, -5.460016250610352, -4.221428871154785, -5.174328327178955]]
[[-2.4784486293792725, -3.8064637184143066, -3.0933687686920166, -1.5986090898513794, -4.253983974456787]]
[[-2.6406474113464355, -4.6539411544799805, -4.467868804931641, -2.5116469860076904, -4.436277866363525]]
[[-3.5918073654174805, -4.3100175857543945, -5.203359603881836, -4.7154436111450195, -5.3690571784973145]]
[[-3.1281661987304688, -4.38779354095459, -5.349128723144531, -4.291800498962402, -6.06910514831543]]
[[-2.980097770690918, -3.298381805419922, -5.444528579711914, -4.23001766204834, -5.503917694091797]]
[[-3.0276384353637695, -2.948732852935791, -4.117033004760742

[[-3.4512925148010254, -4.009108543395996, -4.446843147277832, -3.759394407272339, -5.665010929107666]]
[[-2.610318899154663, -4.883289337158203, -5.301018714904785, -3.412292718887329, -5.895669937133789]]
[[-2.907454013824463, -4.300911903381348, -4.04852819442749, -3.664898157119751, -5.460778713226318]]
[[-4.042003154754639, -4.843228340148926, -4.275428771972656, -2.9548704624176025, -5.521122932434082]]
[[-2.6674857139587402, -4.822113037109375, -4.778772354125977, -4.144524574279785, -4.889468193054199]]
[[-3.2681732177734375, -3.9825949668884277, -3.637843132019043, -3.7866876125335693, -5.303783416748047]]
[[-2.756582021713257, -3.3343236446380615, -4.104456901550293, -4.35307502746582, -5.352379322052002]]
[[-3.059522867202759, -4.366353988647461, -4.412885665893555, -4.2043256759643555, -5.3130998611450195]]
[[-2.937666416168213, -3.1170520782470703, -3.6073062419891357, -3.547581434249878, -5.520169258117676]]
[[-2.8330554962158203, -2.5313363075256348, -3.4663026332855225,

[[-4.318276882171631, -4.69169807434082, -2.5595197677612305, -2.92232084274292, -4.7595601081848145]]
[[-3.6165990829467773, -3.499041795730591, -2.366824150085449, -2.825995445251465, -5.29750394821167]]
[[-3.928844451904297, -4.351870536804199, -1.7847926616668701, -2.7363178730010986, -4.353950500488281]]
[[-3.7580726146698, -3.3829352855682373, -1.7063724994659424, -2.450467109680176, -4.591996669769287]]
[[-3.456629753112793, -3.540233612060547, -2.0026206970214844, -2.8123741149902344, -4.826014041900635]]
[[-3.9783778190612793, -4.91288948059082, -1.4346487522125244, -2.188356399536133, -4.370389461517334]]
[[-3.4973697662353516, -3.991777181625366, -1.5733975172042847, -3.018767833709717, -5.373501300811768]]
[[-3.5546576976776123, -4.900139808654785, -2.2264580726623535, -2.7738428115844727, -4.838057994842529]]
[[-3.71462345123291, -4.554064750671387, -1.8307571411132812, -2.5897819995880127, -4.120614051818848]]
[[-3.394711494445801, -4.668496131896973, -2.24090576171875, -

[[-4.128049850463867, -3.2653400897979736, -2.5883984565734863, -2.6161060333251953, -6.4338507652282715]]
