In [1]:
import os
import sys
import torch
from collections import OrderedDict

sys.path.insert(0, "..")
from face_lib.models.iresnet import iresnet100, iresnet50

In [2]:
magface_repo_path = "/beegfs/home/r.kail/faces/repos/MagFace"
sys.path.insert(0, magface_repo_path)

from models import magface
from inference.network_inf import builder_inf

In [3]:
arcface_path = (
    "/gpfs/data/gpfs0/k.fedyanin/space/models/arcface/backbones/classic_packed.pth"
)
src_path = "/gpfs/data/gpfs0/k.fedyanin/space/models/magface/ms1mv2_ir50_ddp/magface_iresnet50_MS1MV2_ddp_fp32.pth"
trg_path = "/gpfs/data/gpfs0/k.fedyanin/space/models/magface/ms1mv2_ir50_ddp/arcface+magface.pth"

# save_mode = "magface_only"
save_mode = "arcface+magface"

In [4]:
class Args:
    def __init__(self):
        self.arch = "iresnet50"
        self.embedding_size = 512
        self.last_fc_size = 1000
        self.arc_scale = 64
        self.l_margin = 0.45
        self.u_margin = 0.8
        self.l_a = 10
        self.u_a = 110
        self.resume = src_path
        self.cpu_mode = False


args = Args()

In [5]:
def rename_ckpt(ckpt):
    new_dict = [(k[len("features.") :], v) for k, v in ckpt.items()]
    new_dict = OrderedDict(new_dict)
    return new_dict


def create_state_dict(args):
    model = builder_inf(args)
    renamed_state_dict = rename_ckpt(model.state_dict())
    return renamed_state_dict


def check_state_dict_loads(model_name, state_dict):
    if model_name == "iresnet50":
        model = iresnet50()
    elif model_name == "iresnet100":
        model = iresnet100()
    else:
        raise NotImplementedError("Don't knpw this type of model")

    try:
        model.load_state_dict(state_dict)
    except Exception as e:
        print(e)
        return False

    return True

In [6]:
reforged_state_dict = create_state_dict(args)

=> loading pth from /gpfs/data/gpfs0/k.fedyanin/space/models/magface/ms1mv2_ir50_ddp/magface_iresnet50_MS1MV2_ddp_fp32.pth ...[0m


In [7]:
assert check_state_dict_loads(args.arch, reforged_state_dict)

In [8]:
if save_mode == "magface_only":
    torch.save({"backbone": reforged_state_dict}, trg_path)
elif save_mode == "arcface+magface":
    ckpt = torch.load(arcface_path)
    ckpt["uncertainty_model"] = reforged_state_dict
    torch.save(ckpt, trg_path)
else:
    raise NotImplementedError("Don't know this save_mode")