In [None]:
# load the model 
import os

path_to_model = '/content/drive/MyDrive/Colab Notebooks/affecnet8_epoch5_acc0.6209.pth'

os.path.exists(path_to_model)

True

In [None]:
from torch import nn
from torch.nn import functional as F
import torch
import torch.nn.init as init
from torchvision import models


class DAN(nn.Module):
    def __init__(self, num_class=7,num_head=4, pretrained=True):
        super(DAN, self).__init__()
        
        resnet = models.resnet18(pretrained)
        
        if pretrained:
            checkpoint = torch.load('./models/resnet18_msceleb.pth')
            resnet.load_state_dict(checkpoint['state_dict'],strict=True)

        self.features = nn.Sequential(*list(resnet.children())[:-2])
        self.num_head = num_head
        for i in range(num_head):
            setattr(self,"cat_head%d" %i, CrossAttentionHead())
        self.sig = nn.Sigmoid()
        self.fc = nn.Linear(512, num_class)
        self.bn = nn.BatchNorm1d(num_class)


    def forward(self, x):
        x = self.features(x)
        heads = []
        for i in range(self.num_head):
            heads.append(getattr(self,"cat_head%d" %i)(x))
        
        heads = torch.stack(heads).permute([1,0,2])
        if heads.size(1)>1:
            heads = F.log_softmax(heads,dim=1)
            
        out = self.fc(heads.sum(dim=1))
        out = self.bn(out)
   
        return out, x, heads

class CrossAttentionHead(nn.Module):
    def __init__(self):
        super().__init__()
        self.sa = SpatialAttention()
        self.ca = ChannelAttention()
        self.init_weights()


    def init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                init.kaiming_normal_(m.weight, mode='fan_out')
                if m.bias is not None:
                    init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                init.constant_(m.weight, 1)
                init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                init.normal_(m.weight, std=0.001)
                if m.bias is not None:
                    init.constant_(m.bias, 0)
    def forward(self, x):
        sa = self.sa(x)
        ca = self.ca(sa)

        return ca


class SpatialAttention(nn.Module):

    def __init__(self):
        super().__init__()
        self.conv1x1 = nn.Sequential(
            nn.Conv2d(512, 256, kernel_size=1),
            nn.BatchNorm2d(256),
        )
        self.conv_3x3 = nn.Sequential(
            nn.Conv2d(256, 512, kernel_size=3,padding=1),
            nn.BatchNorm2d(512),
        )
        self.conv_1x3 = nn.Sequential(
            nn.Conv2d(256, 512, kernel_size=(1,3),padding=(0,1)),
            nn.BatchNorm2d(512),
        )
        self.conv_3x1 = nn.Sequential(
            nn.Conv2d(256, 512, kernel_size=(3,1),padding=(1,0)),
            nn.BatchNorm2d(512),
        )
        self.relu = nn.ReLU()


    def forward(self, x):
        y = self.conv1x1(x)
        y = self.relu(self.conv_3x3(y) + self.conv_1x3(y) + self.conv_3x1(y))
        y = y.sum(dim=1,keepdim=True) 
        out = x*y
        
        return out 

class ChannelAttention(nn.Module):

    def __init__(self):
        super().__init__()
        self.gap = nn.AdaptiveAvgPool2d(1)
        self.attention = nn.Sequential(
            nn.Linear(512, 32),
            nn.BatchNorm1d(32),
            nn.ReLU(inplace=True),
            nn.Linear(32, 512),
            nn.Sigmoid()    
        )


    def forward(self, sa):
        sa = self.gap(sa)
        sa = sa.view(sa.size(0),-1)
        y = self.attention(sa)
        out = sa * y
        
        return out

In [None]:
import os
import argparse

from PIL import Image
import numpy as np
import cv2

import torch
import torch.nn.functional as nnf
from torchvision import transforms

