In [None]:
import numpy as np
import torch
import os
import glob

from network import LPN
from utils import prox, cvx, prior

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [None]:
X_MIN = -4
X_MAX = 4
N = 100

MODEL_DIR = "experiments/models/"
RESULTS_DIR = "experiments/results/"

os.makedirs(RESULTS_DIR, exist_ok=True)

In [None]:
lpn_model = LPN(in_dim=1, hidden=50, layers=4, beta=10)
lpn_model.to(device)

for file in glob.glob(os.path.join(MODEL_DIR, "*.pth")):
    print("Running:", file)
    lpn_model.load_state_dict(torch.load(file))
    lpn_model.eval()

    x = np.linspace(X_MIN, X_MAX, N)
    y = prox(x, lpn_model)
    c = cvx(x, lpn_model)
    p = prior(x, lpn_model)

    np.save(
        os.path.join(RESULTS_DIR, os.path.basename(file)[:-4] + ".npy"),
        {"x": x, "y": y, "c": c, "p": p},
    )