In [1]:
import torch 
torch.cuda.empty_cache()
import torchvision 
import matplotlib.pyplot as plt 
import numpy as np 
import json 
import shutil 
import pandas as pd 
from torchvision import transforms, utils
from torch.utils.data import Dataset, DataLoader
import models
from PIL import Image
import os
import src.dataset as dl
import src.utility as utility
import src.model as cnn_models

In [2]:
# converted ranges- (0-2, 3-6, 8-13, 15-20, 25-32, 34-43, 45-53, 55-100)

def valid_ext(ext):
    return ext.lower() in ['.jpg', '.jpeg', '.png']


def inference(model, test_dir):

        """Inference for test """

        age_cats = {0:"0-2", 1:"3-6", 2:"8-13", 3:"15-20", 4:"22-32", 5:"34-43", 6:"45-53", 7:"55-100"}
        gender_cats = {0:"f",1:"m",2:"u"}

        output = {}

        for root, dirs, files in os.walk(test_dir):
            for file in files:
                
                path = os.path.join(root, file)

                with torch.no_grad():
                    if not valid_ext(file[-4:]):
                        print(path,'is not valid image. Expected extension: .jpg, .jpeg, .png')
                        continue
                    
                    else:
                        img = Image.open(path)
                        img = img.convert('RGB')
                        img = img.resize((104, 104))
                        
                        img = torch.tensor(np.array(img)).unsqueeze(0)
                        img = img.permute(0, 3, 1, 2).float()

                        age_logits, gender_logits = model(img.to(device))
                        
                        output[str(path)] = {'age':age_cats[age_logits.argmax(1).item()],
                                            'gender':gender_cats[gender_logits.argmax(1).item()]}
                        
        return output                


In [3]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using {device} device")

base_cnn_model = cnn_models.Base_CNN_multi_task().to(device)

# load model weights
# base_cnn_model.load_state_dict(torch.load("models/base_cnn_multi_task_model.pt"))

# root = os.getcwd()

# model parameters 
model_params = torch.load("./models/best_model/best_model.pt")

base_cnn_model.load_state_dict(model_params['state_dict'])
base_cnn_model.eval()


Using cuda device


Base_CNN_multi_task(
  (conv1): Conv2d(3, 6, kernel_size=(3, 3), stride=(1, 1))
  (pool1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv2): Conv2d(6, 10, kernel_size=(3, 3), stride=(1, 1))
  (pool2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv3): Conv2d(10, 32, kernel_size=(2, 2), stride=(1, 1))
  (pool3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (flatten): Flatten(start_dim=1, end_dim=-1)
  (fc1): Linear(in_features=3872, out_features=1000, bias=True)
  (fc2): Linear(in_features=1000, out_features=100, bias=True)
  (head1): Linear(in_features=100, out_features=8, bias=True)
  (head2): Linear(in_features=100, out_features=3, bias=True)
)

In [4]:
inf_dir = './data/test/'
output = inference(base_cnn_model, inf_dir)

In [5]:
print(output)

{'./data/test/im1.jpg': {'age': '22-32', 'gender': 'f'}, './data/test/im2.jpg': {'age': '22-32', 'gender': 'f'}, './data/test/im3.jpg': {'age': '45-53', 'gender': 'f'}, './data/test/im4.jpg': {'age': '0-2', 'gender': 'm'}}
