In [1]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from torch import nn,optim
from torch.utils.data import Dataset, DataLoader

from generic_data import GenericDataset
from SimpleNNs import TwoNet

In [2]:
seed = 0
np.random.seed(seed)
torch.manual_seed(seed)

<torch._C.Generator at 0x7fcf2463f210>

In [3]:
def target(x):
    return np.sin(8 * x) / (4 * np.cos(2 * x))

In [4]:
# define range
low = -2
high = 2

In [5]:
# generate test set
Xtest = np.sort(np.random.uniform(low, high, 100))
# Xtest = np.linspace(low, high, 100)
Ytest = target(Xtest)

In [6]:
test_dataset = GenericDataset(Xtest.reshape(-1,1), Ytest.reshape(-1,1))

In [7]:
subfolder = "random_init2" #"random_init"
paths = [f'./models/{subfolder}/model_{i:02d}.pt' for i in range(100)]

In [8]:
all_preds = []
all_labels = []

In [9]:
device="cuda"

In [10]:
test_loader = DataLoader(dataset=test_dataset,batch_size=64,shuffle=False)

In [11]:
num_models = len(paths)
cnt = 1

for path in paths:
    model = torch.load(path)
    model.to(device)
    model.eval()
    
    preds = []
    labels = []
    
    for batch_vec, batch_labels in iter(test_loader):
        batch_vec = batch_vec.to(device)
        batch_outputs = model(batch_vec)
        batch_preds = batch_outputs.cpu().detach().numpy()
        batch_labels = batch_labels.numpy()
        preds.append(batch_preds)
        labels.append(batch_labels)
    
    preds = np.concatenate(preds).flatten()
    labels = np.concatenate(labels).flatten()
    
    regression_error = np.square(np.subtract(preds, labels)).mean()
    
    print(path)
    print(regression_error)
    print(f'Regression Error {regression_error}. Model {cnt}/{num_models}')
    
    all_preds.append(preds)
    all_labels.append(labels)
    cnt += 1
    
all_preds = np.vstack(all_preds)
all_labels = np.vstack(all_labels)

./models/random_init2/model_00.pt
0.03525429647644411
Regression Error 0.03525429647644411. Model 1/100
./models/random_init2/model_01.pt
0.13690916084973748
Regression Error 0.13690916084973748. Model 2/100
./models/random_init2/model_02.pt
0.17877132024538325
Regression Error 0.17877132024538325. Model 3/100
./models/random_init2/model_03.pt
0.21818804862080188
Regression Error 0.21818804862080188. Model 4/100
./models/random_init2/model_04.pt
0.18827129377892643
Regression Error 0.18827129377892643. Model 5/100
./models/random_init2/model_05.pt
0.1974253571800909
Regression Error 0.1974253571800909. Model 6/100
./models/random_init2/model_06.pt
0.18719297637512444
Regression Error 0.18719297637512444. Model 7/100
./models/random_init2/model_07.pt
0.2403636141284485
Regression Error 0.2403636141284485. Model 8/100
./models/random_init2/model_08.pt
0.25472697343236705
Regression Error 0.25472697343236705. Model 9/100
./models/random_init2/model_09.pt
0.18611362968946477
Regression Err

In [12]:
all_preds

array([[ 0.2157901 ,  0.17496438,  0.17114738, ..., -0.25328147,
        -0.25634775, -0.27245383],
       [ 0.10221692,  0.11645801,  0.11778948, ..., -0.21435678,
        -0.21425946, -0.21374825],
       [ 0.00308843,  0.00503609,  0.00521819, ..., -0.30436193,
        -0.30522822, -0.30977851],
       ...,
       [ 0.09221992,  0.08703698,  0.08655284, ...,  0.02070352,
         0.02059004,  0.01999396],
       [-0.11121853, -0.10728486, -0.10691708, ...,  0.07986144,
         0.07988277,  0.08016296],
       [ 0.10079125,  0.09839323,  0.09816768, ...,  0.10525249,
         0.10565469,  0.10826247]])

In [13]:
all_labels

array([[-0.05182882,  0.10016105,  0.11228079, ..., -0.1394703 ,
        -0.12362103, -0.02770573],
       [-0.05182882,  0.10016105,  0.11228079, ..., -0.1394703 ,
        -0.12362103, -0.02770573],
       [-0.05182882,  0.10016105,  0.11228079, ..., -0.1394703 ,
        -0.12362103, -0.02770573],
       ...,
       [-0.05182882,  0.10016105,  0.11228079, ..., -0.1394703 ,
        -0.12362103, -0.02770573],
       [-0.05182882,  0.10016105,  0.11228079, ..., -0.1394703 ,
        -0.12362103, -0.02770573],
       [-0.05182882,  0.10016105,  0.11228079, ..., -0.1394703 ,
        -0.12362103, -0.02770573]])

In [14]:
np.save(f'./models/{subfolder}/xs', Xtest)
np.save(f'./models/{subfolder}/predictions',all_preds)
np.save(f'./models/{subfolder}/true_labels',all_labels[0])