In [1]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from torch import nn,optim
from torch.utils.data import Dataset, DataLoader

from generic_data import GenericDataset
from SimpleNNs import TwoNet

In [2]:
seed = 0
np.random.seed(seed)
torch.manual_seed(seed)

<torch._C.Generator at 0x7ff7c84c7210>

In [3]:
def target(x):
    return np.sin(8 * x) / (4 * np.cos(2 * x))

In [4]:
# define range
low = -2
high = 2

In [5]:
# generate test set
Xtest = np.sort(np.random.uniform(low, high, 100))
# Xtest = np.linspace(low, high, 100)
Ytest = target(Xtest)

In [6]:
test_dataset = GenericDataset(Xtest.reshape(-1,1), Ytest.reshape(-1,1))

In [7]:
subfolder = "same_base12" #"random_init"
paths = [f'./models/{subfolder}/model_{i:02d}.pt' for i in range(100)]

In [8]:
all_preds = []
all_labels = []

In [9]:
device="cuda"

In [10]:
test_loader = DataLoader(dataset=test_dataset,batch_size=64,shuffle=False)

In [11]:
num_models = len(paths)
cnt = 1

for path in paths:
    model = torch.load(path)
    model.to(device)
    model.eval()
    
    preds = []
    labels = []
    
    for batch_vec, batch_labels in iter(test_loader):
        batch_vec = batch_vec.to(device)
        batch_outputs = model(batch_vec)
        batch_preds = batch_outputs.cpu().detach().numpy()
        batch_labels = batch_labels.numpy()
        preds.append(batch_preds)
        labels.append(batch_labels)
    
    preds = np.concatenate(preds).flatten()
    labels = np.concatenate(labels).flatten()
    
    regression_error = np.square(np.subtract(preds, labels)).mean()
    
    print(path)
    print(regression_error)
    print(f'Regression Error {regression_error}. Model {cnt}/{num_models}')
    
    all_preds.append(preds)
    all_labels.append(labels)
    cnt += 1
    
all_preds = np.vstack(all_preds)
all_labels = np.vstack(all_labels)

./models/same_base12/model_00.pt
0.08195612695829041
Regression Error 0.08195612695829041. Model 1/100
./models/same_base12/model_01.pt
0.08208454962913826
Regression Error 0.08208454962913826. Model 2/100
./models/same_base12/model_02.pt
0.08222158320791484
Regression Error 0.08222158320791484. Model 3/100
./models/same_base12/model_03.pt
0.08221709999182086
Regression Error 0.08221709999182086. Model 4/100
./models/same_base12/model_04.pt
0.08233504675879433
Regression Error 0.08233504675879433. Model 5/100
./models/same_base12/model_05.pt
0.0821700896568579
Regression Error 0.0821700896568579. Model 6/100
./models/same_base12/model_06.pt
0.08233227316384495
Regression Error 0.08233227316384495. Model 7/100
./models/same_base12/model_07.pt
0.08162040332463141
Regression Error 0.08162040332463141. Model 8/100
./models/same_base12/model_08.pt
0.08171002026031793
Regression Error 0.08171002026031793. Model 9/100
./models/same_base12/model_09.pt
0.0820973085949795
Regression Error 0.0820

In [12]:
all_preds

array([[ 0.00430103,  0.01172724,  0.01242155, ..., -0.0127975 ,
        -0.01193912, -0.00742658],
       [ 0.00248183,  0.01036888,  0.01110628, ..., -0.01218976,
        -0.01139218, -0.00718968],
       [ 0.00341124,  0.01125633,  0.0119898 , ..., -0.01188572,
        -0.011086  , -0.00688533],
       ...,
       [ 0.00124948,  0.0093052 ,  0.01005837, ..., -0.0143688 ,
        -0.01353837, -0.0091764 ],
       [ 0.00477749,  0.01225801,  0.0129574 , ..., -0.01321751,
        -0.01236554, -0.00789048],
       [ 0.00305283,  0.01073483,  0.01145305, ..., -0.01336487,
        -0.01250955, -0.00801688]])

In [13]:
all_labels

array([[-0.05182882,  0.10016105,  0.11228079, ..., -0.1394703 ,
        -0.12362103, -0.02770573],
       [-0.05182882,  0.10016105,  0.11228079, ..., -0.1394703 ,
        -0.12362103, -0.02770573],
       [-0.05182882,  0.10016105,  0.11228079, ..., -0.1394703 ,
        -0.12362103, -0.02770573],
       ...,
       [-0.05182882,  0.10016105,  0.11228079, ..., -0.1394703 ,
        -0.12362103, -0.02770573],
       [-0.05182882,  0.10016105,  0.11228079, ..., -0.1394703 ,
        -0.12362103, -0.02770573],
       [-0.05182882,  0.10016105,  0.11228079, ..., -0.1394703 ,
        -0.12362103, -0.02770573]])

In [14]:
np.save(f'./models/{subfolder}/xs', Xtest)
np.save(f'./models/{subfolder}/predictions',all_preds)
np.save(f'./models/{subfolder}/true_labels',all_labels[0])