In [1]:
import torch, numpy as np, matplotlib.pyplot as plt
from torchvision.transforms import v2
from multitask.framework1.multiscale_fusion import StandardMultiScaleFusion
from multitask.framework1.model import MultiTaskFaceAnalysisModel
from backbones.backbones import get_backbone
import datasets as db
from multitask.framework1.subnets import FaceRecognitionEmbeddingSubnet, AgeEstimationSubnet, GenderRecognitionSubnet, EmotionRecognitionSubnet, RaceRecognitionSubnet, AttributeRecognitionSubnet, PoseEstimationSubnet



In [2]:
train_transform = test_transform = v2.Compose([ # for testing on datasets other than face recognition.
    v2.ToImage(),
    v2.ToDtype(torch.float32, scale = True),
    v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

In [3]:
# training
ms1mv2 = db.MS1MV2(transform=train_transform)
num_classes = ms1mv2.number_of_classes()
celeba = db.CelebA(transform = train_transform, subset = 'combined')
attribute_pos_weight = celeba.get_attribute_weights()

train_db_dict = {
    'face_recognition' : [ms1mv2],
    'emotion_recognition' : [
        db.FERPlus(transform = train_transform, subset = 'combined'),
        db.AffectNet(transform = train_transform, subset = 'combined'), 
        db.RAFDB(transform = train_transform, subset = 'train')
    ],
    'age_gender_race_recognition' : [
        db.MORPH(transform = train_transform, subset = 'combined'),
        db.FairFace(transform = train_transform, subset = 'train'),
        db.UTKFace(transform = train_transform, subset = 'train'),
        db.IMDB_WIKI(transform = train_transform)
    ],
    'attribute_recognition' : [celeba],
    # 'head_pose_estimation' : [db.W300LP(transform = pose_estimation_transform)],
}

In [4]:
train_loader = db.get_balanced_loader(
    train_db_dict,
    batch_size = 16, 
    num_workers = 2,
    epoch_size = None,
)

In [5]:
train_loader.__len__() * 16

6779264

In [6]:
images, labels = next(iter(train_loader))

In [7]:
images.shape

torch.Size([16, 3, 112, 112])

In [8]:
labels

{'face_recognition': tensor([   -1,    -1, 54529,    -1,    -1, 58447,  3794,  4572,    -1,    -1,
            -1,    -1,    -1,    -1, 34862,    -1]),
 'emotion': tensor([-1, -1, -1, -1, -1, -1, -1, -1,  5, -1, -1, -1, -1,  4, -1,  1]),
 'age': tensor([21, -1, -1, -1, -1, -1, -1, -1, -1, 58,  6, -1, 31, -1, -1, -1]),
 'gender': tensor([ 0, -1, -1, -1, -1, -1, -1, -1, -1,  1,  0, -1,  0, -1, -1, -1]),
 'race': tensor([-1, -1, -1, -1, -1, -1, -1, -1, -1, -1,  2, -1, -1, -1, -1, -1]),
 'attributes': tensor([[-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
          -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
          -1, -1, -1, -1],
         [ 1,  0,  0,  0,  0,  0,  0,  0,  1,  0,  1,  0,  0,  0,  0,  0,  0,  0,
           0,  0,  0,  0,  1,  0,  0,  1,  0,  0,  0,  0,  0,  1,  0,  0,  0,  1,
           0,  0,  0,  1],
         [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
          -1, -1, -1, -1, -1, -1, 