In [4]:
import sys
sys.path.append("../")

from src.models.primal_dual import PrimalDual, tau, gamma, rho
from src.models.ista import ISTA, chi, gamma
from src.models.half_quadratic import HalfQuadratic, lambda_
from src.train import SingleTrainer, SparseDataset



In [1]:
def plot_algo_params_by_iteration(model):
    """test model type if it is PrimalDual, ISTA or HalfQuadratic
    Args:
        model: trained model to be tested
    Returns:
        layers: list of layers
        params: list of parameters
        init_params: list of initial parameters
    """
    if isinstance(model, PrimalDual):
        taus = [layer.tau.detach().numpy()[0] for layer in model.layers]
        gammas = [layer.gamma.detach().numpy()[0] for layer in model.layers]
        rhos = [layer.rho.detach().numpy()[0] for layer in model.layers]
        parmas = {
            "tau": taus,
            "gamma": gammas,
            "rho": rhos
        }
        init_params = {
            "tau": [tau*model.init_factor for _ in model.layers],
            "gamma": [gamma*model.init_factor for _ in model.layers],
            "rho": [rho*model.init_factor for _ in model.layers]
        }
        layers = range(1, len(model.layers)+1)
    elif isinstance(model, ISTA):
        chis = [layer.chi.detach().numpy()[0] for layer in model.layers]
        gammas = [layer.gamma.detach().numpy()[0] for layer in model.layers]
        parmas = {
            "chi": chis,
            "gamma": gammas
        }
        init_params = {
            "chi": [chi*model.init_factor for _ in model.layers],
            "gamma": [gamma*model.init_factor for _ in model.layers]
        }
        layers = range(1, len(model.layers)+1)
    elif isinstance(model, HalfQuadratic):
        lambdas = [layer.lambd.detach().numpy()[0] for layer in model.layers]
        layers = range(1, len(model.layers)+1)
        parmas = {
            "lambda": lambdas
        }
        init_params = {
            "lambda": [lambda_*model.init_factor for _ in model.layers]
        }
    return layers, parmas, init_params