In [None]:
import os 
import json 
import torch
from slimnet import SlimNet
from torchvision import transforms
from PIL import Image
import numpy as np
from glob import glob 
from tqdm import tqdm 

DEST_PATH = '/raid/t-yazen/datasets/ravdess_text'
PATH_TO_IMAGE = '/raid/t-yazen/datasets/ravdess_256/'
os.makedirs(DEST_PATH, exist_ok=True)

In [None]:
labels = np.array(['5_o_Clock_Shadow', 'Arched_Eyebrows', 'Attractive', 'Bags_Under_Eyes',
       'Bald', 'Bangs', 'Big_Lips', 'Big_Nose', 'Black_Hair', 'Blond_Hair',
       'Blurry', 'Brown_Hair', 'Bushy_Eyebrows', 'Chubby', 'Double_Chin',
       'Eyeglasses', 'Goatee', 'Gray_Hair', 'Heavy_Makeup', 'High_Cheekbones',
       'Male', 'Mouth_Slightly_Open', 'Mustache', 'Narrow_Eyes', 'No_Beard',
       'Oval_Face', 'Pale_Skin', 'Pointy_Nose', 'Receding_Hairline',
       'Rosy_Cheeks', 'Sideburns', 'Smiling', 'Straight_Hair', 'Wavy_Hair',
       'Wearing_Earrings', 'Wearing_Hat', 'Wearing_Lipstick',
       'Wearing_Necklace', 'Wearing_Necktie', 'Young'])
# GPU isn't necessary but could definitly speed up, swap the comments to use best hardware available
#device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device = torch.device('cuda')
transform = transforms.Compose([
                              transforms.Resize((178,218)),
                              transforms.ToTensor(),
                              transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                           ])

In [None]:

model = SlimNet.load_pretrained('models/celeba_20.pth').to(device)
model.eval() 

In [None]:
# Make tensor and normalize, add pseudo batch dimension and move to configured device

video_list = sorted([os.path.join(PATH_TO_IMAGE, i) for i in os.listdir(PATH_TO_IMAGE)])
for vpath in video_list: 
    frame_list = sorted(list(glob(os.path.join(vpath, '*.png'))))
    img = []
    # load image to tensor 
    for frame in frame_list: 
        with open(frame, 'rb') as f:
            x = transform(Image.open(f)).unsqueeze(0).to(device)
            img.append(x)
    x = torch.cat(img, dim=0)

    # inference 
    with torch.no_grad():
        logits = model(x)
        sigmoid_logits = torch.sigmoid(logits).squeeze().cpu().numpy()

    # save results
    info = {
        'video_name': os.path.basename(vpath), 
        'frames_logits': []}
    for i in range(len(frame_list)): 
        info['frames_logits'].append({os.path.basename(frame_list[i]): list(sigmoid_logits)})
    video_logits = sigmoid_logits.mean(0)
    info['video_logits'] = list(video_logits)
    info['video_attribute'] = labels[(video_logits>0.5).astype(bool)]
    

In [None]:
with open(os.path.join(DEST_PATH, vpath+'.json'), 'w') as f:
    json.dump(info, f)