In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import argparse
import os
import random
import sys

import numpy as np
import torch

In [None]:
def make_model(cfg_file, cfg, args):
    """build models from cfg file name and model cfg

    Args:
        cfg_file: the file name of the model cfg, such as "LCNN/wavefake"
        cfg: the model config

    """
    if cfg_file.startswith("LCNN/"):
        from .LFCC_LCNN import LCNN_lit

        model = LCNN_lit()
    elif cfg_file.startswith("RawNet2/"):
        from .RawNet import RawNet2_lit

        model = RawNet2_lit()
    elif cfg_file.startswith("WaveLM/"):
        from .WaveLM import WaveLM_lit

        model = WaveLM_lit()
    elif cfg_file.startswith("Wave2Vec2"):
        from .Wave2Vec2 import Wav2Vec2_lit

        model = Wav2Vec2_lit()
    elif cfg_file.startswith("LibriSeVoc"):
        from .LibriSeVoc import LibriSeVoc_lit

        model = LibriSeVoc_lit(cfg=cfg.MODEL)
    elif cfg_file.startswith("Ours/"):
        from .Ours import AudioModel_lit

        model = AudioModel_lit(cfg=cfg.MODEL, args=args)
    elif cfg_file.startswith("Wav2Clip/"):
        from .Wav2Clip import Wav2Clip_lit

        model = Wav2Clip_lit(cfg=cfg.MODEL)
    elif cfg_file.startswith("AudioClip/"):
        from .AudioClip import AudioClip_lit

        model = AudioClip_lit(cfg=cfg.MODEL)
    elif cfg_file.startswith("AASIST/"):
        from .Aaasist import AASIST_lit

        model = AASIST_lit(cfg=cfg.MODEL)
    elif cfg_file.startswith("RawGAT/"):
        from .RawGAT_ST import RawGAT_lit

        model = RawGAT_lit(cfg=cfg.MODEL)
    elif cfg_file.startswith("OursMultiView"):
        from .OursMultiView import MultiViewModel_lit
        model = MultiViewModel_lit(cfg=cfg.MODEL, args=args)
    return model

In [8]:
def make_attack_model(cfg_file, cfg, args):
    from .RawNet import RawNet2_lit
    path = (
        "/mnt/data/zky/DATA/1-model_save/00-Deepfake/1-df-audio-old/RawNet2/DECRO_chinese"
        "/version_0/checkpoints/best-epoch=12-val-auc=0.9745.ckpt"
    )
    cls_model = RawNet2_lit()
    sd = torch.load(path)['state_dict']
    cls_model.load_state_dict(sd)

    
    if cfg_file.startswith("Attack/Ours"):
        from .attacks.Ours import AudioAttackModel

        model = AudioAttackModel(
            cfg=cfg.MODEL, args=args, audio_detection_model=cls_model
        )
    return model