In [1]:
import torch
from torch import nn
import torch.nn.functional as F

from torchvision import transforms as TF
from torchvision import models
from torchvision.models import ResNet18_Weights

from glob import glob

from tqdm import tqdm

from PIL import Image

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
model = models.resnet18(weights=ResNet18_Weights.DEFAULT)
num_features = model.fc.in_features
model.fc = nn.Linear(num_features, 2)
model.load_state_dict(torch.load("face_gender_classification_transfer_learning_with_ResNet18.pth"))
model = model.to(device)
model.eval()

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
  

In [4]:
transforms = TF.Compose([
    TF.Resize((224, 224)),
    TF.ToTensor(),
    TF.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

paths = glob("./CelebAMask-HQ/CelebA-HQ-img/*.jpg")
print(len(paths))

30000


In [5]:
male_number = []

In [6]:
cnt = 0
for path in tqdm(paths):
    img_number = int(path.split("/")[-1].split(".")[0])
    img = Image.open(path).resize((512, 512))
    transformed_img = transforms(img).unsqueeze(0).to(device) # (1, 3, 224, 224)
    
    with torch.no_grad():
        pred = F.softmax(model(transformed_img), dim=1).squeeze(0)

    if pred[1] > 0.95:
        
        try:
            mask = Image.open(f"./CelebAMask-HQ/CelebAMask-HQ-mask-anno/%.5d_hair.png"%(img_number))
            male_number.append(img_number)
            
        except:
            pass

100%|██████████| 30000/30000 [04:56<00:00, 101.12it/s]


In [7]:
print(len(male_number))

7443


In [8]:
import pickle

with open("./celeba_male_number.pickle", "wb") as f:
    pickle.dump(male_number, f)

In [9]:
# import numpy as np
# import pandas as pd

# datasamples = np.full((cnt, 2), None) # (7334, 2)
# files = np.array([[f"{i}.jpg"] for i in range(cnt)]) # (7334, 1)
# data = np.concatenate((files, datasamples), axis=1) # (7334, 3)
# label_frame = pd.DataFrame(data)
# label_frame.columns = ["file", "forehead (0, 1, 2)", "length (0, 1, 2, 3, 4)"]
# label_frame.set_index(["file"])
# label_frame.head()

In [10]:
# label_frame.to_csv("dataset/label_frame.csv", index=False)

In [11]:
# data = pd.read_csv("dataset/label_frame.csv")
# data.head()

In [None]:
# add ffhq dataset
# https://github.com/royorel/FFHQ-Aging-Dataset#downloading-with-pydrive

In [12]:
import numpy as np
import pandas as pd

In [13]:
ffhq_aging_labels = pd.read_csv("ffhq_aging_labels.csv")
ffhq_aging_labels.head()

Unnamed: 0,image_number,age_group,age_group_confidence,gender,gender_confidence,head_pitch,head_roll,head_yaw,left_eye_occluded,right_eye_occluded,glasses
0,0,0-2,1.0,male,1.0,4.644246,2.179985,-9.359347,0.0,0.007,
1,1,30-39,0.6522,female,1.0,8.730963,-1.48598,6.370928,0.0,0.001,
2,2,30-39,0.6221,female,1.0,8.776943,2.089832,-5.144061,0.0,0.0,
3,3,3-6,0.6623,female,1.0,4.311426,2.019704,2.472943,0.001,0.0,
4,4,20-29,1.0,female,1.0,11.698202,0.262232,4.569711,0.0,0.0,


In [19]:
ffhq_male_number = ffhq_aging_labels[ffhq_aging_labels["gender"] == "male"]["image_number"].tolist()
print(ffhq_male_number[:10])

[0, 5, 9, 10, 11, 12, 13, 14, 15, 16]


In [20]:
len(ffhq_male_number)

32170

In [21]:
with open("./ffhq_male_number.pickle", "wb") as f:
    pickle.dump(ffhq_male_number, f)