class Model():
    def __init__(self):
        if torch.cuda.is_available():
            print('Working with the gpu')
        self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        self.data_transforms = transforms.Compose([
                                    transforms.Resize((224, 224)),
                                    transforms.ToTensor(),
                                    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                    std=[0.229, 0.224, 0.225])
                                ])
        self.labels = ['neutral', 'happy', 'sad', 'surprise', 'fear', 'disgust', 'anger', 'contempt']

        self.model = DAN(num_head=4, num_class=8, pretrained=False)
        checkpoint = torch.load(path_to_model,
            map_location=self.device)
        self.model.load_state_dict(checkpoint['model_state_dict'], strict=True)
        self.model.to(self.device)
        self.model.eval()

        self.face_cascade = cv2.CascadeClassifier(cv2.data.haarcascades+'haarcascade_frontalface_default.xml')
    
    def detect(self, img0):
        img = cv2.cvtColor(np.asarray(img0),cv2.COLOR_RGB2BGR)
        faces = self.face_cascade.detectMultiScale(img)
        
        return faces

    def fer(self, path):

        img0 = Image.open(path).convert('RGB')

        faces = self.detect(img0)

        if len(faces) == 0:
            return 'null'

        ##  single face detection
        x, y, w, h = faces[0]

        img = img0.crop((x,y, x+w, y+h))

        img = self.data_transforms(img)
        img = img.view(1,3,224,224)
        img = img.to(self.device)

        with torch.set_grad_enabled(False):
            out, _, _ = self.model(img)
            
            softmax_probs = nnf.softmax(out, dim=1)
            
            _, pred = torch.max(out,1)
            
            index = int(pred)
            label = self.labels[index]

            return label, dict(zip(self.labels, 
                                   list(np.round(i, 4) for i in softmax_probs.tolist())[0]))

In [None]:
class Color:
   PURPLE = '\033[95m'
   CYAN = '\033[96m'
   DARKCYAN = '\033[36m'
   BLUE = '\033[94m'
   GREEN = '\033[92m'
   YELLOW = '\033[93m'
   RED = '\033[91m'
   BOLD = '\033[1m'
   UNDERLINE = '\033[4m'
   END = '\033[0m'

### Test it by providing a URL

In [None]:
# enter an image URL
import time

img_url = input('URL:').strip()
salt = int(time.time())

URL:https://miro.medium.com/max/1400/1*vnewG2OYRUij_b_zm5uNrQ.png


In [None]:
from IPython.display import Image as image

image(url=img_url, width=400, height=256)

In [None]:
# download the image
!wget -q -4 $img_url -O pic${salt}.png

In [None]:
model = Model()

#image = f"/content/sample_data/pic{str(salt)[1:]}.jpg"

image = "image.jpg"

assert os.path.exists(image), "Failed to load image file."

print(image)
label, probs = model.fer(image)
print(f'Emotion label: {label}')
print("=" * 20)
print(f'Softmax probabilities', end='\n\n')
probs

image.jpg
Emotion label: neutral
Softmax probabilities



{'anger': 0.0585,
 'contempt': 0.1081,
 'disgust': 0.0499,
 'fear': 0.0819,
 'happy': 0.0982,
 'neutral': 0.3766,
 'sad': 0.1175,
 'surprise': 0.1092}

In [None]:
s = sum(probs[i] for i in ('anger', 'disgust', 'fear', 'happy', 'sad'))
[probs[i]/s for i in ('anger', 'disgust', 'fear', 'happy', 'sad')]

[0.14408866995073893,
 0.12290640394088671,
 0.2017241379310345,
 0.24187192118226603,
 0.2894088669950739]

### Predict on all frames/images in a folder

In [None]:
# download the image
img_url = r"https://i1.prth.gr/images/w880/_webp/files/2022-05-25/geetha__1_.jpg"
!wget -q -4 $img_url -O /content/Images/pic${int(time.time())}.jpg

In [None]:
folder_name = r'/content/drive/MyDrive/Colab Notebooks/Images_bordered/bordered_crazy_good/bordered_crazy_good'
last_part = os.path.basename(os.path.normpath(folder_name))+"_"

base = os.path.abspath(os.path.join(os.path.curdir, folder_name))
assert os.path.exists(base), 'Path doesn\'t exist'
({'path': base, 'n_images': sum(1 for _ in os.scandir(base))})

{'n_images': 482,
 'path': '/content/drive/MyDrive/Colab Notebooks/Images_bordered/bordered_crazy_good/bordered_crazy_good'}

In [None]:
# rename files
[os.rename(os.path.join(folder_name, filename), os.path.join(folder_name,last_part+filename)) 
    for filename in os.listdir(folder_name)]

