In [None]:
import warnings
warnings.filterwarnings('ignore')

In [None]:
import logging
import os

import torch
from mivolo.predictor_orig import Predictor
from timm.utils import setup_default_logging

from argparse import Namespace


### Loading of pre-trained model checkpoint


In [None]:
checkpoint_name = "<<MIVOLO CHECKPOINT NAME>>"

In [None]:
temp_state = torch.load(f"models/{checkpoint_name}.pth.tar")

In [None]:
args_dict = {
    "output": "output",
    "detector_weights": "models/yolov8x_person_face.pt",
    "checkpoint": f"models/{checkpoint_name}.pth.tar",
    "with_persons": False,
    "disable_faces": False,
    "draw": False,
    "device": "cuda:0" if torch.cuda.is_available() else "cpu",
}

args = Namespace(**args_dict)

setup_default_logging()

if torch.cuda.is_available():
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cudnn.benchmark = True
os.makedirs(args.output, exist_ok=True)

predictor = Predictor(args, verbose=True)



In [None]:
head_weight_new = torch.nn.parameter.Parameter(torch.concat([
    predictor.age_gender_model.model.head.weight,
    torch.nn.init.xavier_normal_(torch.zeros(1, predictor.age_gender_model.model.head.weight.shape[-1])).to(args.device),
]))

aux_head_weight_new = torch.nn.parameter.Parameter(torch.concat([
    predictor.age_gender_model.model.aux_head.weight,
    torch.nn.init.xavier_normal_(torch.zeros(1, predictor.age_gender_model.model.aux_head.weight.shape[-1])).to(args.device),
]))


head_bias_new = torch.nn.parameter.Parameter(torch.concat([
    predictor.age_gender_model.model.head.bias,
    torch.zeros(1).to(args.device),
]))

aux_head_bias_new = torch.nn.parameter.Parameter(torch.concat([
    predictor.age_gender_model.model.aux_head.bias,
    torch.zeros(1).to(args.device),
]))

In [None]:
predictor.age_gender_model.model.head.weight = head_weight_new
predictor.age_gender_model.model.aux_head.weight = aux_head_weight_new

predictor.age_gender_model.model.head.bias = head_bias_new
predictor.age_gender_model.model.aux_head.bias = aux_head_bias_new


In [None]:
temp_state["state_dict"] = predictor.age_gender_model.model.state_dict()
temp_state["with_persons_model"] = False

In [None]:
torch.save(temp_state, "models/mivolo_imdb_adjusted.pth.tar")