In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from spline_module import linearspline, linearspline_utils
from torch.utils.data import Dataset
from tqdm import tqdm

In [None]:
seed = 5
if (torch.cuda.is_available()):
    device = "cuda:0" 
else:
    device = "cpu"

In [None]:
# Dataset Class
class Function1D(Dataset):

    def __init__(self, function, n_points, seed):
        
        np.random.seed(seed)
        X = (2.0 * np.random.rand(n_points, 1) - 1.0).astype(np.float32)
        y = np.vectorize(function)(X).astype(np.float32)
        self.X, self.y = torch.tensor(X), torch.tensor(y)
        
    def __len__(self):
        return self.X.shape[0]

    def __getitem__(self, idx):
        return self.X[idx,:], self.y[idx,:]

In [None]:
function = lambda x: x**2         # Quadratic function
train_dataset = Function1D(function, 500, seed)
batch_size = 100
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, num_workers=1, shuffle=True, drop_last=True)

In [None]:
# Plot the data points
plt.scatter(train_dataset.X, train_dataset.y, marker='x', s=8)
plt.show()

In [None]:
num_activations = 1
num_coeffs = 51
x_min = -1.0
x_max = 1.0
init = "identity"
slope_min = None
slope_max = None
lmbda = 0.0

In [None]:
spline = linearspline.LinearSpline(num_activations=num_activations, num_coeffs=num_coeffs, x_min=x_min, x_max=x_max, spline_init=init, slope_min=slope_min, slope_max=slope_max, apply_scaling=False)
spline = spline.to(device)

In [None]:
criterion = torch.nn.MSELoss(reduction='sum')
optimizer = torch.optim.Adam(spline.parameters(), lr=1e-3)

In [None]:
# Fit spline to the data

spline.train()
tbar = tqdm(range(1000), ncols=135)
log = {}
for epoch in tbar:
    log_loss = 0.0
    for batch_idx, data in enumerate(train_dataloader):
        
        input_data = data[0].to(device)
        target_data = data[1].to(device)

        optimizer.zero_grad()
        
        output = spline(input_data)

        data_fidelity = criterion(output, target_data) / batch_size
        
        regularization = torch.zeros_like(data_fidelity)
        if (lmbda > 0.0):
            regularization = lmbda*spline.tv2()

        total_loss = data_fidelity + regularization
        
        total_loss.backward()
        optimizer.step()

        log_loss = log_loss + total_loss.detach().cpu().item()

    log['Train loss'] = log_loss
    tbar.set_description('T ({}) | TotalLoss {:.8f} |'.format(epoch, log['Train loss']))
            

    

In [None]:
x_vals = torch.linspace(-1.0,1.0,5000, device=device)
x_vals = x_vals.unsqueeze(1)
plt.scatter(train_dataset.X, train_dataset.y, marker='x', s=8)
plt.plot(x_vals.cpu().numpy(), spline(x_vals).detach().cpu().numpy(), label="spline", color=(255/255, 16/255, 240/255))
plt.legend(loc="lower right")
plt.show()