In [1]:
import argparse
import json
import os
import sys
import warnings
from importlib import import_module
from pathlib import Path
from shutil import copy
from typing import Dict, List, Union

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from torchcontrib.optim import SWA

from data_utils import TestDataset, TrainDataset, genSpoof_list
from eval.calculate_metrics import (calculate_aDCF_tdcf_tEER,
                                    calculate_minDCF_EER_CLLR)
from utils import create_optimizer, seed_worker, set_seed, str_to_bool

warnings.filterwarnings("ignore", category=FutureWarning)
from tqdm import tqdm

In [2]:
def get_model(model_config: Dict, device: torch.device):
    """Define DNN model architecture"""
    module = import_module("models.{}".format(model_config["architecture"]))
    _model = getattr(module, "Model")
    model = _model(model_config).to(device)
    nb_params = sum([param.view(-1).size()[0] for param in model.parameters()])
    print(f"no. model params:{(nb_params / 1000):.3f}k")

    return model

In [None]:
with open(args.config, "r") as f_json:
    config = json.loads(f_json.read())
model_config = config["model_config"]

In [3]:
model = get_model()

TypeError: get_model() missing 2 required positional arguments: 'model_config' and 'device'