[None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,

In [None]:
import matplotlib.pyplot as plt
import matplotlib.image as mpimg

valid_extensions = (".jpg", ".jpeg", ".png")
pics = [os.path.join(base, file_) for file_ in os.listdir(base)
            if os.path.isfile(os.path.join(base, file_)) and 
            file_.endswith(valid_extensions)]

# show images
'''
images = []
for pic in pics:
   if os.path.isfile(pic) and pic.endswith(valid_extensions): 
       images += [mpimg.imread(pic)]


plt.figure(figsize=(15,12))
plt.tight_layout()
columns = 5

for i, pic in enumerate(images):
    plt.subplot(len(images) / columns + 1, columns, i + 1)
    plt.axis('off')
    plt.imshow(pic, aspect='equal')
'''

"\nimages = []\nfor pic in pics:\n   if os.path.isfile(pic) and pic.endswith(valid_extensions): \n       images += [mpimg.imread(pic)]\n\n\nplt.figure(figsize=(15,12))\nplt.tight_layout()\ncolumns = 5\n\nfor i, pic in enumerate(images):\n    plt.subplot(len(images) / columns + 1, columns, i + 1)\n    plt.axis('off')\n    plt.imshow(pic, aspect='equal')\n"

In [None]:
%%time
from collections import defaultdict
from pathlib import PurePath
import shutil 

path_to_save = os.path.abspath(r"/content/drive/MyDrive/Colab Notebooks/results/")

# store probabilities here
dic = defaultdict(list)
invalid = 0
model = Model()

for pic in pics:
    prediction = model.fer(pic)
    if prediction is 'null':
        # raise RuntimeWarning(f'No label for image {os.path.basename(pic)}')
        print(f"{Color.RED}Invalid input for {PurePath(pic).name}{Color.END}")
        invalid += 1
    else:
        label, probs = prediction
        dic[label] += [probs[label]]
        shutil.copy(os.path.join(base, PurePath(pic).name), os.path.join(path_to_save, label))
        print(f'Emotion label for {os.path.basename(pic)}: {Color.DARKCYAN}{Color.UNDERLINE}{label}{Color.END}')
        # print(f'Softmax probabilities')
        print(probs)
    print("=" * 45)

print(f"\n{Color.YELLOW}Found {invalid} invalid picture(s) out of {pics.__len__()} ({invalid / pics.__len__() * 100:.4}%){Color.END}")

Emotion label for bordered_crazy_good_frame_122_border.png: [36m[4mneutral[0m
{'neutral': 0.168, 'happy': 0.1123, 'sad': 0.0941, 'surprise': 0.1071, 'fear': 0.1376, 'disgust': 0.1065, 'anger': 0.1258, 'contempt': 0.1487}
Emotion label for bordered_crazy_good_frame_121_border.png: [36m[4msurprise[0m
{'neutral': 0.1383, 'happy': 0.1406, 'sad': 0.1066, 'surprise': 0.1545, 'fear': 0.1526, 'disgust': 0.1116, 'anger': 0.078, 'contempt': 0.1177}
[91mInvalid input for bordered_crazy_good_frame_120_border.png[0m
Emotion label for bordered_crazy_good_frame_116_border.png: [36m[4mneutral[0m
{'neutral': 0.3223, 'happy': 0.1434, 'sad': 0.092, 'surprise': 0.0837, 'fear': 0.0673, 'disgust': 0.0605, 'anger': 0.1416, 'contempt': 0.0892}
[91mInvalid input for bordered_crazy_good_frame_115_border.png[0m
[91mInvalid input for bordered_crazy_good_frame_114_border.png[0m
[91mInvalid input for bordered_crazy_good_frame_11_border.png[0m
[91mInvalid input for bordered_crazy_good_frame_119_bor

In [None]:
sum_dict = defaultdict(int)

for emotion in ('neutral', 'happy', 'sad', 'surprise', 'fear', 'disgust', 'anger', 'contempt'):
    sum_dict[emotion] = sum(1 for _ in os.scandir(os.path.join(path_to_save, emotion)))

sum_dict

defaultdict(int,
            {'anger': 334,
             'contempt': 89,
             'disgust': 234,
             'fear': 321,
             'happy': 417,
             'neutral': 225,
             'sad': 602,
             'surprise': 318})

In [None]:
mean_vals = dict(zip(dic.keys(), list(map(lambda x : np.mean(x), dic.values()))))
(dic, mean_vals)