In [1]:
import torch, numpy as np, matplotlib.pyplot as plt
from torchvision.transforms import v2
from multitask.framework1.multiscale_fusion import StandardMultiScaleFusion
from multitask.framework1.model import MultiTaskFaceAnalysisModel
from backbones.backbones import get_backbone
import datasets as db
from multitask.framework1.subnets import FaceRecognitionEmbeddingSubnet, AgeEstimationSubnet, GenderRecognitionSubnet, EmotionRecognitionSubnet, RaceRecognitionSubnet, AttributeRecognitionSubnet, PoseEstimationSubnet



In [2]:
test_transform = v2.Compose([ # for testing on datasets other than face recognition.
    v2.ToImage(),
    v2.ToDtype(torch.float32, scale = True),
    v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

In [3]:
database = db.CelebA(transform = test_transform)
groups = database.attribute_groups
labels_df = database.labels_df

In [4]:
dataloader = torch.utils.data.DataLoader(
    database,
    batch_size = 16
)

In [5]:
images, labels = next(iter(dataloader))

In [6]:
groups.keys()

dict_keys(['mouth', 'ear', 'lower_face', 'cheeks', 'nose', 'eyes', 'hair', 'object'])

In [7]:
cut_indices = database.get_cut_indices()
columns = labels_df.columns[1:-1]

In [8]:
backbone = get_backbone(backbone_name = 'swin_v2_t')

  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


In [9]:
multiscale_features = backbone(images)

In [10]:
subnet = AttributeRecognitionSubnet()

In [11]:
output = subnet(multiscale_features)

In [12]:
count = 0
for name, param in subnet.named_parameters():
    count += param.numel()

In [13]:
count

8391016

In [14]:
output.shape

torch.Size([16, 40])

In [15]:
output

tensor([[-1.1896, -0.2716,  0.2899,  0.3340,  0.2250, -0.1166,  0.2632, -1.1993,
          0.3493,  0.5329,  0.0887,  0.7529, -0.4488,  0.2816,  0.4117, -0.0103,
          0.5567, -0.2626,  0.1622, -0.2292,  0.0359,  0.2094, -0.0020, -0.3680,
         -0.0735, -0.3132, -0.5628, -0.4685, -0.4508,  0.1028,  0.1408, -0.2435,
          0.2066,  0.3403,  0.7124,  0.6286, -0.2958, -0.3491,  0.1369, -0.2096],
        [ 0.1433, -0.3327,  0.0410, -0.1509, -0.2834,  0.3436, -0.5876,  0.1739,
          0.0342,  0.1382,  0.0339, -0.0846,  0.1568,  0.1483,  0.3305, -0.2442,
         -0.0549, -0.1278, -0.6025,  0.0559, -0.0888, -0.0499,  0.3861,  0.2748,
          0.0339, -0.0684,  0.0154,  0.2011, -0.1158, -0.5333, -0.4298,  0.3480,
          0.1861,  0.1045, -0.0048, -0.0222,  0.2139,  0.4705,  0.1543, -0.2199],
        [-0.2762, -0.6013,  0.4963, -0.5957, -0.5384,  1.1583, -0.2230, -0.1686,
         -0.2289,  0.4679,  0.1112,  0.2054, -0.0947, -0.0831, -0.0629, -1.0398,
          0.2559, -0.3030,