In [None]:
import warnings
warnings.filterwarnings('ignore')

In [None]:
import argparse
import os

import torch
from mivolo.data.dataset.age_gender_dataset import AgeGenderDataset
from mivolo.data.dataset.age_gender_loader import create_loader
from mivolo.predictor import Predictor
from timm.utils import setup_default_logging
from tqdm import tqdm

import pandas as pd

In [None]:
checkpoint_name = "variance_feature_attribution_mivolo_checkpoint"

### Load model and dataset

In [None]:
args_dict = {
    "output": "output",
    "detector_weights": "models/yolov8x_person_face.pt",
    "checkpoint": f"models/{checkpoint_name}.pth.tar",
    "with_persons": False,
    "disable_faces": False,
    "draw": False,
    "device": "cuda:0" if torch.cuda.is_available() else "cpu",
}

args = argparse.Namespace(**args_dict)

setup_default_logging()

if torch.cuda.is_available():
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cudnn.benchmark = True
os.makedirs(args.output, exist_ok=True)

predictor = Predictor(args, verbose=True)




In [None]:
# Loading the dataset
test_dataset = AgeGenderDataset(
    "mivolo/data/dataset/images",
    "mivolo/data/dataset/annotations",
    name="test",
    split="test",
    use_persons=False,
    model_with_persons=False,
    is_training=False,
    min_age=predictor.age_gender_model.meta.min_age,
    max_age=predictor.age_gender_model.meta.max_age
)

In [None]:
test_loader = create_loader(
    test_dataset,
    (3, 224, 224),
    1,
    num_workers=8,
    crop_pct=None,
    crop_mode=None,
    pin_memory=True,
    img_dtype=torch.float32,
    device=torch.device("cuda" if torch.cuda.is_available() else "cpu"),
    persistent_workers=True,
    worker_seeding="all",
    target_type=torch.float,
)

In [None]:

batch_results = []

model = predictor.age_gender_model.model
model.eval()

for inputs, labels in tqdm(test_loader):
    # Forward pass
    output = model(inputs)


    age_output = output[:, 2]  * (predictor.age_gender_model.meta.max_age - predictor.age_gender_model.meta.min_age) + predictor.age_gender_model.meta.avg_age
    age_variance_raw = output[:, 3].exp()
    age_variance = age_variance_raw  * (predictor.age_gender_model.meta.max_age - predictor.age_gender_model.meta.min_age) ** 2 
    age_target = labels[:, 0] * (predictor.age_gender_model.meta.max_age - predictor.age_gender_model.meta.min_age) + predictor.age_gender_model.meta.avg_age
    gender_output = output[:, :2].softmax(-1)
    gender_m = gender_output[:, 0]
    gender_f = gender_output[:, 1]

    batch_results.append(pd.DataFrame({"age_pred": age_output.detach().cpu(), 
                                    #    "age_var_raw": age_variance_raw.cpu(), 
                                        "age_var": age_variance.detach().cpu(), 
                                        "age_target": age_target.detach().cpu(), 
                                        "gender_m": gender_m.detach().cpu(), 
                                        "gender_f": gender_f.detach().cpu(), 
                                        "gender_target": labels[:, 1].type(torch.LongTensor)}))


output = pd.concat(batch_results).reset_index(drop=True)
  


### Run inference on the test set and save

In [None]:
filenames = test_loader.dataset.filenames()
output["filename"] = filenames

In [None]:
output = output.sort_values("age_var", ascending=False).reset_index(drop=True)

In [None]:
output.to_csv(f"test_results_{checkpoint_name}_test.csv", index=False)