In [127]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt

class MSE(torch.nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, y_t, y_prime_t):
        error = torch.pow(y_t - y_prime_t, 2)
        return torch.sum(error)
    
class LogCoshLoss(torch.nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, y_t, y_prime_t):
        ey_t = 10* (y_t - y_prime_t)
        return torch.mean(torch.log(torch.cosh(ey_t)))

def design_matrix(basis, times):
    mat = torch.DoubleTensor([[func(time) for func in basis] for time in times])
    return mat

def pseudo_inverse(a):
    return torch.linalg.pinv(a)

def ls_fit(xdata, ydata, start_index, end_index, basis, design_mat = None):
    if (design_mat == None):
        design_mat = design_matrix(basis, xdata)
    pseudo_inv = pseudo_inverse(design_mat[start_index:end_index])
    return torch.sum(pseudo_inv * ydata[start_index:end_index],dim = 1)
    
class Slicer(nn.Module):
    def __init__(self, segments, samples, xdata, basis):
        super().__init__()
        self.basis = basis
        self.segments = segments
        self.samples = samples
        self.xdata = xdata
        self.design_mat = design_matrix(self.basis, self.xdata)
        self.register_parameter(name='slice_locations', param = torch.nn.Parameter(torch.sort(samples * torch.rand(segments - 1))[0]))
        
    
    def forward(self, x):
        # the network may have shuffled during optimisation these around and broken ordering
        self.slice_locations = torch.nn.Parameter((torch.sort(self.slice_locations))[0])
        
        # extract from the input
        target_fit = x[0,0]
        
        # to be indexed later
        fitted = torch.zeros(self.samples)
        
        # furthest left segment
        fitparams = ls_fit(self.xdata, target_fit, 0, int(self.slice_locations[0]), self.basis)
        segment_fit = torch.DoubleTensor([torch.sum(torch.DoubleTensor([fitparams[i] * basis[i](time) for i in range(0, len(basis))])) for time in xdata[0:int(self.slice_locations[0])]])
        fitted[0:int(self.slice_locations[0])] = segment_fit
        
        # furthest right segment
        fitparams = ls_fit(self.xdata, target_fit, int(self.slice_locations[-1]), self.samples - 1, self.basis)
        segment_fit = torch.DoubleTensor([torch.sum(torch.DoubleTensor([fitparams[i] * basis[i](time) for i in range(0, len(basis))])) for time in xdata[int(self.slice_locations[-1]):self.samples]])
        fitted[self.slice_locations.int()[-1]:self.samples] = segment_fit
        
        # general case
        for i in range(0, self.slice_locations.shape[0] - 1):
            fitparams = ls_fit(self.xdata, target_fit, int(self.slice_locations[i]), int(self.slice_locations[i + 1]), self.basis)
            segment_fit = torch.DoubleTensor([torch.sum(torch.DoubleTensor([fitparams[i] * basis[i](time) for i in range(0, len(basis))])) for time in xdata[int(self.slice_locations[i]):int(self.slice_locations[i + 1])]])
            fitted[int(self.slice_locations[i]):int(self.slice_locations[i + 1])] = segment_fit
        
        return torch.ones(self.samples)[self.slice_locations[0]]

def learn(net, optimizer, X, y, batch_size=1, device='cpu'):
    torch.backends.cudnn.fastest = True
    loss_func = MSE()
    net.train()
    
    totalloss=0
    batches = np.int(np.floor(len(X) / batch_size))
    for batch in range(batches):
            optimizer.zero_grad()
            
            tx = X[batch*batch_size:(batch+1)*batch_size].to(device)
            ty = y[batch*batch_size:(batch+1)*batch_size].to(device)
            
            pred = net(tx)
            loss = loss_func(pred, ty)
            totalloss += loss.detach()
            
            loss.backward()
            
            optimizer.step()
    loss = totalloss / batches
    return loss.item()

In [98]:
def const(x):
    return 1

def linear(x):
    return x

def quad(x):
    return x ** 2

def cubic(x):
    return x ** 3

def quartic(x):
    return x ** 4

def exp(x):
    return torch.exp(x)

In [99]:
xdata = torch.linspace(0, 10, 1001,dtype=torch.double)
ydata = quad(xdata)
basis = [const, linear, quad]

fit_params = ls_fit(xdata, ydata, 0, 1001, basis)

fit_sample = torch.DoubleTensor([torch.sum(torch.DoubleTensor([fit_params[i] * basis[i](time) for i in range(0,len(basis))])) for time in xdata])

errfunc = MSE()
errfunc.forward(fit_sample, ydata)

tensor(3.6348e-24, dtype=torch.float64)

In [128]:
basis = [const]
slicer = Slicer(10, 1001, xdata, basis)

fit = slicer(ydata.view((1, 1, ydata.shape[0])))

print(fit)

# plt.plot(xdata, fit)
# plt.plot(xdata, ydata)
# plt.show()

mse = LogCoshLoss()

loss = mse(fit, ydata)

loss.backward()

RuntimeError: the derivative for 'indices' is not implemented